File args_util.py changed (mode: 100644) (index 5867a39..8eb8fa4) |
... |
... |
def meow_parse(): |
132 |
132 |
# help="if true, use mse and negative ssim as loss function") |
# help="if true, use mse and negative ssim as loss function") |
133 |
133 |
parser.add_argument('--loss_fn', action="store", default="MSE", type=str) |
parser.add_argument('--loss_fn', action="store", default="MSE", type=str) |
134 |
134 |
parser.add_argument('--optim', action="store", default="adam", type=str) |
parser.add_argument('--optim', action="store", default="adam", type=str) |
|
135 |
|
parser.add_argument('--eval_only', action="store_true", default=False) |
135 |
136 |
arg = parser.parse_args() |
arg = parser.parse_args() |
136 |
137 |
return arg |
return arg |
137 |
138 |
|
|
File experiment_main.py changed (mode: 100644) (index 34c6625..74787c7) |
... |
... |
if __name__ == "__main__": |
339 |
339 |
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=10), save_handler) |
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=10), save_handler) |
340 |
340 |
evaluator_validate.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler_best) |
evaluator_validate.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler_best) |
341 |
341 |
|
|
342 |
|
trainer.run(train_loader, max_epochs=args.epochs) |
|
|
342 |
|
if args.eval_only: |
|
343 |
|
print("evaluation only, no training") |
|
344 |
|
evaluate_validate_timer.resume() |
|
345 |
|
evaluator_validate.run(val_loader) |
|
346 |
|
evaluate_validate_timer.pause() |
|
347 |
|
evaluate_validate_timer.step() |
|
348 |
|
|
|
349 |
|
metrics = evaluator_validate.state.metrics |
|
350 |
|
timestamp = get_readable_time() |
|
351 |
|
print(timestamp + " Validation set Results - Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
352 |
|
.format( metrics['mae'], metrics['mse'], 0)) |
|
353 |
|
experiment.log_metric("valid_mae", metrics['mae']) |
|
354 |
|
experiment.log_metric("valid_mse", metrics['mse']) |
|
355 |
|
|
|
356 |
|
# timer |
|
357 |
|
experiment.log_metric("evaluate_valid_timer", evaluate_validate_timer.value()) |
|
358 |
|
print("evaluate_valid_timer ", evaluate_validate_timer.value()) |
|
359 |
|
|
|
360 |
|
# check if that validate is best |
|
361 |
|
flag_mae = best_mae.checkAndRecord(metrics['mae'], metrics['mse']) |
|
362 |
|
flag_mse = best_mse.checkAndRecord(metrics['mae'], metrics['mse']) |
|
363 |
|
|
|
364 |
|
if flag_mae or flag_mse: |
|
365 |
|
experiment.log_metric("valid_best_mae", metrics['mae']) |
|
366 |
|
experiment.log_metric("valid_best_mse", metrics['mse']) |
|
367 |
|
print("BEST VAL, evaluating on test set") |
|
368 |
|
evaluate_test_timer.resume() |
|
369 |
|
evaluator_test.run(test_loader) |
|
370 |
|
evaluate_test_timer.pause() |
|
371 |
|
evaluate_test_timer.step() |
|
372 |
|
test_metrics = evaluator_test.state.metrics |
|
373 |
|
timestamp = get_readable_time() |
|
374 |
|
print(timestamp + " Test set Results - Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
375 |
|
.format( test_metrics['mae'], test_metrics['mse'], 0)) |
|
376 |
|
experiment.log_metric("test_mae", test_metrics['mae']) |
|
377 |
|
experiment.log_metric("test_mse", test_metrics['mse']) |
|
378 |
|
experiment.log_metric("evaluate_test_timer", evaluate_test_timer.value()) |
|
379 |
|
print("evaluate_test_timer ", evaluate_test_timer.value()) |
|
380 |
|
# experiment.log_metric("test_loss", test_metrics['loss']) |
|
381 |
|
else: |
|
382 |
|
trainer.run(train_loader, max_epochs=args.epochs) |