File main_pacnn.py changed (mode: 100644) (index 43b6c3d..b42c5a0) |
1 |
1 |
from comet_ml import Experiment |
from comet_ml import Experiment |
2 |
2 |
from args_util import real_args_parse |
from args_util import real_args_parse |
3 |
3 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
4 |
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
5 |
|
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError |
|
6 |
4 |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
7 |
5 |
import torch |
import torch |
8 |
6 |
from torch import nn |
from torch import nn |
|
... |
... |
from evaluator import MAECalculator |
18 |
16 |
|
|
19 |
17 |
from model_util import save_checkpoint |
from model_util import save_checkpoint |
20 |
18 |
|
|
21 |
|
import apex |
|
22 |
|
from apex import amp |
|
|
19 |
|
# import apex |
|
20 |
|
# from apex import amp |
23 |
21 |
|
|
24 |
22 |
if __name__ == "__main__": |
if __name__ == "__main__": |
25 |
23 |
# import comet_ml in the top of your file |
# import comet_ml in the top of your file |
26 |
24 |
|
|
27 |
25 |
|
|
28 |
|
MODEL_SAVE_NAME = "dev5" |
|
|
26 |
|
MODEL_SAVE_NAME = "dev7" |
29 |
27 |
# Add the following code anywhere in your machine learning file |
# Add the following code anywhere in your machine learning file |
30 |
28 |
experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM", |
experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM", |
31 |
29 |
project_name="pacnn-dev2", workspace="ttpro1995") |
project_name="pacnn-dev2", workspace="ttpro1995") |
|
... |
... |
if __name__ == "__main__": |
40 |
38 |
print(args) |
print(args) |
41 |
39 |
DATA_PATH = args.input |
DATA_PATH = args.input |
42 |
40 |
DATASET_NAME = "shanghaitech" |
DATASET_NAME = "shanghaitech" |
|
41 |
|
TOTAL_EPOCH = args.epochs |
43 |
42 |
PACNN_PERSPECTIVE_AWARE_MODEL = True |
PACNN_PERSPECTIVE_AWARE_MODEL = True |
44 |
43 |
|
|
45 |
44 |
|
|
|
... |
... |
if __name__ == "__main__": |
89 |
88 |
momentum=args.momentum, |
momentum=args.momentum, |
90 |
89 |
weight_decay=args.decay) |
weight_decay=args.decay) |
91 |
90 |
# Allow Amp to perform casts as required by the opt_level |
# Allow Amp to perform casts as required by the opt_level |
92 |
|
net, optimizer = amp.initialize(net, optimizer, opt_level="O1", enabled=False) |
|
93 |
|
|
|
94 |
|
for e in range(10): |
|
95 |
|
print("start epoch ", e) |
|
|
91 |
|
# net, optimizer = amp.initialize(net, optimizer, opt_level="O1", enabled=False) |
|
92 |
|
|
|
93 |
|
current_save_model_name = "" |
|
94 |
|
current_epoch = 0 |
|
95 |
|
while (current_epoch < TOTAL_EPOCH): |
|
96 |
|
experiment.log_current_epoch(current_epoch) |
|
97 |
|
current_epoch += 1 |
|
98 |
|
print("start epoch ", current_epoch) |
96 |
99 |
loss_sum = 0 |
loss_sum = 0 |
97 |
100 |
sample = 0 |
sample = 0 |
98 |
101 |
start_time = time() |
start_time = time() |
|
... |
... |
if __name__ == "__main__": |
120 |
123 |
pass |
pass |
121 |
124 |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
122 |
125 |
loss += loss_d |
loss += loss_d |
123 |
|
# loss.backward() |
|
124 |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
125 |
|
scaled_loss.backward() |
|
|
126 |
|
loss.backward() |
|
127 |
|
# with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
128 |
|
# scaled_loss.backward() |
126 |
129 |
optimizer.step() |
optimizer.step() |
127 |
130 |
optimizer.zero_grad() |
optimizer.zero_grad() |
128 |
131 |
loss_sum += loss.item() |
loss_sum += loss.item() |
|
... |
... |
if __name__ == "__main__": |
135 |
138 |
experiment.log_metric("avg_loss_ministep", avg_loss_ministep) |
experiment.log_metric("avg_loss_ministep", avg_loss_ministep) |
136 |
139 |
# if counting == 100: |
# if counting == 100: |
137 |
140 |
# break |
# break |
|
141 |
|
|
138 |
142 |
end_time = time() |
end_time = time() |
139 |
143 |
avg_loss = loss_sum/sample |
avg_loss = loss_sum/sample |
140 |
144 |
epoch_time = end_time - start_time |
epoch_time = end_time - start_time |
141 |
|
print("==END epoch ", e, " =============================================") |
|
|
145 |
|
print("==END epoch ", current_epoch, " =============================================") |
142 |
146 |
print(epoch_time, avg_loss, sample) |
print(epoch_time, avg_loss, sample) |
143 |
147 |
experiment.log_metric("avg_loss_epoch", avg_loss) |
experiment.log_metric("avg_loss_epoch", avg_loss) |
144 |
148 |
print("=================================================================") |
print("=================================================================") |
145 |
149 |
|
|
146 |
|
save_checkpoint({ |
|
|
150 |
|
current_save_model_name = save_checkpoint({ |
147 |
151 |
'model': net.state_dict(), |
'model': net.state_dict(), |
148 |
152 |
'optimizer': optimizer.state_dict(), |
'optimizer': optimizer.state_dict(), |
|
153 |
|
'e': current_epoch, |
|
154 |
|
'PACNN_PERSPECTIVE_AWARE_MODEL': PACNN_PERSPECTIVE_AWARE_MODEL |
149 |
155 |
# 'amp': amp.state_dict() |
# 'amp': amp.state_dict() |
150 |
|
}, False, MODEL_SAVE_NAME) |
|
|
156 |
|
}, False, MODEL_SAVE_NAME+"_"+str(current_epoch)+"_") |
|
157 |
|
|
|
158 |
|
experiment.log_asset(current_save_model_name) |
151 |
159 |
|
|
|
160 |
|
# end 1 epoch |
152 |
161 |
|
|
153 |
162 |
# after epoch evaluate |
# after epoch evaluate |
154 |
163 |
mae_calculator_d1 = MAECalculator() |
mae_calculator_d1 = MAECalculator() |
|
... |
... |
if __name__ == "__main__": |
189 |
198 |
net = PACNNWithPerspectiveMap(PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
net = PACNNWithPerspectiveMap(PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
190 |
199 |
print(net) |
print(net) |
191 |
200 |
|
|
192 |
|
best_checkpoint = torch.load(MODEL_SAVE_NAME + "checkpoint.pth.tar") |
|
|
201 |
|
best_checkpoint = torch.load(current_save_model_name) |
193 |
202 |
net.load_state_dict(best_checkpoint['model']) |
net.load_state_dict(best_checkpoint['model']) |
194 |
203 |
|
|
195 |
204 |
# device = "cpu" |
# device = "cpu" |
File model_util.py changed (mode: 100644) (index 8bb883a..eeb3468) |
... |
... |
import h5py |
2 |
2 |
import torch |
import torch |
3 |
3 |
import shutil |
import shutil |
4 |
4 |
import numpy as np |
import numpy as np |
|
5 |
|
import os |
|
6 |
|
|
5 |
7 |
|
|
6 |
8 |
def save_net(fname, net): |
def save_net(fname, net): |
7 |
9 |
with h5py.File(fname, 'w') as h5f: |
with h5py.File(fname, 'w') as h5f: |
|
... |
... |
def load_net(fname, net): |
17 |
19 |
|
|
18 |
20 |
|
|
19 |
21 |
def save_checkpoint(state, is_best, task_id, filename='checkpoint.pth.tar'): |
def save_checkpoint(state, is_best, task_id, filename='checkpoint.pth.tar'): |
20 |
|
torch.save(state, task_id + filename) |
|
|
22 |
|
if not os.path.exists("saved_model"): |
|
23 |
|
os.makedirs("saved_model") |
|
24 |
|
full_file_name = os.path.join("saved_model", task_id + filename) |
|
25 |
|
torch.save(state, full_file_name) |
21 |
26 |
if is_best: |
if is_best: |
22 |
|
shutil.copyfile(task_id + filename, task_id + 'model_best.pth.tar') |
|
|
27 |
|
shutil.copyfile(task_id + filename, task_id + 'model_best.pth.tar') |
|
28 |
|
return full_file_name |