File args_util.py changed (mode: 100644) (index ff5520d..25625bd) |
... |
... |
def context_aware_network_args_parse(): |
82 |
82 |
def my_args_parse(): |
def my_args_parse(): |
83 |
83 |
parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network') |
parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network') |
84 |
84 |
parser.add_argument("--task_id", action="store", default="dev") |
parser.add_argument("--task_id", action="store", default="dev") |
85 |
|
parser.add_argument('-a', action="store_true", default=False) |
|
86 |
85 |
|
|
87 |
86 |
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
88 |
87 |
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull") |
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull") |
|
... |
... |
def my_args_parse(): |
90 |
89 |
# args with default value |
# args with default value |
91 |
90 |
parser.add_argument('--load_model', action="store", default="", type=str) |
parser.add_argument('--load_model', action="store", default="", type=str) |
92 |
91 |
parser.add_argument('--lr', action="store", default=1e-8, type=float) |
parser.add_argument('--lr', action="store", default=1e-8, type=float) |
93 |
|
parser.add_argument('--momentum', action="store", default=0.9, type=float) |
|
|
92 |
|
# parser.add_argument('--momentum', action="store", default=0.9, type=float) |
94 |
93 |
parser.add_argument('--decay', action="store", default=5*1e-3, type=float) |
parser.add_argument('--decay', action="store", default=5*1e-3, type=float) |
95 |
94 |
parser.add_argument('--epochs', action="store", default=1, type=int) |
parser.add_argument('--epochs', action="store", default=1, type=int) |
96 |
95 |
parser.add_argument('--test', action="store_true", default=False) |
parser.add_argument('--test', action="store_true", default=False) |
File train_compact_cnn_lrscheduler.py changed (mode: 100644) (index 3e3b3d0..848f4e7) |
|
1 |
|
from comet_ml import Experiment |
|
2 |
|
|
1 |
3 |
from args_util import my_args_parse |
from args_util import my_args_parse |
2 |
4 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
3 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
... |
... |
from torch import nn |
12 |
14 |
from models import CompactCNN |
from models import CompactCNN |
13 |
15 |
import os |
import os |
14 |
16 |
from ignite.contrib.handlers import PiecewiseLinear |
from ignite.contrib.handlers import PiecewiseLinear |
|
17 |
|
from model_util import get_lr |
|
18 |
|
|
|
19 |
|
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
|
20 |
|
PROJECT_NAME = "crowd-counting-framework" |
15 |
21 |
|
|
16 |
22 |
if __name__ == "__main__": |
if __name__ == "__main__": |
|
23 |
|
experiment = Experiment(project_name=PROJECT_NAME) |
|
24 |
|
|
17 |
25 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
18 |
26 |
print(device) |
print(device) |
19 |
27 |
args = my_args_parse() |
args = my_args_parse() |
|
28 |
|
experiment.set_name(args.task_id) |
20 |
29 |
print(args) |
print(args) |
|
30 |
|
experiment.set_cmd_args() |
|
31 |
|
|
21 |
32 |
DATA_PATH = args.input |
DATA_PATH = args.input |
22 |
33 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
23 |
34 |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
|
... |
... |
if __name__ == "__main__": |
49 |
60 |
optimizer = torch.optim.Adam(model.parameters(), args.lr, |
optimizer = torch.optim.Adam(model.parameters(), args.lr, |
50 |
61 |
weight_decay=args.decay) |
weight_decay=args.decay) |
51 |
62 |
|
|
52 |
|
milestones_values = [(50, 1e-4), (50, 5e-5), (50, 1e-5), (50, 5e-6), (50, 1e-6), (100, 1e-7)] |
|
|
63 |
|
milestones_values = [(20, 1e-4), (30, 5e-5), (100, 1e-5), (50, 5e-6), (50, 1e-6), (100, 1e-7)] |
53 |
64 |
lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) |
lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) |
54 |
65 |
|
|
55 |
66 |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) |
|
... |
... |
if __name__ == "__main__": |
57 |
68 |
metrics={ |
metrics={ |
58 |
69 |
'mae': CrowdCountingMeanAbsoluteError(), |
'mae': CrowdCountingMeanAbsoluteError(), |
59 |
70 |
'mse': CrowdCountingMeanSquaredError(), |
'mse': CrowdCountingMeanSquaredError(), |
60 |
|
'nll': Loss(loss_fn) |
|
|
71 |
|
'loss': Loss(loss_fn) |
61 |
72 |
}, device=device) |
}, device=device) |
62 |
73 |
print(model) |
print(model) |
63 |
74 |
|
|
|
... |
... |
if __name__ == "__main__": |
75 |
86 |
print("change lr to ", args.lr) |
print("change lr to ", args.lr) |
76 |
87 |
else: |
else: |
77 |
88 |
print("do not load, keep training") |
print("do not load, keep training") |
78 |
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) |
|
|
89 |
|
trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) |
79 |
90 |
|
|
80 |
91 |
|
|
81 |
92 |
@trainer.on(Events.ITERATION_COMPLETED(every=50)) |
@trainer.on(Events.ITERATION_COMPLETED(every=50)) |
|
... |
... |
if __name__ == "__main__": |
90 |
101 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
91 |
102 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
92 |
103 |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
93 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
|
94 |
|
|
|
|
104 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
|
105 |
|
experiment.set_epoch(epoch=trainer.state.epoch) |
|
106 |
|
experiment.log_metric("train_mae", metrics['mae']) |
|
107 |
|
experiment.log_metric("train_mse", metrics['mse']) |
|
108 |
|
experiment.log_metric("train_loss", metrics['loss']) |
|
109 |
|
experiment.log_metric("lr", get_lr(optimizer)) |
95 |
110 |
|
|
96 |
111 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
97 |
112 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
|
... |
... |
if __name__ == "__main__": |
99 |
114 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
100 |
115 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
101 |
116 |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
102 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
|
103 |
|
|
|
|
117 |
|
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss'])) |
|
118 |
|
experiment.set_epoch(epoch=trainer.state.epoch) |
|
119 |
|
experiment.log_metric("valid_mae", metrics['mae']) |
|
120 |
|
experiment.log_metric("valid_mse", metrics['mse']) |
|
121 |
|
experiment.log_metric("valid_loss", metrics['loss']) |
104 |
122 |
|
|
105 |
123 |
|
|
106 |
124 |
# docs on save and load |
# docs on save and load |