File crowd_counting_error_metrics.py changed (mode: 100644) (index 1165f1a..42efe7b) |
... |
... |
class CrowdCountingMeanSquaredError(Metric): |
52 |
52 |
raise NotComputableError('MeanSquaredError must have at least one example before it can be computed.') |
raise NotComputableError('MeanSquaredError must have at least one example before it can be computed.') |
53 |
53 |
return math.sqrt(self._sum_of_squared_errors / self._num_examples) |
return math.sqrt(self._sum_of_squared_errors / self._num_examples) |
54 |
54 |
|
|
|
55 |
|
########################################### |
|
56 |
|
|
|
57 |
|
|
|
58 |
|
class CrowdCountingMeanAbsoluteErrorWithCount(Metric): |
|
59 |
|
""" |
|
60 |
|
Calculates the mean absolute error. |
|
61 |
|
Compare directly with original count |
|
62 |
|
|
|
63 |
|
- `update` must receive output of the form `(y_pred, y, count)`. |
|
64 |
|
""" |
|
65 |
|
def reset(self): |
|
66 |
|
self._sum_of_absolute_errors = 0.0 |
|
67 |
|
self._num_examples = 0 |
|
68 |
|
|
|
69 |
|
def update(self, output): |
|
70 |
|
y_pred, y, true_count = output |
|
71 |
|
pred_count = torch.sum(y_pred) |
|
72 |
|
# true_count = torch.sum(y) |
|
73 |
|
absolute_errors = torch.abs(pred_count - true_count) |
|
74 |
|
self._sum_of_absolute_errors += torch.sum(absolute_errors).item() |
|
75 |
|
self._num_examples += y.shape[0] |
|
76 |
|
|
|
77 |
|
def compute(self): |
|
78 |
|
if self._num_examples == 0: |
|
79 |
|
raise NotComputableError('MeanAbsoluteError must have at least one example before it can be computed.') |
|
80 |
|
return self._sum_of_absolute_errors / self._num_examples |
|
81 |
|
|
|
82 |
|
|
|
83 |
|
class CrowdCountingMeanSquaredErrorWithCount(Metric): |
|
84 |
|
""" |
|
85 |
|
Calculates the mean squared error. |
|
86 |
|
Compare directly with original count |
|
87 |
|
|
|
88 |
|
- `update` must receive output of the form `(y_pred, y, count)`. |
|
89 |
|
""" |
|
90 |
|
def reset(self): |
|
91 |
|
self._sum_of_squared_errors = 0.0 |
|
92 |
|
self._num_examples = 0 |
|
93 |
|
|
|
94 |
|
def update(self, output): |
|
95 |
|
y_pred, y, true_count = output |
|
96 |
|
pred_count = torch.sum(y_pred) |
|
97 |
|
# true_count = torch.sum(y) |
|
98 |
|
squared_errors = torch.pow(pred_count - true_count, 2) |
|
99 |
|
self._sum_of_squared_errors += torch.sum(squared_errors).item() |
|
100 |
|
self._num_examples += y.shape[0] |
|
101 |
|
|
|
102 |
|
def compute(self): |
|
103 |
|
if self._num_examples == 0: |
|
104 |
|
raise NotComputableError('MeanSquaredError must have at least one example before it can be computed.') |
|
105 |
|
return math.sqrt(self._sum_of_squared_errors / self._num_examples) |
55 |
106 |
|
|
File data_flow.py changed (mode: 100644) (index 5c329ef..9f5d04e) |
... |
... |
from torch.utils.data import Dataset |
17 |
17 |
from PIL import Image |
from PIL import Image |
18 |
18 |
import torchvision.transforms.functional as F |
import torchvision.transforms.functional as F |
19 |
19 |
from torchvision import datasets, transforms |
from torchvision import datasets, transforms |
|
20 |
|
import scipy |
|
21 |
|
|
20 |
22 |
|
|
21 |
23 |
""" |
""" |
22 |
24 |
create a list of file (full directory) |
create a list of file (full directory) |
23 |
25 |
""" |
""" |
24 |
26 |
|
|
|
27 |
|
def count_gt_annotation_sha(mat_path): |
|
28 |
|
""" |
|
29 |
|
read the annotation and count number of head from annotation |
|
30 |
|
:param mat_path: |
|
31 |
|
:return: count |
|
32 |
|
""" |
|
33 |
|
mat = scipy.io.loadmat(mat_path) |
|
34 |
|
gt = mat["image_info"][0, 0][0, 0][0] |
|
35 |
|
return len(gt) |
|
36 |
|
|
25 |
37 |
def create_training_image_list(data_path): |
def create_training_image_list(data_path): |
26 |
38 |
""" |
""" |
27 |
39 |
create a list of absolutely path of jpg file |
create a list of absolutely path of jpg file |
|
... |
... |
def load_data_shanghaitech_rnd(img_path, train=True): |
134 |
146 |
interpolation=cv2.INTER_CUBIC) * 64 |
interpolation=cv2.INTER_CUBIC) * 64 |
135 |
147 |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
136 |
148 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
149 |
|
|
|
150 |
|
if not train: |
|
151 |
|
# get correct people head count from head annotation |
|
152 |
|
mat_path = img_path.replace('.jpg', '.map').replace('images', 'ground-truth').replace('IMG', 'GT_IMG') |
|
153 |
|
gt_count = count_gt_annotation_sha(mat_path) |
|
154 |
|
return img, target1, gt_count |
|
155 |
|
|
137 |
156 |
return img, target1 |
return img, target1 |
138 |
157 |
|
|
139 |
158 |
|
|
|
... |
... |
def load_data_shanghaitech_20p(img_path, train=True): |
256 |
275 |
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
interpolation=cv2.INTER_CUBIC) * target_factor * target_factor |
257 |
276 |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
# target1 = target1.unsqueeze(0) # make dim (batch size, channel size, x, y) to make model output |
258 |
277 |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
target1 = np.expand_dims(target1, axis=0) # make dim (batch size, channel size, x, y) to make model output |
|
278 |
|
|
|
279 |
|
if not train: |
|
280 |
|
# get correct people head count from head annotation |
|
281 |
|
mat_path = img_path.replace('.jpg', '.map').replace('images', 'ground-truth').replace('IMG', 'GT_IMG') |
|
282 |
|
gt_count = count_gt_annotation_sha(mat_path) |
|
283 |
|
return img, target1, gt_count |
|
284 |
|
|
259 |
285 |
return img, target1 |
return img, target1 |
260 |
286 |
|
|
261 |
287 |
|
|
|
... |
... |
def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] co |
750 |
776 |
# so how to sample another dataset entry? |
# so how to sample another dataset entry? |
751 |
777 |
return torch.utils.data.dataloader.default_collate(batch) |
return torch.utils.data.dataloader.default_collate(batch) |
752 |
778 |
|
|
|
779 |
|
|
753 |
780 |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, train_loader_for_eval_check = False): |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, train_loader_for_eval_check = False): |
754 |
781 |
if visualize_mode: |
if visualize_mode: |
755 |
782 |
transformer = transforms.Compose([ |
transformer = transforms.Compose([ |