File train_attn_can_adcrowdnet_simple.py changed (mode: 100644) (index 31334b0..224ef56) |
|
1 |
|
from comet_ml import Experiment |
|
2 |
|
|
1 |
3 |
from args_util import my_args_parse |
from args_util import my_args_parse |
2 |
4 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
3 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
... |
... |
from torch import nn |
12 |
14 |
from models import AttnCanAdcrowdNetSimpleV4 |
from models import AttnCanAdcrowdNetSimpleV4 |
13 |
15 |
import os |
import os |
14 |
16 |
|
|
|
17 |
|
from ignite.contrib.handlers import PiecewiseLinear |
|
18 |
|
from model_util import get_lr |
|
19 |
|
|
|
20 |
|
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
|
21 |
|
PROJECT_NAME = "crowd-counting-framework" |
15 |
22 |
|
|
16 |
23 |
if __name__ == "__main__": |
if __name__ == "__main__": |
|
24 |
|
experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) |
17 |
25 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
18 |
26 |
print(device) |
print(device) |
19 |
27 |
args = my_args_parse() |
args = my_args_parse() |
20 |
28 |
print(args) |
print(args) |
|
29 |
|
|
|
30 |
|
experiment.set_name(args.task_id) |
|
31 |
|
experiment.set_cmd_args() |
|
32 |
|
|
21 |
33 |
DATA_PATH = args.input |
DATA_PATH = args.input |
22 |
34 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
23 |
35 |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
|
... |
... |
if __name__ == "__main__": |
54 |
66 |
metrics={ |
metrics={ |
55 |
67 |
'mae': CrowdCountingMeanAbsoluteError(), |
'mae': CrowdCountingMeanAbsoluteError(), |
56 |
68 |
'mse': CrowdCountingMeanSquaredError(), |
'mse': CrowdCountingMeanSquaredError(), |
57 |
|
'nll': Loss(loss_fn) |
|
|
69 |
|
'loss': Loss(loss_fn) |
58 |
70 |
}, device=device) |
}, device=device) |
59 |
71 |
print(model) |
print(model) |
60 |
72 |
|
|
|
... |
... |
if __name__ == "__main__": |
86 |
98 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
87 |
99 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
88 |
100 |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
89 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
|
90 |
|
|
|
|
101 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
|
102 |
|
experiment.log_metric("epoch", trainer.state.epoch) |
|
103 |
|
experiment.log_metric("train_mae", metrics['mae']) |
|
104 |
|
experiment.log_metric("train_mse", metrics['mse']) |
|
105 |
|
experiment.log_metric("train_loss", metrics['loss']) |
|
106 |
|
experiment.log_metric("lr", get_lr(optimizer)) |
91 |
107 |
|
|
92 |
108 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
93 |
109 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
|
... |
... |
if __name__ == "__main__": |
95 |
111 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
96 |
112 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
97 |
113 |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
98 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
|
|
114 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
|
115 |
|
experiment.log_metric("valid_mae", metrics['mae']) |
|
116 |
|
experiment.log_metric("valid_mse", metrics['mse']) |
|
117 |
|
experiment.log_metric("valid_loss", metrics['loss']) |
99 |
118 |
|
|
100 |
119 |
|
|
101 |
120 |
|
|
File train_attn_can_adcrowdnet_simple_lrscheduler.py copied from file train_compact_cnn_lrscheduler.py (similarity 94%) (mode: 100644) (index d1cf26d..a7146e3) |
... |
... |
from visualize_util import get_readable_time |
11 |
11 |
|
|
12 |
12 |
import torch |
import torch |
13 |
13 |
from torch import nn |
from torch import nn |
14 |
|
from models import CompactCNN |
|
|
14 |
|
from models import AttnCanAdcrowdNetSimpleV4 |
15 |
15 |
import os |
import os |
|
16 |
|
|
16 |
17 |
from ignite.contrib.handlers import PiecewiseLinear |
from ignite.contrib.handlers import PiecewiseLinear |
17 |
18 |
from model_util import get_lr |
from model_util import get_lr |
18 |
19 |
|
|
|
... |
... |
PROJECT_NAME = "crowd-counting-framework" |
21 |
22 |
|
|
22 |
23 |
if __name__ == "__main__": |
if __name__ == "__main__": |
23 |
24 |
experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) |
experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) |
24 |
|
|
|
25 |
25 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
26 |
26 |
print(device) |
print(device) |
27 |
27 |
args = my_args_parse() |
args = my_args_parse() |
28 |
|
experiment.set_name(args.task_id) |
|
29 |
28 |
print(args) |
print(args) |
|
29 |
|
|
|
30 |
|
experiment.set_name(args.task_id) |
30 |
31 |
experiment.set_cmd_args() |
experiment.set_cmd_args() |
31 |
32 |
|
|
32 |
33 |
DATA_PATH = args.input |
DATA_PATH = args.input |
|
... |
... |
if __name__ == "__main__": |
51 |
52 |
print("len train_loader ", len(train_loader)) |
print("len train_loader ", len(train_loader)) |
52 |
53 |
|
|
53 |
54 |
# model |
# model |
54 |
|
model = CompactCNN() |
|
|
55 |
|
model = AttnCanAdcrowdNetSimpleV4() |
55 |
56 |
model = model.to(device) |
model = model.to(device) |
56 |
57 |
|
|
57 |
58 |
# loss function |
# loss function |
|
... |
... |
if __name__ == "__main__": |
60 |
61 |
optimizer = torch.optim.Adam(model.parameters(), args.lr, |
optimizer = torch.optim.Adam(model.parameters(), args.lr, |
61 |
62 |
weight_decay=args.decay) |
weight_decay=args.decay) |
62 |
63 |
|
|
63 |
|
milestones_values = [(50, 1e-4), (100, 5e-5), (200, 1e-5), (400, 9e-6), (500, 5e-6), (600, 1e-6)] |
|
64 |
|
experiment.log_parameter("milestones_values", "[(50, 1e-4), (100, 5e-5), (200, 1e-5), (400, 9e-6), (500, 5e-6), (600, 1e-6)]") |
|
65 |
|
lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) |
|
66 |
|
|
|
67 |
64 |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
68 |
65 |
evaluator = create_supervised_evaluator(model, |
evaluator = create_supervised_evaluator(model, |
69 |
66 |
metrics={ |
metrics={ |
|
... |
... |
if __name__ == "__main__": |
75 |
72 |
|
|
76 |
73 |
print(args) |
print(args) |
77 |
74 |
|
|
|
75 |
|
milestones_values = [(10, 1e-4), (20, 1e-5), (60, 1e-5), (100, 1e-6)] |
|
76 |
|
experiment.log_parameter("milestones_values", str(milestones_values)) |
|
77 |
|
lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) |
|
78 |
|
|
78 |
79 |
if len(args.load_model) > 0: |
if len(args.load_model) > 0: |
79 |
80 |
load_model_path = args.load_model |
load_model_path = args.load_model |
80 |
81 |
print("load mode " + load_model_path) |
print("load mode " + load_model_path) |
|
... |
... |
if __name__ == "__main__": |
87 |
88 |
print("change lr to ", args.lr) |
print("change lr to ", args.lr) |
88 |
89 |
else: |
else: |
89 |
90 |
print("do not load, keep training") |
print("do not load, keep training") |
|
91 |
|
|
90 |
92 |
trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) |
trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) |
91 |
93 |
|
|
92 |
94 |
|
|
|
... |
... |
if __name__ == "__main__": |
121 |
123 |
experiment.log_metric("valid_loss", metrics['loss']) |
experiment.log_metric("valid_loss", metrics['loss']) |
122 |
124 |
|
|
123 |
125 |
|
|
|
126 |
|
|
124 |
127 |
# docs on save and load |
# docs on save and load |
125 |
128 |
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler} |
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler} |
126 |
129 |
save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True, atomic=True), |
save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True, atomic=True), |