List of commits:
Subject Hash Author Date (UTC)
customccnn v3 with many batchnorm 42ea62efcce0489673cdf71db792eb06d88fa5e6 Thai Thien 2020-03-16 17:22:30
CustomCNNv3 b1b117d627147bc23727e55ab0c3026e920d9f20 Thai Thien 2020-03-16 16:05:02
fix dimension mismake again d2341d70250a38071fefa9e6af24d7382ccf3138 Thai Thien 2020-03-15 09:29:03
fix custom ccnn dimension mismake 8b2fe4f669469196685a37974b5e7fee5e9a5fe0 Thai Thien 2020-03-15 08:57:16
custom ccnn fa81a2a28140cc8a84e3d5df49443ebe79e89268 Thai Thien 2020-03-15 08:54:19
forgot to add batch size arg in "train_compact_cnn" bc9f7ca249f6719d8ee67a631bce184914d6b199 Thai Thien 2020-03-14 17:58:26
ccnn_v5_t2shb 85dcfe49ecbcef2ccb1bb70b2a1b898440180286 Thai Thien 2020-03-14 17:43:00
add batchsize, ready train shb with batch5 254ba309e8031fcffb911f29d92a46f56106e8aa Thai Thien 2020-03-14 17:39:18
gpu 5 44fb230848afb391fa4efa4e412d3def3e32965a Thai Thien 2020-03-14 11:29:26
reduce lr to 1e-4 9c5e5b64621b33ae7ecc0124c204c4053296d426 Thai Thien 2020-03-14 11:24:24
train with scheduler 390958d81f108ed3ca3cfe668ceef2a4ebf6a69f Thai Thien 2020-03-14 10:57:44
prepare to train 8c16b70c805d48e4f944fa469cb370f1ee1297f0 Thai Thien 2020-03-14 10:15:22
add 1 move layer to ccnn 10d5d3711dda204d7b059b79d775c0359d1a964d Thai Thien 2020-03-14 10:04:17
continue train 3627b8cbf4192856d7453b72972a11659c34ae5f Thai Thien 2020-03-14 01:58:25
simple v4 t2 4dee2eba24246cf21a84dd7f9ef74d6c434edf1d Thai Thien 2020-03-13 18:36:33
nll to loss 958f0895b81d42e6d31a0bbd5787211538c081fe Thai Thien 2020-03-13 16:27:27
ccnn_v4_t1 1f950b91f4fb89a0f08baf08dffb2a501546b64f Thai Thien 2020-03-13 16:23:39
add proxy 9afb66a73e3ae24b2144faf311b128dbe5768f3c Thai Thien 2020-03-13 16:09:30
add comet, add scheduler to simple 4ef5939124745dfc54ad8a87936954a2bde8a5a2 Thai Thien 2020-03-13 16:06:24
ccnn v1_t5 lr scheduler 31f7a693eff8e60a07bd7bc439575cc7712fe31b Thai Thien 2020-03-13 15:37:37
Commit 42ea62efcce0489673cdf71db792eb06d88fa5e6 - customccnn v3 with many batchnorm
Author: Thai Thien
Author date (UTC): 2020-03-16 17:22
Committer name: Thai Thien
Committer date (UTC): 2020-03-16 17:22
Parent(s): b1b117d627147bc23727e55ab0c3026e920d9f20
Signer:
Signing key:
Signing status: N
Tree: 1dda8fdadf443fdd048300fec6037aef97ce87fc
File Lines added Lines deleted
models/__init__.py 1 1
models/my_ccnn.py 15 10
train_custom_compact_cnn.py 3 3
train_custom_compact_cnn_lrscheduler.py 29 11
train_script/CCNN_custom/custom_ccnn_v3_t1_scheduler_shb.sh 9 0
train_script/CCNN_custom/custom_ccnn_v3_t1_shb.sh 4 4
File models/__init__.py changed (mode: 100644) (index 94cb9a5..b90477d)
... ... from .attn_can_adcrowdnet import AttnCanAdcrowdNet
7 7 from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg
8 8 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4
9 9 from .compact_cnn import CompactCNN, CompactCNNV2, CompactDilatedCNN, DefDilatedCCNN, DilatedCCNNv2 from .compact_cnn import CompactCNN, CompactCNNV2, CompactDilatedCNN, DefDilatedCCNN, DilatedCCNNv2
10 from .my_ccnn import CustomCNNv1, CustomCNNv2
10 from .my_ccnn import CustomCNNv1, CustomCNNv2, CustomCNNv3
File models/my_ccnn.py changed (mode: 100644) (index 1bd1c0a..9ee3cc0)
... ... class CustomCNNv3(nn.Module):
169 169
170 170 # ideal from crowd counting using DMCNN # ideal from crowd counting using DMCNN
171 171 self.front_cnn_1 = nn.Conv2d(3, 20, 3, padding=1) self.front_cnn_1 = nn.Conv2d(3, 20, 3, padding=1)
172 self.front_bn1 = nn.BatchNorm2d(20)
172 173 self.front_cnn_2 = nn.Conv2d(20, 16, 3, padding=1) self.front_cnn_2 = nn.Conv2d(20, 16, 3, padding=1)
174 self.front_bn2 = nn.BatchNorm2d(16)
173 175 self.front_cnn_3 = nn.Conv2d(16, 14, 3, padding=1) self.front_cnn_3 = nn.Conv2d(16, 14, 3, padding=1)
176 self.front_bn3 = nn.BatchNorm2d(14)
174 177 self.front_cnn_4 = nn.Conv2d(14, 10, 3, padding=1) self.front_cnn_4 = nn.Conv2d(14, 10, 3, padding=1)
178 self.front_bn4 = nn.BatchNorm2d(10)
175 179
176 180 self.c0 = nn.Conv2d(40, 40, 3, padding=1) self.c0 = nn.Conv2d(40, 40, 3, padding=1)
181 self.bn0 = nn.BatchNorm2d(40)
177 182 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
178 183
179 184 self.c1 = nn.Conv2d(40, 60, 3, padding=1) self.c1 = nn.Conv2d(40, 60, 3, padding=1)
180
185 self.bn1 = nn.BatchNorm2d(60)
181 186 # ideal from CSRNet # ideal from CSRNet
182 187 self.c2 = nn.Conv2d(60, 40, 3, padding=2, dilation=2, bias=False) self.c2 = nn.Conv2d(60, 40, 3, padding=2, dilation=2, bias=False)
183 188 self.bn2 = nn.BatchNorm2d(40) self.bn2 = nn.BatchNorm2d(40)
 
... ... class CustomCNNv3(nn.Module):
192 197 #x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True)) #x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True))
193 198 #x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True)) #x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True))
194 199
195 x_red = F.relu(self.front_cnn_1(x), inplace=True)
196 x_red = F.relu(self.front_cnn_2(x_red), inplace=True)
197 x_red = F.relu(self.front_cnn_3(x_red), inplace=True)
198 x_red = F.relu(self.front_cnn_4(x_red), inplace=True)
200 x_red = F.relu(self.front_bn1(self.front_cnn_1(x)), inplace=True)
201 x_red = F.relu(self.front_bn2(self.front_cnn_2(x_red)), inplace=True)
202 x_red = F.relu(self.front_bn3(self.front_cnn_3(x_red)), inplace=True)
203 x_red = F.relu(self.front_bn4(self.front_cnn_4(x_red)), inplace=True)
199 204 x_red = self.max_pooling(x_red) x_red = self.max_pooling(x_red)
200 205
201 x_green = F.relu(self.front_cnn_1(x), inplace=True)
202 x_green = F.relu(self.front_cnn_2(x_green), inplace=True)
203 x_green = F.relu(self.front_cnn_3(x_green), inplace=True)
206 x_green = F.relu(self.front_bn1(self.front_cnn_1(x)), inplace=True)
207 x_green = F.relu(self.front_bn2(self.front_cnn_2(x_green)), inplace=True)
208 x_green = F.relu(self.front_bn3(self.front_cnn_3(x_green)), inplace=True)
204 209 x_green = self.max_pooling(x_green) x_green = self.max_pooling(x_green)
205 210
206 x_blue = F.relu(self.front_cnn_1(x), inplace=True)
207 x_blue = F.relu(self.front_cnn_2(x_blue), inplace=True)
211 x_blue = F.relu(self.front_bn1(self.front_cnn_1(x)), inplace=True)
212 x_blue = F.relu(self.front_bn2(self.front_cnn_2(x_blue)), inplace=True)
208 213 x_blue = self.max_pooling(x_blue) x_blue = self.max_pooling(x_blue)
209 214
210 215 x = torch.cat((x_red, x_green, x_blue), 1) x = torch.cat((x_red, x_green, x_blue), 1)
File train_custom_compact_cnn.py changed (mode: 100644) (index 3514028..24183e4)
... ... from visualize_util import get_readable_time
11 11
12 12 import torch import torch
13 13 from torch import nn from torch import nn
14 from models import CustomCNNv2
14 from models import CustomCNNv3
15 15 import os import os
16 16 from model_util import get_lr from model_util import get_lr
17 17
 
... ... if __name__ == "__main__":
51 51 print("len train_loader ", len(train_loader)) print("len train_loader ", len(train_loader))
52 52
53 53 # model # model
54 model = CustomCNNv2()
54 model = CustomCNNv3()
55 55 model = model.to(device) model = model.to(device)
56 56
57 57 # loss function # loss function
 
... ... if __name__ == "__main__":
122 122 filename_prefix=args.task_id, filename_prefix=args.task_id,
123 123 n_saved=5) n_saved=5)
124 124
125 trainer.add_event_handler(Events.EPOCH_COMPLETED(every=3), save_handler)
125 trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), save_handler)
126 126
127 127 trainer.run(train_loader, max_epochs=args.epochs) trainer.run(train_loader, max_epochs=args.epochs)
File train_custom_compact_cnn_lrscheduler.py changed (mode: 100644) (index 943fa68..07d6c9c)
1 from comet_ml import Experiment
2
1 3 from args_util import my_args_parse from args_util import my_args_parse
2 4 from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list
3 5 from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
4 6 from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError
7 from ignite.engine import Engine
5 8 from ignite.handlers import Checkpoint, DiskSaver from ignite.handlers import Checkpoint, DiskSaver
6 9 from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError
7 10 from visualize_util import get_readable_time from visualize_util import get_readable_time
8 11
9 12 import torch import torch
10 13 from torch import nn from torch import nn
11 from models import CompactDilatedCNN
14 from models import CustomCNNv3
12 15 import os import os
13 16 from ignite.contrib.handlers import PiecewiseLinear from ignite.contrib.handlers import PiecewiseLinear
17 from model_util import get_lr
14 18
19 COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM"
20 PROJECT_NAME = "crowd-counting-framework"
15 21
16 22 if __name__ == "__main__": if __name__ == "__main__":
23 experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API)
24
17 25 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 26 print(device) print(device)
19 27 args = my_args_parse() args = my_args_parse()
28 experiment.set_name(args.task_id)
20 29 print(args) print(args)
30 experiment.set_cmd_args()
31
21 32 DATA_PATH = args.input DATA_PATH = args.input
22 33 TRAIN_PATH = os.path.join(DATA_PATH, "train_data") TRAIN_PATH = os.path.join(DATA_PATH, "train_data")
23 34 TEST_PATH = os.path.join(DATA_PATH, "test_data") TEST_PATH = os.path.join(DATA_PATH, "test_data")
 
... ... if __name__ == "__main__":
35 46 test_list = create_image_list(TEST_PATH) test_list = create_image_list(TEST_PATH)
36 47
37 48 # create data loader # create data loader
38 train_loader, val_loader, test_loader = get_dataloader(train_list, None, test_list, dataset_name=dataset_name)
49 train_loader, val_loader, test_loader = get_dataloader(train_list, None, test_list, dataset_name=dataset_name, batch_size=args.batch_size)
39 50
40 51 print("len train_loader ", len(train_loader)) print("len train_loader ", len(train_loader))
41 52
42 53 # model # model
43 model = CompactDilatedCNN()
54 model = CustomCNNv3()
44 55 model = model.to(device) model = model.to(device)
45 56
46 57 # loss function # loss function
 
... ... if __name__ == "__main__":
49 60 optimizer = torch.optim.Adam(model.parameters(), args.lr, optimizer = torch.optim.Adam(model.parameters(), args.lr,
50 61 weight_decay=args.decay) weight_decay=args.decay)
51 62
52 milestones_values = [(50, 1e-4), (50, 5e-5), (50, 1e-5), (50, 5e-6), (50, 1e-6), (100, 1e-7)]
63 milestones_values = [(70, 1e-4), (100, 1e-5), (200, 1e-5)]
64 experiment.log_parameter("milestones_values", str(milestones_values))
53 65 lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)
66
54 67 trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
55 68 evaluator = create_supervised_evaluator(model, evaluator = create_supervised_evaluator(model,
56 69 metrics={ metrics={
57 70 'mae': CrowdCountingMeanAbsoluteError(), 'mae': CrowdCountingMeanAbsoluteError(),
58 71 'mse': CrowdCountingMeanSquaredError(), 'mse': CrowdCountingMeanSquaredError(),
59 'nll': Loss(loss_fn)
72 'loss': Loss(loss_fn)
60 73 }, device=device) }, device=device)
61 74 print(model) print(model)
62 75
 
... ... if __name__ == "__main__":
74 87 print("change lr to ", args.lr) print("change lr to ", args.lr)
75 88 else: else:
76 89 print("do not load, keep training") print("do not load, keep training")
77
78 90 trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler)
79 91
80 92
 
... ... if __name__ == "__main__":
90 102 metrics = evaluator.state.metrics metrics = evaluator.state.metrics
91 103 timestamp = get_readable_time() timestamp = get_readable_time()
92 104 print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
93 .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll']))
94
105 .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss']))
106 experiment.log_metric("epoch", trainer.state.epoch)
107 experiment.log_metric("train_mae", metrics['mae'])
108 experiment.log_metric("train_mse", metrics['mse'])
109 experiment.log_metric("train_loss", metrics['loss'])
110 experiment.log_metric("lr", get_lr(optimizer))
95 111
96 112 @trainer.on(Events.EPOCH_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED)
97 113 def log_validation_results(trainer): def log_validation_results(trainer):
 
... ... if __name__ == "__main__":
99 115 metrics = evaluator.state.metrics metrics = evaluator.state.metrics
100 116 timestamp = get_readable_time() timestamp = get_readable_time()
101 117 print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
102 .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll']))
103
118 .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss']))
119 experiment.log_metric("valid_mae", metrics['mae'])
120 experiment.log_metric("valid_mse", metrics['mse'])
121 experiment.log_metric("valid_loss", metrics['loss'])
104 122
105 123
106 124 # docs on save and load # docs on save and load
 
... ... if __name__ == "__main__":
109 127 filename_prefix=args.task_id, filename_prefix=args.task_id,
110 128 n_saved=5) n_saved=5)
111 129
112 trainer.add_event_handler(Events.EPOCH_COMPLETED(every=3), save_handler)
130 trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), save_handler)
113 131
114 132 trainer.run(train_loader, max_epochs=args.epochs) trainer.run(train_loader, max_epochs=args.epochs)
File train_script/CCNN_custom/custom_ccnn_v3_t1_scheduler_shb.sh added (mode: 100644) (index 0000000..559f17d)
1 CUDA_VISIBLE_DEVICES=2 HTTPS_PROXY="http://10.30.58.36:81" nohup python train_custom_compact_cnn_lrscheduler.py \
2 --task_id custom_ccnn_v3_t1_scheduler_shb \
3 --note "train custom ccnn v1 branching 3 branch" \
4 --input /data/rnd/thient/thient_data/ShanghaiTech/part_B \
5 --lr 1e-4 \
6 --decay 1e-4 \
7 --datasetname shanghaitech_keepfull \
8 --batch_size 5 \
9 --epochs 400 > logs/custom_ccnn_v3_t1_scheduler_shb.log &
File train_script/CCNN_custom/custom_ccnn_v3_t1_shb.sh copied from file train_script/CCNN_custom/custom_ccnn_v2_t1_shb.sh (similarity 57%) (mode: 100644) (index 873fb1c..58efe29)
1 CUDA_VISIBLE_DEVICES=6 HTTPS_PROXY="http://10.30.58.36:81" nohup python train_custom_compact_cnn.py \
2 --task_id custom_ccnn_v2_t1_shb \
1 CUDA_VISIBLE_DEVICES=2 HTTPS_PROXY="http://10.30.58.36:81" nohup python train_custom_compact_cnn.py \
2 --task_id custom_ccnn_v3_t1_shb \
3 3 --note "train custom ccnn v1 branching 3 branch" \ --note "train custom ccnn v1 branching 3 branch" \
4 4 --input /data/rnd/thient/thient_data/ShanghaiTech/part_B \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_B \
5 5 --lr 1e-4 \ --lr 1e-4 \
6 --decay 0 \
6 --decay 1e-4 \
7 7 --datasetname shanghaitech_keepfull \ --datasetname shanghaitech_keepfull \
8 8 --batch_size 5 \ --batch_size 5 \
9 --epochs 502 > logs/custom_ccnn_v2_t1_shb.log &
9 --epochs 400 > logs/custom_ccnn_v3_t1_shb.log &
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main