File main_pacnn.py changed (mode: 100644) (index 24d3d1c..2abdbad) |
... |
... |
from data_flow import ListDataset |
13 |
13 |
import pytorch_ssim |
import pytorch_ssim |
14 |
14 |
from time import time |
from time import time |
15 |
15 |
from evaluator import MAECalculator |
from evaluator import MAECalculator |
16 |
|
|
|
|
16 |
|
import torch.backends.cudnn as cudnn |
17 |
17 |
from model_util import save_checkpoint |
from model_util import save_checkpoint |
18 |
18 |
|
|
19 |
|
# import apex |
|
20 |
|
# from apex import amp |
|
|
19 |
|
import apex |
|
20 |
|
from apex import amp |
21 |
21 |
|
|
22 |
22 |
if __name__ == "__main__": |
if __name__ == "__main__": |
23 |
23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
... |
... |
if __name__ == "__main__": |
38 |
38 |
TOTAL_EPOCH = args.epochs |
TOTAL_EPOCH = args.epochs |
39 |
39 |
PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL |
PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL |
40 |
40 |
|
|
|
41 |
|
APEX_AMP = True |
|
42 |
|
if APEX_AMP: |
|
43 |
|
cudnn.benchmark = True |
|
44 |
|
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." |
|
45 |
|
|
41 |
46 |
experiment.set_name(args.task_id) |
experiment.set_name(args.task_id) |
42 |
47 |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
43 |
48 |
experiment.log_parameter("PACNN_PERSPECTIVE_AWARE_MODEL", PACNN_PERSPECTIVE_AWARE_MODEL) |
experiment.log_parameter("PACNN_PERSPECTIVE_AWARE_MODEL", PACNN_PERSPECTIVE_AWARE_MODEL) |
|
... |
... |
if __name__ == "__main__": |
89 |
94 |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
90 |
95 |
momentum=args.momentum, |
momentum=args.momentum, |
91 |
96 |
weight_decay=args.decay) |
weight_decay=args.decay) |
92 |
|
# Allow Amp to perform casts as required by the opt_level |
|
93 |
|
# net, optimizer = amp.initialize(net, optimizer, opt_level="O1", enabled=False) |
|
|
97 |
|
|
|
98 |
|
if APEX_AMP: |
|
99 |
|
# Allow Amp to perform casts as required by the opt_level |
|
100 |
|
net, optimizer = amp.initialize(net, optimizer, opt_level="O3", enabled=APEX_AMP, |
|
101 |
|
keep_batchnorm_fp32=True) |
|
102 |
|
# loss_scale="dynamic" |
|
103 |
|
|
|
104 |
|
|
94 |
105 |
|
|
95 |
106 |
current_save_model_name = "" |
current_save_model_name = "" |
96 |
107 |
current_epoch = 0 |
current_epoch = 0 |
|
... |
... |
if __name__ == "__main__": |
102 |
113 |
net.load_state_dict(checkpoint['model']) |
net.load_state_dict(checkpoint['model']) |
103 |
114 |
optimizer.load_state_dict(checkpoint['optimizer']) |
optimizer.load_state_dict(checkpoint['optimizer']) |
104 |
115 |
current_epoch = checkpoint['e'] |
current_epoch = checkpoint['e'] |
|
116 |
|
if APEX_AMP: |
|
117 |
|
amp.load_state_dict(checkpoint['amp']) |
105 |
118 |
print("load ", load_model, " epoch ", str(current_epoch)) |
print("load ", load_model, " epoch ", str(current_epoch)) |
106 |
119 |
else: |
else: |
107 |
120 |
print("new model") |
print("new model") |
|
... |
... |
if __name__ == "__main__": |
139 |
152 |
pass |
pass |
140 |
153 |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
141 |
154 |
loss += loss_d |
loss += loss_d |
142 |
|
loss.backward() |
|
143 |
|
# with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
144 |
|
# scaled_loss.backward() |
|
145 |
|
optimizer.step() |
|
146 |
|
optimizer.zero_grad() |
|
|
155 |
|
|
|
156 |
|
optimizer.zero_grad() # make optimizer grad = 0 |
|
157 |
|
|
|
158 |
|
if APEX_AMP: |
|
159 |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
160 |
|
scaled_loss.backward() |
|
161 |
|
else: |
|
162 |
|
loss.backward() # calculate grad for optimizer |
|
163 |
|
|
|
164 |
|
optimizer.step() # optimize param |
|
165 |
|
|
147 |
166 |
loss_sum += loss.item() |
loss_sum += loss.item() |
148 |
167 |
sample += 1 |
sample += 1 |
149 |
168 |
counting += 1 |
counting += 1 |
|
... |
... |
if __name__ == "__main__": |
166 |
185 |
print("=================================================================") |
print("=================================================================") |
167 |
186 |
|
|
168 |
187 |
if current_epoch % MODEL_SAVE_INTERVAL == 0: |
if current_epoch % MODEL_SAVE_INTERVAL == 0: |
|
188 |
|
amp_state_dict = None |
|
189 |
|
if APEX_AMP: |
|
190 |
|
amp_state_dict = amp.state_dict() |
|
191 |
|
|
169 |
192 |
current_save_model_name = save_checkpoint({ |
current_save_model_name = save_checkpoint({ |
170 |
193 |
'model': net.state_dict(), |
'model': net.state_dict(), |
171 |
194 |
'optimizer': optimizer.state_dict(), |
'optimizer': optimizer.state_dict(), |
172 |
195 |
'e': current_epoch, |
'e': current_epoch, |
173 |
|
'PACNN_PERSPECTIVE_AWARE_MODEL': PACNN_PERSPECTIVE_AWARE_MODEL |
|
174 |
|
# 'amp': amp.state_dict() |
|
|
196 |
|
'PACNN_PERSPECTIVE_AWARE_MODEL': PACNN_PERSPECTIVE_AWARE_MODEL, |
|
197 |
|
'amp': amp_state_dict # amp.state_dict() |
175 |
198 |
}, False, MODEL_SAVE_NAME+"_"+str(current_epoch)+"_") |
}, False, MODEL_SAVE_NAME+"_"+str(current_epoch)+"_") |
176 |
199 |
experiment.log_asset(current_save_model_name) |
experiment.log_asset(current_save_model_name) |
177 |
200 |
print("saved ", current_save_model_name) |
print("saved ", current_save_model_name) |