List of commits:
Subject Hash Author Date (UTC)
train h1 bigtail ea6391257cd243098cbbb771e705f1f115b845df Thai Thien 2020-05-12 16:58:26
mse mean e96c22a36e305681d7fed415a5a949fa0c1791c9 Thai Thien 2020-05-10 18:32:02
no fix 7bd97e91de5d7c2d307407287c82e60e893c0c92 Thai Thien 2020-05-10 18:22:45
no fix fc20ae6922c2e53f7d37f4228fb921894cd78eab Thai Thien 2020-05-10 18:19:59
t9 d8ef865ea602670548e897d8b7ac4c925cc9b393 Thai Thien 2020-05-10 18:19:30
test with L1 loss 6492b65da4bdf6351b661f39b6bce6f08d37f17c Thai Thien 2020-05-10 18:10:49
H2 1d6d11b2eeecb67dd7d329e38de61b872870a9aa Thai Thien 2020-05-06 17:42:52
do something with l1 loss 5268c4fc163bb512f293fbac381a64a75c4fe462 Thai Thien 2020-05-06 17:32:45
typo b7b8e2303ce99b2196402ec93334598598e71e5a Thai Thien 2020-05-05 17:32:31
increase epoch 67f89509e4294c4310b42e790425c82279df16b3 Thai Thien 2020-05-05 17:25:17
H1 t8 1c692b37536bd72abaa0995001d3a396b82bc2f0 Thai Thien 2020-05-05 17:24:56
OMP_NUM_THREADS=5 ac76431f8ca1ada27ca7ffdaa289996baee064c1 Thai Thien 2020-05-05 17:14:41
train da020f46703ca4fae867a09960593ef6818b4a91 Thai Thien 2020-05-05 17:05:06
batch_size 10 6b6478b9570f9133489c8a9427a857c14a14fb13 Thai Thien 2020-05-02 11:26:38
change dataset preprocess for t3 267d31931fd80178714812fced9f86a27479d54f Thai Thien 2020-05-02 11:23:19
t3 e2a1c6f6e8a6d34b36aa8d6c86a5509bc8d41cdd Thai Thien 2020-05-02 11:20:05
batch size 20 ea5737c694cb2967cb041db99ca391d06a66100d Thai Thien 2020-05-02 11:19:18
ccn v7 shb fixed 15 4b28c4049c4b25a6afeb563864f76907a1e2360e Thai Thien 2020-05-02 11:16:14
shb 7af5a7bb61d8858a2f6ef36d44844506cde917c3 Thai Thien 2020-05-02 11:14:03
batch đéo 62e1e9124b7e611c6749c1544c60687abd30895e Thai Thien 2020-05-01 17:10:44
Commit ea6391257cd243098cbbb771e705f1f115b845df - train h1 bigtail
Author: Thai Thien
Author date (UTC): 2020-05-12 16:58
Committer name: Thai Thien
Committer date (UTC): 2020-05-12 16:58
Parent(s): e96c22a36e305681d7fed415a5a949fa0c1791c9
Signing key:
Tree: b35f42516fd34b989d386cafec1755f088c3709c
File Lines added Lines deleted
experiment_meow_main.py 5 1
models/meow_experiment/ccnn_head.py 0 1
models/meow_experiment/kitten_meow_1.py 66 1
train_script/combine/h1_bigtail_t1_sha.sh 3 3
train_script/combine/h1_bigtail_t1_shb_fixed.sh 4 5
train_script/combine/h1_bigtail_t2_sha.sh 4 4
train_script/combine/h1_bigtail_t2_shb_fixed.sh 5 6
File experiment_meow_main.py changed (mode: 100644) (index 9beb6c8..049432d)
... ... from torch import nn
13 13 from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4 from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4
14 14 from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3 from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3
15 15 from models.meow_experiment.ccnn_head import H1, H2 from models.meow_experiment.ccnn_head import H1, H2
16 from models.meow_experiment.kitten_meow_1 import H1_Bigtail3
16 17 from models import CustomCNNv2 from models import CustomCNNv2
17 18 import os import os
18 19 from model_util import get_lr from model_util import get_lr
 
... ... if __name__ == "__main__":
82 83 elif model_name == "H1": elif model_name == "H1":
83 84 model = H1() model = H1()
84 85 elif model_name == "H2": elif model_name == "H2":
85 model = H1()
86 model = H2()
87 elif model_name == "H1_Bigtail3":
88 model = H1_Bigtail3()
89
86 90 else: else:
87 91 print("error: you didn't pick a model") print("error: you didn't pick a model")
88 92 exit(-1) exit(-1)
File models/meow_experiment/ccnn_head.py changed (mode: 100644) (index a7d6359..2b34f59)
... ... class H1(nn.Module):
31 31
32 32 self.c1 = nn.Conv2d(40, 60, 3, padding=1) self.c1 = nn.Conv2d(40, 60, 3, padding=1)
33 33
34 # ideal from CSRNet
35 34 self.c2 = nn.Conv2d(60, 40, 3, padding=1) self.c2 = nn.Conv2d(60, 40, 3, padding=1)
36 35 self.c3 = nn.Conv2d(40, 20, 3, padding=1) self.c3 = nn.Conv2d(40, 20, 3, padding=1)
37 36 self.c4 = nn.Conv2d(20, 10, 3, padding=1) self.c4 = nn.Conv2d(20, 10, 3, padding=1)
File models/meow_experiment/kitten_meow_1.py changed (mode: 100644) (index 2daba3f..11453a6)
... ... class M4(nn.Module):
256 256 x = F.relu(self.c5(x), inplace=True) x = F.relu(self.c5(x), inplace=True)
257 257
258 258 x = self.output(x) x = self.output(x)
259 return x
259 return x
260
261
262 class H1_Bigtail3(nn.Module):
263 """
264 A REAL-TIME DEEP NETWORK FOR CROWD COUNTING
265 https://arxiv.org/pdf/2002.06515.pdf
266 the improve version
267
268 we change 5x5 7x7 9x9 with 3x3
269 Keep the tail
270 """
271 def __init__(self, load_weights=False):
272 super(H1_Bigtail3, self).__init__()
273 self.model_note = "headh1, all 20, with bigtail3 dilated"
274
275 # ideal from crowd counting using DMCNN
276 self.front_cnn_1 = nn.Conv2d(3, 10, 3, padding=1)
277 self.front_cnn_2 = nn.Conv2d(10, 20, 3, padding=1)
278 self.front_cnn_3 = nn.Conv2d(20, 20, 3, padding=1)
279 self.front_cnn_4 = nn.Conv2d(20, 20, 3, padding=1)
280
281 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
282
283 self.c0 = nn.Conv2d(40, 60, 3, padding=2, dilation=2)
284 self.c1 = nn.Conv2d(60, 60, 3, padding=2, dilation=2)
285 self.c2 = nn.Conv2d(60, 60, 3, padding=2, dilation=2)
286 self.c3 = nn.Conv2d(60, 30, 3, padding=2, dilation=2)
287 self.c4 = nn.Conv2d(30, 15, 3, padding=2, dilation=2)
288 self.c5 = nn.Conv2d(15, 10, 3, padding=2, dilation=2)
289
290 self.output = nn.Conv2d(10, 1, 1)
291
292 def forward(self, x):
293 x_red = F.relu(self.front_cnn_1(x), inplace=True)
294 x_red = F.relu(self.front_cnn_2(x_red), inplace=True)
295 x_red = F.relu(self.front_cnn_3(x_red), inplace=True)
296 x_red = F.relu(self.front_cnn_4(x_red), inplace=True)
297 x_red = self.max_pooling(x_red)
298
299 x_green = F.relu(self.front_cnn_1(x), inplace=True)
300 x_green = F.relu(self.front_cnn_2(x_green), inplace=True)
301 x_green = F.relu(self.front_cnn_3(x_green), inplace=True)
302 x_green = self.max_pooling(x_green)
303
304 x_blue = F.relu(self.front_cnn_1(x), inplace=True)
305 x_blue = F.relu(self.front_cnn_2(x_blue), inplace=True)
306 x_blue = self.max_pooling(x_blue)
307
308 # x = self.max_pooling(x)
309 x = torch.cat((x_red, x_green, x_blue), 1)
310 x = F.relu(self.c0(x), inplace=True)
311
312 x = F.relu(self.c1(x), inplace=True)
313
314 x = F.relu(self.c2(x), inplace=True)
315 x = self.max_pooling(x)
316
317 x = F.relu(self.c3(x), inplace=True)
318 x = self.max_pooling(x)
319
320 x = F.relu(self.c4(x), inplace=True)
321 x = F.relu(self.c5(x), inplace=True)
322
323 x = self.output(x)
324 return x
File train_script/combine/h1_bigtail_t1_sha.sh copied from file train_script/meow_one/big_tail/bigtail3_t4_sha.sh (similarity 73%) (mode: 100644) (index 8777e85..fae9b31)
1 task="bigtail3_t4_sha"
1 task="h1_bigtail_t1_sha"
2 2
3 CUDA_VISIBLE_DEVICES=4 OMP_NUM_THREADS=5 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
3 CUDA_VISIBLE_DEVICES=1 OMP_NUM_THREADS=4 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
4 4 --task_id $task \ --task_id $task \
5 5 --note "bigtail3 L1" \ --note "bigtail3 L1" \
6 --model "BigTail3" \
6 --model "H1_Bigtail3" \
7 7 --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \
8 8 --lr 1e-4 \ --lr 1e-4 \
9 9 --decay 1e-4 \ --decay 1e-4 \
File train_script/combine/h1_bigtail_t1_shb_fixed.sh copied from file train_script/meow_one/big_tail/bigtail3_t4_shb_fixed.sh (similarity 66%) (mode: 100644) (index 1de7d38..7d5aba9)
1 task="bigtail3_t4_shb_fixed"
1 task="h1_bigtail_t1_shb_fixed"
2 2
3 CUDA_VISIBLE_DEVICES=2 OMP_NUM_THREADS=5 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
3 CUDA_VISIBLE_DEVICES=3 OMP_NUM_THREADS=4 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
4 4 --task_id $task \ --task_id $task \
5 5 --note "bigtail3 L1" \ --note "bigtail3 L1" \
6 --model "BigTail3" \
6 --model "H1_Bigtail3" \
7 7 --input /data/rnd/thient/thient_data/shanghaitech_with_people_density_map/ShanghaiTech_fixed_sigma/part_B \ --input /data/rnd/thient/thient_data/shanghaitech_with_people_density_map/ShanghaiTech_fixed_sigma/part_B \
8 8 --lr 1e-4 \ --lr 1e-4 \
9 9 --decay 1e-4 \ --decay 1e-4 \
10 --batch_size 8 \
11 10 --loss_fn "L1" \ --loss_fn "L1" \
12 11 --datasetname shanghaitech_rnd \ --datasetname shanghaitech_rnd \
13 --epochs 601 > logs/$task.log &
12 --epochs 1201 > logs/$task.log &
14 13
15 14 echo logs/$task.log # for convenience echo logs/$task.log # for convenience
File train_script/combine/h1_bigtail_t2_sha.sh copied from file train_script/meow_one/big_tail/bigtail3_t4_sha.sh (similarity 68%) (mode: 100644) (index 8777e85..08e7382)
1 task="bigtail3_t4_sha"
1 task="h1_bigtail_t2_sha"
2 2
3 CUDA_VISIBLE_DEVICES=4 OMP_NUM_THREADS=5 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
3 CUDA_VISIBLE_DEVICES=2 OMP_NUM_THREADS=4 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
4 4 --task_id $task \ --task_id $task \
5 5 --note "bigtail3 L1" \ --note "bigtail3 L1" \
6 --model "BigTail3" \
6 --model "H1_Bigtail3" \
7 7 --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \
8 8 --lr 1e-4 \ --lr 1e-4 \
9 9 --decay 1e-4 \ --decay 1e-4 \
10 --loss_fn "L1" \
10 --loss_fn "MSE" \
11 11 --datasetname shanghaitech_20p \ --datasetname shanghaitech_20p \
12 12 --epochs 1201 > logs/$task.log & --epochs 1201 > logs/$task.log &
13 13
File train_script/combine/h1_bigtail_t2_shb_fixed.sh copied from file train_script/meow_one/big_tail/bigtail3_t4_shb_fixed.sh (similarity 62%) (mode: 100644) (index 1de7d38..7586edf)
1 task="bigtail3_t4_shb_fixed"
1 task="h1_bigtail_t2_shb_fixed"
2 2
3 CUDA_VISIBLE_DEVICES=2 OMP_NUM_THREADS=5 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
3 CUDA_VISIBLE_DEVICES=4 OMP_NUM_THREADS=4 HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_meow_main.py \
4 4 --task_id $task \ --task_id $task \
5 5 --note "bigtail3 L1" \ --note "bigtail3 L1" \
6 --model "BigTail3" \
6 --model "H1_Bigtail3" \
7 7 --input /data/rnd/thient/thient_data/shanghaitech_with_people_density_map/ShanghaiTech_fixed_sigma/part_B \ --input /data/rnd/thient/thient_data/shanghaitech_with_people_density_map/ShanghaiTech_fixed_sigma/part_B \
8 8 --lr 1e-4 \ --lr 1e-4 \
9 9 --decay 1e-4 \ --decay 1e-4 \
10 --batch_size 8 \
11 --loss_fn "L1" \
10 --loss_fn "MSE" \
12 11 --datasetname shanghaitech_rnd \ --datasetname shanghaitech_rnd \
13 --epochs 601 > logs/$task.log &
12 --epochs 1201 > logs/$task.log &
14 13
15 14 echo logs/$task.log # for convenience echo logs/$task.log # for convenience
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main