File data_flow.py changed (mode: 100644) (index c50eac3..a9da806) |
... |
... |
def load_data_shanghaitech_180(img_path, train=True): |
255 |
255 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
256 |
256 |
return img, target1 |
return img, target1 |
257 |
257 |
|
|
|
258 |
|
|
|
259 |
|
def load_data_shanghaitech_256(img_path, train=True): |
|
260 |
|
""" |
|
261 |
|
crop fixed 256, allow batch in non-uniform dataset |
|
262 |
|
:param img_path: |
|
263 |
|
:param train: |
|
264 |
|
:return: |
|
265 |
|
""" |
|
266 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
267 |
|
img = Image.open(img_path).convert('RGB') |
|
268 |
|
gt_file = h5py.File(gt_path, 'r') |
|
269 |
|
target = np.asarray(gt_file['density']) |
|
270 |
|
target_factor = 8 |
|
271 |
|
crop_sq_size = 256 |
|
272 |
|
if train: |
|
273 |
|
crop_size = (crop_sq_size, crop_sq_size) |
|
274 |
|
dx = int(random.random() * (img.size[0] - crop_sq_size)) |
|
275 |
|
dy = int(random.random() * (img.size[1] - crop_sq_size)) |
|
276 |
|
if dx < 0 or dy < 0: # we crop more than we can chew, so... |
|
277 |
|
return None, None |
|
278 |
|
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
279 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
280 |
|
|
|
281 |
|
if random.random() > 0.8: |
|
282 |
|
target = np.fliplr(target) |
|
283 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
284 |
|
|
|
285 |
|
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
|
286 |
|
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
287 |
|
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
|
288 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
289 |
|
return img, target1 |
|
290 |
|
|
258 |
291 |
def load_data_shanghaitech_same_size_density_map(img_path, train=True): |
def load_data_shanghaitech_same_size_density_map(img_path, train=True): |
259 |
292 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
260 |
293 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
|
... |
... |
class ListDataset(Dataset): |
555 |
588 |
self.load_data_fn = load_data_shanghaitech_20p_rnd |
self.load_data_fn = load_data_shanghaitech_20p_rnd |
556 |
589 |
elif dataset_name == "shanghaitech_180": |
elif dataset_name == "shanghaitech_180": |
557 |
590 |
self.load_data_fn = load_data_shanghaitech_180 |
self.load_data_fn = load_data_shanghaitech_180 |
558 |
|
|
|
|
591 |
|
elif dataset_name == "shanghaitech_256": |
|
592 |
|
self.load_data_fn = load_data_shanghaitech_256 |
559 |
593 |
elif dataset_name == "ucf_cc_50": |
elif dataset_name == "ucf_cc_50": |
560 |
594 |
self.load_data_fn = load_data_ucf_cc50 |
self.load_data_fn = load_data_ucf_cc50 |
561 |
595 |
elif dataset_name == "ucf_cc_50_pacnn": |
elif dataset_name == "ucf_cc_50_pacnn": |
|
... |
... |
class ListDataset(Dataset): |
574 |
608 |
if self.debug: |
if self.debug: |
575 |
609 |
print(img_path) |
print(img_path) |
576 |
610 |
img, target = self.load_data_fn(img_path, self.train) |
img, target = self.load_data_fn(img_path, self.train) |
|
611 |
|
if img is None or target is None: |
|
612 |
|
return None |
577 |
613 |
if self.transform is not None: |
if self.transform is not None: |
578 |
614 |
img = self.transform(img) |
img = self.transform(img) |
579 |
615 |
return img, target |
return img, target |
580 |
616 |
|
|
581 |
617 |
|
|
|
618 |
|
def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}] |
|
619 |
|
""" |
|
620 |
|
collate that ignore None |
|
621 |
|
However, if all sample is None, we have problem, so, set batch size bigger |
|
622 |
|
https://stackoverflow.com/questions/57815001/pytorch-collate-fn-reject-sample-and-yield-another |
|
623 |
|
:param batch: |
|
624 |
|
:return: |
|
625 |
|
""" |
|
626 |
|
batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}] |
|
627 |
|
# I want len(G) = 4 |
|
628 |
|
# so how to sample another dataset entry? |
|
629 |
|
return torch.utils.data.dataloader.default_collate(batch) |
|
630 |
|
|
582 |
631 |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1): |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1): |
583 |
632 |
if visualize_mode: |
if visualize_mode: |
584 |
633 |
transformer = transforms.Compose([ |
transformer = transforms.Compose([ |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
599 |
648 |
num_workers=0, |
num_workers=0, |
600 |
649 |
dataset_name=dataset_name), |
dataset_name=dataset_name), |
601 |
650 |
batch_size=batch_size, |
batch_size=batch_size, |
602 |
|
num_workers=4) |
|
|
651 |
|
num_workers=4, |
|
652 |
|
collate_fn=my_collate) |
603 |
653 |
|
|
604 |
654 |
if val_list is not None: |
if val_list is not None: |
605 |
655 |
val_loader = torch.utils.data.DataLoader( |
val_loader = torch.utils.data.DataLoader( |
File sanity_check_dataloader.py changed (mode: 100644) (index a1885b5..9f28698) |
... |
... |
if __name__ == "__main__": |
13 |
13 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
14 |
14 |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
15 |
15 |
dataset_name = args.datasetname |
dataset_name = args.datasetname |
|
16 |
|
dataset_name = "shanghaitech_256" |
16 |
17 |
|
|
17 |
|
|
|
|
18 |
|
count_below_256 = 0 |
18 |
19 |
# create list |
# create list |
19 |
20 |
train_list, val_list = get_train_val_list(TRAIN_PATH) |
train_list, val_list = get_train_val_list(TRAIN_PATH) |
20 |
21 |
test_list = None |
test_list = None |
21 |
22 |
|
|
22 |
23 |
# create data loader |
# create data loader |
23 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name) |
|
|
24 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name, batch_size=5) |
24 |
25 |
|
|
25 |
26 |
print("============== TRAIN LOADER ====================================================") |
print("============== TRAIN LOADER ====================================================") |
26 |
27 |
min_1 = 500 |
min_1 = 500 |
|
... |
... |
if __name__ == "__main__": |
33 |
34 |
min_1 = size_1 |
min_1 = size_1 |
34 |
35 |
if min_2 > size_2: |
if min_2 > size_2: |
35 |
36 |
min_2 = size_2 |
min_2 = size_2 |
|
37 |
|
if size_1 < 256 or size_2 < 256: |
|
38 |
|
count_below_256+=1 |
36 |
39 |
# example: img shape:torch.Size([1, 3, 716, 1024]) == label shape torch.Size([1, 1, 89, 128]) |
# example: img shape:torch.Size([1, 3, 716, 1024]) == label shape torch.Size([1, 1, 89, 128]) |
37 |
40 |
|
|
38 |
41 |
print("============== VAL LOADER ====================================================") |
print("============== VAL LOADER ====================================================") |
39 |
42 |
for img, label in val_loader: |
for img, label in val_loader: |
40 |
43 |
print("img shape:" + str(img.shape) + " == " + "label shape " + str(label.shape)) |
print("img shape:" + str(img.shape) + " == " + "label shape " + str(label.shape)) |
41 |
44 |
print(min_1) |
print(min_1) |
42 |
|
print(min_2) |
|
|
45 |
|
print(min_2) |
|
46 |
|
print("count < 256 = ", count_below_256) |