File train_compact_cnn.py changed (mode: 100644) (index dda90bb..a79ba4a) |
... |
... |
if __name__ == "__main__": |
64 |
64 |
metrics={ |
metrics={ |
65 |
65 |
'mae': CrowdCountingMeanAbsoluteError(), |
'mae': CrowdCountingMeanAbsoluteError(), |
66 |
66 |
'mse': CrowdCountingMeanSquaredError(), |
'mse': CrowdCountingMeanSquaredError(), |
67 |
|
'nll': Loss(loss_fn) |
|
|
67 |
|
'loss': Loss(loss_fn) |
68 |
68 |
}, device=device) |
}, device=device) |
69 |
69 |
print(model) |
print(model) |
70 |
70 |
|
|
|
... |
... |
if __name__ == "__main__": |
96 |
96 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
97 |
97 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
98 |
98 |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
99 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
|
|
99 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
100 |
100 |
experiment.log_metric("epoch", trainer.state.epoch) |
experiment.log_metric("epoch", trainer.state.epoch) |
101 |
101 |
experiment.log_metric("train_mae", metrics['mae']) |
experiment.log_metric("train_mae", metrics['mae']) |
102 |
102 |
experiment.log_metric("train_mse", metrics['mse']) |
experiment.log_metric("train_mse", metrics['mse']) |
|
... |
... |
if __name__ == "__main__": |
109 |
109 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
110 |
110 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
111 |
111 |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
112 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
|
|
112 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
113 |
113 |
experiment.log_metric("valid_mae", metrics['mae']) |
experiment.log_metric("valid_mae", metrics['mae']) |
114 |
114 |
experiment.log_metric("valid_mse", metrics['mse']) |
experiment.log_metric("valid_mse", metrics['mse']) |
115 |
115 |
experiment.log_metric("valid_loss", metrics['loss']) |
experiment.log_metric("valid_loss", metrics['loss']) |