File args_util.py changed (mode: 100644) (index 2e74c07..fa2f9fc) |
... |
... |
contain dummy args with config |
3 |
3 |
helpfull for copy paste Kaggle |
helpfull for copy paste Kaggle |
4 |
4 |
""" |
""" |
5 |
5 |
import argparse |
import argparse |
6 |
|
|
|
|
6 |
|
from hard_code_variable import HardCodeVariable |
7 |
7 |
|
|
8 |
8 |
def make_args(gpu="0", task="task_one_"): |
def make_args(gpu="0", task="task_one_"): |
9 |
9 |
""" |
""" |
|
... |
... |
def real_args_parse(): |
62 |
62 |
parser.add_argument("--task_id", action="store", default="dev") |
parser.add_argument("--task_id", action="store", default="dev") |
63 |
63 |
parser.add_argument('-a', action="store_true", default=False) |
parser.add_argument('-a', action="store_true", default=False) |
64 |
64 |
|
|
65 |
|
parser.add_argument('--input', action="store", type=str) |
|
66 |
|
parser.add_argument('--output', action="store", type=str) |
|
|
65 |
|
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
|
66 |
|
parser.add_argument('--output', action="store", type=str, default="saved_model") |
67 |
67 |
parser.add_argument('--model', action="store", default="pacnn") |
parser.add_argument('--model', action="store", default="pacnn") |
68 |
68 |
|
|
69 |
69 |
# args with default value |
# args with default value |
File data_flow.py changed (mode: 100644) (index a3f3d3c..ac150b0) |
... |
... |
def get_train_val_list(data_path, test_size=0.1): |
46 |
46 |
|
|
47 |
47 |
|
|
48 |
48 |
def load_data(img_path, train=True): |
def load_data(img_path, train=True): |
|
49 |
|
""" |
|
50 |
|
get a sample |
|
51 |
|
:deprecate: use load_data_shanghaiTech now |
|
52 |
|
:param img_path: |
|
53 |
|
:param train: |
|
54 |
|
:return: |
|
55 |
|
""" |
49 |
56 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
50 |
57 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
51 |
58 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
|
... |
... |
def load_data(img_path, train=True): |
57 |
64 |
return img, target |
return img, target |
58 |
65 |
|
|
59 |
66 |
|
|
|
67 |
|
def load_data_shanghaitech(img_path, train=True): |
|
68 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
69 |
|
img = Image.open(img_path).convert('RGB') |
|
70 |
|
gt_file = h5py.File(gt_path, 'r') |
|
71 |
|
target = np.asarray(gt_file['density']) |
|
72 |
|
|
|
73 |
|
if train: |
|
74 |
|
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
|
75 |
|
if random.randint(0, 9) <= -1: |
|
76 |
|
|
|
77 |
|
dx = int(random.randint(0, 1) * img.size[0] * 1. / 2) |
|
78 |
|
dy = int(random.randint(0, 1) * img.size[1] * 1. / 2) |
|
79 |
|
else: |
|
80 |
|
dx = int(random.random() * img.size[0] * 1. / 2) |
|
81 |
|
dy = int(random.random() * img.size[1] * 1. / 2) |
|
82 |
|
|
|
83 |
|
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
84 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
85 |
|
|
|
86 |
|
if random.random() > 0.8: |
|
87 |
|
target = np.fliplr(target) |
|
88 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
89 |
|
|
|
90 |
|
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
|
91 |
|
interpolation=cv2.INTER_CUBIC) * 64 |
|
92 |
|
target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
|
93 |
|
return img, target1 |
|
94 |
|
|
|
95 |
|
|
|
96 |
|
def load_data_shanghaitech_keepfull(img_path, train=True): |
|
97 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
98 |
|
img = Image.open(img_path).convert('RGB') |
|
99 |
|
gt_file = h5py.File(gt_path, 'r') |
|
100 |
|
target = np.asarray(gt_file['density']) |
|
101 |
|
|
|
102 |
|
if train: |
|
103 |
|
if random.random() > 0.8: |
|
104 |
|
target = np.fliplr(target) |
|
105 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
106 |
|
|
|
107 |
|
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
|
108 |
|
interpolation=cv2.INTER_CUBIC) * 64 |
|
109 |
|
|
|
110 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
111 |
|
# np.expand_dims(target1, axis=0) # again |
|
112 |
|
return img, target1 |
|
113 |
|
|
|
114 |
|
|
60 |
115 |
def load_data_ucf_cc50(img_path, train=True): |
def load_data_ucf_cc50(img_path, train=True): |
61 |
116 |
gt_path = img_path.replace('.jpg', '.h5') |
gt_path = img_path.replace('.jpg', '.h5') |
62 |
117 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
|
... |
... |
class ListDataset(Dataset): |
254 |
309 |
self.dataset_name = dataset_name |
self.dataset_name = dataset_name |
255 |
310 |
# load data fn |
# load data fn |
256 |
311 |
if dataset_name is "shanghaitech": |
if dataset_name is "shanghaitech": |
257 |
|
self.load_data_fn = load_data |
|
|
312 |
|
self.load_data_fn = load_data_shanghaitech |
|
313 |
|
if dataset_name is "shanghaitech_keepfull": |
|
314 |
|
self.load_data_fn = load_data_shanghaitech_keepfull |
258 |
315 |
elif dataset_name is "ucf_cc_50": |
elif dataset_name is "ucf_cc_50": |
259 |
316 |
self.load_data_fn = load_data_ucf_cc50 |
self.load_data_fn = load_data_ucf_cc50 |
260 |
317 |
elif dataset_name is "ucf_cc_50_pacnn": |
elif dataset_name is "ucf_cc_50_pacnn": |
|
... |
... |
class ListDataset(Dataset): |
278 |
335 |
return img, target |
return img, target |
279 |
336 |
|
|
280 |
337 |
|
|
281 |
|
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech"): |
|
|
338 |
|
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False): |
|
339 |
|
if visualize_mode: |
|
340 |
|
transformer = transforms.Compose([ |
|
341 |
|
transforms.ToTensor() |
|
342 |
|
]) |
|
343 |
|
else: |
|
344 |
|
transformer = transforms.Compose([ |
|
345 |
|
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
346 |
|
std=[0.229, 0.224, 0.225]), |
|
347 |
|
]) |
|
348 |
|
|
282 |
349 |
train_loader = torch.utils.data.DataLoader( |
train_loader = torch.utils.data.DataLoader( |
283 |
350 |
ListDataset(train_list, |
ListDataset(train_list, |
284 |
|
shuffle=True, |
|
285 |
|
transform=transforms.Compose([ |
|
286 |
|
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
287 |
|
std=[0.229, 0.224, 0.225]), |
|
288 |
|
]), |
|
289 |
|
train=True, |
|
290 |
|
batch_size=1, |
|
291 |
|
num_workers=4, dataset_name=dataset_name), |
|
292 |
|
batch_size=1, num_workers=4) |
|
|
351 |
|
shuffle=True, |
|
352 |
|
transform=transformer, |
|
353 |
|
train=True, |
|
354 |
|
batch_size=1, |
|
355 |
|
num_workers=4, |
|
356 |
|
dataset_name=dataset_name), |
|
357 |
|
batch_size=1, |
|
358 |
|
num_workers=4) |
293 |
359 |
|
|
294 |
360 |
val_loader = torch.utils.data.DataLoader( |
val_loader = torch.utils.data.DataLoader( |
295 |
361 |
ListDataset(val_list, |
ListDataset(val_list, |
296 |
|
shuffle=False, |
|
297 |
|
transform=transforms.Compose([ |
|
298 |
|
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
299 |
|
std=[0.229, 0.224, 0.225]), |
|
300 |
|
]), train=False, dataset_name=dataset_name), |
|
|
362 |
|
shuffle=False, |
|
363 |
|
transform=transformer, |
|
364 |
|
train=False, |
|
365 |
|
dataset_name=dataset_name), |
301 |
366 |
batch_size=1) |
batch_size=1) |
|
367 |
|
|
302 |
368 |
if test_list is not None: |
if test_list is not None: |
303 |
369 |
test_loader = torch.utils.data.DataLoader( |
test_loader = torch.utils.data.DataLoader( |
304 |
370 |
ListDataset(test_list, |
ListDataset(test_list, |
305 |
371 |
shuffle=False, |
shuffle=False, |
306 |
|
transform=transforms.Compose([ |
|
307 |
|
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
308 |
|
std=[0.229, 0.224, 0.225]), |
|
309 |
|
]), train=False, dataset_name=dataset_name), |
|
|
372 |
|
transform=transformer, |
|
373 |
|
train=False, |
|
374 |
|
dataset_name=dataset_name), |
310 |
375 |
batch_size=1) |
batch_size=1) |
311 |
376 |
else: |
else: |
312 |
377 |
test_loader = None |
test_loader = None |
File visualize_data_loader.py changed (mode: 100644) (index 09c9f3b..03c78fe) |
... |
... |
def visualize_ucf_cc_50_pacnn(): |
50 |
50 |
print("count3 ", label[2].numpy()[0].sum()) |
print("count3 ", label[2].numpy()[0].sum()) |
51 |
51 |
|
|
52 |
52 |
|
|
|
53 |
|
def visualize_shanghaitech_keepfull(): |
|
54 |
|
HARD_CODE = HardCodeVariable() |
|
55 |
|
shanghaitech_data = ShanghaiTechDataPath(root=HARD_CODE.SHANGHAITECH_PATH) |
|
56 |
|
shanghaitech_data_part_a_train = shanghaitech_data.get_a().get_train().get() |
|
57 |
|
saved_folder = "visualize/test_dataloader_shanghaitech" |
|
58 |
|
os.makedirs(saved_folder, exist_ok=True) |
|
59 |
|
train_list, val_list = get_train_val_list(shanghaitech_data_part_a_train, test_size=0.2) |
|
60 |
|
test_list = None |
|
61 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull", visualize_mode=True) |
|
62 |
|
|
|
63 |
|
# do with train loader |
|
64 |
|
train_loader_iter = iter(train_loader) |
|
65 |
|
for i in range(10): |
|
66 |
|
img, label = next(train_loader_iter) |
|
67 |
|
save_img(img, os.path.join(saved_folder, "train_img" + str(i) +".png")) |
|
68 |
|
save_density_map(label.numpy()[0][0], os.path.join(saved_folder, "train_label" + str(i) +".png")) |
|
69 |
|
|
|
70 |
|
|
|
71 |
|
|
53 |
72 |
def visualize_shanghaitech_pacnn_with_perspective(): |
def visualize_shanghaitech_pacnn_with_perspective(): |
54 |
73 |
HARD_CODE = HardCodeVariable() |
HARD_CODE = HardCodeVariable() |
55 |
74 |
shanghaitech_data = ShanghaiTechDataPath(root=HARD_CODE.SHANGHAITECH_PATH) |
shanghaitech_data = ShanghaiTechDataPath(root=HARD_CODE.SHANGHAITECH_PATH) |
|
... |
... |
def visualize_shanghaitech_pacnn_with_perspective(): |
95 |
114 |
print("s5 ", label[4].shape) |
print("s5 ", label[4].shape) |
96 |
115 |
|
|
97 |
116 |
if __name__ == "__main__": |
if __name__ == "__main__": |
98 |
|
visualize_shanghaitech_pacnn_with_perspective() |
|
|
117 |
|
# visualize_shanghaitech_pacnn_with_perspective() |
|
118 |
|
visualize_shanghaitech_keepfull() |