File debug/perfomance_test_on_shb.py added (mode: 100644) (index 0000000..cc415fd) |
|
1 |
|
from comet_ml import Experiment |
|
2 |
|
|
|
3 |
|
from args_util import meow_parse, lr_scheduler_milestone_builder |
|
4 |
|
from data_flow import get_dataloader, create_image_list |
|
5 |
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
6 |
|
from ignite.metrics import Loss |
|
7 |
|
from ignite.handlers import Checkpoint, DiskSaver, Timer |
|
8 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount |
|
9 |
|
from visualize_util import get_readable_time |
|
10 |
|
from mse_l1_loss import MSEL1Loss, MSE4L1Loss |
|
11 |
|
import torch |
|
12 |
|
from torch import nn |
|
13 |
|
from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4 |
|
14 |
|
from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3, BigTail4, BigTail5, BigTail6, BigTail7, BigTail8, BigTail6i, BigTail9i |
|
15 |
|
from models.meow_experiment.ccnn_tail import BigTail11i, BigTail10i, BigTail12i, BigTail13i, BigTail14i, BigTail15i |
|
16 |
|
from models.meow_experiment.ccnn_head import H1, H2, H3, H3i, H4i |
|
17 |
|
from models.meow_experiment.kitten_meow_1 import H1_Bigtail3 |
|
18 |
|
from models import CustomCNNv2, CompactCNNV7 |
|
19 |
|
from models.compact_cnn import CompactCNNV8, CompactCNNV9, CompactCNNV7i |
|
20 |
|
import os |
|
21 |
|
from model_util import get_lr, BestMetrics |
|
22 |
|
from ignite.contrib.handlers import PiecewiseLinear |
|
23 |
|
import time |
|
24 |
|
|
|
25 |
|
|
|
26 |
|
|
|
27 |
|
|
|
28 |
|
|
|
29 |
|
|
|
30 |
|
|
|
31 |
|
""" |
|
32 |
|
Document on save load model |
|
33 |
|
https://pytorch.org/tutorials/beginner/saving_loading_models.html |
|
34 |
|
""" |
|
35 |
|
|
|
36 |
|
model_path = "/data/save_model/adamw1_bigtail13i_t1_shb/adamw1_bigtail13i_t1_shb_checkpoint_valid_mae=-7.574910521507263.pth" |
|
37 |
|
checkpoint = torch.load(model_path) |
|
38 |
|
|
|
39 |
|
model = BigTail13i() |
|
40 |
|
model.load_state_dict(checkpoint["model"]) |
|
41 |
|
print("done load") |
|
42 |
|
run_test_loader(model) |
|
43 |
|
|
|
44 |
|
if __name__ == "__main__": |
|
45 |
|
torch.set_num_threads(2) # 4 thread |
|
46 |
|
|
|
47 |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
48 |
|
print(device) |
|
49 |
|
args = meow_parse() |
|
50 |
|
print(args) |
|
51 |
|
|
|
52 |
|
DATA_PATH = args.input |
|
53 |
|
TRAIN_PATH = os.path.join(DATA_PATH, "train_data_train_split") |
|
54 |
|
VAL_PATH = os.path.join(DATA_PATH, "train_data_validate_split") |
|
55 |
|
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
|
56 |
|
dataset_name = args.datasetname |
|
57 |
|
if dataset_name=="shanghaitech": |
|
58 |
|
print("will use shanghaitech dataset with crop ") |
|
59 |
|
elif dataset_name == "shanghaitech_keepfull": |
|
60 |
|
print("will use shanghaitech_keepfull") |
|
61 |
|
else: |
|
62 |
|
print("cannot detect dataset_name") |
|
63 |
|
print("current dataset_name is ", dataset_name) |
|
64 |
|
|
|
65 |
|
# create list |
|
66 |
|
train_list = create_image_list(TRAIN_PATH) |
|
67 |
|
val_list = create_image_list(VAL_PATH) |
|
68 |
|
test_list = create_image_list(TEST_PATH) |
|
69 |
|
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name, batch_size=args.batch_size, |
|
70 |
|
train_loader_for_eval_check=True, |
|
71 |
|
cache=args.cache, |
|
72 |
|
pin_memory=args.pin_memory) |
|
73 |
|
|
|
74 |
|
print("len train_loader ", len(train_loader)) |
|
75 |
|
|
|
76 |
|
# model |
|
77 |
|
model_name = args.model |
|
78 |
|
|
|
79 |
|
if model_name == "M1": |
|
80 |
|
model = M1() |
|
81 |
|
elif model_name == "M2": |
|
82 |
|
model = M2() |
|
83 |
|
elif model_name == "M3": |
|
84 |
|
model = M3() |
|
85 |
|
elif model_name == "M4": |
|
86 |
|
model = M4() |
|
87 |
|
elif model_name == "CustomCNNv2": |
|
88 |
|
model = CustomCNNv2() |
|
89 |
|
elif model_name == "BigTailM1": |
|
90 |
|
model = BigTailM1() |
|
91 |
|
elif model_name == "BigTailM2": |
|
92 |
|
model = BigTailM2() |
|
93 |
|
elif model_name == "BigTail3": |
|
94 |
|
model = BigTail3() |
|
95 |
|
elif model_name == "BigTail4": |
|
96 |
|
model = BigTail4() |
|
97 |
|
elif model_name == "BigTail5": |
|
98 |
|
model = BigTail5() |
|
99 |
|
elif model_name == "BigTail6": |
|
100 |
|
model = BigTail6() |
|
101 |
|
elif model_name == "BigTail6i": |
|
102 |
|
model = BigTail6i() |
|
103 |
|
elif model_name == "BigTail9i": |
|
104 |
|
model = BigTail9i() |
|
105 |
|
elif model_name == "BigTail10i": |
|
106 |
|
model = BigTail10i() |
|
107 |
|
elif model_name == "BigTail11i": |
|
108 |
|
model = BigTail11i() |
|
109 |
|
elif model_name == "BigTail12i": |
|
110 |
|
model = BigTail12i() |
|
111 |
|
elif model_name == "BigTail13i": |
|
112 |
|
model = BigTail13i() |
|
113 |
|
elif model_name == "BigTail14i": |
|
114 |
|
model = BigTail14i() |
|
115 |
|
elif model_name == "BigTail15i": |
|
116 |
|
model = BigTail15i() |
|
117 |
|
elif model_name == "BigTail7": |
|
118 |
|
model = BigTail7() |
|
119 |
|
elif model_name == "BigTail8": |
|
120 |
|
model = BigTail8() |
|
121 |
|
elif model_name == "H1": |
|
122 |
|
model = H1() |
|
123 |
|
elif model_name == "H2": |
|
124 |
|
model = H2() |
|
125 |
|
elif model_name == "H3": |
|
126 |
|
model = H3() |
|
127 |
|
elif model_name == "H3i": |
|
128 |
|
model = H3i() |
|
129 |
|
elif model_name == "H4i": |
|
130 |
|
model = H4i() |
|
131 |
|
elif model_name == "H1_Bigtail3": |
|
132 |
|
model = H1_Bigtail3() |
|
133 |
|
elif model_name == "CompactCNNV7": |
|
134 |
|
model = CompactCNNV7() |
|
135 |
|
elif model_name == "CompactCNNV7i": |
|
136 |
|
model = CompactCNNV7i() |
|
137 |
|
elif model_name == "CompactCNNV8": |
|
138 |
|
model = CompactCNNV8() |
|
139 |
|
elif model_name == "CompactCNNV9": |
|
140 |
|
model = CompactCNNV9() |
|
141 |
|
else: |
|
142 |
|
print("error: you didn't pick a model") |
|
143 |
|
exit(-1) |
|
144 |
|
model = model.to(device) |
|
145 |
|
checkpoint = torch.load(model_path) |
|
146 |
|
model.load_state_dict(checkpoint["model"]) |
|
147 |
|
|
|
148 |
|
s1 = time.perf_counter() |
|
149 |
|
for img, label in test_loader: |
|
150 |
|
pred = model(img.cuda()) |
|
151 |
|
print("done") |
|
152 |
|
s2 = time.perf_counter() |
|
153 |
|
time1 = s1 - s2 |
|
154 |
|
print("test 1 time " + str(s1 - s2)) |
|
155 |
|
|
|
156 |
|
s3 = time.perf_counter() |
|
157 |
|
for img, label in test_loader: |
|
158 |
|
pred = model(img.cuda()) |
|
159 |
|
print("done") |
|
160 |
|
s4 = time.perf_counter() |
|
161 |
|
time1 = s3 - s4 |
|
162 |
|
print("test 1 time " + str(s3 - s4)) |
|
163 |
|
|