File data_flow.py changed (mode: 100644) (index 806ba3c..eea892e) |
... |
... |
def load_data_shanghaitech_non_overlap(img_path, train=True): |
439 |
439 |
:return: |
:return: |
440 |
440 |
""" |
""" |
441 |
441 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
442 |
|
img = Image.open(img_path).convert('RGB') |
|
|
442 |
|
img_origin = Image.open(img_path).convert('RGB') |
|
443 |
|
crop_size = (int(img_origin.size[0] / 2), int(img_origin.size[1] / 2)) |
443 |
444 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
444 |
|
target = np.asarray(gt_file['density']) |
|
|
445 |
|
target_origin = np.asarray(gt_file['density']) |
445 |
446 |
target_factor = 8 |
target_factor = 8 |
446 |
447 |
|
|
447 |
448 |
if train: |
if train: |
|
... |
... |
def load_data_shanghaitech_non_overlap(img_path, train=True): |
452 |
453 |
crop_label = [] |
crop_label = [] |
453 |
454 |
for i in range(2): |
for i in range(2): |
454 |
455 |
for j in range(2): |
for j in range(2): |
455 |
|
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
|
456 |
|
|
|
457 |
456 |
# crop non-overlap |
# crop non-overlap |
458 |
|
dx = int(i * img.size[0] * 1. / 2) |
|
459 |
|
dy = int(j * img.size[1] * 1. / 2) |
|
460 |
|
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
461 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
|
457 |
|
dx = int(i * img_origin.size[0] * 1. / 2) |
|
458 |
|
dy = int(j * img_origin.size[1] * 1. / 2) |
|
459 |
|
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
460 |
|
target = target_origin[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
462 |
461 |
|
|
463 |
462 |
# flip |
# flip |
464 |
463 |
for x in range(2): |
for x in range(2): |
File data_util/dataset_utils.py changed (mode: 100644) (index fdb1754..c8e710e) |
... |
... |
def flatten_collate_broken(batch): |
30 |
30 |
return out_batch |
return out_batch |
31 |
31 |
|
|
32 |
32 |
|
|
33 |
|
def flatten_collate(batch): |
|
|
33 |
|
def _flatten_collate(batch): |
34 |
34 |
""" |
""" |
35 |
35 |
|
|
36 |
|
:param batch: tuple of (data, label) |
|
37 |
|
:return: |
|
|
36 |
|
:param batch: tuple of (data, label) with type(data) == list, type(label) == list |
|
37 |
|
:return: flatten data, label |
38 |
38 |
""" |
""" |
39 |
39 |
# remove null batch |
# remove null batch |
40 |
40 |
batch = list(filter(lambda x: x is not None, batch)) |
batch = list(filter(lambda x: x is not None, batch)) |
|
... |
... |
def flatten_collate(batch): |
51 |
51 |
out_batch = [(img, label) for data_pair in batch for img, label in zip(*data_pair)] |
out_batch = [(img, label) for data_pair in batch for img, label in zip(*data_pair)] |
52 |
52 |
|
|
53 |
53 |
return out_batch |
return out_batch |
|
54 |
|
|
|
55 |
|
|
|
56 |
|
def flatten_collate(batch): |
|
57 |
|
""" |
|
58 |
|
|
|
59 |
|
:param batch: tuple of (data, label) with type(data) == list, type(label) == list |
|
60 |
|
:return: flatten data, label |
|
61 |
|
""" |
|
62 |
|
# remove null batch |
|
63 |
|
batch1 = _flatten_collate(batch) |
|
64 |
|
out_batch = torch.utils.data.dataloader.default_collate(batch1) |
|
65 |
|
return out_batch |
File data_util/test_dataset_utils.py changed (mode: 100644) (index f62719e..744f4cd) |
1 |
1 |
import pytest |
import pytest |
2 |
|
from data_util.dataset_utils import flatten_collate |
|
|
2 |
|
from data_util.dataset_utils import _flatten_collate |
3 |
3 |
|
|
4 |
4 |
|
|
5 |
|
def test_flatten_collate_should_remove_null(): |
|
6 |
|
in_batch = [None, "a", "b", None, "c"] |
|
7 |
|
expected_output = ["a", "b", "c"] |
|
8 |
|
actual_output = flatten_collate(in_batch) |
|
9 |
|
assert actual_output == expected_output |
|
|
5 |
|
# def test_flatten_collate_should_remove_null(): |
|
6 |
|
# in_batch = [None, "a", "b", None, "c"] |
|
7 |
|
# expected_output = ["a", "b", "c"] |
|
8 |
|
# actual_output = _flatten_collate(in_batch) |
|
9 |
|
# assert actual_output == expected_output |
10 |
10 |
|
|
11 |
11 |
|
|
12 |
12 |
def test_flatten_list(): |
def test_flatten_list(): |
13 |
13 |
in_batch = [(["d11", "d12", "d13"],["l11", "l12", "l13"]),(["d21", "d22", "d23"],["l21", "l22", "l23"]), (["d31", "d32", "d33"],["l31", "l32", "l33"])] |
in_batch = [(["d11", "d12", "d13"],["l11", "l12", "l13"]),(["d21", "d22", "d23"],["l21", "l22", "l23"]), (["d31", "d32", "d33"],["l31", "l32", "l33"])] |
14 |
14 |
out_batch = [("d11", "l11"), ("d12", "l12"), ("d13", "l13"), ("d21", "l21"), ("d22", "l22"), ("d23", "l23"), ("d31", "l31"), ("d32", "l32"), ("d33", "l33")] |
out_batch = [("d11", "l11"), ("d12", "l12"), ("d13", "l13"), ("d21", "l21"), ("d22", "l22"), ("d23", "l23"), ("d31", "l31"), ("d32", "l32"), ("d33", "l33")] |
15 |
|
actual_output = flatten_collate(in_batch) |
|
16 |
|
assert actual_output == out_batch |
|
|
15 |
|
actual_output = _flatten_collate(in_batch) |
|
16 |
|
assert actual_output == out_batch |
|
17 |
|
for data in actual_output: |
|
18 |
|
print(len(data)) |
File debug/explore_shb_fatten_list.py changed (mode: 100644) (index c9fd3d5..6e74d94) |
... |
... |
if __name__ == "__main__": |
31 |
31 |
val_list = create_image_list(VAL_PATH) |
val_list = create_image_list(VAL_PATH) |
32 |
32 |
test_list = create_image_list(TEST_PATH) |
test_list = create_image_list(TEST_PATH) |
33 |
33 |
|
|
|
34 |
|
# train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, |
|
35 |
|
# dataset_name="shanghaitech_more_random" |
|
36 |
|
# , batch_size=1, |
|
37 |
|
# train_loader_for_eval_check=True) |
|
38 |
|
|
34 |
39 |
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, |
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, |
35 |
40 |
dataset_name="shanghaitech_non_overlap" |
dataset_name="shanghaitech_non_overlap" |
36 |
|
, batch_size=20, |
|
|
41 |
|
, batch_size=1, |
37 |
42 |
train_loader_for_eval_check=True) |
train_loader_for_eval_check=True) |
38 |
43 |
print(len(train_loader)) |
print(len(train_loader)) |
39 |
44 |
print(len(val_loader)) |
print(len(val_loader)) |
40 |
45 |
|
|
41 |
|
for obs in train_loader: |
|
42 |
|
print(len(obs)) |
|
|
46 |
|
for img, label in train_loader: |
|
47 |
|
print(img.shape, label.shape) |