File main_pacnn.py changed (mode: 100644) (index fc19585..fc87076) |
... |
... |
if __name__ == "__main__": |
37 |
37 |
DATASET_NAME = "shanghaitech" |
DATASET_NAME = "shanghaitech" |
38 |
38 |
TOTAL_EPOCH = args.epochs |
TOTAL_EPOCH = args.epochs |
39 |
39 |
PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL |
PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL |
|
40 |
|
PACNN_MUTILPLE_SCALE_LOSS = args.PACNN_MUTILPLE_SCALE_LOSS |
40 |
41 |
|
|
41 |
42 |
experiment.set_name(args.task_id) |
experiment.set_name(args.task_id) |
42 |
43 |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
43 |
44 |
experiment.log_parameter("PACNN_PERSPECTIVE_AWARE_MODEL", PACNN_PERSPECTIVE_AWARE_MODEL) |
experiment.log_parameter("PACNN_PERSPECTIVE_AWARE_MODEL", PACNN_PERSPECTIVE_AWARE_MODEL) |
|
45 |
|
experiment.log_parameter("PACNN_MUTILPLE_SCALE_LOSS", PACNN_MUTILPLE_SCALE_LOSS) |
44 |
46 |
experiment.log_parameter("train", "train without p") |
experiment.log_parameter("train", "train without p") |
45 |
47 |
experiment.log_parameter("momentum", args.momentum) |
experiment.log_parameter("momentum", args.momentum) |
46 |
48 |
experiment.log_parameter("lr", args.lr) |
experiment.log_parameter("lr", args.lr) |
|
... |
... |
if __name__ == "__main__": |
128 |
130 |
# forward pass |
# forward pass |
129 |
131 |
|
|
130 |
132 |
d1, d2, d3, p_s, p, d = net(train_img.to(device)) |
d1, d2, d3, p_s, p, d = net(train_img.to(device)) |
131 |
|
loss_1 = criterion_mse(d1, d1_label) + criterion_ssim(d1, d1_label) |
|
132 |
|
loss_2 = criterion_mse(d2, d2_label) + criterion_ssim(d2, d2_label) |
|
133 |
|
loss_3 = criterion_mse(d3, d3_label) + criterion_ssim(d3, d3_label) |
|
134 |
133 |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
135 |
|
loss = loss_d + loss_1 + loss_2 + loss_3 |
|
|
134 |
|
loss = loss_d |
|
135 |
|
|
|
136 |
|
if PACNN_MUTILPLE_SCALE_LOSS: |
|
137 |
|
loss_1 = criterion_mse(d1, d1_label) + criterion_ssim(d1, d1_label) |
|
138 |
|
loss_2 = criterion_mse(d2, d2_label) + criterion_ssim(d2, d2_label) |
|
139 |
|
loss_3 = criterion_mse(d3, d3_label) + criterion_ssim(d3, d3_label) |
|
140 |
|
loss += loss_1 + loss_2 + loss_3 |
|
141 |
|
|
136 |
142 |
|
|
137 |
143 |
if PACNN_PERSPECTIVE_AWARE_MODEL: |
if PACNN_PERSPECTIVE_AWARE_MODEL: |
138 |
144 |
# TODO: loss for perspective map here |
# TODO: loss for perspective map here |
139 |
145 |
pass |
pass |
140 |
|
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
|
141 |
|
loss += loss_d |
|
|
146 |
|
# what is this, loss_d count 2 ? |
|
147 |
|
## loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
|
148 |
|
## loss += loss_d |
142 |
149 |
|
|
143 |
150 |
# with amp.scale_loss(loss, optimizer) as scaled_loss: |
# with amp.scale_loss(loss, optimizer) as scaled_loss: |
144 |
151 |
# scaled_loss.backward() |
# scaled_loss.backward() |
File train_script/train_pacnn_shanghaitechA.sh changed (mode: 100644) (index 72ae080..9f2955d) |
22 |
22 |
#--task_id train_state1_attemp4 |
#--task_id train_state1_attemp4 |
23 |
23 |
|
|
24 |
24 |
|
|
25 |
|
python main_pacnn.py \ |
|
26 |
|
--input data/ShanghaiTech/part_A \ |
|
27 |
|
--load_model saved_model/train_state1_attemp4_35_checkpoint.pth.tar \ |
|
28 |
|
--epochs 151 \ |
|
29 |
|
--lr 1e-8 \ |
|
30 |
|
--task_id train_state1_attemp5 |
|
|
25 |
|
#python main_pacnn.py \ |
|
26 |
|
#--input data/ShanghaiTech/part_A \ |
|
27 |
|
#--load_model saved_model/train_state1_attemp4_35_checkpoint.pth.tar \ |
|
28 |
|
#--epochs 151 \ |
|
29 |
|
#--lr 1e-8 \ |
|
30 |
|
#--task_id train_state1_attemp5 |
|
31 |
|
|
|
32 |
|
################3 |
|
33 |
|
|
|
34 |
|
## TODO: train this |
|
35 |
|
#python main_pacnn.py \ |
|
36 |
|
#--input data/ShanghaiTech/part_A \ |
|
37 |
|
#--load_model saved_model/train_state1_attemp5_40_checkpoint.pth.tar \ |
|
38 |
|
#--epochs 300 \ |
|
39 |
|
#--lr 1e-8 \ |
|
40 |
|
#--task_id train_state1_attemp6 |
|
41 |
|
|
|
42 |
|
|
|
43 |
|
#python main_pacnn.py \ |
|
44 |
|
#--input data/ShanghaiTech/part_A \ |
|
45 |
|
#--load_model saved_model/train_state1_attemp6_120_checkpoint.pth.tar \ |
|
46 |
|
#--epochs 300 \ |
|
47 |
|
#--lr 1e-9 \ |
|
48 |
|
#--task_id train_state1_attemp7 |