File debug/evaluate_shb.py changed (mode: 100644) (index 5a731e0..3539281) |
... |
... |
from data_util import ShanghaiTechDataPath |
27 |
27 |
import argparse |
import argparse |
28 |
28 |
import cv2 |
import cv2 |
29 |
29 |
import numpy as np |
import numpy as np |
|
30 |
|
import math |
30 |
31 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
31 |
32 |
""" |
""" |
32 |
33 |
This file evaluation on SHB and get information on evaluation process |
This file evaluation on SHB and get information on evaluation process |
|
... |
... |
def visualize_evaluation_shanghaitech_keepfull(model, args): |
56 |
57 |
debug=True) |
debug=True) |
57 |
58 |
|
|
58 |
59 |
log_f = open(args.meta_data, "w") |
log_f = open(args.meta_data, "w") |
|
60 |
|
mae_s = 0 |
|
61 |
|
mse_s = 0 |
|
62 |
|
n = 0 |
59 |
63 |
with torch.no_grad(): |
with torch.no_grad(): |
60 |
64 |
for item in test_loader: |
for item in test_loader: |
61 |
65 |
img, gt_density, debug_info = item |
img, gt_density, debug_info = item |
|
... |
... |
def visualize_evaluation_shanghaitech_keepfull(model, args): |
75 |
79 |
print("shape compare " + str(predicted_density_map.shape) + " " + str(predicted_density_map_enlarge.shape)) |
print("shape compare " + str(predicted_density_map.shape) + " " + str(predicted_density_map_enlarge.shape)) |
76 |
80 |
density_map_count = gt_density.detach().sum() |
density_map_count = gt_density.detach().sum() |
77 |
81 |
pred_count = pred.detach().cpu().sum() |
pred_count = pred.detach().cpu().sum() |
78 |
|
log_str = str(file_name) + " " + str(density_map_count.item()) + " " + str(gt_count.item()) + str(pred_count.item()) |
|
|
82 |
|
density_map_count_num = density_map_count.item() |
|
83 |
|
gt_count_num = gt_count.item() |
|
84 |
|
pred_count_num = pred_count.item() |
|
85 |
|
error = abs(pred_count_num-gt_count_num) |
|
86 |
|
mae_s += error |
|
87 |
|
mse_s += error*error |
|
88 |
|
log_str = str(file_name_only) + " " + str(density_map_count_num) + " " + str(gt_count.item()) + str(pred_count.item()) |
79 |
89 |
print(log_str) |
print(log_str) |
80 |
90 |
log_f.write(log_str+"\n") |
log_f.write(log_str+"\n") |
81 |
91 |
log_f.close() |
log_f.close() |
82 |
|
|
|
|
92 |
|
mae = mae_s / n |
|
93 |
|
mse = math.sqrt(mse_s / n) |
|
94 |
|
print("mae ", mae) |
|
95 |
|
print("mse", mse) |
83 |
96 |
|
|
84 |
97 |
|
|
85 |
98 |
if __name__ == "__main__": |
if __name__ == "__main__": |