List of commits:
Subject Hash Author Date (UTC)
add perspective 642d6fff8c9f31e510fda85a7fb631fb855d8a6d Thai Thien 2019-10-06 16:54:44
fix padding with p 86c2fa07822d956a34b3b37e14da485a4249f01b Thai Thien 2019-10-06 02:52:58
pacnn perspective loss fb673e38a5f24ae9004fe2b7b93c88991e0c2304 Thai Thien 2019-10-06 01:38:28
data_flow shanghaitech_pacnn_with_perspective seem working 91d350a06f358e03223966297d124daee94123d0 Thai Thien 2019-10-06 01:31:11
multiscale loss and final loss only mode c65dd0e74ad28503821e5c8651a3b47b4a0c7c64 Thai Thien 2019-10-05 15:58:19
wip : perspective map eac63f2671dc5b064753acc4f40bf0f9f216ad2a Thai Thien 2019-10-04 16:26:56
shell script f2106e700b6f6174d4dd276f25ec6f3d9ff239bb thient 2019-10-04 07:42:51
WIP 42c7c8e1d772fbbda61a4bdf9e329f74e1efb600 tthien 2019-10-03 17:52:47
add readme 580cf43d1edddd67b1f6a2c57fdd5cee3dba925c Thai Thien 2019-10-02 17:44:49
update script, debug ddb68b95389be1c1d398118677dd227a8bb2b70b Thai Thien 2019-10-02 15:52:31
add d (output density map) to loss function) a0c71bf4bf2ab7393d60b06a84db8dfbbfb1a6c2 tthien 2019-09-30 16:32:39
fix the args, add save interval for model, so we don't save them all 9fdf9daa2ac4bd12b7b62521d81e520db0debd01 tthien 2019-09-30 16:30:00
meow 1ad19a22a310992e27a26471feeb37375124d075 tthien 2019-09-29 18:25:43
fix pacnn perspective map 453ece3ccb818889ba895bfc4285f7905d33cba5 Thai Thien 2019-09-25 17:20:33
apex not work so well da8c0dd57297f972201f31d57e66897177922f48 Thai Thien 2019-09-24 17:25:59
fix data loader pacnn so it will scale up with correct number of people 11d55b50d764511f2491291f0208fee0905dec49 Thai Thien 2019-09-24 15:40:56
add comet ml a9d4b89ce594f5e241168ccafdcdf0f150ea0ebb Thai Thien 2019-09-23 17:07:58
fix pacnn avg schema c2140a96886195782e5689c24aeeb4fe7a2db7ad Thai Thien 2019-09-22 17:35:01
debug number not divisible by 8 a568fd7f294a8bd31b3db78437b4b6b51b5b41b9 Thai Thien 2019-09-22 04:36:06
pacnn 967074890d14ab0eefc277801860270a468e8f9f Thai Thien 2019-09-22 03:54:48
Commit 642d6fff8c9f31e510fda85a7fb631fb855d8a6d - add perspective
Author: Thai Thien
Author date (UTC): 2019-10-06 16:54
Committer name: Thai Thien
Committer date (UTC): 2019-10-06 16:54
Parent(s): 86c2fa07822d956a34b3b37e14da485a4249f01b
Signing key:
Tree: 3816f82b3c01435472a8758e7e4322db1b3a6a6f
File Lines added Lines deleted
args_util.py 1 0
data_flow.py 4 2
main_pacnn.py 72 7
train_script/test_pacnn_shanghaitechA.sh 16 0
train_script/train_pacnn_shanghaitechA.sh 51 9
File args_util.py changed (mode: 100644) (index 3fa2c42..2e74c07)
... ... def real_args_parse():
72 72 parser.add_argument('--momentum', action="store", default=0.9, type=float) parser.add_argument('--momentum', action="store", default=0.9, type=float)
73 73 parser.add_argument('--decay', action="store", default=5*1e-3, type=float) parser.add_argument('--decay', action="store", default=5*1e-3, type=float)
74 74 parser.add_argument('--epochs', action="store", default=1, type=int) parser.add_argument('--epochs', action="store", default=1, type=int)
75 parser.add_argument('--test', action="store_true", default=False)
75 76
76 77 # pacnn setting only # pacnn setting only
77 78 parser.add_argument('--PACNN_PERSPECTIVE_AWARE_MODEL', action="store", default=0, type=int) parser.add_argument('--PACNN_PERSPECTIVE_AWARE_MODEL', action="store", default=0, type=int)
File data_flow.py changed (mode: 100644) (index 6faebf8..a3f3d3c)
... ... def load_data_shanghaitech_pacnn_with_perspective(img_path, train=True):
138 138 img = img.transpose(Image.FLIP_LEFT_RIGHT) img = img.transpose(Image.FLIP_LEFT_RIGHT)
139 139 perspective = np.fliplr(perspective) perspective = np.fliplr(perspective)
140 140
141 perspective /= np.max(perspective)
142
141 143 target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)), target1 = cv2.resize(target, (int(target.shape[1] / 8), int(target.shape[0] / 8)),
142 144 interpolation=cv2.INTER_CUBIC) * 64 interpolation=cv2.INTER_CUBIC) * 64
143 145 target2 = cv2.resize(target, (int(target.shape[1] / 16), int(target.shape[0] / 16)), target2 = cv2.resize(target, (int(target.shape[1] / 16), int(target.shape[0] / 16)),
 
... ... def load_data_shanghaitech_pacnn_with_perspective(img_path, train=True):
146 148 interpolation=cv2.INTER_CUBIC) * 1024 interpolation=cv2.INTER_CUBIC) * 1024
147 149
148 150 perspective_s = cv2.resize(perspective, (int(perspective.shape[1] / 16), int(perspective.shape[0] / 16)), perspective_s = cv2.resize(perspective, (int(perspective.shape[1] / 16), int(perspective.shape[0] / 16)),
149 interpolation=cv2.INTER_CUBIC) * 256
151 interpolation=cv2.INTER_CUBIC)
150 152
151 153 perspective_p = cv2.resize(perspective, (int(perspective.shape[1] / 8), int(perspective.shape[0] / 8)), perspective_p = cv2.resize(perspective, (int(perspective.shape[1] / 8), int(perspective.shape[0] / 8)),
152 interpolation=cv2.INTER_CUBIC) * 64
154 interpolation=cv2.INTER_CUBIC)
153 155
154 156 return img, (target1, target2, target3, perspective_s, perspective_p) return img, (target1, target2, target3, perspective_s, perspective_p)
155 157
File main_pacnn.py changed (mode: 100644) (index 823774f..cb90fab)
... ... if __name__ == "__main__":
24 24
25 25 # Add the following code anywhere in your machine learning file # Add the following code anywhere in your machine learning file
26 26 experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM", experiment = Experiment(api_key="S3mM1eMq6NumMxk2QJAXASkUM",
27 project_name="pacnn-dev2", workspace="ttpro1995", disabled=True)
27 project_name="pacnn-dev2", workspace="ttpro1995")
28 28
29 29 args = real_args_parse() args = real_args_parse()
30 device = "cpu"
31 30 print(device) print(device)
32 31 print(args) print(args)
33 32
34 33
35 34
36 35 MODEL_SAVE_NAME = args.task_id MODEL_SAVE_NAME = args.task_id
36 TEST = args.test
37 37 MODEL_SAVE_INTERVAL = 5 MODEL_SAVE_INTERVAL = 5
38 38 DATA_PATH = args.input DATA_PATH = args.input
39 39 TOTAL_EPOCH = args.epochs TOTAL_EPOCH = args.epochs
40 40 PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL PACNN_PERSPECTIVE_AWARE_MODEL = args.PACNN_PERSPECTIVE_AWARE_MODEL
41 41 PACNN_MUTILPLE_SCALE_LOSS = args.PACNN_MUTILPLE_SCALE_LOSS PACNN_MUTILPLE_SCALE_LOSS = args.PACNN_MUTILPLE_SCALE_LOSS
42 42 DATASET_NAME = "shanghaitech_pacnn" DATASET_NAME = "shanghaitech_pacnn"
43
43 44 if PACNN_PERSPECTIVE_AWARE_MODEL: if PACNN_PERSPECTIVE_AWARE_MODEL:
44 45 DATASET_NAME = "shanghaitech_pacnn_with_perspective" DATASET_NAME = "shanghaitech_pacnn_with_perspective"
45 46
 
... ... if __name__ == "__main__":
87 88 num_workers=4, dataset_name="shanghaitech_pacnn"), num_workers=4, dataset_name="shanghaitech_pacnn"),
88 89 batch_size=1, num_workers=4) batch_size=1, num_workers=4)
89 90
91 test_loader_pacnn = torch.utils.data.DataLoader(
92 ListDataset(test_list,
93 shuffle=False,
94 transform=transforms.Compose([
95 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
96 std=[0.229, 0.224, 0.225]),
97 ]),
98 train=False,
99 batch_size=1,
100 num_workers=4, dataset_name="shanghaitech_pacnn"),
101 batch_size=1, num_workers=4)
102
90 103 # create model # create model
91 104 net = PACNNWithPerspectiveMap(perspective_aware_mode=PACNN_PERSPECTIVE_AWARE_MODEL).to(device) net = PACNNWithPerspectiveMap(perspective_aware_mode=PACNN_PERSPECTIVE_AWARE_MODEL).to(device)
92 105 criterion_mse = nn.MSELoss(size_average=False).to(device) criterion_mse = nn.MSELoss(size_average=False).to(device)
 
... ... if __name__ == "__main__":
112 125 else: else:
113 126 print("new model") print("new model")
114 127
128 if TEST:
129 print("test model")
130 mae_calculator_d1 = MAECalculator()
131 mae_calculator_d2 = MAECalculator()
132 mae_calculator_d3 = MAECalculator()
133 mae_calculator_final = MAECalculator()
134 with torch.no_grad():
135 for val_img, label in test_loader_pacnn:
136 net.eval()
137 # load data
138 d1_label, d2_label, d3_label = label
139
140 # forward pass
141 d1, d2, d3, p_s, p, d = net(val_img.to(device))
142
143 d1_label = d1_label.to(device)
144 d2_label = d2_label.to(device)
145 d3_label = d3_label.to(device)
146
147 # score
148 mae_calculator_d1.eval(d1.cpu().detach().numpy(), d1_label.cpu().detach().numpy())
149 mae_calculator_d2.eval(d2.cpu().detach().numpy(), d2_label.cpu().detach().numpy())
150 mae_calculator_d3.eval(d3.cpu().detach().numpy(), d3_label.cpu().detach().numpy())
151 mae_calculator_final.eval(d.cpu().detach().numpy(), d1_label.cpu().detach().numpy())
152 print("count ", mae_calculator_d1.count)
153 print("d1_val ", mae_calculator_d1.get_mae())
154 print("d2_val ", mae_calculator_d2.get_mae())
155 print("d3_val ", mae_calculator_d3.get_mae())
156 print("dfinal_val ", mae_calculator_final.get_mae())
157 experiment.log_metric("d1_val", mae_calculator_d1.get_mae())
158 experiment.log_metric("d2_val", mae_calculator_d2.get_mae())
159 experiment.log_metric("d3_val", mae_calculator_d3.get_mae())
160 experiment.log_metric("dfinal_val", mae_calculator_final.get_mae())
161 exit()
162
115 163 while current_epoch < TOTAL_EPOCH: while current_epoch < TOTAL_EPOCH:
116 164 experiment.log_current_epoch(current_epoch) experiment.log_current_epoch(current_epoch)
117 165 current_epoch += 1 current_epoch += 1
 
... ... if __name__ == "__main__":
153 201 # TODO: loss for perspective map here # TODO: loss for perspective map here
154 202 pad_p_0 = perspective_p.size()[2] - p.size()[2] pad_p_0 = perspective_p.size()[2] - p.size()[2]
155 203 pad_p_1 = perspective_p.size()[3] - p.size()[3] pad_p_1 = perspective_p.size()[3] - p.size()[3]
156 p_pad = F.pad(p, (0, pad_p_1, 0, pad_p_0), mode='replicate')
204 if pad_p_0 == 0:
205 pad_p_0 = -perspective_p.size()[2]
206 if pad_p_1 == 0:
207 pad_p_1 = -perspective_p.size()[3]
208
209 # p_pad = F.pad(p, (0, pad_p_1, 0, pad_p_0), mode='replicate')
210 perspective_p_pad = perspective_p[:,:, 0:-pad_p_0, 0:-pad_p_1]
157 211
158 loss_p = criterion_mse(p_pad, perspective_p) + criterion_ssim(p_pad, perspective_p)
212 # print(p.shape)
213 # print(perspective_p.shape)
214 # print(pad_p_0, pad_p_1)
215 # print(perspective_p_pad.shape)
216 loss_p = criterion_mse(p, perspective_p_pad) + criterion_ssim(p, perspective_p_pad)
159 217
160 218 loss += loss_p loss += loss_p
161 219 if PACNN_MUTILPLE_SCALE_LOSS: if PACNN_MUTILPLE_SCALE_LOSS:
162 220 pad_s_0 = perspective_s.size()[2] - p_s.size()[2] pad_s_0 = perspective_s.size()[2] - p_s.size()[2]
163 221 pad_s_1 = perspective_s.size()[3] - p_s.size()[3] pad_s_1 = perspective_s.size()[3] - p_s.size()[3]
164 p_s_pad = F.pad(perspective_s, (0, pad_s_1, 0, pad_s_0),
165 mode='replicate')
222 # p_s_pad = F.pad(perspective_s, (0, pad_s_1, 0, pad_s_0),
223 # mode='replicate')
224
225 if pad_s_0 == 0:
226 pad_s_0 = -perspective_s.size()[2]
227 if pad_s_1 == 0:
228 pad_s_1 = -perspective_s.size()[3]
229
230 perspective_s_pad = perspective_s[:,:, 0:-pad_s_0, 0:-pad_s_1]
166 231
167 loss_p_s = criterion_mse(p_s_pad, perspective_s) + criterion_ssim(p_s_pad, perspective_s)
232 loss_p_s = criterion_mse(p_s, perspective_s_pad) + criterion_ssim(p_s, perspective_s_pad)
168 233 loss += loss_p_s loss += loss_p_s
169 234
170 235 # what is this, loss_d count 2 ? # what is this, loss_d count 2 ?
File train_script/test_pacnn_shanghaitechA.sh added (mode: 100644) (index 0000000..b372569)
1 #python main_pacnn.py \
2 #--input data/ShanghaiTech/part_A \
3 #--load_model saved_model/train_state2_attemp4_265_checkpoint.pth.tar \
4 #--PACNN_PERSPECTIVE_AWARE_MODEL 0 \
5 #--PACNN_MUTILPLE_SCALE_LOSS 0 \
6 #--test \
7 #--task_id test
8
9 python main_pacnn.py \
10 --input data/ShanghaiTech/part_A \
11 --load_model saved_model/train_state1_attemp7_180_checkpoint.pth.tar \
12 --PACNN_PERSPECTIVE_AWARE_MODEL 0 \
13 --PACNN_MUTILPLE_SCALE_LOSS 0 \
14 --test \
15 --task_id test
16
File train_script/train_pacnn_shanghaitechA.sh changed (mode: 100644) (index a405a96..ae75a9f)
58 58
59 59 #################### ####################
60 60
61 python main_pacnn.py \
62 --input data/ShanghaiTech/part_A \
63 --load_model saved_model/train_state1_attemp7_180_checkpoint.pth.tar \
64 --epochs 500 \
65 --lr 1e-9 \
66 --PACNN_PERSPECTIVE_AWARE_MODEL 1 \
67 --PACNN_MUTILPLE_SCALE_LOSS 1 \
68 --task_id train_state2_attemp1
61 #python main_pacnn.py \
62 #--input data/ShanghaiTech/part_A \
63 #--load_model saved_model/train_state1_attemp7_180_checkpoint.pth.tar \
64 #--epochs 500 \
65 #--lr 1e-9 \
66 #--PACNN_PERSPECTIVE_AWARE_MODEL 1 \
67 #--PACNN_MUTILPLE_SCALE_LOSS 1 \
68 #--task_id train_state2_attemp1
69
70
71 #python main_pacnn.py \
72 #--input data/ShanghaiTech/part_A \
73 #--load_model saved_model/train_state2_attemp1_185_checkpoint.pth.tar \
74 #--epochs 500 \
75 #--lr 3e-10 \
76 #--momentum 0.7 \
77 #--PACNN_PERSPECTIVE_AWARE_MODEL 1 \
78 #--PACNN_MUTILPLE_SCALE_LOSS 1 \
79 #--task_id train_state2_attemp2
80
69 81
82 # momentum
70 83
71 84 #--input data/ShanghaiTech/part_A \ #--input data/ShanghaiTech/part_A \
72 85 #--load_model saved_model/train_state1_attemp7_180_checkpoint.pth.tar #--load_model saved_model/train_state1_attemp7_180_checkpoint.pth.tar
 
... ... python main_pacnn.py \
74 87 #--lr 1e-9 #--lr 1e-9
75 88 #--PACNN_PERSPECTIVE_AWARE_MODEL 1 #--PACNN_PERSPECTIVE_AWARE_MODEL 1
76 89 #--PACNN_MUTILPLE_SCALE_LOSS 1 #--PACNN_MUTILPLE_SCALE_LOSS 1
77 #--task_id dev
90 #--task_id dev
91
92
93 #python main_pacnn.py \
94 #--input data/ShanghaiTech/part_A \
95 #--load_model saved_model/train_state2_attemp2_220_checkpoint.pth.tar \
96 #--epochs 500 \
97 #--lr 3e-10 \
98 #--momentum 0.7 \
99 #--PACNN_PERSPECTIVE_AWARE_MODEL 1 \
100 #--PACNN_MUTILPLE_SCALE_LOSS 0 \
101 #--task_id train_state2_attemp3
102
103 #python main_pacnn.py \
104 #--input data/ShanghaiTech/part_A \
105 #--load_model saved_model/train_state2_attemp3_240_checkpoint.pth.tar \
106 #--epochs 600 \
107 #--lr 1e-9 \
108 #--momentum 0.9 \
109 #--PACNN_PERSPECTIVE_AWARE_MODEL 1 \
110 #--PACNN_MUTILPLE_SCALE_LOSS 0 \
111 #--task_id train_state2_attemp4
112
113 python main_pacnn.py \
114 --input data/ShanghaiTech/part_A \
115 --load_model saved_model/train_state2_attemp4_265_checkpoint.pth.tar \
116 --PACNN_PERSPECTIVE_AWARE_MODEL 1 \
117 --PACNN_MUTILPLE_SCALE_LOSS 0 \
118 --test \
119 --task_id test
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main