File models/can_adcrowdnet.py added (mode: 100644) (index 0000000..e022339) |
|
1 |
|
import torch.nn as nn |
|
2 |
|
import torch |
|
3 |
|
from torchvision import models |
|
4 |
|
import collections |
|
5 |
|
import torch.nn.functional as F |
|
6 |
|
import os |
|
7 |
|
from .deform_conv_v2 import DeformConv2d |
|
8 |
|
# from dcn.modules.deform_conv import DeformConvPack, ModulatedDeformConvPack |
|
9 |
|
|
|
10 |
|
|
|
11 |
|
class CanAdcrowdNet(nn.Module): |
|
12 |
|
def __init__(self, load_weights=False): |
|
13 |
|
super(CanAdcrowdNet, self).__init__() |
|
14 |
|
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] |
|
15 |
|
self.frontend = make_layers(self.frontend_feat) |
|
16 |
|
self.concat_filter_layer = nn.Conv2d(1024, 512, kernel_size=3, padding=2, dilation=2) |
|
17 |
|
|
|
18 |
|
self.deform_conv_1_3 = DeformConv2d(512, 256, kernel_size=3, stride=1, padding=1) |
|
19 |
|
self.deform_conv_1_5 = DeformConv2d(512, 256, kernel_size=5, stride=1, padding=2) |
|
20 |
|
self.deform_conv_1_7 = DeformConv2d(512, 256, kernel_size=7, stride=1, padding=3) |
|
21 |
|
self.concat_filter_layer_1 = nn.Conv2d(256 * 3, 256, kernel_size=3, padding=2, dilation=2) |
|
22 |
|
|
|
23 |
|
self.deform_conv_2_3 = DeformConv2d(256, 128, kernel_size=3, stride=1, padding=1) |
|
24 |
|
self.deform_conv_2_5 = DeformConv2d(256, 128, kernel_size=5, stride=1, padding=2) |
|
25 |
|
self.deform_conv_2_7 = DeformConv2d(256, 128, kernel_size=7, stride=1, padding=3) |
|
26 |
|
self.concat_filter_layer_2 = nn.Conv2d(128 * 3, 128, kernel_size=3, padding=2, dilation=2) |
|
27 |
|
|
|
28 |
|
self.deform_conv_3_3 = DeformConv2d(128, 64, kernel_size=3, stride=1, padding=1) |
|
29 |
|
self.deform_conv_3_5 = DeformConv2d(128, 64, kernel_size=5, stride=1, padding=2) |
|
30 |
|
self.deform_conv_3_7 = DeformConv2d(128, 64, kernel_size=7, stride=1, padding=3) |
|
31 |
|
self.concat_filter_layer_3 = nn.Conv2d(64 * 3, 64, kernel_size=3, padding=2, dilation=2) |
|
32 |
|
|
|
33 |
|
self.output_layer = nn.Conv2d(64, 1, kernel_size=1) |
|
34 |
|
self.conv1_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
35 |
|
self.conv1_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
36 |
|
self.conv2_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
37 |
|
self.conv2_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
38 |
|
self.conv3_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
39 |
|
self.conv3_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
40 |
|
self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
41 |
|
self.conv6_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False) |
|
42 |
|
if not load_weights: |
|
43 |
|
mod = models.vgg16(pretrained=True) |
|
44 |
|
self._initialize_weights() |
|
45 |
|
fsd = collections.OrderedDict() |
|
46 |
|
for i in range(len(self.frontend.state_dict().items())): |
|
47 |
|
temp_key = list(self.frontend.state_dict().items())[i][0] |
|
48 |
|
fsd[temp_key] = list(mod.state_dict().items())[i][1] |
|
49 |
|
self.frontend.load_state_dict(fsd) |
|
50 |
|
|
|
51 |
|
def forward(self, x): |
|
52 |
|
fv = self.frontend(x) |
|
53 |
|
# S=1 |
|
54 |
|
ave1 = nn.functional.adaptive_avg_pool2d(fv, (1, 1)) |
|
55 |
|
ave1 = self.conv1_1(ave1) |
|
56 |
|
s1 = nn.functional.upsample(ave1, size=(fv.shape[2], fv.shape[3]), mode='bilinear') |
|
57 |
|
c1 = s1 - fv |
|
58 |
|
w1 = self.conv1_2(c1) |
|
59 |
|
w1 = nn.functional.sigmoid(w1) |
|
60 |
|
# S=2 |
|
61 |
|
ave2 = nn.functional.adaptive_avg_pool2d(fv, (2, 2)) |
|
62 |
|
ave2 = self.conv2_1(ave2) |
|
63 |
|
s2 = nn.functional.upsample(ave2, size=(fv.shape[2], fv.shape[3]), mode='bilinear') |
|
64 |
|
c2 = s2 - fv |
|
65 |
|
w2 = self.conv2_2(c2) |
|
66 |
|
w2 = nn.functional.sigmoid(w2) |
|
67 |
|
# S=3 |
|
68 |
|
ave3 = nn.functional.adaptive_avg_pool2d(fv, (3, 3)) |
|
69 |
|
ave3 = self.conv3_1(ave3) |
|
70 |
|
s3 = nn.functional.upsample(ave3, size=(fv.shape[2], fv.shape[3]), mode='bilinear') |
|
71 |
|
c3 = s3 - fv |
|
72 |
|
w3 = self.conv3_2(c3) |
|
73 |
|
w3 = nn.functional.sigmoid(w3) |
|
74 |
|
# S=6 |
|
75 |
|
ave6 = nn.functional.adaptive_avg_pool2d(fv, (6, 6)) |
|
76 |
|
ave6 = self.conv6_1(ave6) |
|
77 |
|
s6 = nn.functional.upsample(ave6, size=(fv.shape[2], fv.shape[3]), mode='bilinear') |
|
78 |
|
c6 = s6 - fv |
|
79 |
|
w6 = self.conv6_2(c6) |
|
80 |
|
w6 = nn.functional.sigmoid(w6) |
|
81 |
|
|
|
82 |
|
fi = (w1 * s1 + w2 * s2 + w3 * s3 + w6 * s6) / (w1 + w2 + w3 + w6 + 0.000000000001) |
|
83 |
|
x = torch.cat((fv, fi), 1) |
|
84 |
|
x = F.relu(self.concat_filter_layer(x)) |
|
85 |
|
|
|
86 |
|
x3 = self.deform_conv_1_3(x) |
|
87 |
|
x5 = self.deform_conv_1_5(x) |
|
88 |
|
x7 = self.deform_conv_1_7(x) |
|
89 |
|
x = torch.cat((x3, x5, x7), 1) |
|
90 |
|
x = F.relu(self.concat_filter_layer_1(x)) |
|
91 |
|
|
|
92 |
|
x3 = self.deform_conv_2_3(x) |
|
93 |
|
x5 = self.deform_conv_2_5(x) |
|
94 |
|
x7 = self.deform_conv_2_7(x) |
|
95 |
|
x = torch.cat((x3, x5, x7), 1) |
|
96 |
|
x = F.relu(self.concat_filter_layer_2(x)) |
|
97 |
|
|
|
98 |
|
x3 = self.deform_conv_3_3(x) |
|
99 |
|
x5 = self.deform_conv_3_5(x) |
|
100 |
|
x7 = self.deform_conv_3_7(x) |
|
101 |
|
x = torch.cat((x3, x5, x7), 1) |
|
102 |
|
x = F.relu(self.concat_filter_layer_3(x)) |
|
103 |
|
|
|
104 |
|
x = self.output_layer(x) |
|
105 |
|
x = nn.functional.upsample(x, scale_factor=8, mode='bilinear') / 64.0 |
|
106 |
|
return x |
|
107 |
|
|
|
108 |
|
def _initialize_weights(self): |
|
109 |
|
for m in self.modules(): |
|
110 |
|
if isinstance(m, nn.Conv2d): |
|
111 |
|
nn.init.normal_(m.weight, std=0.01) |
|
112 |
|
if m.bias is not None: |
|
113 |
|
nn.init.constant_(m.bias, 0) |
|
114 |
|
elif isinstance(m, nn.BatchNorm2d): |
|
115 |
|
nn.init.constant_(m.weight, 1) |
|
116 |
|
nn.init.constant_(m.bias, 0) |
|
117 |
|
|
|
118 |
|
|
|
119 |
|
def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False): |
|
120 |
|
if dilation: |
|
121 |
|
d_rate = 2 |
|
122 |
|
else: |
|
123 |
|
d_rate = 1 |
|
124 |
|
layers = [] |
|
125 |
|
for v in cfg: |
|
126 |
|
if v == 'M': |
|
127 |
|
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] |
|
128 |
|
else: |
|
129 |
|
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate) |
|
130 |
|
if batch_norm: |
|
131 |
|
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] |
|
132 |
|
else: |
|
133 |
|
layers += [conv2d, nn.ReLU(inplace=True)] |
|
134 |
|
in_channels = v |
|
135 |
|
return nn.Sequential(*layers) |
File models/deform_conv_v2.py added (mode: 100644) (index 0000000..02e98e8) |
|
1 |
|
import torch |
|
2 |
|
from torch import nn |
|
3 |
|
|
|
4 |
|
|
|
5 |
|
class DeformConv2d(nn.Module): |
|
6 |
|
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): |
|
7 |
|
""" |
|
8 |
|
Args: |
|
9 |
|
modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). |
|
10 |
|
""" |
|
11 |
|
super(DeformConv2d, self).__init__() |
|
12 |
|
self.kernel_size = kernel_size |
|
13 |
|
self.padding = padding |
|
14 |
|
self.stride = stride |
|
15 |
|
self.zero_padding = nn.ZeroPad2d(padding) |
|
16 |
|
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) |
|
17 |
|
|
|
18 |
|
self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) |
|
19 |
|
nn.init.constant_(self.p_conv.weight, 0) |
|
20 |
|
self.p_conv.register_backward_hook(self._set_lr) |
|
21 |
|
|
|
22 |
|
self.modulation = modulation |
|
23 |
|
if modulation: |
|
24 |
|
self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) |
|
25 |
|
nn.init.constant_(self.m_conv.weight, 0) |
|
26 |
|
self.m_conv.register_backward_hook(self._set_lr) |
|
27 |
|
|
|
28 |
|
@staticmethod |
|
29 |
|
def _set_lr(module, grad_input, grad_output): |
|
30 |
|
grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) |
|
31 |
|
grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) |
|
32 |
|
|
|
33 |
|
def forward(self, x): |
|
34 |
|
offset = self.p_conv(x) |
|
35 |
|
if self.modulation: |
|
36 |
|
m = torch.sigmoid(self.m_conv(x)) |
|
37 |
|
|
|
38 |
|
dtype = offset.data.type() |
|
39 |
|
ks = self.kernel_size |
|
40 |
|
N = offset.size(1) // 2 |
|
41 |
|
|
|
42 |
|
if self.padding: |
|
43 |
|
x = self.zero_padding(x) |
|
44 |
|
|
|
45 |
|
# (b, 2N, h, w) |
|
46 |
|
p = self._get_p(offset, dtype) |
|
47 |
|
|
|
48 |
|
# (b, h, w, 2N) |
|
49 |
|
p = p.contiguous().permute(0, 2, 3, 1) |
|
50 |
|
q_lt = p.detach().floor() |
|
51 |
|
q_rb = q_lt + 1 |
|
52 |
|
|
|
53 |
|
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() |
|
54 |
|
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() |
|
55 |
|
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) |
|
56 |
|
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) |
|
57 |
|
|
|
58 |
|
# clip p |
|
59 |
|
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) |
|
60 |
|
|
|
61 |
|
# bilinear kernel (b, h, w, N) |
|
62 |
|
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) |
|
63 |
|
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) |
|
64 |
|
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) |
|
65 |
|
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) |
|
66 |
|
|
|
67 |
|
# (b, c, h, w, N) |
|
68 |
|
x_q_lt = self._get_x_q(x, q_lt, N) |
|
69 |
|
x_q_rb = self._get_x_q(x, q_rb, N) |
|
70 |
|
x_q_lb = self._get_x_q(x, q_lb, N) |
|
71 |
|
x_q_rt = self._get_x_q(x, q_rt, N) |
|
72 |
|
|
|
73 |
|
# (b, c, h, w, N) |
|
74 |
|
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ |
|
75 |
|
g_rb.unsqueeze(dim=1) * x_q_rb + \ |
|
76 |
|
g_lb.unsqueeze(dim=1) * x_q_lb + \ |
|
77 |
|
g_rt.unsqueeze(dim=1) * x_q_rt |
|
78 |
|
|
|
79 |
|
# modulation |
|
80 |
|
if self.modulation: |
|
81 |
|
m = m.contiguous().permute(0, 2, 3, 1) |
|
82 |
|
m = m.unsqueeze(dim=1) |
|
83 |
|
m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) |
|
84 |
|
x_offset *= m |
|
85 |
|
|
|
86 |
|
x_offset = self._reshape_x_offset(x_offset, ks) |
|
87 |
|
out = self.conv(x_offset) |
|
88 |
|
|
|
89 |
|
return out |
|
90 |
|
|
|
91 |
|
def _get_p_n(self, N, dtype): |
|
92 |
|
p_n_x, p_n_y = torch.meshgrid( |
|
93 |
|
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), |
|
94 |
|
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) |
|
95 |
|
# (2N, 1) |
|
96 |
|
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) |
|
97 |
|
p_n = p_n.view(1, 2*N, 1, 1).type(dtype) |
|
98 |
|
|
|
99 |
|
return p_n |
|
100 |
|
|
|
101 |
|
def _get_p_0(self, h, w, N, dtype): |
|
102 |
|
p_0_x, p_0_y = torch.meshgrid( |
|
103 |
|
torch.arange(1, h*self.stride+1, self.stride), |
|
104 |
|
torch.arange(1, w*self.stride+1, self.stride)) |
|
105 |
|
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) |
|
106 |
|
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) |
|
107 |
|
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) |
|
108 |
|
|
|
109 |
|
return p_0 |
|
110 |
|
|
|
111 |
|
def _get_p(self, offset, dtype): |
|
112 |
|
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) |
|
113 |
|
|
|
114 |
|
# (1, 2N, 1, 1) |
|
115 |
|
p_n = self._get_p_n(N, dtype) |
|
116 |
|
# (1, 2N, h, w) |
|
117 |
|
p_0 = self._get_p_0(h, w, N, dtype) |
|
118 |
|
p = p_0 + p_n + offset |
|
119 |
|
return p |
|
120 |
|
|
|
121 |
|
def _get_x_q(self, x, q, N): |
|
122 |
|
b, h, w, _ = q.size() |
|
123 |
|
padded_w = x.size(3) |
|
124 |
|
c = x.size(1) |
|
125 |
|
# (b, c, h*w) |
|
126 |
|
x = x.contiguous().view(b, c, -1) |
|
127 |
|
|
|
128 |
|
# (b, h, w, N) |
|
129 |
|
index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y |
|
130 |
|
# (b, c, h*w*N) |
|
131 |
|
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) |
|
132 |
|
|
|
133 |
|
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) |
|
134 |
|
|
|
135 |
|
return x_offset |
|
136 |
|
|
|
137 |
|
@staticmethod |
|
138 |
|
def _reshape_x_offset(x_offset, ks): |
|
139 |
|
b, c, h, w, N = x_offset.size() |
|
140 |
|
x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) |
|
141 |
|
x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) |
|
142 |
|
|
|
143 |
|
return x_offset |