File main_pacnn.py changed (mode: 100644) (index fc87076..4eee4ac) |
... |
... |
if __name__ == "__main__": |
34 |
34 |
MODEL_SAVE_NAME = args.task_id |
MODEL_SAVE_NAME = args.task_id |
35 |
35 |
MODEL_SAVE_INTERVAL = 5 |
MODEL_SAVE_INTERVAL = 5 |
36 |
36 |
DATA_PATH = args.input |
DATA_PATH = args.input |
37 |
|
DATASET_NAME = "shanghaitech" |
|
38 |
37 |
TOTAL_EPOCH = args.epochs |
TOTAL_EPOCH = args.epochs |
39 |
38 |
PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL |
PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL |
40 |
39 |
PACNN_MUTILPLE_SCALE_LOSS = args.PACNN_MUTILPLE_SCALE_LOSS |
PACNN_MUTILPLE_SCALE_LOSS = args.PACNN_MUTILPLE_SCALE_LOSS |
|
40 |
|
DATASET_NAME = "shanghaitech_pacnn" |
|
41 |
|
if PACNN_PERSPECTIVE_AWARE_MODEL: |
|
42 |
|
DATASET_NAME = "shanghaitech_pacnn_with_perspective" |
41 |
43 |
|
|
42 |
44 |
experiment.set_name(args.task_id) |
experiment.set_name(args.task_id) |
43 |
45 |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
experiment.log_parameter("DATA_PATH", DATA_PATH) |
|
... |
... |
if __name__ == "__main__": |
122 |
124 |
optimizer.zero_grad() |
optimizer.zero_grad() |
123 |
125 |
|
|
124 |
126 |
# load data |
# load data |
125 |
|
d1_label, d2_label, d3_label = label |
|
|
127 |
|
if PACNN_PERSPECTIVE_AWARE_MODEL: |
|
128 |
|
d1_label, d2_label, d3_label, perspective_s, perspective_p = label |
|
129 |
|
perspective_s = perspective_s.to(device).unsqueeze(0) |
|
130 |
|
perspective_p = perspective_p.to(device).unsqueeze(0) |
|
131 |
|
else: |
|
132 |
|
d1_label, d2_label, d3_label = label |
126 |
133 |
d1_label = d1_label.to(device).unsqueeze(0) |
d1_label = d1_label.to(device).unsqueeze(0) |
127 |
134 |
d2_label = d2_label.to(device).unsqueeze(0) |
d2_label = d2_label.to(device).unsqueeze(0) |
128 |
135 |
d3_label = d3_label.to(device).unsqueeze(0) |
d3_label = d3_label.to(device).unsqueeze(0) |
|
... |
... |
if __name__ == "__main__": |
142 |
149 |
|
|
143 |
150 |
if PACNN_PERSPECTIVE_AWARE_MODEL: |
if PACNN_PERSPECTIVE_AWARE_MODEL: |
144 |
151 |
# TODO: loss for perspective map here |
# TODO: loss for perspective map here |
145 |
|
pass |
|
|
152 |
|
loss_p = criterion_mse(p, perspective_p) + criterion_ssim(p, perspective_p) |
|
153 |
|
loss += loss_p |
|
154 |
|
if PACNN_MUTILPLE_SCALE_LOSS: |
|
155 |
|
loss_p_s = criterion_mse(p_s, perspective_s) + criterion_ssim(p_s, perspective_s) |
|
156 |
|
loss += loss_p_s |
|
157 |
|
|
146 |
158 |
# what is this, loss_d count 2 ? |
# what is this, loss_d count 2 ? |
147 |
159 |
## loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
## loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
148 |
160 |
## loss += loss_d |
## loss += loss_d |
File visualize_data_loader.py changed (mode: 100644) (index f4aa432..b6d376e) |
... |
... |
def visualize_ucf_cc_50_pacnn(): |
48 |
48 |
print("count2 ", label[1].numpy()[0].sum()) |
print("count2 ", label[1].numpy()[0].sum()) |
49 |
49 |
print("count3 ", label[2].numpy()[0].sum()) |
print("count3 ", label[2].numpy()[0].sum()) |
50 |
50 |
|
|
|
51 |
|
|
51 |
52 |
def visualize_shanghaitech_pacnn_with_perspective(): |
def visualize_shanghaitech_pacnn_with_perspective(): |
52 |
53 |
HARD_CODE = HardCodeVariable() |
HARD_CODE = HardCodeVariable() |
53 |
54 |
saved_folder = "visualize/test_dataloader" |
saved_folder = "visualize/test_dataloader" |