File debug/evaluate_shb.py changed (mode: 100644) (index 1e9801e..1d637ab) |
... |
... |
from data_flow import get_train_val_list, get_dataloader, create_training_image_ |
33 |
33 |
This file evaluation on SHB and get information on evaluation process |
This file evaluation on SHB and get information on evaluation process |
34 |
34 |
""" |
""" |
35 |
35 |
|
|
|
36 |
|
"/data/ShanghaiTech/part_A/test_data" |
36 |
37 |
|
|
37 |
38 |
def _parse(): |
def _parse(): |
38 |
39 |
parser = argparse.ArgumentParser(description='evaluatiuon SHB') |
parser = argparse.ArgumentParser(description='evaluatiuon SHB') |
39 |
40 |
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
40 |
41 |
parser.add_argument('--output', action="store", type=str, default="visualize/verify_dataloader_shanghaitech") |
parser.add_argument('--output', action="store", type=str, default="visualize/verify_dataloader_shanghaitech") |
41 |
|
parser.add_argument('--load_model', action="store", type=str, default="visualize/verify_dataloader_shanghaitech") |
|
|
42 |
|
parser.add_argument('--load_model', action="store", type=str, default=None) |
42 |
43 |
parser.add_argument('--model', action="store", type=str, default="visualize/verify_dataloader_shanghaitech") |
parser.add_argument('--model', action="store", type=str, default="visualize/verify_dataloader_shanghaitech") |
43 |
44 |
parser.add_argument('--meta_data', action="store", type=str, default="data_info.txt") |
parser.add_argument('--meta_data', action="store", type=str, default="data_info.txt") |
44 |
45 |
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull_r50") |
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull_r50") |
|
... |
... |
def visualize_evaluation_shanghaitech_keepfull(model, args): |
60 |
61 |
mae_s = 0 |
mae_s = 0 |
61 |
62 |
mse_s = 0 |
mse_s = 0 |
62 |
63 |
n = 0 |
n = 0 |
|
64 |
|
train_loader_iter = iter(train_loader) |
|
65 |
|
_, gt_density,_ = next(train_loader_iter) |
63 |
66 |
with torch.no_grad(): |
with torch.no_grad(): |
64 |
67 |
for item in test_loader: |
for item in test_loader: |
65 |
68 |
img, gt_density, debug_info = item |
img, gt_density, debug_info = item |
|
... |
... |
def visualize_evaluation_shanghaitech_keepfull(model, args): |
69 |
72 |
file_name_only = file_name[0].split(".")[0] |
file_name_only = file_name[0].split(".")[0] |
70 |
73 |
save_path = os.path.join(saved_folder, "label_" + file_name_only +".png") |
save_path = os.path.join(saved_folder, "label_" + file_name_only +".png") |
71 |
74 |
save_pred_path = os.path.join(saved_folder, "pred_" + file_name_only +".png") |
save_pred_path = os.path.join(saved_folder, "pred_" + file_name_only +".png") |
72 |
|
save_density_map(gt_density.numpy()[0][0], save_path) |
|
|
75 |
|
save_density_map(gt_density.numpy()[0], save_path) |
73 |
76 |
pred = model(img.cuda()) |
pred = model(img.cuda()) |
74 |
77 |
predicted_density_map = pred.detach().cpu().clone().numpy() |
predicted_density_map = pred.detach().cpu().clone().numpy() |
75 |
78 |
predicted_density_map_enlarge = cv2.resize(np.squeeze(predicted_density_map[0][0]), (int(predicted_density_map.shape[3] * 8), int(predicted_density_map.shape[2] * 8)), interpolation=cv2.INTER_CUBIC) / 64 |
predicted_density_map_enlarge = cv2.resize(np.squeeze(predicted_density_map[0][0]), (int(predicted_density_map.shape[3] * 8), int(predicted_density_map.shape[2] * 8)), interpolation=cv2.INTER_CUBIC) / 64 |
|
... |
... |
if __name__ == "__main__": |
104 |
107 |
args = _parse() |
args = _parse() |
105 |
108 |
print(args) |
print(args) |
106 |
109 |
|
|
107 |
|
DATA_PATH = args.input |
|
108 |
|
TRAIN_PATH = os.path.join(DATA_PATH, "train_data_train_split") |
|
109 |
|
VAL_PATH = os.path.join(DATA_PATH, "train_data_validate_split") |
|
110 |
|
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
|
111 |
|
dataset_name = args.datasetname |
|
112 |
|
if dataset_name == "shanghaitech": |
|
113 |
|
print("will use shanghaitech dataset with crop ") |
|
114 |
|
elif dataset_name == "shanghaitech_keepfull": |
|
115 |
|
print("will use shanghaitech_keepfull") |
|
116 |
|
else: |
|
117 |
|
print("cannot detect dataset_name") |
|
118 |
|
print("current dataset_name is ", dataset_name) |
|
|
110 |
|
|
119 |
111 |
|
|
120 |
112 |
# # create list |
# # create list |
121 |
113 |
# train_list = create_image_list(TRAIN_PATH) |
# train_list = create_image_list(TRAIN_PATH) |
|
... |
... |
if __name__ == "__main__": |
200 |
192 |
print("error: you didn't pick a model") |
print("error: you didn't pick a model") |
201 |
193 |
exit(-1) |
exit(-1) |
202 |
194 |
model = model.to(device) |
model = model.to(device) |
203 |
|
checkpoint = torch.load(args.load_model) |
|
204 |
|
model.load_state_dict(checkpoint["model"]) |
|
|
195 |
|
if args.load_model is not None: |
|
196 |
|
checkpoint = torch.load(args.load_model) |
|
197 |
|
model.load_state_dict(checkpoint["model"]) |
205 |
198 |
model.eval() |
model.eval() |
206 |
199 |
visualize_evaluation_shanghaitech_keepfull(model, args) |
visualize_evaluation_shanghaitech_keepfull(model, args) |
207 |
200 |
|
|