File data_flow.py changed (mode: 100644) (index c985103..f4bb8bc) |
... |
... |
def load_data_shanghaitech_more_rnd(img_path, train=True): |
188 |
188 |
target = np.fliplr(target) |
target = np.fliplr(target) |
189 |
189 |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
190 |
190 |
|
|
|
191 |
|
if not train: |
|
192 |
|
# get correct people head count from head annotation |
|
193 |
|
mat_path = img_path.replace('.jpg', '.mat').replace('images', 'ground-truth').replace('IMG', 'GT_IMG') |
|
194 |
|
gt_count = count_gt_annotation_sha(mat_path) |
|
195 |
|
return img, gt_count |
|
196 |
|
|
191 |
197 |
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
192 |
198 |
interpolation=cv2.INTER_CUBIC) * 64 |
interpolation=cv2.INTER_CUBIC) * 64 |
193 |
199 |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
File debug/explore_shb.py added (mode: 100644) (index 0000000..d6814c7) |
|
1 |
|
from args_util import meow_parse |
|
2 |
|
from data_flow import get_dataloader, create_image_list |
|
3 |
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
4 |
|
from ignite.metrics import Loss |
|
5 |
|
from ignite.handlers import Checkpoint, DiskSaver, Timer |
|
6 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount |
|
7 |
|
from visualize_util import get_readable_time |
|
8 |
|
|
|
9 |
|
import torch |
|
10 |
|
from torch import nn |
|
11 |
|
from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4 |
|
12 |
|
from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3, BigTail4 |
|
13 |
|
from models.meow_experiment.ccnn_head import H1, H2 |
|
14 |
|
from models.meow_experiment.kitten_meow_1 import H1_Bigtail3 |
|
15 |
|
from models import CustomCNNv2, CompactCNNV7 |
|
16 |
|
import os |
|
17 |
|
from model_util import get_lr, BestMetrics |
|
18 |
|
""" |
|
19 |
|
shanghaitech_more_random |
|
20 |
|
""" |
|
21 |
|
|
|
22 |
|
|
|
23 |
|
if __name__ == "__main__": |
|
24 |
|
DATA_PATH = "/data/ShanghaiTech_fixed_sigma/part_B/" |
|
25 |
|
TRAIN_PATH = os.path.join(DATA_PATH, "train_data_train_split") |
|
26 |
|
VAL_PATH = os.path.join(DATA_PATH, "train_data_validate_split") |
|
27 |
|
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
|
28 |
|
|
|
29 |
|
# create list |
|
30 |
|
train_list = create_image_list(TRAIN_PATH) |
|
31 |
|
val_list = create_image_list(VAL_PATH) |
|
32 |
|
test_list = create_image_list(TEST_PATH) |
|
33 |
|
|
|
34 |
|
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, |
|
35 |
|
dataset_name="shanghaitech_more_random" |
|
36 |
|
, batch_size=20, |
|
37 |
|
train_loader_for_eval_check=True) |
|
38 |
|
print(len(train_loader)) |
|
39 |
|
print(len(val_loader)) |
|
40 |
|
|
|
41 |
|
for data, label in val_loader: |
|
42 |
|
print(label) |