File explore_model_summary.py changed (mode: 100644) (index 9d5c32d..651839b) |
1 |
|
from models import CompactCNN, AttnCanAdcrowdNetSimpleV3, CompactDilatedCNN, DefDilatedCCNN |
|
|
1 |
|
from models import CompactCNN, AttnCanAdcrowdNetSimpleV3, CompactDilatedCNN, DefDilatedCCNN, CompactCNNV2 |
2 |
2 |
from torchsummary import summary |
from torchsummary import summary |
3 |
3 |
|
|
4 |
4 |
def very_simple_param_count(model): |
def very_simple_param_count(model): |
|
... |
... |
if __name__ == "__main__": |
15 |
15 |
# print(summary(ccnn, (3, 512, 512))) |
# print(summary(ccnn, (3, 512, 512))) |
16 |
16 |
# print("simple count", very_simple_param_count(ccnn)) |
# print("simple count", very_simple_param_count(ccnn)) |
17 |
17 |
# print("===========================================================================") |
# print("===========================================================================") |
18 |
|
# print("dilate ccnn") |
|
19 |
|
# dcnn1 = CompactDilatedCNN() |
|
20 |
|
# print(summary(dcnn1, (3, 512, 512))) |
|
21 |
|
# print("=============================================================================") |
|
22 |
|
print("dilate ccnn") |
|
|
18 |
|
print("ccnn") |
|
19 |
|
dcnn1 = CompactCNN() |
|
20 |
|
print(summary(dcnn1, (3, 512, 512), device="cpu")) |
|
21 |
|
print("=============================================================================") |
|
22 |
|
print("ccnn v2") |
|
23 |
|
dcnn2 = CompactCNNV2() |
|
24 |
|
print(summary(dcnn2, (3, 512, 512), device="cpu")) |
|
25 |
|
print("=============================================================================") |
|
26 |
|
print("DefDilatedCCNN") |
23 |
27 |
dcnn2 = DefDilatedCCNN() |
dcnn2 = DefDilatedCCNN() |
24 |
|
print(summary(dcnn2, (3, 512, 512))) |
|
|
28 |
|
print(summary(dcnn2, (3, 512, 512), device="cpu")) |
25 |
29 |
print("=============================================================================") |
print("=============================================================================") |
26 |
30 |
# print("simple_v3") |
# print("simple_v3") |
27 |
31 |
# simplev3 = AttnCanAdcrowdNetSimpleV3() |
# simplev3 = AttnCanAdcrowdNetSimpleV3() |
File models/__init__.py changed (mode: 100644) (index 15d156f..8f69b2b) |
... |
... |
from .can_adcrowdnet import CanAdcrowdNet |
6 |
6 |
from .attn_can_adcrowdnet import AttnCanAdcrowdNet |
from .attn_can_adcrowdnet import AttnCanAdcrowdNet |
7 |
7 |
from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg |
from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg |
8 |
8 |
from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4 |
from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4 |
9 |
|
from .compact_cnn import CompactCNN, CompactDilatedCNN, DefDilatedCCNN, DilatedCCNNv2 |
|
|
9 |
|
from .compact_cnn import CompactCNN, CompactCNNV2, CompactDilatedCNN, DefDilatedCCNN, DilatedCCNNv2 |
File models/compact_cnn.py changed (mode: 100644) (index 584995a..946ef2f) |
... |
... |
class CompactCNN(nn.Module): |
10 |
10 |
""" |
""" |
11 |
11 |
A REAL-TIME DEEP NETWORK FOR CROWD COUNTING |
A REAL-TIME DEEP NETWORK FOR CROWD COUNTING |
12 |
12 |
https://arxiv.org/pdf/2002.06515.pdf |
https://arxiv.org/pdf/2002.06515.pdf |
|
13 |
|
:deprecated: I think implement incorrectly, please use CompactCNNV2 |
13 |
14 |
""" |
""" |
14 |
15 |
def __init__(self, load_weights=False): |
def __init__(self, load_weights=False): |
15 |
16 |
super(CompactCNN, self).__init__() |
super(CompactCNN, self).__init__() |
|
... |
... |
class CompactCNN(nn.Module): |
44 |
45 |
x = self.output(x) |
x = self.output(x) |
45 |
46 |
return x |
return x |
46 |
47 |
|
|
|
48 |
|
|
|
49 |
|
class CompactCNNV2(nn.Module): |
|
50 |
|
""" |
|
51 |
|
A REAL-TIME DEEP NETWORK FOR CROWD COUNTING |
|
52 |
|
https://arxiv.org/pdf/2002.06515.pdf |
|
53 |
|
""" |
|
54 |
|
def __init__(self, load_weights=False): |
|
55 |
|
super(CompactCNNV2, self).__init__() |
|
56 |
|
self.red_cnn = nn.Conv2d(3, 10, 9, padding=4) |
|
57 |
|
self.green_cnn = nn.Conv2d(3, 14, 7, padding=3) |
|
58 |
|
self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2) |
|
59 |
|
self.c0 = nn.Conv2d(40, 40, 3, padding=1) |
|
60 |
|
|
|
61 |
|
self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2) |
|
62 |
|
|
|
63 |
|
self.c1 = nn.Conv2d(40, 60, 3, padding=1) |
|
64 |
|
self.c2 = nn.Conv2d(60, 40, 3, padding=1) |
|
65 |
|
self.c3 = nn.Conv2d(40, 20, 3, padding=1) |
|
66 |
|
self.c4 = nn.Conv2d(20, 10, 3, padding=1) |
|
67 |
|
self.output = nn.Conv2d(10, 1, 1) |
|
68 |
|
|
|
69 |
|
def forward(self,x): |
|
70 |
|
x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True)) |
|
71 |
|
x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True)) |
|
72 |
|
x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True)) |
|
73 |
|
|
|
74 |
|
x = torch.cat((x_red, x_green, x_blue), 1) |
|
75 |
|
x = F.relu(self.c0(x), inplace=True) |
|
76 |
|
|
|
77 |
|
x = F.relu(self.c1(x), inplace=True) |
|
78 |
|
|
|
79 |
|
x = F.relu(self.c2(x), inplace=True) |
|
80 |
|
x = self.max_pooling(x) |
|
81 |
|
|
|
82 |
|
x = F.relu(self.c3(x), inplace=True) |
|
83 |
|
x = self.max_pooling(x) |
|
84 |
|
|
|
85 |
|
x = F.relu(self.c4(x), inplace=True) |
|
86 |
|
|
|
87 |
|
x = self.output(x) |
|
88 |
|
return x |
|
89 |
|
|
47 |
90 |
class CompactDilatedCNN(nn.Module): |
class CompactDilatedCNN(nn.Module): |
48 |
91 |
""" |
""" |
49 |
92 |
|
|