File data_flow.py changed (mode: 100644) (index 4d6f927..2dffa53) |
... |
... |
def load_data_shanghaitech_non_overlap(img_path, train=True): |
464 |
464 |
if x==1: |
if x==1: |
465 |
465 |
target = np.fliplr(target) |
target = np.fliplr(target) |
466 |
466 |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
467 |
|
|
|
468 |
|
|
|
469 |
467 |
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
470 |
468 |
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
471 |
469 |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
472 |
470 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
473 |
471 |
crop_img.append(img) |
crop_img.append(img) |
474 |
472 |
crop_label.append(target1) |
crop_label.append(target1) |
|
473 |
|
# shuffle in pair |
|
474 |
|
tmp_pair = list(zip(crop_img, crop_label)) |
|
475 |
|
random.shuffle(tmp_pair) |
|
476 |
|
crop_img, crop_label = zip(*tmp_pair) |
475 |
477 |
return crop_img, crop_label |
return crop_img, crop_label |
476 |
478 |
|
|
477 |
479 |
if not train: |
if not train: |
|
... |
... |
def load_data_shanghaitech_non_overlap(img_path, train=True): |
481 |
483 |
return img_origin, gt_count |
return img_origin, gt_count |
482 |
484 |
|
|
483 |
485 |
|
|
|
486 |
|
def load_data_shanghaitech_non_overlap_noflip(img_path, train=True): |
|
487 |
|
""" |
|
488 |
|
per sample, crop 4, non-overlap |
|
489 |
|
:param img_path: |
|
490 |
|
:param train: |
|
491 |
|
:return: |
|
492 |
|
""" |
|
493 |
|
gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground-truth-h5') |
|
494 |
|
img_origin = Image.open(img_path).convert('RGB') |
|
495 |
|
crop_size = (int(img_origin.size[0] / 2), int(img_origin.size[1] / 2)) |
|
496 |
|
gt_file = h5py.File(gt_path, 'r') |
|
497 |
|
target_origin = np.asarray(gt_file['density']) |
|
498 |
|
target_factor = 8 |
|
499 |
|
|
|
500 |
|
if train: |
|
501 |
|
# for each image |
|
502 |
|
# create 8 patches, 4 non-overlap 4 corner |
|
503 |
|
# for each of 4 patch, create another 4 flip |
|
504 |
|
crop_img = [] |
|
505 |
|
crop_label = [] |
|
506 |
|
for i in range(2): |
|
507 |
|
for j in range(2): |
|
508 |
|
# crop non-overlap |
|
509 |
|
dx = int(i * img_origin.size[0] * 1. / 2) |
|
510 |
|
dy = int(j * img_origin.size[1] * 1. / 2) |
|
511 |
|
img = img_origin.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
|
512 |
|
target = target_origin[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
|
513 |
|
|
|
514 |
|
|
|
515 |
|
target1 = cv2.resize(target, (int(target.shape[1] / target_factor), int(target.shape[0] / target_factor)), |
|
516 |
|
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
|
517 |
|
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
518 |
|
crop_img.append(img) |
|
519 |
|
crop_label.append(target1) |
|
520 |
|
return crop_img, crop_label |
|
521 |
|
|
|
522 |
|
if not train: |
|
523 |
|
# get correct people head count from head annotation |
|
524 |
|
mat_path = img_path.replace('.jpg', '.mat').replace('images', 'ground-truth').replace('IMG', 'GT_IMG') |
|
525 |
|
gt_count = count_gt_annotation_sha(mat_path) |
|
526 |
|
return img_origin, gt_count |
|
527 |
|
|
484 |
528 |
def load_data_shanghaitech_crop_random(img_path, train=True): |
def load_data_shanghaitech_crop_random(img_path, train=True): |
485 |
529 |
""" |
""" |
486 |
530 |
40 percent crop |
40 percent crop |
File experiment_main.py changed (mode: 100644) (index b6b690d..1ea7cdf) |
... |
... |
from ignite.metrics import Loss |
7 |
7 |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
8 |
8 |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount |
9 |
9 |
from visualize_util import get_readable_time |
from visualize_util import get_readable_time |
10 |
|
|
|
|
10 |
|
from mse_l1_loss import MSEL1Loss |
11 |
11 |
import torch |
import torch |
12 |
12 |
from torch import nn |
from torch import nn |
13 |
13 |
from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4 |
from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4 |
|
... |
... |
if __name__ == "__main__": |
130 |
130 |
elif args.loss_fn == "MSEMean": |
elif args.loss_fn == "MSEMean": |
131 |
131 |
loss_fn = nn.MSELoss(reduction='mean').to(device) |
loss_fn = nn.MSELoss(reduction='mean').to(device) |
132 |
132 |
print("use MSEMean") |
print("use MSEMean") |
|
133 |
|
elif args.loss_fn == "MSEL1Mean": |
|
134 |
|
loss_fn = MSEL1Loss(reduction='mean').to(device) |
|
135 |
|
print("use MSEL1Mean") |
133 |
136 |
elif args.loss_fn == "MSENone": |
elif args.loss_fn == "MSENone": |
134 |
137 |
""" |
""" |
135 |
138 |
Doesnt work |
Doesnt work |