File eval_context_aware_network.py changed (mode: 100644) (index 6c8d057..9b254c3) |
... |
... |
from torchvision import transforms |
10 |
10 |
from models.context_aware_network import CANNet |
from models.context_aware_network import CANNet |
11 |
11 |
from data_util import ShanghaiTechDataPath |
from data_util import ShanghaiTechDataPath |
12 |
12 |
from hard_code_variable import HardCodeVariable |
from hard_code_variable import HardCodeVariable |
13 |
|
from visualize_util import save_img, save_density_map |
|
|
13 |
|
from visualize_util import save_img, save_density_map, save_density_map_with_colorrange |
14 |
14 |
|
|
15 |
15 |
_description=""" |
_description=""" |
16 |
16 |
This file run predict |
This file run predict |
|
... |
... |
for i in range(len(img_paths)): |
92 |
92 |
pred.append(pred_sum) |
pred.append(pred_sum) |
93 |
93 |
gt.append(np.sum(groundtruth)) |
gt.append(np.sum(groundtruth)) |
94 |
94 |
print("done ", i, "pred ",pred_sum, " gt ", np.sum(groundtruth)) |
print("done ", i, "pred ",pred_sum, " gt ", np.sum(groundtruth)) |
|
95 |
|
|
|
96 |
|
max_people_per_pix = 0 |
|
97 |
|
if density_1.max() > max_people_per_pix: |
|
98 |
|
max_people_per_pix = density_1.max() |
|
99 |
|
if density_2.max() > max_people_per_pix: |
|
100 |
|
max_people_per_pix = density_2.max() |
|
101 |
|
if density_3.max() > max_people_per_pix: |
|
102 |
|
max_people_per_pix = density_3.max() |
|
103 |
|
if density_4.max() > max_people_per_pix: |
|
104 |
|
max_people_per_pix = density_4.max() |
|
105 |
|
|
95 |
106 |
## print out visual |
## print out visual |
96 |
|
name_prefix = os.path.join(saved_folder, "sample_"+str(i)) |
|
97 |
|
save_img(img_original_1, name_prefix+"_img1.png") |
|
98 |
|
save_img(img_original_2, name_prefix + "_img2.png") |
|
99 |
|
save_img(img_original_3, name_prefix + "_img3.png") |
|
100 |
|
save_img(img_original_4, name_prefix + "_img4.png") |
|
101 |
|
|
|
102 |
|
save_density_map(density_1.squeeze(), name_prefix + "_pred1.png") |
|
103 |
|
save_density_map(density_2.squeeze(), name_prefix + "_pred2.png") |
|
104 |
|
save_density_map(density_3.squeeze(), name_prefix + "_pred3.png") |
|
105 |
|
save_density_map(density_4.squeeze(), name_prefix + "_pred4.png") |
|
|
107 |
|
if IS_VISUAL: |
|
108 |
|
name_prefix = os.path.join(saved_folder, "sample_"+str(i)) |
|
109 |
|
save_img(img_original_1, name_prefix+"_img1.png") |
|
110 |
|
save_img(img_original_2, name_prefix + "_img2.png") |
|
111 |
|
save_img(img_original_3, name_prefix + "_img3.png") |
|
112 |
|
save_img(img_original_4, name_prefix + "_img4.png") |
|
113 |
|
|
|
114 |
|
save_density_map_with_colorrange(density_1.squeeze(), name_prefix + "_pred1.png", 0, 0.18) |
|
115 |
|
save_density_map_with_colorrange(density_2.squeeze(), name_prefix + "_pred2.png", 0, 0.18) |
|
116 |
|
save_density_map_with_colorrange(density_3.squeeze(), name_prefix + "_pred3.png", 0, 0.18) |
|
117 |
|
save_density_map_with_colorrange(density_4.squeeze(), name_prefix + "_pred4.png", 0, 0.18) |
106 |
118 |
## |
## |
107 |
119 |
|
|
108 |
120 |
print(len(pred)) |
print(len(pred)) |
|
... |
... |
rmse = np.sqrt(mean_squared_error(pred,gt)) |
112 |
124 |
|
|
113 |
125 |
print('MAE: ',mae) |
print('MAE: ',mae) |
114 |
126 |
print('RMSE: ',rmse) |
print('RMSE: ',rmse) |
|
127 |
|
print("max people per pix ", max_people_per_pix) |
File visualize_util.py changed (mode: 100644) (index 4064de3..0190315) |
... |
... |
def save_density_map(density_map, name): |
14 |
14 |
plt.margins(0, 0) |
plt.margins(0, 0) |
15 |
15 |
plt.imshow(density_map, cmap=CM.jet) |
plt.imshow(density_map, cmap=CM.jet) |
16 |
16 |
plt.savefig(name, dpi=600, bbox_inches='tight', pad_inches=0) |
plt.savefig(name, dpi=600, bbox_inches='tight', pad_inches=0) |
|
17 |
|
plt.close() |
17 |
18 |
|
|
|
19 |
|
def save_density_map_with_colorrange(density_map, name, vmin, vmax): |
|
20 |
|
plt.figure(dpi=600) |
|
21 |
|
plt.axis('off') |
|
22 |
|
plt.margins(0, 0) |
|
23 |
|
plt.imshow(density_map, cmap=CM.jet) |
|
24 |
|
plt.clim(vmin, vmax) |
|
25 |
|
plt.savefig(name, dpi=600, bbox_inches='tight', pad_inches=0) |
|
26 |
|
plt.close() |
18 |
27 |
|
|
19 |
28 |
def save_img(imgnp, name): |
def save_img(imgnp, name): |
20 |
29 |
# plt.imshow(imgnp[0].permute(1, 2, 0).numpy()) |
# plt.imshow(imgnp[0].permute(1, 2, 0).numpy()) |