File data_flow.py changed (mode: 100644) (index 0418e46..3347278) |
... |
... |
def load_data_jhucrowd_256(img_path, train=True, debug=False): |
1058 |
1058 |
return img, target1 |
return img, target1 |
1059 |
1059 |
|
|
1060 |
1060 |
|
|
|
1061 |
|
def load_data_jhucrowd_downsample_256(img_path, train=True, debug=False): |
|
1062 |
|
""" |
|
1063 |
|
for jhucrowd |
|
1064 |
|
crop fixed 256, allow batch in non-uniform dataset |
|
1065 |
|
:param img_path: |
|
1066 |
|
:param train: |
|
1067 |
|
:return: |
|
1068 |
|
""" |
|
1069 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
1070 |
|
img_origin = Image.open(img_path).convert('RGB') |
|
1071 |
|
# downsample by half |
|
1072 |
|
gt_file = h5py.File(gt_path, 'r') |
|
1073 |
|
target = np.asarray(gt_file['density']).astype('float32') |
|
1074 |
|
downsample_factor = 2 |
|
1075 |
|
target_factor = 8 / downsample_factor |
|
1076 |
|
crop_sq_size = 256 * downsample_factor |
|
1077 |
|
if train: |
|
1078 |
|
crop_size = (crop_sq_size, crop_sq_size) |
|
1079 |
|
dx = int(random.random() * (img_origin.size[0] - crop_sq_size)) |
|
1080 |
|
dy = int(random.random() * (img_origin.size[1] - crop_sq_size)) |
|
1081 |
|
if img_origin.size[0] - crop_sq_size < 0 or img_origin.size[1] - crop_sq_size < 0: # we crop more than we can chew, so... |
|
1082 |
|
return None, None |
|
1083 |
|
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
1084 |
|
img2 = img.resize((int(img_origin.size[0] / 2), int(img_origin.size[1] / 2)), resample=Image.ANTIALIAS) |
|
1085 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
1086 |
|
|
|
1087 |
|
if random.random() > 0.8: |
|
1088 |
|
target = np.fliplr(target) |
|
1089 |
|
img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) |
|
1090 |
|
|
|
1091 |
|
if not train: |
|
1092 |
|
# get correct people head count from head annotation |
|
1093 |
|
txt_path = img_path.replace('.jpg', '.txt').replace('images', 'ground-truth') |
|
1094 |
|
gt_count = count_gt_annotation_jhu(txt_path) |
|
1095 |
|
img_out = img_origin.resize((int(img_origin.size[0] / 2), int(img_origin.size[1] / 2)), resample=Image.ANTIALIAS) |
|
1096 |
|
if debug: |
|
1097 |
|
gt_file = h5py.File(gt_path, 'r') |
|
1098 |
|
target = np.asarray(gt_file['density']) |
|
1099 |
|
return img_origin, gt_count, target |
|
1100 |
|
return img_out, gt_count |
|
1101 |
|
|
|
1102 |
|
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
|
1103 |
|
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
1104 |
|
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
|
1105 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
1106 |
|
return img, target1 |
|
1107 |
|
|
1061 |
1108 |
def data_augmentation(img, target): |
def data_augmentation(img, target): |
1062 |
1109 |
""" |
""" |
1063 |
1110 |
return 1 pair of img, target after apply augmentation |
return 1 pair of img, target after apply augmentation |
|
... |
... |
class ListDataset(Dataset): |
1154 |
1201 |
self.load_data_fn = load_data_shanghaitech_180 |
self.load_data_fn = load_data_shanghaitech_180 |
1155 |
1202 |
elif dataset_name == "shanghaitech_256": |
elif dataset_name == "shanghaitech_256": |
1156 |
1203 |
self.load_data_fn = load_data_shanghaitech_256 |
self.load_data_fn = load_data_shanghaitech_256 |
|
1204 |
|
elif dataset_name == "jhucrowd_downsample_256": |
|
1205 |
|
self.load_data_fn = load_data_jhucrowd_downsample_256 |
1157 |
1206 |
elif dataset_name == "shanghaitech_non_overlap": |
elif dataset_name == "shanghaitech_non_overlap": |
1158 |
1207 |
self.load_data_fn = load_data_shanghaitech_non_overlap |
self.load_data_fn = load_data_shanghaitech_non_overlap |
1159 |
1208 |
elif dataset_name == "shanghaitech_non_overlap_downsample": |
elif dataset_name == "shanghaitech_non_overlap_downsample": |