File data_flow.py changed (mode: 100644) (index 19347e3..a7b2270) |
... |
... |
def load_data_shanghaitech_pacnn_with_perspective(img_path, train=True): |
114 |
114 |
:return: |
:return: |
115 |
115 |
""" |
""" |
116 |
116 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
117 |
|
p_path = img_path.replace(".jpg", ".mat").replace("images", "p_map") |
|
|
117 |
|
p_path = img_path.replace(".jpg", ".mat").replace("images", "pmap") |
118 |
118 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
119 |
119 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
120 |
120 |
target = np.asarray(gt_file['density']) |
target = np.asarray(gt_file['density']) |
121 |
|
perspective = np.array(h5py.File(p_path, "r")) |
|
122 |
|
|
|
|
121 |
|
perspective = np.array(h5py.File(p_path, "r")['pmap']) |
|
122 |
|
perspective = np.rot90(perspective, k=3) |
123 |
123 |
if train: |
if train: |
124 |
124 |
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
125 |
125 |
if random.randint(0, 9) <= -1: |
if random.randint(0, 9) <= -1: |
|
... |
... |
def load_data_shanghaitech_pacnn_with_perspective(img_path, train=True): |
132 |
132 |
|
|
133 |
133 |
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
134 |
134 |
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
135 |
|
perspective = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
|
135 |
|
perspective = perspective[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
136 |
136 |
if random.random() > 0.8: |
if random.random() > 0.8: |
137 |
137 |
target = np.fliplr(target) |
target = np.fliplr(target) |
138 |
138 |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
... |
... |
def load_data_shanghaitech_pacnn_with_perspective(img_path, train=True): |
148 |
148 |
perspective_s = cv2.resize(perspective, (int(perspective.shape[1] / 16), int(perspective.shape[0] / 16)), |
perspective_s = cv2.resize(perspective, (int(perspective.shape[1] / 16), int(perspective.shape[0] / 16)), |
149 |
149 |
interpolation=cv2.INTER_CUBIC) * 256 |
interpolation=cv2.INTER_CUBIC) * 256 |
150 |
150 |
|
|
151 |
|
perspective_m = cv2.resize(perspective, (int(perspective.shape[1] / 8), int(perspective.shape[0] / 8)), |
|
|
151 |
|
perspective_p = cv2.resize(perspective, (int(perspective.shape[1] / 8), int(perspective.shape[0] / 8)), |
152 |
152 |
interpolation=cv2.INTER_CUBIC) * 64 |
interpolation=cv2.INTER_CUBIC) * 64 |
153 |
153 |
|
|
154 |
|
return img, (target1, target2, target3, perspective_s, perspective_m) |
|
|
154 |
|
return img, (target1, target2, target3, perspective_s, perspective_p) |
|
155 |
|
|
155 |
156 |
|
|
156 |
157 |
def load_data_ucf_cc50_pacnn(img_path, train=True): |
def load_data_ucf_cc50_pacnn(img_path, train=True): |
157 |
158 |
""" |
""" |
|
... |
... |
class ListDataset(Dataset): |
258 |
259 |
self.load_data_fn = load_data_ucf_cc50_pacnn |
self.load_data_fn = load_data_ucf_cc50_pacnn |
259 |
260 |
elif dataset_name is "shanghaitech_pacnn": |
elif dataset_name is "shanghaitech_pacnn": |
260 |
261 |
self.load_data_fn = load_data_shanghaitech_pacnn |
self.load_data_fn = load_data_shanghaitech_pacnn |
|
262 |
|
elif dataset_name is "shanghaitech_pacnn_with_perspective": |
|
263 |
|
self.load_data_fn = load_data_shanghaitech_pacnn_with_perspective |
261 |
264 |
|
|
262 |
265 |
def __len__(self): |
def __len__(self): |
263 |
266 |
return self.nSamples |
return self.nSamples |
File visualize_data_loader.py changed (mode: 100644) (index b515bbd..f4aa432) |
... |
... |
def visualize_ucf_cc_50_pacnn(): |
48 |
48 |
print("count2 ", label[1].numpy()[0].sum()) |
print("count2 ", label[1].numpy()[0].sum()) |
49 |
49 |
print("count3 ", label[2].numpy()[0].sum()) |
print("count3 ", label[2].numpy()[0].sum()) |
50 |
50 |
|
|
|
51 |
|
def visualize_shanghaitech_pacnn_with_perspective(): |
|
52 |
|
HARD_CODE = HardCodeVariable() |
|
53 |
|
saved_folder = "visualize/test_dataloader" |
|
54 |
|
os.makedirs(saved_folder, exist_ok=True) |
|
55 |
|
DATA_PATH = HARD_CODE.SHANGHAITECH_PATH |
|
56 |
|
train_list, val_list = get_train_val_list(DATA_PATH, test_size=0.2) |
|
57 |
|
test_list = None |
|
58 |
|
|
|
59 |
|
# create data loader |
|
60 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="ucf_cc_50") |
|
61 |
|
train_loader_pacnn = torch.utils.data.DataLoader( |
|
62 |
|
ListDataset(train_list, |
|
63 |
|
shuffle=True, |
|
64 |
|
transform=transforms.Compose([ |
|
65 |
|
transforms.ToTensor() |
|
66 |
|
]), |
|
67 |
|
train=True, |
|
68 |
|
batch_size=1, |
|
69 |
|
num_workers=4, dataset_name="shanghaitech_pacnn_with_perspective", debug=True), |
|
70 |
|
batch_size=1, num_workers=4) |
|
71 |
|
|
|
72 |
|
img, label = next(iter(train_loader_pacnn)) |
|
73 |
|
|
|
74 |
|
print(img.shape) |
|
75 |
|
save_img(img, os.path.join(saved_folder, "pacnn_loader_img.png")) |
|
76 |
|
save_density_map(label[0].numpy()[0], os.path.join(saved_folder,"pacnn_loader_with_p_density1.png")) |
|
77 |
|
save_density_map(label[1].numpy()[0], os.path.join(saved_folder,"pacnn_loader_with_p_density2.png")) |
|
78 |
|
save_density_map(label[2].numpy()[0], os.path.join(saved_folder,"pacnn_loader_with_p_density3.png")) |
|
79 |
|
save_density_map(label[3].numpy()[0], os.path.join(saved_folder, "pacnn_loader_p_s_4.png")) |
|
80 |
|
save_density_map(label[4].numpy()[0], os.path.join(saved_folder, "pacnn_loader_p_5.png")) |
|
81 |
|
print("count1 ", label[0].numpy()[0].sum()) |
|
82 |
|
print("count2 ", label[1].numpy()[0].sum()) |
|
83 |
|
print("count3 ", label[2].numpy()[0].sum()) |
|
84 |
|
print("count4 ", label[3].numpy()[0].sum()) |
|
85 |
|
print("count5 ", label[4].numpy()[0].sum()) |
|
86 |
|
|
|
87 |
|
print("s1 ", label[0].shape) |
|
88 |
|
print("s2 ", label[1].shape) |
|
89 |
|
print("s3 ", label[2].shape) |
|
90 |
|
print("s4 ", label[3].shape) |
|
91 |
|
print("s5 ", label[4].shape) |
51 |
92 |
|
|
52 |
93 |
if __name__ == "__main__": |
if __name__ == "__main__": |
53 |
|
visualize_ucf_cc_50_pacnn() |
|
|
94 |
|
visualize_shanghaitech_pacnn_with_perspective() |