File crowd_counting_error_metrics.py changed (mode: 100644) (index ea8ceba..f8d2d31) |
... |
... |
class CrowdCountingMeanSquaredErrorWithCount(Metric): |
112 |
112 |
import piq |
import piq |
113 |
113 |
|
|
114 |
114 |
|
|
115 |
|
class CrowdCountingMeanSSIM(Metric): |
|
|
115 |
|
class CrowdCountingMeanSSIMabs(Metric): |
116 |
116 |
""" |
""" |
117 |
117 |
Calculates ssim |
Calculates ssim |
118 |
118 |
require package https://github.com/photosynthesis-team/piq |
require package https://github.com/photosynthesis-team/piq |
|
... |
... |
class CrowdCountingMeanSSIM(Metric): |
144 |
144 |
return self._sum / self._num_examples |
return self._sum / self._num_examples |
145 |
145 |
|
|
146 |
146 |
|
|
147 |
|
class CrowdCountingMeanPSNR(Metric): |
|
|
147 |
|
class CrowdCountingMeanPSNRabs(Metric): |
148 |
148 |
""" |
""" |
149 |
149 |
Calculates ssim |
Calculates ssim |
150 |
150 |
require package https://github.com/photosynthesis-team/piq |
require package https://github.com/photosynthesis-team/piq |
|
... |
... |
class CrowdCountingMeanPSNR(Metric): |
176 |
176 |
raise NotComputableError('CrowdCountingMeanPSNR must have at least one example before it can be computed.') |
raise NotComputableError('CrowdCountingMeanPSNR must have at least one example before it can be computed.') |
177 |
177 |
return self._sum / self._num_examples |
return self._sum / self._num_examples |
178 |
178 |
|
|
|
179 |
|
#################3 |
|
180 |
|
|
|
181 |
|
|
|
182 |
|
class CrowdCountingMeanSSIMclamp(Metric): |
|
183 |
|
""" |
|
184 |
|
Calculates ssim |
|
185 |
|
require package https://github.com/photosynthesis-team/piq |
|
186 |
|
- `update` must receive output of the form `(y_pred, y)`. |
|
187 |
|
""" |
|
188 |
|
def reset(self): |
|
189 |
|
self._sum = 0.0 |
|
190 |
|
self._num_examples = 0 |
|
191 |
|
|
|
192 |
|
def update(self, output): |
|
193 |
|
y_pred = output[0] |
|
194 |
|
y = output[1] |
|
195 |
|
y_pred = torch.clamp_min(y_pred, min=0.0) |
|
196 |
|
y = torch.clamp_min(y, min=0.0) |
|
197 |
|
|
|
198 |
|
|
|
199 |
|
ssim_metric = piq.ssim(y, y_pred) |
|
200 |
|
|
|
201 |
|
self._sum += ssim_metric.item() * y.shape[0] |
|
202 |
|
# we multiply because ssim calculate mean of each image in batch |
|
203 |
|
# we multiply so we will divide correctly |
|
204 |
|
|
|
205 |
|
self._num_examples += y.shape[0] |
|
206 |
|
|
|
207 |
|
def compute(self): |
|
208 |
|
if self._num_examples == 0: |
|
209 |
|
raise NotComputableError('CrowdCountingMeanSSIM must have at least one example before it can be computed.') |
|
210 |
|
return self._sum / self._num_examples |
|
211 |
|
|
|
212 |
|
|
|
213 |
|
class CrowdCountingMeanPSNRclamp(Metric): |
|
214 |
|
""" |
|
215 |
|
Calculates ssim |
|
216 |
|
require package https://github.com/photosynthesis-team/piq |
|
217 |
|
- `update` must receive output of the form `(y_pred, y)`. |
|
218 |
|
""" |
|
219 |
|
def reset(self): |
|
220 |
|
self._sum = 0.0 |
|
221 |
|
self._num_examples = 0 |
|
222 |
|
|
|
223 |
|
def update(self, output): |
|
224 |
|
y_pred = output[0] |
|
225 |
|
y_pred = torch.clamp_min(y_pred, min=0.0) |
|
226 |
|
y = output[1] |
|
227 |
|
y = torch.clamp_min(y, min=0.0) |
|
228 |
|
|
|
229 |
|
psnr_metric = piq.psnr(y, y_pred) |
|
230 |
|
|
|
231 |
|
self._sum += psnr_metric.item() * y.shape[0] |
|
232 |
|
# we multiply because ssim calculate mean of each image in batch |
|
233 |
|
# we multiply so we will divide correctly |
|
234 |
|
|
|
235 |
|
self._num_examples += y.shape[0] |
|
236 |
|
|
|
237 |
|
def compute(self): |
|
238 |
|
if self._num_examples == 0: |
|
239 |
|
raise NotComputableError('CrowdCountingMeanPSNR must have at least one example before it can be computed.') |
|
240 |
|
return self._sum / self._num_examples |
File experiment_main.py changed (mode: 100644) (index 5210431..10aac00) |
... |
... |
from data_flow import get_dataloader, create_image_list |
5 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
6 |
6 |
from ignite.metrics import Loss |
from ignite.metrics import Loss |
7 |
7 |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
8 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount, CrowdCountingMeanSSIM, CrowdCountingMeanPSNR |
|
|
8 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError,\ |
|
9 |
|
CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount,\ |
|
10 |
|
CrowdCountingMeanSSIMabs, CrowdCountingMeanPSNRabs, \ |
|
11 |
|
CrowdCountingMeanSSIMclamp, CrowdCountingMeanPSNRclamp |
|
12 |
|
|
9 |
13 |
from visualize_util import get_readable_time |
from visualize_util import get_readable_time |
10 |
14 |
from mse_l1_loss import MSEL1Loss, MSE4L1Loss |
from mse_l1_loss import MSEL1Loss, MSE4L1Loss |
11 |
15 |
import torch |
import torch |
|
... |
... |
if __name__ == "__main__": |
228 |
232 |
if args.eval_density: |
if args.eval_density: |
229 |
233 |
evaluator_test = create_supervised_evaluator(model, |
evaluator_test = create_supervised_evaluator(model, |
230 |
234 |
metrics={ |
metrics={ |
231 |
|
'ssim': CrowdCountingMeanSSIM(), |
|
232 |
|
'psnr': CrowdCountingMeanPSNR(), |
|
|
235 |
|
'ssimabs': CrowdCountingMeanSSIMabs(), |
|
236 |
|
'psnrabs': CrowdCountingMeanPSNRabs(), |
|
237 |
|
'ssimclamp': CrowdCountingMeanSSIMclamp(), |
|
238 |
|
'psnrclamp': CrowdCountingMeanPSNRclamp(), |
233 |
239 |
}, device=device) |
}, device=device) |
234 |
240 |
else: |
else: |
235 |
241 |
evaluator_test = create_supervised_evaluator(model, |
evaluator_test = create_supervised_evaluator(model, |
|
... |
... |
if __name__ == "__main__": |
371 |
377 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
372 |
378 |
|
|
373 |
379 |
if args.eval_density: |
if args.eval_density: |
374 |
|
print(timestamp + " Test set Results - Avg ssim: {:.2f} Avg psnr: {:.2f} Avg loss: {:.2f}" |
|
375 |
|
.format(test_metrics['ssim'], test_metrics['psnr'], 0)) |
|
376 |
|
experiment.log_metric("test_ssim", test_metrics['ssim']) |
|
377 |
|
experiment.log_metric("test_psnr", test_metrics['psnr']) |
|
|
380 |
|
print(timestamp + " Test set Results ABS - Avg ssim: {:.2f} Avg psnr: {:.2f} Avg loss: {:.2f}" |
|
381 |
|
.format(test_metrics['ssimabs'], test_metrics['psnrabs'], 0)) |
|
382 |
|
experiment.log_metric("test_ssim abs", test_metrics['ssimabs']) |
|
383 |
|
experiment.log_metric("test_psnr abs", test_metrics['psnrabs']) |
|
384 |
|
|
|
385 |
|
print(timestamp + " Test set Results CLAMP - Avg ssim: {:.2f} Avg psnr: {:.2f} Avg loss: {:.2f}" |
|
386 |
|
.format(test_metrics['ssimclamp'], test_metrics['psnrclamp'], 0)) |
|
387 |
|
experiment.log_metric("test_ssim clamp", test_metrics['ssimclamp']) |
|
388 |
|
experiment.log_metric("test_psnr clamp", test_metrics['psnrclamp']) |
378 |
389 |
else: |
else: |
379 |
390 |
print(timestamp + " Test set Results - Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Test set Results - Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
380 |
391 |
.format( test_metrics['mae'], test_metrics['mse'], 0)) |
.format( test_metrics['mae'], test_metrics['mse'], 0)) |