File args_util.py changed (mode: 100644) (index fa2f9fc..33e540f) |
... |
... |
def like_real_args_parse(data_input): |
52 |
52 |
args.print_freq = 30 |
args.print_freq = 30 |
53 |
53 |
|
|
54 |
54 |
|
|
|
55 |
|
def context_aware_network_args_parse(): |
|
56 |
|
""" |
|
57 |
|
this is not dummy |
|
58 |
|
if you are going to make all-in-one notebook, ignore this |
|
59 |
|
:return: |
|
60 |
|
""" |
|
61 |
|
parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network') |
|
62 |
|
parser.add_argument("--task_id", action="store", default="dev") |
|
63 |
|
parser.add_argument('-a', action="store_true", default=False) |
|
64 |
|
|
|
65 |
|
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
|
66 |
|
parser.add_argument('--output', action="store", type=str, default="saved_model/context_aware_network") |
|
67 |
|
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull") |
|
68 |
|
|
|
69 |
|
# args with default value |
|
70 |
|
parser.add_argument('--load_model', action="store", default="", type=str) |
|
71 |
|
parser.add_argument('--lr', action="store", default=1e-8, type=float) |
|
72 |
|
parser.add_argument('--momentum', action="store", default=0.9, type=float) |
|
73 |
|
parser.add_argument('--decay', action="store", default=5*1e-3, type=float) |
|
74 |
|
parser.add_argument('--epochs', action="store", default=1, type=int) |
|
75 |
|
parser.add_argument('--test', action="store_true", default=False) |
|
76 |
|
|
|
77 |
|
|
|
78 |
|
arg = parser.parse_args() |
|
79 |
|
return arg |
|
80 |
|
|
55 |
81 |
def real_args_parse(): |
def real_args_parse(): |
56 |
82 |
""" |
""" |
57 |
83 |
this is not dummy |
this is not dummy |
File train_context_aware_network.py copied from file main_shanghaitech.py (similarity 70%) (mode: 100644) (index 3c2d5ba..0194e9b) |
1 |
|
from args_util import real_args_parse |
|
|
1 |
|
from args_util import context_aware_network_args_parse |
2 |
2 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
3 |
3 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
4 |
4 |
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError |
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError |
|
5 |
|
from ignite.engine import Engine |
|
6 |
|
from ignite.handlers import Checkpoint, DiskSaver |
5 |
7 |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
|
8 |
|
|
6 |
9 |
import torch |
import torch |
7 |
10 |
from torch import nn |
from torch import nn |
8 |
|
from models import CSRNet |
|
|
11 |
|
from models import CANNet |
9 |
12 |
import os |
import os |
10 |
13 |
|
|
|
14 |
|
|
11 |
15 |
if __name__ == "__main__": |
if __name__ == "__main__": |
12 |
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
13 |
17 |
print(device) |
print(device) |
14 |
|
args = real_args_parse() |
|
|
18 |
|
args = context_aware_network_args_parse() |
15 |
19 |
print(args) |
print(args) |
16 |
20 |
DATA_PATH = args.input |
DATA_PATH = args.input |
17 |
21 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
|
... |
... |
if __name__ == "__main__": |
23 |
27 |
test_list = create_training_image_list(TEST_PATH) |
test_list = create_training_image_list(TEST_PATH) |
24 |
28 |
|
|
25 |
29 |
# create data loader |
# create data loader |
26 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list) |
|
|
30 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull") |
27 |
31 |
|
|
28 |
32 |
|
|
29 |
33 |
# model |
# model |
30 |
|
model = CSRNet() |
|
|
34 |
|
model = CANNet() |
31 |
35 |
model = model.to(device) |
model = model.to(device) |
32 |
36 |
|
|
33 |
37 |
# loss function |
# loss function |
34 |
|
loss_fn = nn.MSELoss(size_average=False).cuda() |
|
|
38 |
|
loss_fn = nn.MSELoss(size_average=False).to(device) |
35 |
39 |
|
|
36 |
40 |
optimizer = torch.optim.SGD(model.parameters(), args.lr, |
optimizer = torch.optim.SGD(model.parameters(), args.lr, |
37 |
41 |
momentum=args.momentum, |
momentum=args.momentum, |
|
... |
... |
if __name__ == "__main__": |
49 |
53 |
print(args) |
print(args) |
50 |
54 |
|
|
51 |
55 |
|
|
52 |
|
@trainer.on(Events.ITERATION_COMPLETED) |
|
|
56 |
|
@trainer.on(Events.ITERATION_COMPLETED(every=50)) |
53 |
57 |
def log_training_loss(trainer): |
def log_training_loss(trainer): |
54 |
58 |
print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) |
print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) |
55 |
59 |
|
|
|
... |
... |
if __name__ == "__main__": |
58 |
62 |
def log_training_results(trainer): |
def log_training_results(trainer): |
59 |
63 |
evaluator.run(train_loader) |
evaluator.run(train_loader) |
60 |
64 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
61 |
|
print("Validation Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
|
65 |
|
print("Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
62 |
66 |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
63 |
67 |
|
|
64 |
68 |
|
|
|
... |
... |
if __name__ == "__main__": |
66 |
70 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
67 |
71 |
evaluator.run(val_loader) |
evaluator.run(val_loader) |
68 |
72 |
metrics = evaluator.state.metrics |
metrics = evaluator.state.metrics |
69 |
|
print("Validation Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
|
73 |
|
print("Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
70 |
74 |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
71 |
75 |
|
|
72 |
|
|
|
73 |
|
trainer.run(train_loader, max_epochs=10) |
|
|
76 |
|
# docs on save and load |
|
77 |
|
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
|
78 |
|
save_handler = Checkpoint(to_save, DiskSaver('saved_model/context_aware_network', create_dir=True)) |
|
79 |
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler) |
|
80 |
|
trainer.run(train_loader, max_epochs=args.epochs) |