File data_flow.py changed (mode: 100644) (index 7e0162d..924b307) |
... |
... |
def load_data_shanghaitech_pacnn(img_path, train=True): |
98 |
98 |
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), |
99 |
99 |
interpolation=cv2.INTER_CUBIC) * 64 |
interpolation=cv2.INTER_CUBIC) * 64 |
100 |
100 |
target2 = cv2.resize(target, (int(target.shape[1] / 16), int(target.shape[0] / 16)), |
target2 = cv2.resize(target, (int(target.shape[1] / 16), int(target.shape[0] / 16)), |
101 |
|
interpolation=cv2.INTER_CUBIC) * 64 #*2 |
|
|
101 |
|
interpolation=cv2.INTER_CUBIC) * 64 *2 |
102 |
102 |
target3 = cv2.resize(target, (int(target.shape[1] / 32), int(target.shape[0] / 32)), |
target3 = cv2.resize(target, (int(target.shape[1] / 32), int(target.shape[0] / 32)), |
103 |
|
interpolation=cv2.INTER_CUBIC) * 64 #*4 |
|
|
103 |
|
interpolation=cv2.INTER_CUBIC) * 64 *4 |
104 |
104 |
|
|
105 |
105 |
return img, (target1, target2, target3) |
return img, (target1, target2, target3) |
106 |
106 |
|
|
File main_pacnn.py changed (mode: 100644) (index 9902b26..449320c) |
... |
... |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCo |
6 |
6 |
import torch |
import torch |
7 |
7 |
from torch import nn |
from torch import nn |
8 |
8 |
import torch.nn.functional as F |
import torch.nn.functional as F |
9 |
|
from models import CSRNet,PACNN |
|
|
9 |
|
from models import CSRNet, PACNN, PACNNWithPerspectiveMap |
10 |
10 |
import os |
import os |
11 |
11 |
import cv2 |
import cv2 |
12 |
12 |
from torchvision import datasets, transforms |
from torchvision import datasets, transforms |
|
... |
... |
if __name__ == "__main__": |
26 |
26 |
print(args) |
print(args) |
27 |
27 |
DATA_PATH = args.input |
DATA_PATH = args.input |
28 |
28 |
DATASET_NAME = "shanghaitech" |
DATASET_NAME = "shanghaitech" |
|
29 |
|
PACNN_PERSPECTIVE_AWARE_MODEL = False |
29 |
30 |
|
|
30 |
31 |
# create list |
# create list |
31 |
32 |
if DATASET_NAME is "shanghaitech": |
if DATASET_NAME is "shanghaitech": |
|
... |
... |
if __name__ == "__main__": |
64 |
65 |
batch_size=1, num_workers=4) |
batch_size=1, num_workers=4) |
65 |
66 |
|
|
66 |
67 |
# create model |
# create model |
67 |
|
net = PACNN().to(device) |
|
|
68 |
|
net = PACNNWithPerspectiveMap(perspective_aware_mode=PACNN_PERSPECTIVE_AWARE_MODEL).to(device) |
68 |
69 |
criterion_mse = nn.MSELoss(size_average=False).to(device) |
criterion_mse = nn.MSELoss(size_average=False).to(device) |
69 |
|
criterion_ssim = pytorch_ssim.SSIM(window_size=11).to(device) |
|
|
70 |
|
criterion_ssim = pytorch_ssim.SSIM(window_size=5).to(device) |
70 |
71 |
|
|
71 |
72 |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
optimizer = torch.optim.SGD(net.parameters(), args.lr, |
72 |
73 |
momentum=args.momentum, |
momentum=args.momentum, |
|
... |
... |
if __name__ == "__main__": |
84 |
85 |
|
|
85 |
86 |
# load data |
# load data |
86 |
87 |
d1_label, d2_label, d3_label = label |
d1_label, d2_label, d3_label = label |
87 |
|
d1_label = d1_label.to(device) |
|
88 |
|
d2_label = d2_label.to(device) |
|
89 |
|
d3_label = d3_label.to(device) |
|
|
88 |
|
d1_label = d1_label.to(device).unsqueeze(0) |
|
89 |
|
d2_label = d2_label.to(device).unsqueeze(0) |
|
90 |
|
d3_label = d3_label.to(device).unsqueeze(0) |
90 |
91 |
|
|
91 |
92 |
# forward pass |
# forward pass |
92 |
93 |
|
|
93 |
|
d1, d2, d3 = net(train_img.to(device)) |
|
94 |
|
loss_1 = criterion_mse(d1, d1_label) + criterion_ssim(d1.unsqueeze(0), d1_label.unsqueeze(0)) |
|
95 |
|
loss_2 = criterion_mse(d2, d2_label) + criterion_ssim(d2.unsqueeze(0), d2_label.unsqueeze(0)) |
|
96 |
|
loss_3 = criterion_mse(d3, d3_label) + criterion_ssim(d3.unsqueeze(0), d3_label.unsqueeze(0)) |
|
97 |
|
|
|
|
94 |
|
d1, d2, d3, p_s, p, d = net(train_img.to(device)) |
|
95 |
|
loss_1 = criterion_mse(d1, d1_label) + criterion_ssim(d1, d1_label) |
|
96 |
|
loss_2 = criterion_mse(d2, d2_label) + criterion_ssim(d2, d2_label) |
|
97 |
|
loss_3 = criterion_mse(d3, d3_label) + criterion_ssim(d3, d3_label) |
98 |
98 |
loss = loss_1 + loss_2 + loss_3 |
loss = loss_1 + loss_2 + loss_3 |
|
99 |
|
if PACNN_PERSPECTIVE_AWARE_MODEL: |
|
100 |
|
# TODO: loss for perspective map here |
|
101 |
|
pass |
|
102 |
|
loss_d = criterion_mse(d, d1_label) + criterion_ssim(d, d1_label) |
|
103 |
|
loss += loss_d |
99 |
104 |
loss.backward() |
loss.backward() |
100 |
105 |
optimizer.step() |
optimizer.step() |
101 |
106 |
loss_sum += loss.item() |
loss_sum += loss.item() |
File models/pacnn.py changed (mode: 100644) (index 82897d0..9e7203e) |
... |
... |
class PACNNWithPerspectiveMap(nn.Module): |
75 |
75 |
de23 = pespective_w_s * de2 + (1 - pespective_w_s)*(de2 + self.up23(de3)) |
de23 = pespective_w_s * de2 + (1 - pespective_w_s)*(de2 + self.up23(de3)) |
76 |
76 |
de = pespective_w * de1 + (1 - pespective_w)*(de1 + self.up12(de23)) |
de = pespective_w * de1 + (1 - pespective_w)*(de1 + self.up12(de23)) |
77 |
77 |
else: |
else: |
|
78 |
|
#try: |
|
79 |
|
pespective_w_s = None |
|
80 |
|
pespective_w = None |
78 |
81 |
de23 = (de2 + self.up23(de3))/2 |
de23 = (de2 + self.up23(de3))/2 |
79 |
82 |
de = (de1 + self.up12(de23))/2 |
de = (de1 + self.up12(de23))/2 |
80 |
|
return de |
|
|
83 |
|
# except Exception as e: |
|
84 |
|
# print("EXECEPTION ", e) |
|
85 |
|
# print(x.size()) |
|
86 |
|
# print(de2.size(), de3.size()) |
|
87 |
|
return de1, de2, de3, pespective_w_s, pespective_w, de |
81 |
88 |
|
|
82 |
89 |
def count_param(net): |
def count_param(net): |
83 |
90 |
pytorch_total_params = sum(p.numel() for p in net.parameters()) |
pytorch_total_params = sum(p.numel() for p in net.parameters()) |
File models/test_PACNNWithPerspectiveMap.py changed (mode: 100644) (index cd3eab6..af8340f) |
... |
... |
import torch |
4 |
4 |
|
|
5 |
5 |
class TestPACNNWithPerspectiveMap(TestCase): |
class TestPACNNWithPerspectiveMap(TestCase): |
6 |
6 |
|
|
|
7 |
|
def test_debug_avg_schema_pacnn(self): |
|
8 |
|
net = PACNNWithPerspectiveMap() |
|
9 |
|
image = torch.rand(1, 3, 330, 512) |
|
10 |
|
_, _, _, _, _, density_map = net(image) |
|
11 |
|
print(density_map.size()) |
|
12 |
|
|
|
13 |
|
|
7 |
14 |
def test_avg_schema_pacnn(self): |
def test_avg_schema_pacnn(self): |
8 |
15 |
net = PACNNWithPerspectiveMap() |
net = PACNNWithPerspectiveMap() |
9 |
16 |
# image |
# image |
10 |
17 |
# batch size, channel, h, w |
# batch size, channel, h, w |
11 |
|
image = torch.rand(1, 3, 224, 224) |
|
12 |
|
density_map = net(image) |
|
|
18 |
|
image = torch.rand(1, 3, 330, 512) |
|
19 |
|
_, _, _, _, _, density_map = net(image) |
|
20 |
|
print(density_map.size()) |
|
21 |
|
image2 = torch.rand(1, 3, 225, 225) |
|
22 |
|
_, _, _, _, _, density_map2 = net(image2) |
|
23 |
|
print(density_map2.size()) |
|
24 |
|
image3 = torch.rand(1, 3, 226, 226) |
|
25 |
|
_, _, _, _, _, density_map3 = net(image3) |
|
26 |
|
print(density_map3.size()) |
|
27 |
|
|
|
28 |
|
image = torch.rand(1, 3, 227, 227) |
|
29 |
|
_, _, _, _, _, density_map = net(image) |
13 |
30 |
print(density_map.size()) |
print(density_map.size()) |
|
31 |
|
image2 = torch.rand(1, 3, 228, 228) |
|
32 |
|
_, _, _, _, _, density_map2 = net(image2) |
|
33 |
|
print(density_map2.size()) |
|
34 |
|
image3 = torch.rand(1, 3, 229, 229) |
|
35 |
|
_, _, _, _, _, density_map3 = net(image3) |
|
36 |
|
print(density_map3.size()) |
14 |
37 |
|
|
15 |
38 |
def test_perspective_aware_schema_pacnn(self): |
def test_perspective_aware_schema_pacnn(self): |
16 |
39 |
net = PACNNWithPerspectiveMap(perspective_aware_mode=True) |
net = PACNNWithPerspectiveMap(perspective_aware_mode=True) |
17 |
40 |
# image |
# image |
18 |
41 |
# batch size, channel, h, w |
# batch size, channel, h, w |
19 |
42 |
image = torch.rand(1, 3, 224, 224) |
image = torch.rand(1, 3, 224, 224) |
20 |
|
density_map = net(image) |
|
|
43 |
|
_, _, _, _, _, density_map = net(image) |
21 |
44 |
print(density_map.size()) |
print(density_map.size()) |