File data_flow.py changed (mode: 100644) (index d7fed10..23088d0) |
... |
... |
def load_data_jhucrowd_256(img_path, train=True, debug=False): |
1058 |
1058 |
return img, target1 |
return img, target1 |
1059 |
1059 |
|
|
1060 |
1060 |
|
|
1061 |
|
def load_data_jhucrowd_downsample_256(img_path, train=True, debug=False): |
|
|
1061 |
|
def load_data_jhucrowd_downsample_512(img_path, train=True, debug=False): |
1062 |
1062 |
""" |
""" |
1063 |
1063 |
for jhucrowd |
for jhucrowd |
1064 |
1064 |
crop fixed 256, allow batch in non-uniform dataset |
crop fixed 256, allow batch in non-uniform dataset |
|
... |
... |
def load_data_jhucrowd_downsample_256(img_path, train=True, debug=False): |
1068 |
1068 |
""" |
""" |
1069 |
1069 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
1070 |
1070 |
img_origin = Image.open(img_path).convert('RGB') |
img_origin = Image.open(img_path).convert('RGB') |
1071 |
|
# downsample by half |
|
1072 |
1071 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
1073 |
1072 |
target = np.asarray(gt_file['density']).astype('float32') |
target = np.asarray(gt_file['density']).astype('float32') |
1074 |
|
downsample_factor = 2 |
|
1075 |
|
target_factor = 8 / downsample_factor |
|
1076 |
|
crop_sq_size = 256 * downsample_factor |
|
|
1073 |
|
downsample_rate = 2 |
|
1074 |
|
target_factor = 8 * downsample_rate |
|
1075 |
|
crop_sq_size = 512 |
|
1076 |
|
|
1077 |
1077 |
if train: |
if train: |
1078 |
1078 |
crop_size = (crop_sq_size, crop_sq_size) |
crop_size = (crop_sq_size, crop_sq_size) |
1079 |
1079 |
dx = int(random.random() * (img_origin.size[0] - crop_sq_size)) |
dx = int(random.random() * (img_origin.size[0] - crop_sq_size)) |
|
... |
... |
def load_data_jhucrowd_downsample_256(img_path, train=True, debug=False): |
1081 |
1081 |
if img_origin.size[0] - crop_sq_size < 0 or img_origin.size[1] - crop_sq_size < 0: # we crop more than we can chew, so... |
if img_origin.size[0] - crop_sq_size < 0 or img_origin.size[1] - crop_sq_size < 0: # we crop more than we can chew, so... |
1082 |
1082 |
return None, None |
return None, None |
1083 |
1083 |
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
1084 |
|
img2 = img.resize((int(img.size[0] / 2), int(img.size[1] / 2)), resample=Image.ANTIALIAS) |
|
|
1084 |
|
img2 = img.resize((int(img.size[0] / downsample_rate), int(img.size[1] / downsample_rate)), |
|
1085 |
|
resample=Image.ANTIALIAS) |
1085 |
1086 |
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
1086 |
1087 |
|
|
1087 |
1088 |
if random.random() > 0.8: |
if random.random() > 0.8: |
|
... |
... |
def load_data_jhucrowd_downsample_256(img_path, train=True, debug=False): |
1092 |
1093 |
# get correct people head count from head annotation |
# get correct people head count from head annotation |
1093 |
1094 |
txt_path = img_path.replace('.jpg', '.txt').replace('images', 'ground-truth') |
txt_path = img_path.replace('.jpg', '.txt').replace('images', 'ground-truth') |
1094 |
1095 |
gt_count = count_gt_annotation_jhu(txt_path) |
gt_count = count_gt_annotation_jhu(txt_path) |
1095 |
|
img_out = img_origin.resize((int(img_origin.size[0] / 2), int(img_origin.size[1] / 2)), resample=Image.ANTIALIAS) |
|
|
1096 |
|
img_eval = img_origin.resize((int(img_origin.size[0] / downsample_rate), int(img_origin.size[1] / downsample_rate)), |
|
1097 |
|
resample=Image.ANTIALIAS) |
1096 |
1098 |
if debug: |
if debug: |
1097 |
1099 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
1098 |
1100 |
target = np.asarray(gt_file['density']) |
target = np.asarray(gt_file['density']) |
1099 |
|
return img_origin, gt_count, target |
|
1100 |
|
return img_out, gt_count |
|
|
1101 |
|
return img_eval, gt_count, target |
|
1102 |
|
return img_eval, gt_count |
1101 |
1103 |
|
|
1102 |
1104 |
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
1103 |
1105 |
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
... |
... |
class ListDataset(Dataset): |
1201 |
1203 |
self.load_data_fn = load_data_shanghaitech_180 |
self.load_data_fn = load_data_shanghaitech_180 |
1202 |
1204 |
elif dataset_name == "shanghaitech_256": |
elif dataset_name == "shanghaitech_256": |
1203 |
1205 |
self.load_data_fn = load_data_shanghaitech_256 |
self.load_data_fn = load_data_shanghaitech_256 |
1204 |
|
elif dataset_name == "jhucrowd_downsample_256": |
|
1205 |
|
self.load_data_fn = load_data_jhucrowd_downsample_256 |
|
|
1206 |
|
elif dataset_name == "jhucrowd_downsample_512": |
|
1207 |
|
self.load_data_fn = load_data_jhucrowd_downsample_512 |
1206 |
1208 |
elif dataset_name == "shanghaitech_non_overlap": |
elif dataset_name == "shanghaitech_non_overlap": |
1207 |
1209 |
self.load_data_fn = load_data_shanghaitech_non_overlap |
self.load_data_fn = load_data_shanghaitech_non_overlap |
1208 |
1210 |
elif dataset_name == "shanghaitech_non_overlap_downsample": |
elif dataset_name == "shanghaitech_non_overlap_downsample": |