File models/meow_experiment/ccnn_head.py added (mode: 100644) (index 0000000..1d2e772) |
|
1 |
|
import torch.nn as nn |
|
2 |
|
import torch |
|
3 |
|
import collections |
|
4 |
|
import torch.nn.functional as F |
|
5 |
|
|
|
6 |
|
|
|
7 |
|
class H1(nn.Module): |
|
8 |
|
""" |
|
9 |
|
A REAL-TIME DEEP NETWORK FOR CROWD COUNTING |
|
10 |
|
https://arxiv.org/pdf/2002.06515.pdf |
|
11 |
|
the improve version |
|
12 |
|
|
|
13 |
|
we change 5x5 7x7 9x9 with 3x3 |
|
14 |
|
Keep the tail |
|
15 |
|
""" |
|
16 |
|
def __init__(self, load_weights=False): |
|
17 |
|
super(H1, self).__init__() |
|
18 |
|
self.model_note = "We replace 5x5 7x7 9x9 with 3x3, no batchnorm yet, keep tail, no dilated" |
|
19 |
|
# self.red_cnn = nn.Conv2d(3, 10, 9, padding=4) |
|
20 |
|
# self.green_cnn = nn.Conv2d(3, 14, 7, padding=3) |
|
21 |
|
# self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2) |
|
22 |
|
|
|
23 |
|
# ideal from crowd counting using DMCNN |
|
24 |
|
self.front_cnn_1 = nn.Conv2d(3, 10, 3, padding=1) |
|
25 |
|
self.front_cnn_2 = nn.Conv2d(10, 20, 3, padding=1) |
|
26 |
|
self.front_cnn_3 = nn.Conv2d(20, 20, 3, padding=1) |
|
27 |
|
self.front_cnn_4 = nn.Conv2d(20, 20, 3, padding=1) |
|
28 |
|
|
|
29 |
|
self.c0 = nn.Conv2d(60, 40, 3, padding=1) |
|
30 |
|
self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2) |
|
31 |
|
|
|
32 |
|
self.c1 = nn.Conv2d(40, 60, 3, padding=1) |
|
33 |
|
|
|
34 |
|
# ideal from CSRNet |
|
35 |
|
self.c2 = nn.Conv2d(60, 40, 3, padding=1) |
|
36 |
|
self.c3 = nn.Conv2d(40, 20, 3, padding=1) |
|
37 |
|
self.c4 = nn.Conv2d(20, 10, 3, padding=1) |
|
38 |
|
self.output = nn.Conv2d(10, 1, 1) |
|
39 |
|
|
|
40 |
|
def forward(self,x): |
|
41 |
|
#x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True)) |
|
42 |
|
#x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True)) |
|
43 |
|
#x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True)) |
|
44 |
|
|
|
45 |
|
x_red = F.relu(self.front_cnn_1(x), inplace=True) |
|
46 |
|
x_red = F.relu(self.front_cnn_2(x_red), inplace=True) |
|
47 |
|
x_red = F.relu(self.front_cnn_3(x_red), inplace=True) |
|
48 |
|
x_red = F.relu(self.front_cnn_4(x_red), inplace=True) |
|
49 |
|
x_red = self.max_pooling(x_red) |
|
50 |
|
|
|
51 |
|
x_green = F.relu(self.front_cnn_1(x), inplace=True) |
|
52 |
|
x_green = F.relu(self.front_cnn_2(x_green), inplace=True) |
|
53 |
|
x_green = F.relu(self.front_cnn_3(x_green), inplace=True) |
|
54 |
|
x_green = self.max_pooling(x_green) |
|
55 |
|
|
|
56 |
|
x_blue = F.relu(self.front_cnn_1(x), inplace=True) |
|
57 |
|
x_blue = F.relu(self.front_cnn_2(x_blue), inplace=True) |
|
58 |
|
x_blue = self.max_pooling(x_blue) |
|
59 |
|
|
|
60 |
|
# x = self.max_pooling(x) |
|
61 |
|
x = torch.cat((x_red, x_green, x_blue), 1) |
|
62 |
|
x = F.relu(self.c0(x), inplace=True) |
|
63 |
|
|
|
64 |
|
x = F.relu(self.c1(x), inplace=True) |
|
65 |
|
|
|
66 |
|
x = F.relu(self.c2(x), inplace=True) |
|
67 |
|
x = self.max_pooling(x) |
|
68 |
|
|
|
69 |
|
x = F.relu(self.c3(x), inplace=True) |
|
70 |
|
x = self.max_pooling(x) |
|
71 |
|
|
|
72 |
|
x = F.relu(self.c4(x), inplace=True) |
|
73 |
|
|
|
74 |
|
x = self.output(x) |
|
75 |
|
return x |