File train_compact_cnn.py changed (mode: 100644) (index 966c973..661d6a7) |
1 |
1 |
from comet_ml import Experiment |
from comet_ml import Experiment |
2 |
2 |
|
|
3 |
|
from args_util import my_args_parse |
|
4 |
|
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
|
|
3 |
|
from args_util import meow_parse |
|
4 |
|
from data_flow import get_dataloader, create_image_list |
5 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
6 |
|
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError |
|
7 |
|
from ignite.engine import Engine |
|
8 |
|
from ignite.handlers import Checkpoint, DiskSaver |
|
|
6 |
|
from ignite.metrics import Loss |
|
7 |
|
from ignite.handlers import Checkpoint, DiskSaver, Timer |
9 |
8 |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
10 |
9 |
from visualize_util import get_readable_time |
from visualize_util import get_readable_time |
11 |
10 |
|
|
12 |
11 |
import torch |
import torch |
13 |
12 |
from torch import nn |
from torch import nn |
14 |
|
from models import CompactCNNV2 |
|
|
13 |
|
|
|
14 |
|
from models import CompactCNNV2, CompactCNNV3 |
|
15 |
|
|
15 |
16 |
import os |
import os |
16 |
17 |
from model_util import get_lr |
from model_util import get_lr |
17 |
|
from torchsummary import summary |
|
18 |
18 |
|
|
19 |
19 |
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
20 |
20 |
PROJECT_NAME = "crowd-counting-framework" |
PROJECT_NAME = "crowd-counting-framework" |
|
21 |
|
# PROJECT_NAME = "crowd-counting-debug" |
21 |
22 |
|
|
22 |
23 |
|
|
23 |
24 |
def very_simple_param_count(model): |
def very_simple_param_count(model): |
|
... |
... |
if __name__ == "__main__": |
29 |
30 |
experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) |
experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) |
30 |
31 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
31 |
32 |
print(device) |
print(device) |
32 |
|
args = my_args_parse() |
|
|
33 |
|
args = meow_parse() |
33 |
34 |
print(args) |
print(args) |
34 |
35 |
|
|
35 |
36 |
experiment.set_name(args.task_id) |
experiment.set_name(args.task_id) |
36 |
37 |
experiment.set_cmd_args() |
experiment.set_cmd_args() |
37 |
|
experiment.log_other("note", args.note) |
|
|
38 |
|
experiment.log_text(args.note) |
38 |
39 |
|
|
39 |
40 |
DATA_PATH = args.input |
DATA_PATH = args.input |
40 |
41 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
|
... |
... |
if __name__ == "__main__": |
53 |
54 |
test_list = create_image_list(TEST_PATH) |
test_list = create_image_list(TEST_PATH) |
54 |
55 |
|
|
55 |
56 |
# create data loader |
# create data loader |
56 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, |
|
57 |
|
None, |
|
|
57 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, None, |
58 |
58 |
test_list, |
test_list, |
59 |
59 |
dataset_name=dataset_name, |
dataset_name=dataset_name, |
60 |
60 |
batch_size=args.batch_size, |
batch_size=args.batch_size, |
|
... |
... |
if __name__ == "__main__": |
63 |
63 |
print("len train_loader ", len(train_loader)) |
print("len train_loader ", len(train_loader)) |
64 |
64 |
|
|
65 |
65 |
# model |
# model |
66 |
|
model = CompactCNNV2() |
|
67 |
|
summary(model, (3, 128, 128), device="cpu") |
|
68 |
|
experiment.log_other("n_param", very_simple_param_count(model)) |
|
|
66 |
|
model_name = args.model |
|
67 |
|
experiment.log_other("model", model_name) |
|
68 |
|
if model_name == "CompactCNNV2": |
|
69 |
|
model = CompactCNNV2() |
|
70 |
|
elif model_name == "CompactCNNV3": |
|
71 |
|
model = CompactCNNV3() |
|
72 |
|
else: |
|
73 |
|
print("error: you didn't pick a model") |
|
74 |
|
exit(-1) |
|
75 |
|
n_param = very_simple_param_count(model) |
|
76 |
|
experiment.log_other("n_param", n_param) |
69 |
77 |
if hasattr(model, 'model_note'): |
if hasattr(model, 'model_note'): |
70 |
78 |
experiment.log_other("model_note", model.model_note) |
experiment.log_other("model_note", model.model_note) |
71 |
79 |
model = model.to(device) |
model = model.to(device) |
72 |
80 |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
81 |
# loss function |
# loss function |
76 |
82 |
loss_fn = nn.MSELoss(reduction='sum').to(device) |
loss_fn = nn.MSELoss(reduction='sum').to(device) |
77 |
83 |
|
|
|
... |
... |
if __name__ == "__main__": |
79 |
85 |
weight_decay=args.decay) |
weight_decay=args.decay) |
80 |
86 |
|
|
81 |
87 |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
82 |
|
evaluator = create_supervised_evaluator(model, |
|
|
88 |
|
evaluator_train = create_supervised_evaluator(model, |
|
89 |
|
metrics={ |
|
90 |
|
'mae': CrowdCountingMeanAbsoluteError(), |
|
91 |
|
'mse': CrowdCountingMeanSquaredError(), |
|
92 |
|
'loss': Loss(loss_fn) |
|
93 |
|
}, device=device) |
|
94 |
|
|
|
95 |
|
evaluator_validate = create_supervised_evaluator(model, |
83 |
96 |
metrics={ |
metrics={ |
84 |
97 |
'mae': CrowdCountingMeanAbsoluteError(), |
'mae': CrowdCountingMeanAbsoluteError(), |
85 |
98 |
'mse': CrowdCountingMeanSquaredError(), |
'mse': CrowdCountingMeanSquaredError(), |
|
... |
... |
if __name__ == "__main__": |
89 |
102 |
|
|
90 |
103 |
print(args) |
print(args) |
91 |
104 |
|
|
|
105 |
|
|
|
106 |
|
# timer |
|
107 |
|
train_timer = Timer(average=True) # time to train whole epoch |
|
108 |
|
batch_timer = Timer(average=True) # every batch |
|
109 |
|
evaluate_timer = Timer(average=True) |
|
110 |
|
|
|
111 |
|
batch_timer.attach(trainer, |
|
112 |
|
start =Events.EPOCH_STARTED, |
|
113 |
|
resume =Events.ITERATION_STARTED, |
|
114 |
|
pause =Events.ITERATION_COMPLETED, |
|
115 |
|
step =Events.ITERATION_COMPLETED) |
|
116 |
|
|
|
117 |
|
train_timer.attach(trainer, |
|
118 |
|
start =Events.EPOCH_STARTED, |
|
119 |
|
resume =Events.EPOCH_STARTED, |
|
120 |
|
pause =Events.EPOCH_COMPLETED, |
|
121 |
|
step =Events.EPOCH_COMPLETED) |
|
122 |
|
|
92 |
123 |
if len(args.load_model) > 0: |
if len(args.load_model) > 0: |
93 |
124 |
load_model_path = args.load_model |
load_model_path = args.load_model |
94 |
125 |
print("load mode " + load_model_path) |
print("load mode " + load_model_path) |
|
... |
... |
if __name__ == "__main__": |
98 |
129 |
print("load model complete") |
print("load model complete") |
99 |
130 |
for param_group in optimizer.param_groups: |
for param_group in optimizer.param_groups: |
100 |
131 |
param_group['lr'] = args.lr |
param_group['lr'] = args.lr |
101 |
|
param_group['weight_decay'] = args.decay |
|
102 |
|
print("change lr to ", param_group['lr']) |
|
103 |
|
print("change weight_decay to ", param_group['weight_decay']) |
|
|
132 |
|
print("change lr to ", args.lr) |
104 |
133 |
else: |
else: |
105 |
134 |
print("do not load, keep training") |
print("do not load, keep training") |
106 |
135 |
|
|
107 |
136 |
|
|
108 |
|
@trainer.on(Events.ITERATION_COMPLETED(every=50)) |
|
|
137 |
|
@trainer.on(Events.ITERATION_COMPLETED(every=100)) |
109 |
138 |
def log_training_loss(trainer): |
def log_training_loss(trainer): |
110 |
139 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
111 |
140 |
print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) |
print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) |
|
... |
... |
if __name__ == "__main__": |
113 |
142 |
|
|
114 |
143 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
115 |
144 |
def log_training_results(trainer): |
def log_training_results(trainer): |
116 |
|
evaluator.run(train_loader) |
|
117 |
|
metrics = evaluator.state.metrics |
|
|
145 |
|
evaluator_train.run(train_loader) |
|
146 |
|
metrics = evaluator_train.state.metrics |
118 |
147 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
119 |
148 |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
120 |
149 |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
|
... |
... |
if __name__ == "__main__": |
124 |
153 |
experiment.log_metric("train_loss", metrics['loss']) |
experiment.log_metric("train_loss", metrics['loss']) |
125 |
154 |
experiment.log_metric("lr", get_lr(optimizer)) |
experiment.log_metric("lr", get_lr(optimizer)) |
126 |
155 |
|
|
|
156 |
|
experiment.log_metric("batch_timer", batch_timer.value()) |
|
157 |
|
experiment.log_metric("train_timer", train_timer.value()) |
|
158 |
|
|
|
159 |
|
print("batch_timer ", batch_timer.value()) |
|
160 |
|
print("train_timer ", train_timer.value()) |
|
161 |
|
|
127 |
162 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
128 |
163 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
129 |
|
evaluator.run(test_loader) |
|
130 |
|
metrics = evaluator.state.metrics |
|
|
164 |
|
evaluate_timer.resume() |
|
165 |
|
evaluator_validate.run(test_loader) |
|
166 |
|
evaluate_timer.pause() |
|
167 |
|
evaluate_timer.step() |
|
168 |
|
|
|
169 |
|
metrics = evaluator_validate.state.metrics |
131 |
170 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
132 |
171 |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
133 |
172 |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
|
... |
... |
if __name__ == "__main__": |
135 |
174 |
experiment.log_metric("valid_mse", metrics['mse']) |
experiment.log_metric("valid_mse", metrics['mse']) |
136 |
175 |
experiment.log_metric("valid_loss", metrics['loss']) |
experiment.log_metric("valid_loss", metrics['loss']) |
137 |
176 |
|
|
|
177 |
|
# timer |
|
178 |
|
experiment.log_metric("evaluate_timer", evaluate_timer.value()) |
|
179 |
|
print("evaluate_timer ", evaluate_timer.value()) |
|
180 |
|
|
|
181 |
|
def checkpoint_valid_mae_score_function(engine): |
|
182 |
|
score = engine.state.metrics['mae'] |
|
183 |
|
return score |
|
184 |
|
|
138 |
185 |
|
|
139 |
186 |
# docs on save and load |
# docs on save and load |
140 |
187 |
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
|
... |
... |
if __name__ == "__main__": |
142 |
189 |
filename_prefix=args.task_id, |
filename_prefix=args.task_id, |
143 |
190 |
n_saved=5) |
n_saved=5) |
144 |
191 |
|
|
145 |
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=3), save_handler) |
|
|
192 |
|
save_handler_best = Checkpoint(to_save, DiskSaver('saved_model_best/' + args.task_id, create_dir=True, atomic=True), |
|
193 |
|
filename_prefix=args.task_id, score_name="valid_mae", score_function=checkpoint_valid_mae_score_function, |
|
194 |
|
n_saved=5) |
|
195 |
|
|
|
196 |
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), save_handler) |
|
197 |
|
evaluator_validate.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler_best) |
|
198 |
|
|
146 |
199 |
|
|
147 |
200 |
trainer.run(train_loader, max_epochs=args.epochs) |
trainer.run(train_loader, max_epochs=args.epochs) |