List of commits:
Subject Hash Author Date (UTC)
setting up code 6e83edd7283bfeec6ec1a6e5057719e8ae95e8ea Thai Thien 2020-04-02 10:25:18
fix typo f7565a25dc7058d1f344cc0d6e610fa2bf8f19e9 Thai Thien 2020-04-01 10:58:38
t4 aug 36b5bbc17111dca87cfa082f5a512b84576c8fca Thai Thien 2020-04-01 10:55:12
reproduce ccnn without batchnorm e4678b328ba9d00b9d94617a1cd0f34732f7b2ac Thai Thien 2020-03-30 16:48:35
change gpu number f4c616536d13d4caf383166caae0d8298cb73738 Thai Thien 2020-03-28 04:17:37
train custom ccnn_v2 8b6fca9cd3cb15111befd7332bc7e764d5494a66 Thai Thien 2020-03-28 03:51:07
v5 24e4a8a9c8948c21b0d6beadec0418adceec8577 Thai Thien 2020-03-20 17:56:32
hope it fix cuda bug 7b8cbb70550df279035ef56a22450d17fed6b29a Thai Thien 2020-03-20 17:55:24
ccnnv4 sha a4acad0caf5a18fc9ea037a5e3d6db9d6b6106c7 Thai Thien 2020-03-20 17:38:30
change gpu fa2c1c54858d8a57585c3801dc168c464ba689dd Thai Thien 2020-03-18 17:31:38
fix typo again d71ce17c4efbcc2cd31917f0e27e90bd964ac964 Thai Thien 2020-03-18 17:30:34
typo again c16871fb2cc035e168d62df9c5d4401bc465df95 Thai Thien 2020-03-18 17:27:45
fix minor typo 01d907dda5c1e4c08b5ca5607aa77365cfd70a2e Thai Thien 2020-03-18 17:25:32
forgot to change the model cb8bde150d7e764f67f33c980759675a2f5cce8e Thai Thien 2020-03-18 17:23:51
log summary for adcrowdnet simple lrscheduler cefb188f185dbba3dd95f5034f17bd43dd194d44 Thai Thien 2020-03-18 17:18:36
prepare exp 4658d0a852d952f1e1f600d9c494b8d511b9d52c Thai Thien 2020-03-18 17:17:05
CompactCNNV6 with a bunch of batchnorm a9f21eac8ab6a328fac7a23bb5eddbbaa4496d04 Thai Thien 2020-03-18 17:09:26
implement attn_can_adcrowdnet_v5 c905bd76d3a5d3b933d1e6eab6e50da187cee3a2 Thai Thien 2020-03-18 16:56:02
forget add --momentum 58131da12819203ec13ea62e7ad6862cf6db301d thient 2020-03-18 03:41:30
shell script ebb46f6853a666ecdfbaf42bb927cad0ec9e548e thient 2020-03-18 03:38:56
Commit 6e83edd7283bfeec6ec1a6e5057719e8ae95e8ea - setting up code
Author: Thai Thien
Author date (UTC): 2020-04-02 10:25
Committer name: Thai Thien
Committer date (UTC): 2020-04-02 10:25
Parent(s): f7565a25dc7058d1f344cc0d6e610fa2bf8f19e9
Signer:
Signing key:
Signing status: N
Tree: f7ac71a6644180c8a54c15ce2abb5c8769ce5cf0
File Lines added Lines deleted
args_util.py 21 0
experiment_meow_main.py 13 5
models/meow_experiment/__init__.py 0 0
models/meow_experiment/kitten_meow_1.py 119 0
File args_util.py changed (mode: 100644) (index bbe5616..5f8c6f4)
... ... def my_args_parse():
100 100 return arg return arg
101 101
102 102
103 def meow_parse():
104 parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network')
105 parser.add_argument("--task_id", action="store", default="dev")
106 parser.add_argument("--model", action="store", default="dev")
107 parser.add_argument('--note', action="store", default="write anything")
108
109 parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A)
110 parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull")
111
112 # args with default value
113 parser.add_argument('--load_model', action="store", default="", type=str)
114 parser.add_argument('--lr', action="store", default=1e-8, type=float)
115 parser.add_argument('--momentum', action="store", default=0.9, type=float)
116 parser.add_argument('--decay', action="store", default=5*1e-3, type=float)
117 parser.add_argument('--epochs', action="store", default=1, type=int)
118 parser.add_argument('--batch_size', action="store", default=1, type=int,
119 help="only set batch_size > 0 for dataset with image size equal")
120 parser.add_argument('--test', action="store_true", default=False)
121 arg = parser.parse_args()
122 return arg
123
103 124 def sanity_check_dataloader_parse(): def sanity_check_dataloader_parse():
104 125 parser = argparse.ArgumentParser(description='Dataloader') parser = argparse.ArgumentParser(description='Dataloader')
105 126 parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A) parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A)
File experiment_meow_main.py copied from file train_custom_compact_cnn.py (similarity 92%) (mode: 100644) (index 07a23b0..7b76dc2)
1 1 from comet_ml import Experiment from comet_ml import Experiment
2 2
3 from args_util import my_args_parse
3 from args_util import meow_parse
4 4 from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list from data_flow import get_train_val_list, get_dataloader, create_training_image_list, create_image_list
5 5 from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
6 6 from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError
 
... ... from visualize_util import get_readable_time
11 11
12 12 import torch import torch
13 13 from torch import nn from torch import nn
14 from models import CustomCNNv2
14 from models.meow_experiment.kitten_meow_1 import M1, M2
15 15 import os import os
16 16 from model_util import get_lr from model_util import get_lr
17 17
18 18 COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM"
19 PROJECT_NAME = "crowd-counting-framework"
19 PROJECT_NAME = "meow-one-experiment-insita"
20 20
21 21 def very_simple_param_count(model): def very_simple_param_count(model):
22 22 result = sum([p.numel() for p in model.parameters()]) result = sum([p.numel() for p in model.parameters()])
 
... ... if __name__ == "__main__":
26 26 experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API) experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API)
27 27 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 28 print(device) print(device)
29 args = my_args_parse()
29 args = meow_parse()
30 30 print(args) print(args)
31 31
32 32 experiment.set_name(args.task_id) experiment.set_name(args.task_id)
 
... ... if __name__ == "__main__":
55 55 print("len train_loader ", len(train_loader)) print("len train_loader ", len(train_loader))
56 56
57 57 # model # model
58 model = CustomCNNv2()
58 model_name = args.model
59 experiment.log_other("model", model_name)
60 if model_name == "M1":
61 model = M1()
62 elif model_name == "M2":
63 model = M2()
64 else:
65 print("error: you didn't pick a model")
66 exit(-1)
59 67 n_param = very_simple_param_count(model) n_param = very_simple_param_count(model)
60 68 experiment.log_other("n_param", n_param) experiment.log_other("n_param", n_param)
61 69 if hasattr(model, 'model_note'): if hasattr(model, 'model_note'):
File models/meow_experiment/__init__.py copied from file playground/__init__.py (similarity 100%)
File models/meow_experiment/kitten_meow_1.py added (mode: 100644) (index 0000000..055ce5f)
1 import torch.nn as nn
2 import torch
3 from torchvision import models
4 from models.deform_conv_v2 import DeformConv2d, TorchVisionBasicDeformConv2d
5 import collections
6 import torch.nn.functional as F
7
8
9 class M1(nn.Module):
10 """
11 A REAL-TIME DEEP NETWORK FOR CROWD COUNTING
12 https://arxiv.org/pdf/2002.06515.pdf
13 the improve version
14
15 we change 5x5 7x7 9x9 with 3x3
16 Keep the tail
17 """
18 def __init__(self, load_weights=False):
19 super(M1, self).__init__()
20 self.model_note = "We replace 5x5 7x7 9x9 with 3x3, no batchnorm yet, keep tail, no dilated"
21 # self.red_cnn = nn.Conv2d(3, 10, 9, padding=4)
22 # self.green_cnn = nn.Conv2d(3, 14, 7, padding=3)
23 # self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2)
24
25 # ideal from crowd counting using DMCNN
26 self.front_cnn_1 = nn.Conv2d(3, 20, 3, padding=1)
27 self.front_cnn_2 = nn.Conv2d(20, 16, 3, padding=1)
28 self.front_cnn_3 = nn.Conv2d(16, 14, 3, padding=1)
29 self.front_cnn_4 = nn.Conv2d(14, 10, 3, padding=1)
30
31 self.c0 = nn.Conv2d(40, 40, 3, padding=1)
32 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
33
34 self.c1 = nn.Conv2d(40, 60, 3, padding=1)
35
36 # ideal from CSRNet
37 self.c2 = nn.Conv2d(60, 40, 3, padding=1)
38 self.c3 = nn.Conv2d(40, 20, 3, padding=1)
39 self.c4 = nn.Conv2d(20, 10, 3, padding=1)
40 self.output = nn.Conv2d(10, 1, 1)
41
42 def forward(self,x):
43 #x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True))
44 #x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True))
45 #x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True))
46
47 x_red = F.relu(self.front_cnn_1(x), inplace=True)
48 x_red = F.relu(self.front_cnn_2(x_red), inplace=True)
49 x_red = F.relu(self.front_cnn_3(x_red), inplace=True)
50 x_red = F.relu(self.front_cnn_4(x_red), inplace=True)
51 x_red = self.max_pooling(x_red)
52
53 x_green = F.relu(self.front_cnn_1(x), inplace=True)
54 x_green = F.relu(self.front_cnn_2(x_green), inplace=True)
55 x_green = F.relu(self.front_cnn_3(x_green), inplace=True)
56 x_green = self.max_pooling(x_green)
57
58 x_blue = F.relu(self.front_cnn_1(x), inplace=True)
59 x_blue = F.relu(self.front_cnn_2(x_blue), inplace=True)
60 x_blue = self.max_pooling(x_blue)
61
62 x = torch.cat((x_red, x_green, x_blue), 1)
63 x = F.relu(self.c0(x), inplace=True)
64
65 x = F.relu(self.c1(x), inplace=True)
66
67 x = F.relu(self.c2(x), inplace=True)
68 x = self.max_pooling(x)
69
70 x = F.relu(self.c3(x), inplace=True)
71 x = self.max_pooling(x)
72
73 x = F.relu(self.c4(x), inplace=True)
74
75 x = self.output(x)
76 return x
77
78
79 class M2(nn.Module):
80 """
81 A REAL-TIME DEEP NETWORK FOR CROWD COUNTING
82 https://arxiv.org/pdf/2002.06515.pdf
83 """
84 def __init__(self, load_weights=False):
85 super(M2, self).__init__()
86 self.model_note = "No batchnorm, keep head, but dilated tail"
87 self.red_cnn = nn.Conv2d(3, 10, 9, padding=4)
88 self.green_cnn = nn.Conv2d(3, 14, 7, padding=3)
89 self.blue_cnn = nn.Conv2d(3, 16, 5, padding=2)
90 self.c0 = nn.Conv2d(40, 40, 3, padding=1)
91
92 self.max_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
93
94 self.c1 = nn.Conv2d(40, 60, 3, padding=2, dilation=2)
95 self.c2 = nn.Conv2d(60, 40, 3, padding=2, dilation=2)
96 self.c3 = nn.Conv2d(40, 20, 3, padding=2, dilation=2)
97 self.c4 = nn.Conv2d(20, 10, 3, padding=2, dilation=2)
98 self.output = nn.Conv2d(10, 1, 1)
99
100 def forward(self,x):
101 x_red = self.max_pooling(F.relu(self.red_cnn(x), inplace=True))
102 x_green = self.max_pooling(F.relu(self.green_cnn(x), inplace=True))
103 x_blue = self.max_pooling(F.relu(self.blue_cnn(x), inplace=True))
104
105 x = torch.cat((x_red, x_green, x_blue), 1)
106 x = F.relu(self.c0(x), inplace=True)
107
108 x = F.relu(self.c1(x), inplace=True)
109
110 x = F.relu(self.c2(x), inplace=True)
111 x = self.max_pooling(x)
112
113 x = F.relu(self.c3(x), inplace=True)
114 x = self.max_pooling(x)
115
116 x = F.relu(self.c4(x), inplace=True)
117
118 x = self.output(x)
119 return x
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main