import torch.nn as nn import torch from torchvision import models class CSRNet(nn.Module): def __init__(self, load_weights=False): super(CSRNet, self).__init__() self.seen = 0 self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] self.backend_feat = [512, 512, 512, 256, 128, 64] self.frontend = make_layers(self.frontend_feat) self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True, batch_norm=True) self.output_layer = nn.Conv2d(64, 1, kernel_size=1) if not load_weights: mod = models.vgg16(pretrained=True) self._initialize_weights() for i in range(len(list(self.frontend.state_dict().items()))): list(self.frontend.state_dict().items())[i][1].data[:] = list(mod.state_dict().items())[i][1].data[:] # freeze the pretrain vgg for param in self.frontend.parameters(): param.requires_grad = False def forward(self, x): x = self.frontend(x) x = self.backend(x) x = self.output_layer(x) # remove channel dimension # (N, C_{out}, H_{out}, W_{out}) => (N, H_{out}, W_{out}) x = torch.squeeze(x, dim=1) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False): if dilation: d_rate = 2 else: d_rate = 1 layers = [] for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)] else: layers += [conv2d, nn.ReLU(inplace=False)] in_channels = v return nn.Sequential(*layers)
Mode | Type | Size | Ref | File |
---|---|---|---|---|
100644 | blob | 112 | 54a0bfa5d13ea1dd49622ed3704ad36f6cd68749 | .gitignore |
100644 | blob | 1342 | f2eb3073ff4a8536cf4e8104ff942b525e3c7f34 | .travis.yml |
100644 | blob | 1379 | ebf62a5eacea82b71a44bab3c917754f2de5cb10 | README.md |
100644 | blob | 9232 | d68650a3479daf7de65eb76b0cfa40e9e9dff7f4 | args_util.py |
040000 | tree | - | 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 | backup_notebook |
040000 | tree | - | 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 | bug |
100644 | blob | 3591 | 7b4c18e8cf2c0417cd13d3f77ea0571c9e0e493f | crowd_counting_error_metrics.py |
100644 | blob | 63592 | 812e169de39d3f40f6425035db0deb5e4f878d7c | data_flow.py |
040000 | tree | - | 7b2560d2cb223bf0574eb278bafeda5a8577c7db | data_util |
040000 | tree | - | a7a950c8174433e2acc656a49ade044f054ac9f7 | dataset_script |
040000 | tree | - | b26c5fa1187778134a8801e61e1f28055c3624c1 | debug |
040000 | tree | - | 74e02cec26c0d98f846ab7ab573419265856500b | demo |
040000 | tree | - | 13debfeebc3df105633887f857e8b709318cf661 | demo_app |
040000 | tree | - | 9862b9cbc6e7a1d43565f12d85d9b17d1bf1814e | env_file |
100644 | blob | 4460 | 9b254c348a3453f4df2c3ccbf21fb175a16852de | eval_context_aware_network.py |
100644 | blob | 428 | 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 | evaluator.py |
100644 | blob | 17422 | e896debd86b01578a0e5f6bf886fefdf3922a5bd | experiment_main.py |
100644 | blob | 8876 | 049432d6bde50245a4acba4e116d59605b5b6315 | experiment_meow_main.py |
100644 | blob | 1916 | 1d228fa4fa2887927db069f0c93c61a920279d1f | explore_model_summary.py |
100644 | blob | 2718 | b09b84e8b761137654ba6904669799c4866554b3 | hard_code_variable.py |
040000 | tree | - | b3aa858a157f5e1e22c00fdb6f9dd071f4c6c163 | local_train_script |
040000 | tree | - | 927d159228536a86499de8a294700f8599b8a60b | logs |
100644 | blob | 15300 | cb90faba0bd4a45f2606a1e60975ed05bfacdb07 | main_pacnn.py |
100644 | blob | 2760 | 3c2d5ba1c81ef2770ad216c566e268f4ece17262 | main_shanghaitech.py |
100644 | blob | 2683 | 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 | main_ucfcc50.py |
100644 | blob | 2794 | f37b3bb572c53dd942c51243bd5b0853228c6ddb | model_util.py |
040000 | tree | - | 3ae76ede817d90ddfa6fe982440dfbbe193974a2 | models |
100644 | blob | 870 | 8f5ce4f7e0b168add5ff2a363faa973a5b56ca48 | mse_l1_loss.py |
100644 | blob | 1066 | 811554259182e63240d7aa9406f315377b3be1ac | mse_ssim_loss.py |
040000 | tree | - | 1a8318c65dcb1ddae26e7058904e3b8848933d19 | notebook |
040000 | tree | - | 06633cb1846e29f71faab849a74ac0896541c3c4 | playground |
040000 | tree | - | 072abdcb8a8ad064d60f8dc7daf480cf48b3ad06 | predict |
040000 | tree | - | c7c295e9e418154ae7c754dc888a77df8f50aa61 | pytorch_ssim |
100644 | blob | 1727 | 1cd14cbff636cb6145c8bacf013e97eb3f7ed578 | sanity_check_dataloader.py |
040000 | tree | - | a1e8ea43eba8a949288a00fff12974aec8692003 | saved_model_best |
100644 | blob | 3525 | 27067234ad3deddd743dcab0d7b3ba4812902656 | train_attn_can_adcrowdnet.py |
100644 | blob | 3488 | e47bfc7e91c46ca3c61be0c5258302de4730b06d | train_attn_can_adcrowdnet_freeze_vgg.py |
100644 | blob | 5352 | 3ee3269d6fcc7408901af46bed52b1d86ee9818c | train_attn_can_adcrowdnet_simple.py |
100644 | blob | 5728 | 90b846b68f15bdc58e3fd60b41aa4b5d82864ec4 | train_attn_can_adcrowdnet_simple_lrscheduler.py |
100644 | blob | 9081 | 664051f8838434c386e34e6dd6e6bca862cb3ccd | train_compact_cnn.py |
100644 | blob | 5702 | fdec7cd1ee062aa4a2182a91e2fb1bd0db3ab35f | train_compact_cnn_lrscheduler.py |
100644 | blob | 5611 | 2a241c876015db34681d73ce534221de482b0b90 | train_compact_cnn_sgd.py |
100644 | blob | 3525 | eb52f7a4462687c9b2bf1c3a887014c4afefa26d | train_context_aware_network.py |
100644 | blob | 5651 | 48631e36a1fdc063a6d54d9206d2fd45521d8dc8 | train_custom_compact_cnn.py |
100644 | blob | 5594 | 07d6c9c056db36082545b5b60b1c00d9d9f6396d | train_custom_compact_cnn_lrscheduler.py |
100644 | blob | 5281 | 8a92eb87b54f71ad2a799a7e05020344a22e22d3 | train_custom_compact_cnn_sgd.py |
040000 | tree | - | abb0a8c0267147fdba6e467bf8f7ed3024e44594 | train_script |
100644 | blob | 6595 | 5b8afd4fb322dd7cbffd1a589ff5276b0e3edeb5 | visualize_data_loader.py |
100644 | blob | 1772 | 449bb484143443c125566907a4b862d1c283c3f3 | visualize_util.py |