File data_flow.py changed (mode: 100644) (index 0ea457e..07a3340) |
... |
... |
from PIL import Image |
18 |
18 |
import torchvision.transforms.functional as F |
import torchvision.transforms.functional as F |
19 |
19 |
from torchvision import datasets, transforms |
from torchvision import datasets, transforms |
20 |
20 |
import scipy.io # import scipy does not work https://stackoverflow.com/questions/11172623/import-problems-with-scipy-io |
import scipy.io # import scipy does not work https://stackoverflow.com/questions/11172623/import-problems-with-scipy-io |
21 |
|
from data_util.dataset_utils import my_collate |
|
|
21 |
|
from data_util.dataset_utils import my_collate, flatten_collate |
22 |
22 |
|
|
23 |
23 |
""" |
""" |
24 |
24 |
create a list of file (full directory) |
create a list of file (full directory) |
|
... |
... |
def load_data_shanghaitech_60p_random(img_path, train=True): |
430 |
430 |
|
|
431 |
431 |
return img, target1 |
return img, target1 |
432 |
432 |
|
|
|
433 |
|
|
|
434 |
|
def load_data_shanghaitech_non_overlap(img_path, train=True): |
|
435 |
|
""" |
|
436 |
|
per sample, crop 4, non-overlap |
|
437 |
|
:param img_path: |
|
438 |
|
:param train: |
|
439 |
|
:return: |
|
440 |
|
""" |
|
441 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
442 |
|
img = Image.open(img_path).convert('RGB') |
|
443 |
|
gt_file = h5py.File(gt_path, 'r') |
|
444 |
|
target = np.asarray(gt_file['density']) |
|
445 |
|
target_factor = 8 |
|
446 |
|
|
|
447 |
|
if train: |
|
448 |
|
# for each image |
|
449 |
|
# create 8 patches, 4 non-overlap 4 corner |
|
450 |
|
# for each of 4 patch, create another 4 flip |
|
451 |
|
crop_img = [] |
|
452 |
|
crop_label = [] |
|
453 |
|
for i in range(2): |
|
454 |
|
for j in range(2): |
|
455 |
|
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
|
456 |
|
|
|
457 |
|
# crop non-overlap |
|
458 |
|
dx = int(i * img.size[0] * 1. / 2) |
|
459 |
|
dy = int(j * img.size[1] * 1. / 2) |
|
460 |
|
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
461 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
462 |
|
|
|
463 |
|
# flip |
|
464 |
|
for x in range(2): |
|
465 |
|
if x==1: |
|
466 |
|
target = np.fliplr(target) |
|
467 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
468 |
|
|
|
469 |
|
|
|
470 |
|
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
|
471 |
|
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
472 |
|
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
|
473 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
474 |
|
crop_img.append(img) |
|
475 |
|
crop_label.append(target1) |
|
476 |
|
return crop_img, crop_label |
|
477 |
|
|
|
478 |
|
if not train: |
|
479 |
|
# get correct people head count from head annotation |
|
480 |
|
mat_path = img_path.replace('.jpg', '.mat').replace('images', 'ground-truth').replace('IMG', 'GT_IMG') |
|
481 |
|
gt_count = count_gt_annotation_sha(mat_path) |
|
482 |
|
return img, gt_count |
|
483 |
|
|
|
484 |
|
|
433 |
485 |
def load_data_shanghaitech_crop_random(img_path, train=True): |
def load_data_shanghaitech_crop_random(img_path, train=True): |
434 |
486 |
""" |
""" |
435 |
487 |
40 percent crop |
40 percent crop |
|
... |
... |
def load_data_shanghaitech_same_size_density_map(img_path, train=True): |
568 |
620 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
569 |
621 |
return img, target1 |
return img, target1 |
570 |
622 |
|
|
|
623 |
|
|
571 |
624 |
def load_data_shanghaitech_keepfull(img_path, train=True): |
def load_data_shanghaitech_keepfull(img_path, train=True): |
572 |
625 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
573 |
626 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
|
... |
... |
class ListDataset(Dataset): |
857 |
910 |
self.load_data_fn = load_data_shanghaitech_180 |
self.load_data_fn = load_data_shanghaitech_180 |
858 |
911 |
elif dataset_name == "shanghaitech_256": |
elif dataset_name == "shanghaitech_256": |
859 |
912 |
self.load_data_fn = load_data_shanghaitech_256 |
self.load_data_fn = load_data_shanghaitech_256 |
|
913 |
|
elif dataset_name == "shanghaitech_non_overlap": |
|
914 |
|
self.load_data_fn = load_data_shanghaitech_non_overlap |
860 |
915 |
elif dataset_name == "ucf_cc_50": |
elif dataset_name == "ucf_cc_50": |
861 |
916 |
self.load_data_fn = load_data_ucf_cc50 |
self.load_data_fn = load_data_ucf_cc50 |
862 |
917 |
elif dataset_name == "ucf_cc_50_pacnn": |
elif dataset_name == "ucf_cc_50_pacnn": |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
894 |
949 |
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], |
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], |
895 |
950 |
std=[0.229, 0.224, 0.225]), |
std=[0.229, 0.224, 0.225]), |
896 |
951 |
]) |
]) |
897 |
|
|
|
|
952 |
|
train_collate_fn = my_collate |
|
953 |
|
if dataset_name == "shanghaitech_non_overlap": |
|
954 |
|
train_collate_fn = flatten_collate |
898 |
955 |
train_loader = torch.utils.data.DataLoader( |
train_loader = torch.utils.data.DataLoader( |
899 |
956 |
ListDataset(train_list, |
ListDataset(train_list, |
900 |
957 |
shuffle=True, |
shuffle=True, |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
905 |
962 |
dataset_name=dataset_name), |
dataset_name=dataset_name), |
906 |
963 |
batch_size=batch_size, |
batch_size=batch_size, |
907 |
964 |
num_workers=0, |
num_workers=0, |
908 |
|
collate_fn=my_collate, pin_memory=False) |
|
|
965 |
|
collate_fn=train_collate_fn, pin_memory=False) |
909 |
966 |
|
|
910 |
967 |
train_loader_for_eval = torch.utils.data.DataLoader( |
train_loader_for_eval = torch.utils.data.DataLoader( |
911 |
968 |
ListDataset(train_list, |
ListDataset(train_list, |