List of commits:
Subject Hash Author Date (UTC)
DilatedCCNNv2 75d2989232a8a68eba9b4920ab2374ac28438e0e Thai Thien 2020-03-10 05:11:12
fix script for ccnn_v2_t1_c2 57928056d13bc9b1f9b11e14dd305005a3a5aeea Thai Thien 2020-03-10 04:56:33
fix trash code 33c406b13b5d45527b05dfb7f4281c3966c6471e Thai Thien 2020-03-10 04:49:52
repair dir in config baf522825f906a3d1fc5524f42a80da33d059640 Thai Thien 2020-03-10 04:45:11
v3 t1 c2 2d4727f47f4262833dca2087fb9e48f0d117e334 Thai Thien 2020-03-10 04:29:23
dilated ccnn v1 t1 7807d7a979353fa84d0b7319820386e93dbe5cc4 Thai Thien 2020-03-09 17:20:58
new ccnn 44a669c1f918be9d74313f29a5dbbc876c29f2fc Thai Thien 2020-03-09 17:16:49
fix script aa331331b12e5b454d372a550524b30a4bebe706 Thai Thien 2020-03-07 18:32:06
try reproduct ccnn with keepfull and lr 1e-5 814c520cbd1bb2d7fd50d2a8d3579d43da79fe60 Thai Thien 2020-03-07 18:30:42
my simple v4 with addition deform cnn at the end 5392aaf6c14fdd910f52096dbb921bed7470c4f7 Thai Thien 2020-03-07 18:15:22
fix the scheduler 77e6737a040f5aa5745b8a8830f5bec12322b10f Thai Thien 2020-03-07 17:46:02
t4 lr 1e-5 acd41ed30c95f63e01a05a6d9929410637852d9e Thai Thien 2020-03-06 19:41:49
no more lr scheduler 7289adb41de7807258eb8c29e6108fa65f59525a Thai Thien 2020-03-06 19:35:49
reduce learning rate bc8241e5b88b91c18bb7999a8d5d12fc79a5e3f7 Thai Thien 2020-03-06 19:28:27
dilated ccnn 5c5d92bdc0a288dd5d4ec5f1367d8cb928175bbe Thai Thien 2020-03-06 19:04:01
done 9f05e093ec7c10284a4aedf0738f9e61d5ac6fb6 Thai Thien 2020-03-06 18:02:34
with lr scheduler 466c364b60ed22c77319b14ccc9a201614b908bf Thai Thien 2020-03-04 17:57:49
train with learning rate scheduler fcd5a3c8da2dd6763e0d40742edf47b49c95fcfb Thai Thien 2020-03-04 17:55:11
ccnn no padding at output layer 57563fc07f656c63f807de4d80712ff11345109d Thai Thien 2020-03-04 17:36:43
fix dimension of ccnn f4439d9a78273ab3ba450f31a528509816b4352f Thai Thien 2020-03-04 17:32:48
Commit 75d2989232a8a68eba9b4920ab2374ac28438e0e - DilatedCCNNv2
Author: Thai Thien
Author date (UTC): 2020-03-10 05:11
Committer name: Thai Thien
Committer date (UTC): 2020-03-10 05:11
Parent(s): 57928056d13bc9b1f9b11e14dd305005a3a5aeea
Signing key:
Tree: ef4c7a1547265098f139e9a2ae60884c6aa64bea
File Lines added Lines deleted
models/__init__.py 1 1
models/compact_cnn.py 43 1
train_custom_compact_cnn.py 6 3
train_custom_compact_cnn_lrscheduler.py 1 1
train_script/CCNN_custom/dilated_ccnn_v2_t1.sh 3 3
File models/__init__.py changed (mode: 100644) (index a83ed52..3c1cb01)
... ... from .deform_conv_v2 import DeformConv2d
5 5 from .attn_can_adcrowdnet import AttnCanAdcrowdNet from .attn_can_adcrowdnet import AttnCanAdcrowdNet
6 6 from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg
7 7 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4
8 from .compact_cnn import CompactCNN, CompactDilatedCNN, DefDilatedCCNN
8 from .compact_cnn import CompactCNN, CompactDilatedCNN, DefDilatedCCNN, DilatedCCNNv2
File models/compact_cnn.py changed (mode: 100644) (index 25cef10..584995a)
... ... class CompactDilatedCNN(nn.Module):
83 83
84 84 class DefDilatedCCNN(nn.Module): class DefDilatedCCNN(nn.Module):
85 85 """ """
86
86 fail reason: out of cuda memory at red_cnn
87 possible fix: try torchvision deform conv
87 88 """ """
88 89 def __init__(self, load_weights=False): def __init__(self, load_weights=False):
89 90 super(DefDilatedCCNN, self).__init__() super(DefDilatedCCNN, self).__init__()
 
... ... class DefDilatedCCNN(nn.Module):
123 124 x = self.output(x) x = self.output(x)
124 125 return x return x
125 126
127 class DilatedCCNNv2(nn.Module):
128 """
129
130 """
131 def __init__(self, load_weights=False):
132 super(DilatedCCNNv2, self).__init__()
133
134 self.red_cnn = nn.Conv2d(3, 10, 9, padding=4)
135 self.green_cnn = nn.Conv2d(3, 14, 7, padding=3)
136 self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2)
137
138 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
139
140 self.c1 = nn.Conv2d(40, 60, 3, dilation=2, padding=2, bias=False)
141 self.bn1 = nn.BatchNorm2d(60)
142 self.c2 = nn.Conv2d(60, 40, 3, dilation=2, padding=2, bias=False)
143 self.bn2 = nn.BatchNorm2d(40)
144 self.c3 = nn.Conv2d(40, 20, 3, dilation=2, padding=2, bias=False)
145 self.bn3 = nn.BatchNorm2d(20)
146 self.c4 = nn.Conv2d(20, 10, 3, dilation=2, padding=2, bias=False)
147 self.bn4 = nn.BatchNorm2d(10)
148 self.output = nn.Conv2d(10, 1, 1)
149
150 def forward(self,x):
151 x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True))
152 x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True))
153 x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True))
154
155 x = torch.cat((x_red, x_green, x_blue), 1)
156 x = F.relu(self.bn1(self.c1(x)), inplace=True)
157
158 x = F.relu(self.bn2(self.c2(x)), inplace=True)
159 x = self.max_pooling(x)
160
161 x = F.relu(self.bn3(self.c3(x)), inplace=True)
162 x = self.max_pooling(x)
163
164 x = F.relu(self.bn4(self.c4(x)), inplace=True)
165
166 x = self.output(x)
167 return x
File train_custom_compact_cnn.py changed (mode: 100644) (index f28d251..e6bf570)
... ... from ignite.engine import Engine
6 6 from ignite.handlers import Checkpoint, DiskSaver from ignite.handlers import Checkpoint, DiskSaver
7 7 from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError
8 8 from visualize_util import get_readable_time from visualize_util import get_readable_time
9
9 from torchsummary import summary
10 10 import torch import torch
11 11 from torch import nn from torch import nn
12 from models import DefDilatedCCNN
12 from models import DilatedCCNNv2
13 13 import os import os
14 14
15 15 if __name__ == "__main__": if __name__ == "__main__":
 
... ... if __name__ == "__main__":
39 39 print("len train_loader ", len(train_loader)) print("len train_loader ", len(train_loader))
40 40
41 41 # model # model
42 model = DefDilatedCCNN()
42 model = DilatedCCNNv2()
43 43 model = model.to(device) model = model.to(device)
44 44
45 45 # loss function # loss function
 
... ... if __name__ == "__main__":
58 58 }, device=device) }, device=device)
59 59 print(model) print(model)
60 60
61 print (summary(model, (3, 512, 512)))
62
61 63 print(args) print(args)
62 64
65
63 66 if len(args.load_model) > 0: if len(args.load_model) > 0:
64 67 load_model_path = args.load_model load_model_path = args.load_model
65 68 print("load mode " + load_model_path) print("load mode " + load_model_path)
File train_custom_compact_cnn_lrscheduler.py changed (mode: 100644) (index 4a5d6ff..943fa68)
... ... if __name__ == "__main__":
75 75 else: else:
76 76 print("do not load, keep training") print("do not load, keep training")
77 77
78 trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
78 trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler)
79 79
80 80
81 81 @trainer.on(Events.ITERATION_COMPLETED(every=50)) @trainer.on(Events.ITERATION_COMPLETED(every=50))
File train_script/CCNN_custom/dilated_ccnn_v2_t1.sh copied from file train_script/CCNN_custom/dilated_ccnn_v1_t1_scheduler.sh (similarity 65%) (mode: 100644) (index f48d434..5e75eed)
1 1 CUDA_VISIBLE_DEVICES=5 nohup python train_custom_compact_cnn.py \ CUDA_VISIBLE_DEVICES=5 nohup python train_custom_compact_cnn.py \
2 --task_id dilated_ccnn_v1_t3 \
2 --task_id dilated_ccnn_v2_t1 \
3 3 --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \
4 --lr 1e-6 \
4 --lr 1e-5 \
5 5 --decay 5e-5 \ --decay 5e-5 \
6 6 --datasetname shanghaitech \ --datasetname shanghaitech \
7 --epochs 400 > logs/dilated_ccnn_v1_t3.log &
7 --epochs 400 > logs/dilated_ccnn_v2_t1.log &
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main