File args_util.py changed (mode: 100644) (index e7f4e6d..c9c322c) |
... |
... |
def real_args_parse(): |
67 |
67 |
parser.add_argument('--model', action="store", default="pacnn") |
parser.add_argument('--model', action="store", default="pacnn") |
68 |
68 |
|
|
69 |
69 |
# args with default value |
# args with default value |
|
70 |
|
parser.add_argument('--load_model', action="store", default="", type=str) |
70 |
71 |
parser.add_argument('--lr', action="store", default=1e-8, type=float) |
parser.add_argument('--lr', action="store", default=1e-8, type=float) |
71 |
|
parser.add_argument('--momentum', action="store", default=0.95, type=float) |
|
|
72 |
|
parser.add_argument('--momentum', action="store", default=0.9, type=float) |
72 |
73 |
parser.add_argument('--decay', action="store", default=5*1e-3, type=float) |
parser.add_argument('--decay', action="store", default=5*1e-3, type=float) |
73 |
74 |
parser.add_argument('--epochs', action="store", default=1, type=int) |
parser.add_argument('--epochs', action="store", default=1, type=int) |
74 |
75 |
|
|
File main_pacnn.py changed (mode: 100644) (index 9cfb5cc..24d3d1c) |
... |
... |
if __name__ == "__main__": |
32 |
32 |
|
|
33 |
33 |
|
|
34 |
34 |
MODEL_SAVE_NAME = args.task_id |
MODEL_SAVE_NAME = args.task_id |
35 |
|
MODEL_SAVE_INTERVAL = 10 |
|
|
35 |
|
MODEL_SAVE_INTERVAL = 5 |
36 |
36 |
DATA_PATH = args.input |
DATA_PATH = args.input |
37 |
37 |
DATASET_NAME = "shanghaitech" |
DATASET_NAME = "shanghaitech" |
38 |
38 |
TOTAL_EPOCH = args.epochs |
TOTAL_EPOCH = args.epochs |
|
... |
... |
if __name__ == "__main__": |
42 |
42 |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
43 |
43 |
experiment.log_parameter("PACNN_PERSPECTIVE_AWARE_MODEL", PACNN_PERSPECTIVE_AWARE_MODEL) |
experiment.log_parameter("PACNN_PERSPECTIVE_AWARE_MODEL", PACNN_PERSPECTIVE_AWARE_MODEL) |
44 |
44 |
experiment.log_parameter("train", "train without p") |
experiment.log_parameter("train", "train without p") |
|
45 |
|
experiment.log_parameter("momentum", args.momentum) |
|
46 |
|
experiment.log_parameter("lr", args.lr) |
45 |
47 |
|
|
46 |
48 |
# create list |
# create list |
47 |
49 |
if DATASET_NAME is "shanghaitech": |
if DATASET_NAME is "shanghaitech": |
|
... |
... |
if __name__ == "__main__": |
92 |
94 |
|
|
93 |
95 |
current_save_model_name = "" |
current_save_model_name = "" |
94 |
96 |
current_epoch = 0 |
current_epoch = 0 |
|
97 |
|
|
|
98 |
|
# load model |
|
99 |
|
load_model = args.load_model |
|
100 |
|
if len(load_model) > 0: |
|
101 |
|
checkpoint = torch.load(load_model) |
|
102 |
|
net.load_state_dict(checkpoint['model']) |
|
103 |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
104 |
|
current_epoch = checkpoint['e'] |
|
105 |
|
print("load ", load_model, " epoch ", str(current_epoch)) |
|
106 |
|
else: |
|
107 |
|
print("new model") |
|
108 |
|
|
95 |
109 |
while current_epoch < TOTAL_EPOCH: |
while current_epoch < TOTAL_EPOCH: |
96 |
110 |
experiment.log_current_epoch(current_epoch) |
experiment.log_current_epoch(current_epoch) |
97 |
111 |
current_epoch += 1 |
current_epoch += 1 |
File train_script/train_pacnn_shanghaitechA.sh changed (mode: 100644) (index 448ea1a..b8ab1df) |
1 |
1 |
#python /home/tt/project/crowd_counting_framework/main_pacnn.py --input /home/tt/project/crowd_counting_framework/data/ShanghaiTech/part_A |
#python /home/tt/project/crowd_counting_framework/main_pacnn.py --input /home/tt/project/crowd_counting_framework/data/ShanghaiTech/part_A |
2 |
2 |
|
|
|
3 |
|
#python main_pacnn.py \ |
|
4 |
|
#--input data/ShanghaiTech/part_A \ |
|
5 |
|
#--epochs 151 \ |
|
6 |
|
#--task_id train_state1_attemp1 |
|
7 |
|
|
|
8 |
|
#python main_pacnn.py \ |
|
9 |
|
#--input data/ShanghaiTech/part_A \ |
|
10 |
|
#--load_model saved_model/train_state1_attemp1_10_checkpoint.pth.tar \ |
|
11 |
|
#--epochs 151 \ |
|
12 |
|
#--lr 1e-6 \ |
|
13 |
|
#--task_id train_state1_attemp3 |
|
14 |
|
|
|
15 |
|
# trained 30 |
|
16 |
|
|
3 |
17 |
python main_pacnn.py \ |
python main_pacnn.py \ |
4 |
18 |
--input data/ShanghaiTech/part_A \ |
--input data/ShanghaiTech/part_A \ |
5 |
|
--epochs 20 \ |
|
6 |
|
--task_id train_1 |
|
|
19 |
|
--load_model saved_model/train_state1_attemp3_30_checkpoint.pth.tar \ |
|
20 |
|
--epochs 151 \ |
|
21 |
|
--lr 1e-7 \ |
|
22 |
|
--task_id train_state1_attemp4 |