File data_flow.py changed (mode: 100644) (index aaf7b31..806ba3c) |
... |
... |
def load_data_shanghaitech_non_overlap(img_path, train=True): |
473 |
473 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
474 |
474 |
crop_img.append(img) |
crop_img.append(img) |
475 |
475 |
crop_label.append(target1) |
crop_label.append(target1) |
476 |
|
return crop_img, crop_label |
|
|
476 |
|
return crop_img, crop_label |
477 |
477 |
|
|
478 |
478 |
if not train: |
if not train: |
479 |
479 |
# get correct people head count from head annotation |
# get correct people head count from head annotation |
|
... |
... |
class ListDataset(Dataset): |
937 |
937 |
if img is None or target is None: |
if img is None or target is None: |
938 |
938 |
return None |
return None |
939 |
939 |
if self.transform is not None: |
if self.transform is not None: |
940 |
|
img = self.transform(img) |
|
|
940 |
|
if isinstance(img, list): |
|
941 |
|
# for case of generate multiple augmentation per sample |
|
942 |
|
img_r = [self.transform(img_item) for img_item in img] |
|
943 |
|
img = img_r |
|
944 |
|
else: |
|
945 |
|
img = self.transform(img) |
941 |
946 |
return img, target |
return img, target |
942 |
947 |
|
|
943 |
|
|
|
944 |
|
|
|
945 |
|
|
|
946 |
948 |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, train_loader_for_eval_check = False): |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, train_loader_for_eval_check = False): |
947 |
949 |
if visualize_mode: |
if visualize_mode: |
948 |
950 |
transformer = transforms.Compose([ |
transformer = transforms.Compose([ |
File data_util/dataset_utils.py changed (mode: 100644) (index 6e232b2..fdb1754) |
... |
... |
def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] c |
14 |
14 |
# so how to sample another dataset entry? |
# so how to sample another dataset entry? |
15 |
15 |
return torch.utils.data.dataloader.default_collate(batch) |
return torch.utils.data.dataloader.default_collate(batch) |
16 |
16 |
|
|
17 |
|
def flatten_collate(batch): |
|
|
17 |
|
def flatten_collate_broken(batch): |
18 |
18 |
""" |
""" |
19 |
19 |
|
|
20 |
|
:param batch: |
|
|
20 |
|
:param batch: tuple of (data, label) |
21 |
21 |
:return: |
:return: |
22 |
22 |
""" |
""" |
23 |
23 |
# remove null batch |
# remove null batch |
|
... |
... |
def flatten_collate(batch): |
30 |
30 |
return out_batch |
return out_batch |
31 |
31 |
|
|
32 |
32 |
|
|
|
33 |
|
def flatten_collate(batch): |
|
34 |
|
""" |
|
35 |
|
|
|
36 |
|
:param batch: tuple of (data, label) |
|
37 |
|
:return: |
|
38 |
|
""" |
|
39 |
|
# remove null batch |
|
40 |
|
batch = list(filter(lambda x: x is not None, batch)) |
|
41 |
|
|
|
42 |
|
# flattening array |
|
43 |
|
|
|
44 |
|
# more clarify version |
|
45 |
|
# out_batch = [] |
|
46 |
|
# for data_pair in batch: |
|
47 |
|
# for img, label in zip(*data_pair): |
|
48 |
|
# out_batch.append((img, label)) |
|
49 |
|
|
|
50 |
|
# python List Comprehensions |
|
51 |
|
out_batch = [(img, label) for data_pair in batch for img, label in zip(*data_pair)] |
|
52 |
|
|
|
53 |
|
return out_batch |
File data_util/test_dataset_utils.py changed (mode: 100644) (index 3229ece..f62719e) |
... |
... |
def test_flatten_collate_should_remove_null(): |
10 |
10 |
|
|
11 |
11 |
|
|
12 |
12 |
def test_flatten_list(): |
def test_flatten_list(): |
13 |
|
in_batch = [["s11", "s12", "s13"], ["s21", "s22", "s23"], ["s31", "s32", "s33"]] |
|
14 |
|
out_batch = ["s11", "s12", "s13", "s21", "s22", "s23", "s31", "s32", "s33"] |
|
|
13 |
|
in_batch = [(["d11", "d12", "d13"],["l11", "l12", "l13"]),(["d21", "d22", "d23"],["l21", "l22", "l23"]), (["d31", "d32", "d33"],["l31", "l32", "l33"])] |
|
14 |
|
out_batch = [("d11", "l11"), ("d12", "l12"), ("d13", "l13"), ("d21", "l21"), ("d22", "l22"), ("d23", "l23"), ("d31", "l31"), ("d32", "l32"), ("d33", "l33")] |
15 |
15 |
actual_output = flatten_collate(in_batch) |
actual_output = flatten_collate(in_batch) |
16 |
16 |
assert actual_output == out_batch |
assert actual_output == out_batch |
File debug/explore_shb_fatten_list.py copied from file debug/explore_shb.py (similarity 94%) (mode: 100644) (index d6814c7..c9fd3d5) |
... |
... |
if __name__ == "__main__": |
32 |
32 |
test_list = create_image_list(TEST_PATH) |
test_list = create_image_list(TEST_PATH) |
33 |
33 |
|
|
34 |
34 |
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, |
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, |
35 |
|
dataset_name="shanghaitech_more_random" |
|
|
35 |
|
dataset_name="shanghaitech_non_overlap" |
36 |
36 |
, batch_size=20, |
, batch_size=20, |
37 |
37 |
train_loader_for_eval_check=True) |
train_loader_for_eval_check=True) |
38 |
38 |
print(len(train_loader)) |
print(len(train_loader)) |
39 |
39 |
print(len(val_loader)) |
print(len(val_loader)) |
40 |
40 |
|
|
41 |
|
for data, label in val_loader: |
|
42 |
|
print(label) |
|
|
41 |
|
for obs in train_loader: |
|
42 |
|
print(len(obs)) |