List of commits:
Subject Hash Author Date (UTC)
new ccnn 44a669c1f918be9d74313f29a5dbbc876c29f2fc Thai Thien 2020-03-09 17:16:49
fix script aa331331b12e5b454d372a550524b30a4bebe706 Thai Thien 2020-03-07 18:32:06
try reproduct ccnn with keepfull and lr 1e-5 814c520cbd1bb2d7fd50d2a8d3579d43da79fe60 Thai Thien 2020-03-07 18:30:42
my simple v4 with addition deform cnn at the end 5392aaf6c14fdd910f52096dbb921bed7470c4f7 Thai Thien 2020-03-07 18:15:22
fix the scheduler 77e6737a040f5aa5745b8a8830f5bec12322b10f Thai Thien 2020-03-07 17:46:02
t4 lr 1e-5 acd41ed30c95f63e01a05a6d9929410637852d9e Thai Thien 2020-03-06 19:41:49
no more lr scheduler 7289adb41de7807258eb8c29e6108fa65f59525a Thai Thien 2020-03-06 19:35:49
reduce learning rate bc8241e5b88b91c18bb7999a8d5d12fc79a5e3f7 Thai Thien 2020-03-06 19:28:27
dilated ccnn 5c5d92bdc0a288dd5d4ec5f1367d8cb928175bbe Thai Thien 2020-03-06 19:04:01
done 9f05e093ec7c10284a4aedf0738f9e61d5ac6fb6 Thai Thien 2020-03-06 18:02:34
with lr scheduler 466c364b60ed22c77319b14ccc9a201614b908bf Thai Thien 2020-03-04 17:57:49
train with learning rate scheduler fcd5a3c8da2dd6763e0d40742edf47b49c95fcfb Thai Thien 2020-03-04 17:55:11
ccnn no padding at output layer 57563fc07f656c63f807de4d80712ff11345109d Thai Thien 2020-03-04 17:36:43
fix dimension of ccnn f4439d9a78273ab3ba450f31a528509816b4352f Thai Thien 2020-03-04 17:32:48
ready to train dbe0d6c3271dbb22490f0877fa31ba9cd7852b99 Thai Thien 2020-03-04 15:55:05
done implement c-cnn 2deecef953baf1e07ce5cf5477d208bc7ffa34cf Thai Thien 2020-03-03 17:25:03
fix script 0e9d372b9ad60b32939f1e558b2a59fc7d518fa2 Thai Thien 2020-03-02 16:23:55
simple v3 to 91 epoch 539fdd03c3e3497fd22b7db2aaa14f067cbf6f8d Thai Thien 2020-03-02 16:09:43
we train on all training data and validate on test data 9407ef8d5b7c47c53d6f98dcb3c20208aad1d7a9 Thai Thien 2020-03-01 15:36:46
load and continue train v3 12421fb7330e5c9d2eed4f6e574dfe69bdfddefc Thai Thien 2020-03-01 14:50:01
Commit 44a669c1f918be9d74313f29a5dbbc876c29f2fc - new ccnn
Author: Thai Thien
Author date (UTC): 2020-03-09 17:16
Committer name: Thai Thien
Committer date (UTC): 2020-03-09 17:16
Parent(s): aa331331b12e5b454d372a550524b30a4bebe706
Signing key:
Tree: 49ae38e2f210bc15caeb5bea151c6698f017b2c3
File Lines added Lines deleted
models/__init__.py 1 1
models/compact_cnn.py 49 0
train_custom_compact_cnn.py 2 2
File models/__init__.py changed (mode: 100644) (index e02e0ec..a83ed52)
... ... from .deform_conv_v2 import DeformConv2d
5 5 from .attn_can_adcrowdnet import AttnCanAdcrowdNet from .attn_can_adcrowdnet import AttnCanAdcrowdNet
6 6 from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg
7 7 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4
8 from .compact_cnn import CompactCNN, CompactDilatedCNN
8 from .compact_cnn import CompactCNN, CompactDilatedCNN, DefDilatedCCNN
File models/compact_cnn.py changed (mode: 100644) (index 78d918a..3521e86)
1 1 import torch.nn as nn import torch.nn as nn
2 2 import torch import torch
3 3 from torchvision import models from torchvision import models
4 from .deform_conv_v2 import DeformConv2d
4 5 import collections import collections
5 6 import torch.nn.functional as F import torch.nn.functional as F
6 7
 
... ... class CompactCNN(nn.Module):
15 16 self.red_cnn = nn.Conv2d(3, 10, 9, padding=4) self.red_cnn = nn.Conv2d(3, 10, 9, padding=4)
16 17 self.green_cnn = nn.Conv2d(3, 14, 7, padding=3) self.green_cnn = nn.Conv2d(3, 14, 7, padding=3)
17 18 self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2) self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2)
19
20 self.red_cnn = DeformConv2d(3, 10, 9, padding=4)
21 self.green_cnn = DeformConv2d(3, 14, 7, padding=3)
22 self.blue_cnn = DeformConv2d(3, 16, 5, padding=2)
23
18 24 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
19 25
20 26 self.c1 = nn.Conv2d(40, 60, 3, padding=1) self.c1 = nn.Conv2d(40, 60, 3, padding=1)
 
... ... class CompactDilatedCNN(nn.Module):
78 84 x = self.output(x) x = self.output(x)
79 85 return x return x
80 86
87
88 class DefDilatedCCNN(nn.Module):
89 """
90
91 """
92 def __init__(self, load_weights=False):
93 super(DefDilatedCCNN, self).__init__()
94
95 self.red_cnn = DeformConv2d(3, 10, 9, padding=4)
96 self.green_cnn = DeformConv2d(3, 14, 7, padding=3)
97 self.blue_cnn = DeformConv2d(3, 16, 5, padding=2)
98
99 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
100
101 self.c1 = nn.Conv2d(40, 60, 3, dilation=2, padding=2, bias=False)
102 self.bn1 = nn.BatchNorm2d(60)
103 self.c2 = nn.Conv2d(60, 40, 3, dilation=2, padding=2, bias=False)
104 self.bn2 = nn.BatchNorm2d(40)
105 self.c3 = nn.Conv2d(40, 20, 3, dilation=2, padding=2, bias=False)
106 self.bn3 = nn.BatchNorm2d(20)
107 self.c4 = nn.Conv2d(20, 10, 3, dilation=2, padding=2, bias=False)
108 self.bn4 = nn.BatchNorm2d(10)
109 self.output = nn.Conv2d(10, 1, 1)
110
111 def forward(self,x):
112 x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True))
113 x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True))
114 x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True))
115
116 x = torch.cat((x_red, x_green, x_blue), 1)
117 x = F.relu(self.bn1(self.c1(x)), inplace=True)
118
119 x = F.relu(self.bn2(self.c2(x)), inplace=True)
120 x = self.max_pooling(x)
121
122 x = F.relu(self.bn3(self.c3(x)), inplace=True)
123 x = self.max_pooling(x)
124
125 x = F.relu(self.bn4(self.c4(x)), inplace=True)
126
127 x = self.output(x)
128 return x
129
File train_custom_compact_cnn.py changed (mode: 100644) (index 7c5d235..f28d251)
... ... from visualize_util import get_readable_time
9 9
10 10 import torch import torch
11 11 from torch import nn from torch import nn
12 from models import CompactDilatedCNN
12 from models import DefDilatedCCNN
13 13 import os import os
14 14
15 15 if __name__ == "__main__": if __name__ == "__main__":
 
... ... if __name__ == "__main__":
39 39 print("len train_loader ", len(train_loader)) print("len train_loader ", len(train_loader))
40 40
41 41 # model # model
42 model = CompactDilatedCNN()
42 model = DefDilatedCCNN()
43 43 model = model.to(device) model = model.to(device)
44 44
45 45 # loss function # loss function
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main