File data_flow.py changed (mode: 100644) (index 92c16da..4499e05) |
... |
... |
import torch |
15 |
15 |
import numpy as np |
import numpy as np |
16 |
16 |
from torch.utils.data import Dataset |
from torch.utils.data import Dataset |
17 |
17 |
from PIL import Image |
from PIL import Image |
|
18 |
|
import pandas as pd |
18 |
19 |
import torchvision.transforms.functional as F |
import torchvision.transforms.functional as F |
19 |
20 |
from torchvision import datasets, transforms |
from torchvision import datasets, transforms |
20 |
21 |
import scipy.io # import scipy does not work https://stackoverflow.com/questions/11172623/import-problems-with-scipy-io |
import scipy.io # import scipy does not work https://stackoverflow.com/questions/11172623/import-problems-with-scipy-io |
|
... |
... |
def count_gt_annotation_sha(mat_path): |
35 |
36 |
gt = mat["image_info"][0, 0][0, 0][0] |
gt = mat["image_info"][0, 0][0, 0][0] |
36 |
37 |
return len(gt) |
return len(gt) |
37 |
38 |
|
|
|
39 |
|
def count_gt_annotation_jhu(txt_path): |
|
40 |
|
""" |
|
41 |
|
read the annotation and count number of head from annotation |
|
42 |
|
:param mat_path: |
|
43 |
|
:return: count |
|
44 |
|
""" |
|
45 |
|
df = pd.read_csv(txt_path, sep=" ", header=None) |
|
46 |
|
p = df.to_numpy() |
|
47 |
|
return len(p) |
|
48 |
|
|
|
49 |
|
|
38 |
50 |
|
|
39 |
51 |
def create_training_image_list(data_path): |
def create_training_image_list(data_path): |
40 |
52 |
""" |
""" |
|
... |
... |
def load_data_ucf_cc50_pacnn(img_path, train=True): |
997 |
1009 |
return img, (target1, target2, target3) |
return img, (target1, target2, target3) |
998 |
1010 |
|
|
999 |
1011 |
|
|
|
1012 |
|
def load_data_jhucrowd_256(img_path, train=True, debug=False): |
|
1013 |
|
""" |
|
1014 |
|
for jhucrowd |
|
1015 |
|
crop fixed 256, allow batch in non-uniform dataset |
|
1016 |
|
:param img_path: |
|
1017 |
|
:param train: |
|
1018 |
|
:return: |
|
1019 |
|
""" |
|
1020 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
1021 |
|
img_origin = Image.open(img_path).convert('RGB') |
|
1022 |
|
gt_file = h5py.File(gt_path, 'r') |
|
1023 |
|
target = np.asarray(gt_file['density']) |
|
1024 |
|
target_factor = 8 |
|
1025 |
|
crop_sq_size = 256 |
|
1026 |
|
if train: |
|
1027 |
|
crop_size = (crop_sq_size, crop_sq_size) |
|
1028 |
|
dx = int(random.random() * (img_origin.size[0] - crop_sq_size)) |
|
1029 |
|
dy = int(random.random() * (img_origin.size[1] - crop_sq_size)) |
|
1030 |
|
if img_origin.size[0] - crop_sq_size < 0 or img_origin.size[1] - crop_sq_size < 0: # we crop more than we can chew, so... |
|
1031 |
|
return None, None |
|
1032 |
|
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
1033 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
1034 |
|
|
|
1035 |
|
if random.random() > 0.8: |
|
1036 |
|
target = np.fliplr(target) |
|
1037 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
1038 |
|
|
|
1039 |
|
if not train: |
|
1040 |
|
# get correct people head count from head annotation |
|
1041 |
|
txt_path = img_path.replace('.jpg', '.txt').replace('images', 'ground-truth') |
|
1042 |
|
gt_count = count_gt_annotation_jhu(txt_path) |
|
1043 |
|
if debug: |
|
1044 |
|
gt_file = h5py.File(gt_path, 'r') |
|
1045 |
|
target = np.asarray(gt_file['density']) |
|
1046 |
|
return img_origin, gt_count, target |
|
1047 |
|
return img_origin, gt_count |
|
1048 |
|
|
|
1049 |
|
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
|
1050 |
|
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
1051 |
|
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
|
1052 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
1053 |
|
return img, target1 |
|
1054 |
|
|
|
1055 |
|
|
1000 |
1056 |
def data_augmentation(img, target): |
def data_augmentation(img, target): |
1001 |
1057 |
""" |
""" |
1002 |
1058 |
return 1 pair of img, target after apply augmentation |
return 1 pair of img, target after apply augmentation |
|
... |
... |
class ListDataset(Dataset): |
1099 |
1155 |
self.load_data_fn = load_data_shanghaitech_non_overlap_downsample |
self.load_data_fn = load_data_shanghaitech_non_overlap_downsample |
1100 |
1156 |
elif dataset_name == "shanghaitech_flip_only": |
elif dataset_name == "shanghaitech_flip_only": |
1101 |
1157 |
self.load_data_fn = load_data_shanghaitech_flip_only |
self.load_data_fn = load_data_shanghaitech_flip_only |
1102 |
|
|
|
|
1158 |
|
elif dataset_name == "jhucrowd_256": |
|
1159 |
|
self.load_data_fn = load_data_jhucrowd_256 |
1103 |
1160 |
elif dataset_name == "ucf_cc_50": |
elif dataset_name == "ucf_cc_50": |
1104 |
1161 |
self.load_data_fn = load_data_ucf_cc50 |
self.load_data_fn = load_data_ucf_cc50 |
1105 |
1162 |
elif dataset_name == "ucf_cc_50_pacnn": |
elif dataset_name == "ucf_cc_50_pacnn": |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
1214 |
1271 |
transform=transformer, |
transform=transformer, |
1215 |
1272 |
train=False, |
train=False, |
1216 |
1273 |
debug=debug, |
debug=debug, |
1217 |
|
dataset_name=dataset_name, cache=cache), |
|
|
1274 |
|
dataset_name=dataset_name, cache=True), # evaluation set always cache |
1218 |
1275 |
num_workers=0, |
num_workers=0, |
1219 |
1276 |
batch_size=test_size, |
batch_size=test_size, |
1220 |
1277 |
pin_memory=pin_memory) |
pin_memory=pin_memory) |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
1228 |
1285 |
transform=transformer, |
transform=transformer, |
1229 |
1286 |
train=False, |
train=False, |
1230 |
1287 |
debug=debug, |
debug=debug, |
1231 |
|
dataset_name=dataset_name), |
|
|
1288 |
|
dataset_name=dataset_name, cache=True), # evaluation set always cache |
1232 |
1289 |
num_workers=0, |
num_workers=0, |
1233 |
1290 |
batch_size=test_size, |
batch_size=test_size, |
1234 |
1291 |
pin_memory=pin_memory) |
pin_memory=pin_memory) |