File main_pacnn.py changed (mode: 100644) (index 9024f16..a21c0f5) |
|
1 |
|
from comet_ml import Experiment |
1 |
2 |
from args_util import real_args_parse |
from args_util import real_args_parse |
2 |
3 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
3 |
4 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
... |
... |
from evaluator import MAECalculator |
18 |
19 |
from model_util import save_checkpoint |
from model_util import save_checkpoint |
19 |
20 |
|
|
20 |
21 |
if __name__ == "__main__": |
if __name__ == "__main__": |
|
22 |
|
# import comet_ml in the top of your file |
|
23 |
|
|
|
24 |
|
|
|
25 |
|
MODEL_SAVE_NAME = "dev4" |
|
26 |
|
# Add the following code anywhere in your machine learning file |
|
27 |
|
experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM", |
|
28 |
|
project_name="pacnn-dev", workspace="ttpro1995") |
|
29 |
|
experiment.set_name(MODEL_SAVE_NAME) |
|
30 |
|
|
21 |
31 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
22 |
32 |
|
|
23 |
33 |
# device = "cpu" |
# device = "cpu" |
|
... |
... |
if __name__ == "__main__": |
28 |
38 |
DATASET_NAME = "shanghaitech" |
DATASET_NAME = "shanghaitech" |
29 |
39 |
PACNN_PERSPECTIVE_AWARE_MODEL = False |
PACNN_PERSPECTIVE_AWARE_MODEL = False |
30 |
40 |
|
|
|
41 |
|
|
|
42 |
|
|
31 |
43 |
# create list |
# create list |
32 |
44 |
if DATASET_NAME is "shanghaitech": |
if DATASET_NAME is "shanghaitech": |
33 |
45 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
|
... |
... |
if __name__ == "__main__": |
72 |
84 |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
73 |
85 |
momentum=args.momentum, |
momentum=args.momentum, |
74 |
86 |
weight_decay=args.decay) |
weight_decay=args.decay) |
75 |
|
for e in range(1): |
|
|
87 |
|
for e in range(10): |
76 |
88 |
print("start epoch ", e) |
print("start epoch ", e) |
77 |
89 |
loss_sum = 0 |
loss_sum = 0 |
78 |
90 |
sample = 0 |
sample = 0 |
|
... |
... |
if __name__ == "__main__": |
107 |
119 |
sample += 1 |
sample += 1 |
108 |
120 |
optimizer.zero_grad() |
optimizer.zero_grad() |
109 |
121 |
counting += 1 |
counting += 1 |
|
122 |
|
|
110 |
123 |
if counting%10 ==0: |
if counting%10 ==0: |
111 |
|
print("counting ", counting, " -- avg loss", loss_sum/sample) |
|
|
124 |
|
avg_loss_ministep = loss_sum/sample |
|
125 |
|
print("counting ", counting, " -- avg loss ", avg_loss_ministep) |
|
126 |
|
experiment.log_metric("avg_loss_ministep", avg_loss_ministep) |
112 |
127 |
# if counting == 100: |
# if counting == 100: |
113 |
128 |
# break |
# break |
114 |
|
|
|
115 |
129 |
end_time = time() |
end_time = time() |
116 |
130 |
avg_loss = loss_sum/sample |
avg_loss = loss_sum/sample |
117 |
131 |
epoch_time = end_time - start_time |
epoch_time = end_time - start_time |
|
132 |
|
print("==END epoch ", e, " =============================================") |
118 |
133 |
print(epoch_time, avg_loss, sample) |
print(epoch_time, avg_loss, sample) |
119 |
|
|
|
|
134 |
|
experiment.log_metric("avg_loss_epoch", avg_loss) |
|
135 |
|
print("=================================================================") |
120 |
136 |
|
|
121 |
137 |
save_checkpoint({ |
save_checkpoint({ |
122 |
138 |
'state_dict': net.state_dict(), |
'state_dict': net.state_dict(), |
123 |
|
}, False, "test2") |
|
|
139 |
|
}, False, MODEL_SAVE_NAME) |
124 |
140 |
|
|
125 |
141 |
|
|
126 |
142 |
|
|
|
... |
... |
if __name__ == "__main__": |
130 |
146 |
net = PACNNWithPerspectiveMap(PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
net = PACNNWithPerspectiveMap(PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
131 |
147 |
print(net) |
print(net) |
132 |
148 |
|
|
133 |
|
# best_checkpoint = torch.load("test2checkpoint.pth.tar") |
|
134 |
|
# net.load_state_dict(best_checkpoint['state_dict']) |
|
|
149 |
|
best_checkpoint = torch.load(MODEL_SAVE_NAME + "checkpoint.pth.tar") |
|
150 |
|
net.load_state_dict(best_checkpoint['state_dict']) |
135 |
151 |
|
|
136 |
152 |
# device = "cpu" |
# device = "cpu" |
|
153 |
|
# TODO d1_val 155.97279205322266 |
|
154 |
|
# d2_val 35.46327234903971 |
|
155 |
|
# d3_val 23.07176342010498 |
|
156 |
|
# why d2 and d3 mse too low |
137 |
157 |
mae_calculator_d1 = MAECalculator() |
mae_calculator_d1 = MAECalculator() |
138 |
158 |
mae_calculator_d2 = MAECalculator() |
mae_calculator_d2 = MAECalculator() |
139 |
159 |
mae_calculator_d3 = MAECalculator() |
mae_calculator_d3 = MAECalculator() |
|
160 |
|
mae_calculator_final = MAECalculator() |
140 |
161 |
with torch.no_grad(): |
with torch.no_grad(): |
141 |
162 |
for val_img, label in val_loader_pacnn: |
for val_img, label in val_loader_pacnn: |
142 |
163 |
net.eval() |
net.eval() |
|
... |
... |
if __name__ == "__main__": |
144 |
165 |
d1_label, d2_label, d3_label = label |
d1_label, d2_label, d3_label = label |
145 |
166 |
|
|
146 |
167 |
# forward pass |
# forward pass |
147 |
|
d1, d2, d3 = net(val_img.to(device)) |
|
|
168 |
|
d1, d2, d3, p_s, p, d = net(val_img.to(device)) |
148 |
169 |
|
|
149 |
170 |
d1_label = d1_label.to(device) |
d1_label = d1_label.to(device) |
150 |
171 |
d2_label = d2_label.to(device) |
d2_label = d2_label.to(device) |
|
... |
... |
if __name__ == "__main__": |
154 |
175 |
mae_calculator_d1.eval(d1.cpu().detach().numpy(), d1_label.cpu().detach().numpy()) |
mae_calculator_d1.eval(d1.cpu().detach().numpy(), d1_label.cpu().detach().numpy()) |
155 |
176 |
mae_calculator_d2.eval(d2.cpu().detach().numpy(), d2_label.cpu().detach().numpy()) |
mae_calculator_d2.eval(d2.cpu().detach().numpy(), d2_label.cpu().detach().numpy()) |
156 |
177 |
mae_calculator_d3.eval(d3.cpu().detach().numpy(), d3_label.cpu().detach().numpy()) |
mae_calculator_d3.eval(d3.cpu().detach().numpy(), d3_label.cpu().detach().numpy()) |
|
178 |
|
mae_calculator_final.eval(d.cpu().detach().numpy(), d1_label.cpu().detach().numpy()) |
157 |
179 |
print("count ", mae_calculator_d1.count) |
print("count ", mae_calculator_d1.count) |
158 |
180 |
print("d1_val ", mae_calculator_d1.get_mae()) |
print("d1_val ", mae_calculator_d1.get_mae()) |
159 |
181 |
print("d2_val ", mae_calculator_d2.get_mae()) |
print("d2_val ", mae_calculator_d2.get_mae()) |
160 |
182 |
print("d3_val ", mae_calculator_d3.get_mae()) |
print("d3_val ", mae_calculator_d3.get_mae()) |
|
183 |
|
print("dfinal_val ", mae_calculator_final.get_mae()) |
|
184 |
|
experiment.log_metric("d1_val", mae_calculator_d1.get_mae()) |
|
185 |
|
experiment.log_metric("d2_val", mae_calculator_d2.get_mae()) |
|
186 |
|
experiment.log_metric("d3_val", mae_calculator_d3.get_mae()) |
|
187 |
|
experiment.log_metric("dfinal_val", mae_calculator_final.get_mae()) |
161 |
188 |
|
|
162 |
189 |
|
|