File experiment_meow_main.py changed (mode: 100644) (index 84b1083..4441b69) |
... |
... |
from args_util import meow_parse |
4 |
4 |
from data_flow import get_dataloader, create_image_list |
from data_flow import get_dataloader, create_image_list |
5 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
6 |
6 |
from ignite.metrics import Loss |
from ignite.metrics import Loss |
7 |
|
from ignite.handlers import Checkpoint, DiskSaver |
|
|
7 |
|
from ignite.handlers import Checkpoint, DiskSaver, Timer |
8 |
8 |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
9 |
9 |
from visualize_util import get_readable_time |
from visualize_util import get_readable_time |
10 |
10 |
|
|
|
... |
... |
if __name__ == "__main__": |
102 |
102 |
|
|
103 |
103 |
print(args) |
print(args) |
104 |
104 |
|
|
|
105 |
|
|
|
106 |
|
# timer |
|
107 |
|
train_timer = Timer() # time to train whole epoch |
|
108 |
|
batch_timer = Timer(average=True) # every batch |
|
109 |
|
evaluate_timer = Timer() |
|
110 |
|
|
|
111 |
|
batch_timer.attach(trainer, |
|
112 |
|
start =Events.EPOCH_STARTED, |
|
113 |
|
resume =Events.ITERATION_STARTED, |
|
114 |
|
pause =Events.ITERATION_COMPLETED, |
|
115 |
|
step =Events.ITERATION_COMPLETED) |
|
116 |
|
|
|
117 |
|
train_timer.attach(trainer, |
|
118 |
|
start =Events.EPOCH_STARTED, |
|
119 |
|
resume =Events.EPOCH_STARTED, |
|
120 |
|
pause =Events.EPOCH_COMPLETED, |
|
121 |
|
step =Events.EPOCH_COMPLETED) |
|
122 |
|
|
105 |
123 |
if len(args.load_model) > 0: |
if len(args.load_model) > 0: |
106 |
124 |
load_model_path = args.load_model |
load_model_path = args.load_model |
107 |
125 |
print("load mode " + load_model_path) |
print("load mode " + load_model_path) |
|
... |
... |
if __name__ == "__main__": |
116 |
134 |
print("do not load, keep training") |
print("do not load, keep training") |
117 |
135 |
|
|
118 |
136 |
|
|
119 |
|
@trainer.on(Events.ITERATION_COMPLETED(every=50)) |
|
|
137 |
|
@trainer.on(Events.ITERATION_COMPLETED(every=100)) |
120 |
138 |
def log_training_loss(trainer): |
def log_training_loss(trainer): |
121 |
139 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
122 |
140 |
print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) |
print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) |
|
... |
... |
if __name__ == "__main__": |
135 |
153 |
experiment.log_metric("train_loss", metrics['loss']) |
experiment.log_metric("train_loss", metrics['loss']) |
136 |
154 |
experiment.log_metric("lr", get_lr(optimizer)) |
experiment.log_metric("lr", get_lr(optimizer)) |
137 |
155 |
|
|
|
156 |
|
experiment.log_metric("batch_timer", batch_timer.value()) |
|
157 |
|
experiment.log_metric("train_timer", train_timer.value()) |
|
158 |
|
|
138 |
159 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
139 |
160 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
|
161 |
|
evaluate_timer.resume() |
140 |
162 |
evaluator.run(test_loader) |
evaluator.run(test_loader) |
|
163 |
|
evaluate_timer.pause() |
|
164 |
|
evaluate_timer.step() |
|
165 |
|
|
141 |
166 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
142 |
167 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
143 |
168 |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
... |
... |
if __name__ == "__main__": |
146 |
171 |
experiment.log_metric("valid_mse", metrics['mse']) |
experiment.log_metric("valid_mse", metrics['mse']) |
147 |
172 |
experiment.log_metric("valid_loss", metrics['loss']) |
experiment.log_metric("valid_loss", metrics['loss']) |
148 |
173 |
|
|
|
174 |
|
# timer |
|
175 |
|
experiment.log_metric("evaluate_timer", evaluate_timer.value()) |
|
176 |
|
|
|
177 |
|
def checkpoint_valid_mae_score_function(engine): |
|
178 |
|
score = engine.state.metrics['valid_mae'] |
|
179 |
|
return score |
|
180 |
|
|
149 |
181 |
|
|
150 |
182 |
# docs on save and load |
# docs on save and load |
151 |
183 |
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
|
... |
... |
if __name__ == "__main__": |
153 |
185 |
filename_prefix=args.task_id, |
filename_prefix=args.task_id, |
154 |
186 |
n_saved=5) |
n_saved=5) |
155 |
187 |
|
|
|
188 |
|
save_handler_best = Checkpoint(to_save, DiskSaver('saved_model_best/' + args.task_id, create_dir=True, atomic=True), |
|
189 |
|
filename_prefix=args.task_id, score_name="valid_mae", score_function=checkpoint_valid_mae_score_function, |
|
190 |
|
n_saved=5) |
|
191 |
|
|
156 |
192 |
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), save_handler) |
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), save_handler) |
|
193 |
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler_best) |
|
194 |
|
|
157 |
195 |
|
|
158 |
196 |
trainer.run(train_loader, max_epochs=args.epochs) |
trainer.run(train_loader, max_epochs=args.epochs) |