File main_pacnn.py changed (mode: 100644) (index a21c0f5..1078b73) |
... |
... |
from evaluator import MAECalculator |
18 |
18 |
|
|
19 |
19 |
from model_util import save_checkpoint |
from model_util import save_checkpoint |
20 |
20 |
|
|
|
21 |
|
import apex |
|
22 |
|
from apex import amp |
|
23 |
|
|
21 |
24 |
if __name__ == "__main__": |
if __name__ == "__main__": |
22 |
25 |
# import comet_ml in the top of your file |
# import comet_ml in the top of your file |
23 |
26 |
|
|
24 |
27 |
|
|
25 |
|
MODEL_SAVE_NAME = "dev4" |
|
|
28 |
|
MODEL_SAVE_NAME = "dev5" |
26 |
29 |
# Add the following code anywhere in your machine learning file |
# Add the following code anywhere in your machine learning file |
27 |
30 |
experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM", |
experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM", |
28 |
31 |
project_name="pacnn-dev", workspace="ttpro1995") |
project_name="pacnn-dev", workspace="ttpro1995") |
|
... |
... |
if __name__ == "__main__": |
84 |
87 |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
85 |
88 |
momentum=args.momentum, |
momentum=args.momentum, |
86 |
89 |
weight_decay=args.decay) |
weight_decay=args.decay) |
|
90 |
|
# Allow Amp to perform casts as required by the opt_level |
|
91 |
|
net, optimizer = amp.initialize(net, optimizer, opt_level="O1", enabled=False) |
|
92 |
|
|
87 |
93 |
for e in range(10): |
for e in range(10): |
88 |
94 |
print("start epoch ", e) |
print("start epoch ", e) |
89 |
95 |
loss_sum = 0 |
loss_sum = 0 |
|
... |
... |
if __name__ == "__main__": |
113 |
119 |
pass |
pass |
114 |
120 |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
115 |
121 |
loss += loss_d |
loss += loss_d |
116 |
|
loss.backward() |
|
|
122 |
|
# loss.backward() |
|
123 |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
124 |
|
scaled_loss.backward() |
117 |
125 |
optimizer.step() |
optimizer.step() |
|
126 |
|
optimizer.zero_grad() |
118 |
127 |
loss_sum += loss.item() |
loss_sum += loss.item() |
119 |
128 |
sample += 1 |
sample += 1 |
120 |
|
optimizer.zero_grad() |
|
121 |
129 |
counting += 1 |
counting += 1 |
122 |
130 |
|
|
123 |
131 |
if counting%10 ==0: |
if counting%10 ==0: |
|
... |
... |
if __name__ == "__main__": |
135 |
143 |
print("=================================================================") |
print("=================================================================") |
136 |
144 |
|
|
137 |
145 |
save_checkpoint({ |
save_checkpoint({ |
138 |
|
'state_dict': net.state_dict(), |
|
|
146 |
|
'model': net.state_dict(), |
|
147 |
|
'optimizer': optimizer.state_dict(), |
|
148 |
|
# 'amp': amp.state_dict() |
139 |
149 |
}, False, MODEL_SAVE_NAME) |
}, False, MODEL_SAVE_NAME) |
140 |
150 |
|
|
141 |
151 |
|
|
142 |
|
|
|
143 |
|
# evaluate |
|
144 |
|
|
|
145 |
|
|
|
|
152 |
|
# after epoch evaluate |
|
153 |
|
mae_calculator_d1 = MAECalculator() |
|
154 |
|
mae_calculator_d2 = MAECalculator() |
|
155 |
|
mae_calculator_d3 = MAECalculator() |
|
156 |
|
mae_calculator_final = MAECalculator() |
|
157 |
|
with torch.no_grad(): |
|
158 |
|
for val_img, label in val_loader_pacnn: |
|
159 |
|
net.eval() |
|
160 |
|
# load data |
|
161 |
|
d1_label, d2_label, d3_label = label |
|
162 |
|
|
|
163 |
|
# forward pass |
|
164 |
|
d1, d2, d3, p_s, p, d = net(val_img.to(device)) |
|
165 |
|
|
|
166 |
|
d1_label = d1_label.to(device) |
|
167 |
|
d2_label = d2_label.to(device) |
|
168 |
|
d3_label = d3_label.to(device) |
|
169 |
|
|
|
170 |
|
# score |
|
171 |
|
mae_calculator_d1.eval(d1.cpu().detach().numpy(), d1_label.cpu().detach().numpy()) |
|
172 |
|
mae_calculator_d2.eval(d2.cpu().detach().numpy(), d2_label.cpu().detach().numpy()) |
|
173 |
|
mae_calculator_d3.eval(d3.cpu().detach().numpy(), d3_label.cpu().detach().numpy()) |
|
174 |
|
mae_calculator_final.eval(d.cpu().detach().numpy(), d1_label.cpu().detach().numpy()) |
|
175 |
|
print("count ", mae_calculator_d1.count) |
|
176 |
|
print("d1_val ", mae_calculator_d1.get_mae()) |
|
177 |
|
print("d2_val ", mae_calculator_d2.get_mae()) |
|
178 |
|
print("d3_val ", mae_calculator_d3.get_mae()) |
|
179 |
|
print("dfinal_val ", mae_calculator_final.get_mae()) |
|
180 |
|
experiment.log_metric("d1_val", mae_calculator_d1.get_mae()) |
|
181 |
|
experiment.log_metric("d2_val", mae_calculator_d2.get_mae()) |
|
182 |
|
experiment.log_metric("d3_val", mae_calculator_d3.get_mae()) |
|
183 |
|
experiment.log_metric("dfinal_val", mae_calculator_final.get_mae()) |
|
184 |
|
|
|
185 |
|
|
|
186 |
|
############################################# |
|
187 |
|
# done training evaluate |
146 |
188 |
net = PACNNWithPerspectiveMap(PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
net = PACNNWithPerspectiveMap(PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
147 |
189 |
print(net) |
print(net) |
148 |
190 |
|
|
149 |
191 |
best_checkpoint = torch.load(MODEL_SAVE_NAME + "checkpoint.pth.tar") |
best_checkpoint = torch.load(MODEL_SAVE_NAME + "checkpoint.pth.tar") |
150 |
|
net.load_state_dict(best_checkpoint['state_dict']) |
|
|
192 |
|
net.load_state_dict(best_checkpoint['model']) |
151 |
193 |
|
|
152 |
194 |
# device = "cpu" |
# device = "cpu" |
153 |
195 |
# TODO d1_val 155.97279205322266 |
# TODO d1_val 155.97279205322266 |