File data_flow.py changed (mode: 100644) (index 7763ae4..1260368) |
... |
... |
def load_data_shanghaitech_keepfull(img_path, train=True): |
147 |
147 |
return img, target1 |
return img, target1 |
148 |
148 |
|
|
149 |
149 |
|
|
|
150 |
|
def load_data_shanghaitech_keepfull_and_crop(img_path, train=True): |
|
151 |
|
""" |
|
152 |
|
loader might give full image, or crop |
|
153 |
|
:param img_path: |
|
154 |
|
:param train: |
|
155 |
|
:return: |
|
156 |
|
""" |
|
157 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
158 |
|
img = Image.open(img_path).convert('RGB') |
|
159 |
|
gt_file = h5py.File(gt_path, 'r') |
|
160 |
|
target = np.asarray(gt_file['density']) |
|
161 |
|
|
|
162 |
|
if train: |
|
163 |
|
|
|
164 |
|
if random.random() > 0.5: # 50% chance crop |
|
165 |
|
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
|
166 |
|
if random.randint(0, 9) <= -1: |
|
167 |
|
|
|
168 |
|
dx = int(random.randint(0, 1) * img.size[0] * 1. / 2) |
|
169 |
|
dy = int(random.randint(0, 1) * img.size[1] * 1. / 2) |
|
170 |
|
else: |
|
171 |
|
dx = int(random.random() * img.size[0] * 1. / 2) |
|
172 |
|
dy = int(random.random() * img.size[1] * 1. / 2) |
|
173 |
|
|
|
174 |
|
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
175 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
176 |
|
|
|
177 |
|
if random.random() > 0.8: # 20 % chance flip |
|
178 |
|
target = np.fliplr(target) |
|
179 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
180 |
|
|
|
181 |
|
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
|
182 |
|
interpolation=cv2.INTER_CUBIC) * 64 |
|
183 |
|
|
|
184 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
185 |
|
# np.expand_dims(target1, axis=0) # again |
|
186 |
|
return img, target1 |
|
187 |
|
|
|
188 |
|
|
|
189 |
|
|
150 |
190 |
def load_data_ucf_cc50(img_path, train=True): |
def load_data_ucf_cc50(img_path, train=True): |
151 |
191 |
gt_path = img_path.replace('.jpg', '.h5') |
gt_path = img_path.replace('.jpg', '.h5') |
152 |
192 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
|
... |
... |
class ListDataset(Dataset): |
346 |
386 |
# load data fn |
# load data fn |
347 |
387 |
if dataset_name == "shanghaitech": |
if dataset_name == "shanghaitech": |
348 |
388 |
self.load_data_fn = load_data_shanghaitech |
self.load_data_fn = load_data_shanghaitech |
349 |
|
if dataset_name == "shanghaitech_same_size_density_map": |
|
|
389 |
|
elif dataset_name == "shanghaitech_same_size_density_map": |
350 |
390 |
self.load_data_fn = load_data_shanghaitech_same_size_density_map |
self.load_data_fn = load_data_shanghaitech_same_size_density_map |
351 |
|
if dataset_name == "shanghaitech_keepfull": |
|
|
391 |
|
elif dataset_name == "shanghaitech_keepfull": |
352 |
392 |
self.load_data_fn = load_data_shanghaitech_keepfull |
self.load_data_fn = load_data_shanghaitech_keepfull |
|
393 |
|
elif dataset_name == "shanghaitech_keepfull_and_crop": |
|
394 |
|
self.load_data_fn = load_data_shanghaitech_keepfull_and_crop |
353 |
395 |
elif dataset_name == "ucf_cc_50": |
elif dataset_name == "ucf_cc_50": |
354 |
396 |
self.load_data_fn = load_data_ucf_cc50 |
self.load_data_fn = load_data_ucf_cc50 |
355 |
397 |
elif dataset_name == "ucf_cc_50_pacnn": |
elif dataset_name == "ucf_cc_50_pacnn": |