File debug/verify_model_shb_best.py added (mode: 100644) (index 0000000..120e444) |
|
1 |
|
import torch |
|
2 |
|
from models.meow_experiment.ccnn_tail import BigTail11i, BigTail10i, BigTail12i, BigTail13i, BigTail14i, BigTail15i |
|
3 |
|
from hard_code_variable import HardCodeVariable |
|
4 |
|
from data_util import ShanghaiTechDataPath |
|
5 |
|
from visualize_util import save_img, save_density_map |
|
6 |
|
import os |
|
7 |
|
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
|
8 |
|
|
|
9 |
|
def visualize_evaluation_shanghaitech_keepfull(model): |
|
10 |
|
model = model.cuda() |
|
11 |
|
model.eval() |
|
12 |
|
HARD_CODE = HardCodeVariable() |
|
13 |
|
shanghaitech_data = ShanghaiTechDataPath(root=HARD_CODE.SHANGHAITECH_PATH) |
|
14 |
|
shanghaitech_data_part_a_train = shanghaitech_data.get_a().get_train().get() |
|
15 |
|
saved_folder = "visualize/evaluation_dataloader_shanghaitech" |
|
16 |
|
os.makedirs(saved_folder, exist_ok=True) |
|
17 |
|
train_list, val_list = get_train_val_list(shanghaitech_data_part_a_train, test_size=0.2) |
|
18 |
|
test_list = None |
|
19 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull", visualize_mode=False, |
|
20 |
|
debug=True) |
|
21 |
|
|
|
22 |
|
# do with train loader |
|
23 |
|
train_loader_iter = iter(train_loader) |
|
24 |
|
for i in range(10): |
|
25 |
|
img, label, count = next(train_loader_iter) |
|
26 |
|
# save_img(img, os.path.join(saved_folder, "train_img_" + str(i) +".png")) |
|
27 |
|
save_path = os.path.join(saved_folder, "train_label_" + str(i) +".png") |
|
28 |
|
save_pred_path = os.path.join(saved_folder, "train_pred_" + str(i) +".png") |
|
29 |
|
save_density_map(label.numpy()[0][0], save_path) |
|
30 |
|
pred = model(img.cuda()) |
|
31 |
|
predicted_density_map = pred.detach().cpu().clone().numpy() |
|
32 |
|
save_density_map(predicted_density_map[0][0], save_pred_path) |
|
33 |
|
print("pred " + save_pred_path + " value " + str(predicted_density_map.sum())) |
|
34 |
|
|
|
35 |
|
|
|
36 |
|
""" |
|
37 |
|
Document on save load model |
|
38 |
|
https://pytorch.org/tutorials/beginner/saving_loading_models.html |
|
39 |
|
""" |
|
40 |
|
|
|
41 |
|
model_path = "/data/save_model/adamw1_bigtail13i_t1_shb/adamw1_bigtail13i_t1_shb_checkpoint_valid_mae=-7.574910521507263.pth" |
|
42 |
|
checkpoint = torch.load(model_path) |
|
43 |
|
|
|
44 |
|
model = BigTail13i() |
|
45 |
|
model.load_state_dict(checkpoint["model"]) |
|
46 |
|
print("done load") |
|
47 |
|
visualize_evaluation_shanghaitech_keepfull(model) |
|
48 |
|
|