File experiment_main.py changed (mode: 100644) (index 974dcfc..08c5dd3) |
... |
... |
from data_flow import get_dataloader, create_image_list |
5 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
6 |
6 |
from ignite.metrics import Loss |
from ignite.metrics import Loss |
7 |
7 |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
8 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
|
|
8 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount |
9 |
9 |
from visualize_util import get_readable_time |
from visualize_util import get_readable_time |
10 |
10 |
|
|
11 |
11 |
import torch |
import torch |
|
... |
... |
if __name__ == "__main__": |
114 |
114 |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
115 |
115 |
evaluator_train = create_supervised_evaluator(model, |
evaluator_train = create_supervised_evaluator(model, |
116 |
116 |
metrics={ |
metrics={ |
117 |
|
'mae': CrowdCountingMeanAbsoluteError(), |
|
118 |
|
'mse': CrowdCountingMeanSquaredError(), |
|
|
117 |
|
'mae': CrowdCountingMeanAbsoluteErrorWithCount(), |
|
118 |
|
'mse': CrowdCountingMeanSquaredErrorWithCount(), |
119 |
119 |
'loss': Loss(loss_fn) |
'loss': Loss(loss_fn) |
120 |
120 |
}, device=device) |
}, device=device) |
121 |
121 |
|
|
122 |
122 |
evaluator_validate = create_supervised_evaluator(model, |
evaluator_validate = create_supervised_evaluator(model, |
123 |
123 |
metrics={ |
metrics={ |
124 |
|
'mae': CrowdCountingMeanAbsoluteError(), |
|
125 |
|
'mse': CrowdCountingMeanSquaredError(), |
|
|
124 |
|
'mae': CrowdCountingMeanAbsoluteErrorWithCount(), |
|
125 |
|
'mse': CrowdCountingMeanSquaredErrorWithCount(), |
126 |
126 |
'loss': Loss(loss_fn) |
'loss': Loss(loss_fn) |
127 |
127 |
}, device=device) |
}, device=device) |
128 |
128 |
|
|
129 |
129 |
evaluator_test = create_supervised_evaluator(model, |
evaluator_test = create_supervised_evaluator(model, |
130 |
130 |
metrics={ |
metrics={ |
131 |
|
'mae': CrowdCountingMeanAbsoluteError(), |
|
132 |
|
'mse': CrowdCountingMeanSquaredError(), |
|
|
131 |
|
'mae': CrowdCountingMeanAbsoluteErrorWithCount(), |
|
132 |
|
'mse': CrowdCountingMeanSquaredErrorWithCount(), |
133 |
133 |
'loss': Loss(loss_fn) |
'loss': Loss(loss_fn) |
134 |
134 |
}, device=device) |
}, device=device) |
135 |
135 |
|
|