File data_flow.py changed (mode: 100644) (index c4ebf96..44b669e) |
... |
... |
def load_data(img_path, train=True): |
94 |
94 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
95 |
95 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
96 |
96 |
target = np.asarray(gt_file['density']) |
target = np.asarray(gt_file['density']) |
|
97 |
|
gt_file.close() |
97 |
98 |
|
|
98 |
99 |
target = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
target = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
99 |
100 |
interpolation=cv2.INTER_CUBIC) * 64 |
interpolation=cv2.INTER_CUBIC) * 64 |
|
... |
... |
def load_data_shanghaitech(img_path, train=True): |
106 |
107 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
107 |
108 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
108 |
109 |
target = np.asarray(gt_file['density']) |
target = np.asarray(gt_file['density']) |
|
110 |
|
gt_file.close() |
109 |
111 |
|
|
110 |
112 |
if train: |
if train: |
111 |
113 |
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
|
... |
... |
def load_data_shanghaitech_rnd(img_path, train=True): |
142 |
144 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
143 |
145 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
144 |
146 |
target = np.asarray(gt_file['density']) |
target = np.asarray(gt_file['density']) |
145 |
|
|
|
|
147 |
|
gt_file.close() |
146 |
148 |
if train: |
if train: |
147 |
149 |
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
crop_size = (int(img.size[0] / 2), int(img.size[1] / 2)) |
148 |
150 |
if random.randint(0, 9) <= 4: |
if random.randint(0, 9) <= 4: |
|
... |
... |
def load_data_shanghaitech_256(img_path, train=True): |
735 |
737 |
img = Image.open(img_path).convert('RGB') |
img = Image.open(img_path).convert('RGB') |
736 |
738 |
gt_file = h5py.File(gt_path, 'r') |
gt_file = h5py.File(gt_path, 'r') |
737 |
739 |
target = np.asarray(gt_file['density']) |
target = np.asarray(gt_file['density']) |
|
740 |
|
gt_file.close() |
738 |
741 |
target_factor = 8 |
target_factor = 8 |
739 |
742 |
crop_sq_size = 256 |
crop_sq_size = 256 |
740 |
743 |
if train: |
if train: |
|
... |
... |
def load_data_shanghaitech_256(img_path, train=True): |
756 |
759 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
757 |
760 |
return img, target1 |
return img, target1 |
758 |
761 |
|
|
|
762 |
|
def load_data_shanghaitech_256_v2(img_path, train=True): |
|
763 |
|
""" |
|
764 |
|
crop fixed 256, allow batch in non-uniform dataset |
|
765 |
|
:param img_path: |
|
766 |
|
:param train: |
|
767 |
|
:return: |
|
768 |
|
""" |
|
769 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
770 |
|
img_origin = Image.open(img_path).convert('RGB') |
|
771 |
|
gt_file = h5py.File(gt_path, 'r') |
|
772 |
|
target = np.asarray(gt_file['density']) |
|
773 |
|
gt_file.close() |
|
774 |
|
target_factor = 8 |
|
775 |
|
crop_sq_size = 256 |
|
776 |
|
if train: |
|
777 |
|
crop_size = (crop_sq_size, crop_sq_size) |
|
778 |
|
dx = int(random.random() * (img_origin.size[0] - crop_sq_size)) |
|
779 |
|
dy = int(random.random() * (img_origin.size[1] - crop_sq_size)) |
|
780 |
|
if img_origin.size[0] - crop_sq_size < 0 or img_origin.size[1] - crop_sq_size < 0: # we crop more than we can chew, so... |
|
781 |
|
# TODO if exception, do somehthing here |
|
782 |
|
return None, None |
|
783 |
|
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
784 |
|
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
785 |
|
|
|
786 |
|
if random.random() > 0.8: |
|
787 |
|
target = np.fliplr(target) |
|
788 |
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
789 |
|
|
|
790 |
|
if not train: |
|
791 |
|
# get correct people head count from head annotation |
|
792 |
|
mat_path = img_path.replace('.jpg', '.mat').replace('images', 'ground-truth').replace('IMG', 'GT_IMG') |
|
793 |
|
gt_count = count_gt_annotation_sha(mat_path) |
|
794 |
|
return img_origin, gt_count |
|
795 |
|
|
|
796 |
|
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
|
797 |
|
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
798 |
|
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
|
799 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
800 |
|
return img, target1 |
759 |
801 |
|
|
760 |
802 |
def load_data_shanghaitech_same_size_density_map(img_path, train=True): |
def load_data_shanghaitech_same_size_density_map(img_path, train=True): |
761 |
803 |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
... |
... |
class ListDataset(Dataset): |
1255 |
1297 |
self.load_data_fn = load_data_shanghaitech_180 |
self.load_data_fn = load_data_shanghaitech_180 |
1256 |
1298 |
elif dataset_name == "shanghaitech_256": |
elif dataset_name == "shanghaitech_256": |
1257 |
1299 |
self.load_data_fn = load_data_shanghaitech_256 |
self.load_data_fn = load_data_shanghaitech_256 |
|
1300 |
|
elif dataset_name == "shanghaitech_256_v2": |
|
1301 |
|
self.load_data_fn = load_data_shanghaitech_256_v2 |
1258 |
1302 |
elif dataset_name == "jhucrowd_downsample_512": |
elif dataset_name == "jhucrowd_downsample_512": |
1259 |
1303 |
self.load_data_fn = load_data_jhucrowd_downsample_512 |
self.load_data_fn = load_data_jhucrowd_downsample_512 |
1260 |
1304 |
elif dataset_name == "jhucrowd_downsample_testonly_512": |
elif dataset_name == "jhucrowd_downsample_testonly_512": |
File experiment_main.py changed (mode: 100644) (index 917f86e..f8ecc52) |
1 |
1 |
from comet_ml import Experiment |
from comet_ml import Experiment |
2 |
|
|
|
|
2 |
|
import sys |
3 |
3 |
from args_util import meow_parse, lr_scheduler_milestone_builder |
from args_util import meow_parse, lr_scheduler_milestone_builder |
4 |
4 |
from data_flow import get_dataloader, create_image_list |
from data_flow import get_dataloader, create_image_list |
5 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
... |
... |
if __name__ == "__main__": |
140 |
140 |
model = CompactCNNV9() |
model = CompactCNNV9() |
141 |
141 |
else: |
else: |
142 |
142 |
print("error: you didn't pick a model") |
print("error: you didn't pick a model") |
143 |
|
exit(-1) |
|
|
143 |
|
sys.exit(-1) |
144 |
144 |
n_param = very_simple_param_count(model) |
n_param = very_simple_param_count(model) |
145 |
145 |
experiment.log_other("n_param", n_param) |
experiment.log_other("n_param", n_param) |
146 |
146 |
if hasattr(model, 'model_note'): |
if hasattr(model, 'model_note'): |