File args_util.py changed (mode: 100644) (index 25625bd..b4c50b3) |
... |
... |
def context_aware_network_args_parse(): |
82 |
82 |
def my_args_parse(): |
def my_args_parse(): |
83 |
83 |
parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network') |
parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network') |
84 |
84 |
parser.add_argument("--task_id", action="store", default="dev") |
parser.add_argument("--task_id", action="store", default="dev") |
|
85 |
|
parser.add_argument('--note', action="store", default="write anything") |
85 |
86 |
|
|
86 |
87 |
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) |
87 |
88 |
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull") |
parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull") |
File train_compact_cnn.py changed (mode: 100644) (index 1876089..dda90bb) |
|
1 |
|
from comet_ml import Experiment |
|
2 |
|
|
1 |
3 |
from args_util import my_args_parse |
from args_util import my_args_parse |
2 |
4 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list |
3 |
5 |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
... |
... |
import torch |
11 |
13 |
from torch import nn |
from torch import nn |
12 |
14 |
from models import CompactCNN |
from models import CompactCNN |
13 |
15 |
import os |
import os |
|
16 |
|
from model_util import get_lr |
14 |
17 |
|
|
|
18 |
|
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
|
19 |
|
PROJECT_NAME = "crowd-counting-framework" |
15 |
20 |
|
|
16 |
21 |
if __name__ == "__main__": |
if __name__ == "__main__": |
|
22 |
|
experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) |
17 |
23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
18 |
24 |
print(device) |
print(device) |
19 |
25 |
args = my_args_parse() |
args = my_args_parse() |
20 |
26 |
print(args) |
print(args) |
|
27 |
|
|
|
28 |
|
experiment.set_name(args.task_id) |
|
29 |
|
experiment.set_cmd_args() |
|
30 |
|
|
21 |
31 |
DATA_PATH = args.input |
DATA_PATH = args.input |
22 |
32 |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
23 |
33 |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
|
... |
... |
if __name__ == "__main__": |
87 |
97 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
88 |
98 |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
89 |
99 |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
90 |
|
|
|
|
100 |
|
experiment.log_metric("epoch", trainer.state.epoch) |
|
101 |
|
experiment.log_metric("train_mae", metrics['mae']) |
|
102 |
|
experiment.log_metric("train_mse", metrics['mse']) |
|
103 |
|
experiment.log_metric("train_loss", metrics['loss']) |
|
104 |
|
experiment.log_metric("lr", get_lr(optimizer)) |
91 |
105 |
|
|
92 |
106 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
93 |
107 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
|
... |
... |
if __name__ == "__main__": |
96 |
110 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
97 |
111 |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
98 |
112 |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
.format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
99 |
|
|
|
|
113 |
|
experiment.log_metric("valid_mae", metrics['mae']) |
|
114 |
|
experiment.log_metric("valid_mse", metrics['mse']) |
|
115 |
|
experiment.log_metric("valid_loss", metrics['loss']) |
100 |
116 |
|
|
101 |
117 |
|
|
102 |
118 |
# docs on save and load |
# docs on save and load |