List of commits:
Subject Hash Author Date (UTC)
bigtal4 1a434051955cfaeda47780ab801924cfc9c22f94 Thai Thien 2020-06-02 18:00:43
l1 1e-4 2ea7293b60423704648df4a610a012f1da822ec9 Thai Thien 2020-05-29 16:38:44
fix log epoch 84e8beff8e305f81de7a05add2e58dcfe389ede5 Thai Thien 2020-05-29 16:38:31
skip evaluate on training set a38d5731a9d6209bc98a9cca0ad023b4798c68d5 Thai Thien 2020-05-28 16:15:29
t7 ac1f0d71c2d3e0a01241622c639a9090a5b3c427 Thai Thien 2020-05-28 15:55:56
rename 4e9754c3dff98577616d3f043f7f260aeef40ac6 Thai Thien 2020-05-27 15:54:54
not use gpu2 8324400a3d65c58cfdd5a63b0a940217f3277769 Thai Thien 2020-05-27 15:54:11
remove call loss in eval 06f13484f24b85b12cab15e5af16db4d086db497 Thai Thien 2020-05-27 15:52:44
expermient, bug fix e83e7e058b6f1d7ec1237677a2d4770f1f788ea0 Thai Thien 2020-05-27 15:43:58
surpress python warning 82806340334b3beb9916b34bb00a66381ce57aae Thai Thien 2020-05-26 12:00:49
sha total crop no keepfull d0783ddab88ed042607738fb665b34028d485d0d Thai Thien 2020-05-26 11:49:10
sha 60p d5146c9f28b62fac137a15c862d99e1af30ce5fa Thai Thien 2020-05-25 17:28:12
t12 35b7a40082505da5f79caf4cc16023144e18a7d8 Thai Thien 2020-05-24 12:00:45
add sha 20p random 597c821a0103344b5e3b7b80eb2f61c335275ed2 Thai Thien 2020-05-24 11:52:18
h2 t3 shb babec8c08d4aebd8af14558eddfa2898918ff152 Thai Thien 2020-05-23 16:49:58
y_pred e1c1e8096230344dabfc1c85ccae832ba08aaad6 Thai Thien 2020-05-23 16:37:24
when not train, return img and true count 93d3b48fc1cd3a5b43ac0e8782c57f2ea00a48bc Thai Thien 2020-05-23 16:33:00
map->mat c64a1548b1e9745c109c211f76e8dc2554cf6747 Thai Thien 2020-05-23 16:27:48
fix wrong file name 950ad302d9b0ec6ed04fb4d4a870087dda57281c Thai Thien 2020-05-23 16:17:30
fix scipy.io 766d0f3ba48e0c3e0d2562136c5549e20070b805 Thai Thien 2020-05-23 15:44:44
Commit 1a434051955cfaeda47780ab801924cfc9c22f94 - bigtal4
Author: Thai Thien
Author date (UTC): 2020-06-02 18:00
Committer name: Thai Thien
Committer date (UTC): 2020-06-02 18:00
Parent(s): 2ea7293b60423704648df4a610a012f1da822ec9
Signer:
Signing key:
Signing status: N
Tree: f571e845b39b2b38a1568b05a6fbd8697f78dbde
File Lines added Lines deleted
experiment_main.py 3 1
models/meow_experiment/ccnn_tail.py 42 0
train_script/meow_one/big_tail/bigtail4_t1_sha.sh 5 5
train_script/meow_one/big_tail/bigtail4_t2_sha.sh 5 5
File experiment_main.py changed (mode: 100644) (index 2d2b3cf..99f4a01)
... ... from visualize_util import get_readable_time
11 11 import torch import torch
12 12 from torch import nn from torch import nn
13 13 from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4 from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4
14 from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3
14 from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3, BigTail4
15 15 from models.meow_experiment.ccnn_head import H1, H2 from models.meow_experiment.ccnn_head import H1, H2
16 16 from models.meow_experiment.kitten_meow_1 import H1_Bigtail3 from models.meow_experiment.kitten_meow_1 import H1_Bigtail3
17 17 from models import CustomCNNv2, CompactCNNV7 from models import CustomCNNv2, CompactCNNV7
 
... ... if __name__ == "__main__":
83 83 model = BigTailM2() model = BigTailM2()
84 84 elif model_name == "BigTail3": elif model_name == "BigTail3":
85 85 model = BigTail3() model = BigTail3()
86 elif model_name == "BigTail4":
87 model = BigTail4()
86 88 elif model_name == "H1": elif model_name == "H1":
87 89 model = H1() model = H1()
88 90 elif model_name == "H2": elif model_name == "H2":
File models/meow_experiment/ccnn_tail.py changed (mode: 100644) (index 88dd56a..c40cd57)
... ... class BigTail3(nn.Module):
147 147 return x return x
148 148
149 149
150 class BigTail4(nn.Module):
151 """
152 we set max tail at 60 only
153 remove c5 comparing to bigtal3
154 """
155 def __init__(self, load_weights=False):
156 super(BigTail4, self).__init__()
157 self.model_note = "small taill 100"
158 self.red_cnn = nn.Conv2d(3, 10, 9, padding=4)
159 self.green_cnn = nn.Conv2d(3, 14, 7, padding=3)
160 self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2)
161
162 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
163
164 self.c0 = nn.Conv2d(40, 60, 3, padding=2, dilation=2)
165 self.c1 = nn.Conv2d(60, 60, 3, padding=2, dilation=2)
166 self.c2 = nn.Conv2d(60, 60, 3, padding=2, dilation=2)
167 self.c3 = nn.Conv2d(60, 30, 3, padding=2, dilation=2)
168 self.c4 = nn.Conv2d(30, 10, 3, padding=2, dilation=2)
169 self.output = nn.Conv2d(10, 1, 1)
170
171 def forward(self,x):
172 x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True))
173 x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True))
174 x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True))
175
176 x = torch.cat((x_red, x_green, x_blue), 1)
177 x = F.relu(self.c0(x), inplace=True)
178
179 x = F.relu(self.c1(x), inplace=True)
180
181 x = F.relu(self.c2(x), inplace=True)
182 x = self.max_pooling(x)
183
184 x = F.relu(self.c3(x), inplace=True)
185 x = self.max_pooling(x)
186
187 x = F.relu(self.c4(x), inplace=True)
188
189 x = self.output(x)
190 return x
191
File train_script/meow_one/big_tail/bigtail4_t1_sha.sh copied from file train_script/meow_one/head/H2_t8_sha.sh (similarity 72%) (mode: 100644) (index 0c77d99..494338d)
1 task="H2_t8_sha"
1 task="bigtail4_t1_sha"
2 2
3 3 CUDA_VISIBLE_DEVICES=5 OMP_NUM_THREADS=2 PYTHONWARNINGS="ignore" HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_main.py \ CUDA_VISIBLE_DEVICES=5 OMP_NUM_THREADS=2 PYTHONWARNINGS="ignore" HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_main.py \
4 4 --task_id $task \ --task_id $task \
5 --note "a H2 with L1, hope better than baseline" \
6 --model "H2" \
5 --note "bigtail4" \
6 --model "BigTail4" \
7 7 --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \
8 8 --lr 1e-4 \ --lr 1e-4 \
9 --decay 1e-4 \
9 --decay 1e-4 \
10 10 --loss_fn "L1" \ --loss_fn "L1" \
11 11 --skip_train_eval \ --skip_train_eval \
12 12 --datasetname shanghaitech_crop_random \ --datasetname shanghaitech_crop_random \
13 --epochs 1201 > logs/$task.log &
13 --epochs 800 > logs/$task.log &
14 14
15 15 echo logs/$task.log # for convenience echo logs/$task.log # for convenience
File train_script/meow_one/big_tail/bigtail4_t2_sha.sh copied from file train_script/meow_one/big_tail/bigtail3_t7_sha.sh (similarity 64%) (mode: 100644) (index 9503373..8874e0f)
1 task="bigtail3_t7_sha"
1 task="bigtail4_t2_sha"
2 2
3 CUDA_VISIBLE_DEVICES=4 OMP_NUM_THREADS=2 PYTHONWARNINGS="ignore" HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_main.py \
3 CUDA_VISIBLE_DEVICES=6 OMP_NUM_THREADS=2 PYTHONWARNINGS="ignore" HTTPS_PROXY="http://10.60.28.99:86" nohup python experiment_main.py \
4 4 --task_id $task \ --task_id $task \
5 --note "bigtail3 L1" \
6 --model "BigTail3" \
5 --note "bigtail4" \
6 --model "BigTail4" \
7 7 --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \
8 8 --lr 1e-4 \ --lr 1e-4 \
9 9 --decay 1e-4 \ --decay 1e-4 \
10 10 --loss_fn "MSE" \ --loss_fn "MSE" \
11 11 --skip_train_eval \ --skip_train_eval \
12 12 --datasetname shanghaitech_crop_random \ --datasetname shanghaitech_crop_random \
13 --epochs 1001 > logs/$task.log &
13 --epochs 800 > logs/$task.log &
14 14
15 15 echo logs/$task.log # for convenience echo logs/$task.log # for convenience
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main