List of commits:
Subject Hash Author Date (UTC)
my simple v4 with addition deform cnn at the end 5392aaf6c14fdd910f52096dbb921bed7470c4f7 Thai Thien 2020-03-07 18:15:22
fix the scheduler 77e6737a040f5aa5745b8a8830f5bec12322b10f Thai Thien 2020-03-07 17:46:02
t4 lr 1e-5 acd41ed30c95f63e01a05a6d9929410637852d9e Thai Thien 2020-03-06 19:41:49
no more lr scheduler 7289adb41de7807258eb8c29e6108fa65f59525a Thai Thien 2020-03-06 19:35:49
reduce learning rate bc8241e5b88b91c18bb7999a8d5d12fc79a5e3f7 Thai Thien 2020-03-06 19:28:27
dilated ccnn 5c5d92bdc0a288dd5d4ec5f1367d8cb928175bbe Thai Thien 2020-03-06 19:04:01
done 9f05e093ec7c10284a4aedf0738f9e61d5ac6fb6 Thai Thien 2020-03-06 18:02:34
with lr scheduler 466c364b60ed22c77319b14ccc9a201614b908bf Thai Thien 2020-03-04 17:57:49
train with learning rate scheduler fcd5a3c8da2dd6763e0d40742edf47b49c95fcfb Thai Thien 2020-03-04 17:55:11
ccnn no padding at output layer 57563fc07f656c63f807de4d80712ff11345109d Thai Thien 2020-03-04 17:36:43
fix dimension of ccnn f4439d9a78273ab3ba450f31a528509816b4352f Thai Thien 2020-03-04 17:32:48
ready to train dbe0d6c3271dbb22490f0877fa31ba9cd7852b99 Thai Thien 2020-03-04 15:55:05
done implement c-cnn 2deecef953baf1e07ce5cf5477d208bc7ffa34cf Thai Thien 2020-03-03 17:25:03
fix script 0e9d372b9ad60b32939f1e558b2a59fc7d518fa2 Thai Thien 2020-03-02 16:23:55
simple v3 to 91 epoch 539fdd03c3e3497fd22b7db2aaa14f067cbf6f8d Thai Thien 2020-03-02 16:09:43
we train on all training data and validate on test data 9407ef8d5b7c47c53d6f98dcb3c20208aad1d7a9 Thai Thien 2020-03-01 15:36:46
load and continue train v3 12421fb7330e5c9d2eed4f6e574dfe69bdfddefc Thai Thien 2020-03-01 14:50:01
add env file b1ed02088b01af42efc8d6963b3699e0a5c31c01 Thai Thien 2020-03-01 11:10:56
sanity check dataloader 034daa8bef69daff92891cc42b988a6c77b010f9 Thai Thien 2020-03-01 10:24:38
print train loader len eef3995f63a631e0ec5d92e31f5d7db27fd04401 Thai Thien 2020-03-01 05:17:55
Commit 5392aaf6c14fdd910f52096dbb921bed7470c4f7 - my simple v4 with addition deform cnn at the end
Author: Thai Thien
Author date (UTC): 2020-03-07 18:15
Committer name: Thai Thien
Committer date (UTC): 2020-03-07 18:15
Parent(s): 77e6737a040f5aa5745b8a8830f5bec12322b10f
Signing key:
Tree: ad8e31ebda24e28d02998e77d2a127a55768d127
File Lines added Lines deleted
models/__init__.py 1 1
models/attn_can_adcrowdnet_simple.py 132 0
train_attn_can_adcrowdnet_simple.py 2 2
train_script/my_simple1/my_simple_v4_t1.sh 2 2
File models/__init__.py changed (mode: 100644) (index f1cf370..e02e0ec)
... ... from .context_aware_network import CANNet
4 4 from .deform_conv_v2 import DeformConv2d from .deform_conv_v2 import DeformConv2d
5 5 from .attn_can_adcrowdnet import AttnCanAdcrowdNet from .attn_can_adcrowdnet import AttnCanAdcrowdNet
6 6 from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg from .attn_can_adcrowdnet_freeze_vgg import AttnCanAdcrowdNetFreezeVgg
7 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3
7 from .attn_can_adcrowdnet_simple import AttnCanAdcrowdNetSimpleV1, AttnCanAdcrowdNetSimpleV2, AttnCanAdcrowdNetSimpleV3, AttnCanAdcrowdNetSimpleV4
8 8 from .compact_cnn import CompactCNN, CompactDilatedCNN from .compact_cnn import CompactCNN, CompactDilatedCNN
File models/attn_can_adcrowdnet_simple.py changed (mode: 100644) (index 135c6ca..85b6f95)
... ... class AttnCanAdcrowdNetSimpleV3(nn.Module):
392 392 nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
393 393
394 394
395 class AttnCanAdcrowdNetSimpleV4(nn.Module):
396 """
397 compare with v3: add 1 layer (1 branch) of deformable cnn before output layer
398 """
399 def __init__(self, load_weights=False):
400 super(AttnCanAdcrowdNetSimpleV4, self).__init__()
401 self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
402 self.frontend = make_layers(self.frontend_feat)
403
404 # freeze vgg layer
405 for param in self.frontend.parameters():
406 param.requires_grad = False
407
408 self.sSE = SpatialSELayer(num_channels=512)
409
410 self.concat_filter_layer = nn.Conv2d(1024, 512, kernel_size=3, padding=2, dilation=2)
411
412 # we skip one formation of deformconv
413 # self.deform_conv_1_3 = DeformConv2d(512, 256, kernel_size=3, stride=1, padding=1)
414 # self.deform_conv_1_5 = DeformConv2d(512, 256, kernel_size=5, stride=1, padding=2)
415 # self.deform_conv_1_7 = DeformConv2d(512, 256, kernel_size=7, stride=1, padding=3)
416 self.concat_filter_layer_1 = nn.Conv2d(512, 256, kernel_size=3, padding=2, dilation=2)
417
418 self.dilated_conv_2_3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, dilation=2, padding=2)
419 self.dilated_conv_2_5 = nn.Conv2d(256, 128, kernel_size=3, stride=1, dilation=4, padding=4)
420 # self.deform_conv_2_7 = DeformConv2d(256, 128, kernel_size=7, stride=1, padding=3)
421 self.concat_filter_layer_2 = nn.Conv2d(128 * 2, 128, kernel_size=3, padding=2, dilation=2)
422
423 self.deform_conv_3_3 = DeformConv2d(128, 64, kernel_size=3, stride=1, padding=1)
424 self.deform_conv_3_5 = DeformConv2d(128, 64, kernel_size=5, stride=1, padding=2)
425 # self.deform_conv_3_7 = DeformConv2d(128, 64, kernel_size=7, stride=1, padding=3)
426 self.concat_filter_layer_3 = nn.Conv2d(64 * 2, 64, kernel_size=3, padding=2, dilation=2)
427
428 self.deform_conv_4_3 = DeformConv2d(64, 32, kernel_size=3, stride=1, padding=1)
429
430 self.output_layer = nn.Conv2d(32, 1, kernel_size=1)
431 self.conv1_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
432 self.conv1_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
433 self.conv2_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
434 self.conv2_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
435 self.conv3_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
436 self.conv3_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
437 self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
438 self.conv6_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
439 if not load_weights:
440 mod = models.vgg16(pretrained=True)
441 self._initialize_weights()
442 fsd = collections.OrderedDict()
443 for i in range(len(self.frontend.state_dict().items())):
444 temp_key = list(self.frontend.state_dict().items())[i][0]
445 fsd[temp_key] = list(mod.state_dict().items())[i][1]
446 self.frontend.load_state_dict(fsd)
447
448 def forward(self, x):
449 fv = self.frontend(x)
450
451 # spatial squeeze & excitation
452 fv = self.sSE(fv)
453
454 # S=1
455 ave1 = nn.functional.adaptive_avg_pool2d(fv, (1, 1))
456 ave1 = self.conv1_1(ave1)
457 s1 = nn.functional.upsample(ave1, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
458 c1 = s1 - fv
459 w1 = self.conv1_2(c1)
460 w1 = nn.functional.sigmoid(w1)
461 # S=2
462 ave2 = nn.functional.adaptive_avg_pool2d(fv, (2, 2))
463 ave2 = self.conv2_1(ave2)
464 s2 = nn.functional.upsample(ave2, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
465 c2 = s2 - fv
466 w2 = self.conv2_2(c2)
467 w2 = nn.functional.sigmoid(w2)
468 # S=3
469 ave3 = nn.functional.adaptive_avg_pool2d(fv, (3, 3))
470 ave3 = self.conv3_1(ave3)
471 s3 = nn.functional.upsample(ave3, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
472 c3 = s3 - fv
473 w3 = self.conv3_2(c3)
474 w3 = nn.functional.sigmoid(w3)
475 # S=6
476 ave6 = nn.functional.adaptive_avg_pool2d(fv, (6, 6))
477 ave6 = self.conv6_1(ave6)
478 s6 = nn.functional.upsample(ave6, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
479 c6 = s6 - fv
480 w6 = self.conv6_2(c6)
481 w6 = nn.functional.sigmoid(w6)
482
483 fi = (w1 * s1 + w2 * s2 + w3 * s3 + w6 * s6) / (w1 + w2 + w3 + w6 + 0.000000000001)
484 x = torch.cat((fv, fi), 1)
485 x = F.relu(self.concat_filter_layer(x), inplace=True)
486
487 # x3 = self.deform_conv_1_3(x)
488 # x5 = self.deform_conv_1_5(x)
489 # x7 = self.deform_conv_1_7(x)
490 # x = torch.cat((x3, x5, x7), 1)
491 # x = torch.cat((x3, x5), 1)
492 x = F.relu(self.concat_filter_layer_1(x), inplace=True)
493
494 x3 = self.dilated_conv_2_3(x)
495 x5 = self.dilated_conv_2_5(x)
496 # x7 = self.deform_conv_2_7(x)
497 # x = torch.cat((x3, x5, x7), 1)
498 x = F.relu(torch.cat((x3, x5), 1), inplace=True)
499 x = F.relu(self.concat_filter_layer_2(x), inplace=True)
500
501 x3 = self.deform_conv_3_3(x)
502 x5 = self.deform_conv_3_5(x)
503 # x7 = self.deform_conv_3_7(x)
504 # x = torch.cat((x3, x5, x7), 1)
505 x = F.relu(torch.cat((x3, x5), 1), inplace=True)
506 x = F.relu(self.concat_filter_layer_3(x), inplace=True)
507
508 x = F.relu(self.deform_conv_4_3(x))
509 x = self.output_layer(x)
510
511 # this cause too much dimension mismatch problem
512 # so we desampling label instead
513 # x = nn.functional.upsample(x, scale_factor=8, mode='bilinear') / 64.0
514 return x
515
516 def _initialize_weights(self):
517 for m in self.modules():
518 if isinstance(m, nn.Conv2d):
519 nn.init.normal_(m.weight, std=0.01)
520 if m.bias is not None:
521 nn.init.constant_(m.bias, 0)
522 elif isinstance(m, nn.BatchNorm2d):
523 nn.init.constant_(m.weight, 1)
524 nn.init.constant_(m.bias, 0)
525
526
395 527 def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False): def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False):
396 528 if dilation: if dilation:
397 529 d_rate = 2 d_rate = 2
File train_attn_can_adcrowdnet_simple.py changed (mode: 100644) (index 4096800..31334b0)
... ... from visualize_util import get_readable_time
9 9
10 10 import torch import torch
11 11 from torch import nn from torch import nn
12 from models import AttnCanAdcrowdNetSimpleV3
12 from models import AttnCanAdcrowdNetSimpleV4
13 13 import os import os
14 14
15 15
 
... ... if __name__ == "__main__":
40 40 print("len train_loader ", len(train_loader)) print("len train_loader ", len(train_loader))
41 41
42 42 # model # model
43 model = AttnCanAdcrowdNetSimpleV3()
43 model = AttnCanAdcrowdNetSimpleV4()
44 44 model = model.to(device) model = model.to(device)
45 45
46 46 # loss function # loss function
File train_script/my_simple1/my_simple_v4_t1.sh copied from file train_script/my_simple1/my_simplev3_noload_t1.sh (similarity 75%) (mode: 100644) (index fe6e65a..423cead)
1 1 CUDA_VISIBLE_DEVICES=4 nohup python train_attn_can_adcrowdnet_simple.py \ CUDA_VISIBLE_DEVICES=4 nohup python train_attn_can_adcrowdnet_simple.py \
2 --task_id simple_v3_t1 \
2 --task_id simple_v4_t1 \
3 3 --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \
4 4 --lr 1e-5 \ --lr 1e-5 \
5 5 --decay 5e-4 \ --decay 5e-4 \
6 6 --datasetname shanghaitech_keepfull \ --datasetname shanghaitech_keepfull \
7 --epochs 22 > logs/simple_v3_t1.log &
7 --epochs 32 > logs/simple_v4_t1.log &
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main