/models/can_adcrowdnet.py (e022339d854a5e31c891397fa5742b32351b0f08) (5883 bytes) (mode 100644) (type blob)

import torch.nn as nn
import torch
from torchvision import models
import collections
import torch.nn.functional as F
import os
from .deform_conv_v2 import DeformConv2d
# from dcn.modules.deform_conv import DeformConvPack, ModulatedDeformConvPack


class CanAdcrowdNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CanAdcrowdNet, self).__init__()
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.frontend = make_layers(self.frontend_feat)
        self.concat_filter_layer = nn.Conv2d(1024, 512, kernel_size=3, padding=2, dilation=2)

        self.deform_conv_1_3 = DeformConv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.deform_conv_1_5 = DeformConv2d(512, 256, kernel_size=5, stride=1, padding=2)
        self.deform_conv_1_7 = DeformConv2d(512, 256, kernel_size=7, stride=1, padding=3)
        self.concat_filter_layer_1 = nn.Conv2d(256 * 3, 256, kernel_size=3, padding=2, dilation=2)

        self.deform_conv_2_3 = DeformConv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.deform_conv_2_5 = DeformConv2d(256, 128, kernel_size=5, stride=1, padding=2)
        self.deform_conv_2_7 = DeformConv2d(256, 128, kernel_size=7, stride=1, padding=3)
        self.concat_filter_layer_2 = nn.Conv2d(128 * 3, 128, kernel_size=3, padding=2, dilation=2)

        self.deform_conv_3_3 = DeformConv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.deform_conv_3_5 = DeformConv2d(128, 64, kernel_size=5, stride=1, padding=2)
        self.deform_conv_3_7 = DeformConv2d(128, 64, kernel_size=7, stride=1, padding=3)
        self.concat_filter_layer_3 = nn.Conv2d(64 * 3, 64, kernel_size=3, padding=2, dilation=2)

        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        self.conv1_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv1_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv2_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv2_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv3_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv3_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.conv6_2 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        if not load_weights:
            mod = models.vgg16(pretrained=True)
            self._initialize_weights()
            fsd = collections.OrderedDict()
            for i in range(len(self.frontend.state_dict().items())):
                temp_key = list(self.frontend.state_dict().items())[i][0]
                fsd[temp_key] = list(mod.state_dict().items())[i][1]
            self.frontend.load_state_dict(fsd)

    def forward(self, x):
        fv = self.frontend(x)
        # S=1
        ave1 = nn.functional.adaptive_avg_pool2d(fv, (1, 1))
        ave1 = self.conv1_1(ave1)
        s1 = nn.functional.upsample(ave1, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
        c1 = s1 - fv
        w1 = self.conv1_2(c1)
        w1 = nn.functional.sigmoid(w1)
        # S=2
        ave2 = nn.functional.adaptive_avg_pool2d(fv, (2, 2))
        ave2 = self.conv2_1(ave2)
        s2 = nn.functional.upsample(ave2, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
        c2 = s2 - fv
        w2 = self.conv2_2(c2)
        w2 = nn.functional.sigmoid(w2)
        # S=3
        ave3 = nn.functional.adaptive_avg_pool2d(fv, (3, 3))
        ave3 = self.conv3_1(ave3)
        s3 = nn.functional.upsample(ave3, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
        c3 = s3 - fv
        w3 = self.conv3_2(c3)
        w3 = nn.functional.sigmoid(w3)
        # S=6
        ave6 = nn.functional.adaptive_avg_pool2d(fv, (6, 6))
        ave6 = self.conv6_1(ave6)
        s6 = nn.functional.upsample(ave6, size=(fv.shape[2], fv.shape[3]), mode='bilinear')
        c6 = s6 - fv
        w6 = self.conv6_2(c6)
        w6 = nn.functional.sigmoid(w6)

        fi = (w1 * s1 + w2 * s2 + w3 * s3 + w6 * s6) / (w1 + w2 + w3 + w6 + 0.000000000001)
        x = torch.cat((fv, fi), 1)
        x = F.relu(self.concat_filter_layer(x))

        x3 = self.deform_conv_1_3(x)
        x5 = self.deform_conv_1_5(x)
        x7 = self.deform_conv_1_7(x)
        x = torch.cat((x3, x5, x7), 1)
        x = F.relu(self.concat_filter_layer_1(x))

        x3 = self.deform_conv_2_3(x)
        x5 = self.deform_conv_2_5(x)
        x7 = self.deform_conv_2_7(x)
        x = torch.cat((x3, x5, x7), 1)
        x = F.relu(self.concat_filter_layer_2(x))

        x3 = self.deform_conv_3_3(x)
        x5 = self.deform_conv_3_5(x)
        x7 = self.deform_conv_3_7(x)
        x = torch.cat((x3, x5, x7), 1)
        x = F.relu(self.concat_filter_layer_3(x))

        x = self.output_layer(x)
        x = nn.functional.upsample(x, scale_factor=8, mode='bilinear') / 64.0
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


Mode Type Size Ref File
100644 blob 112 54a0bfa5d13ea1dd49622ed3704ad36f6cd68749 .gitignore
100644 blob 1342 f2eb3073ff4a8536cf4e8104ff942b525e3c7f34 .travis.yml
100644 blob 1379 ebf62a5eacea82b71a44bab3c917754f2de5cb10 README.md
100644 blob 9232 d68650a3479daf7de65eb76b0cfa40e9e9dff7f4 args_util.py
040000 tree - 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 backup_notebook
040000 tree - 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 bug
100644 blob 3591 7b4c18e8cf2c0417cd13d3f77ea0571c9e0e493f crowd_counting_error_metrics.py
100644 blob 63592 812e169de39d3f40f6425035db0deb5e4f878d7c data_flow.py
040000 tree - 7b2560d2cb223bf0574eb278bafeda5a8577c7db data_util
040000 tree - a7a950c8174433e2acc656a49ade044f054ac9f7 dataset_script
040000 tree - b26c5fa1187778134a8801e61e1f28055c3624c1 debug
040000 tree - 74e02cec26c0d98f846ab7ab573419265856500b demo
040000 tree - 13debfeebc3df105633887f857e8b709318cf661 demo_app
040000 tree - 9862b9cbc6e7a1d43565f12d85d9b17d1bf1814e env_file
100644 blob 4460 9b254c348a3453f4df2c3ccbf21fb175a16852de eval_context_aware_network.py
100644 blob 428 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 evaluator.py
100644 blob 17422 e896debd86b01578a0e5f6bf886fefdf3922a5bd experiment_main.py
100644 blob 8876 049432d6bde50245a4acba4e116d59605b5b6315 experiment_meow_main.py
100644 blob 1916 1d228fa4fa2887927db069f0c93c61a920279d1f explore_model_summary.py
100644 blob 2718 b09b84e8b761137654ba6904669799c4866554b3 hard_code_variable.py
040000 tree - b3aa858a157f5e1e22c00fdb6f9dd071f4c6c163 local_train_script
040000 tree - 927d159228536a86499de8a294700f8599b8a60b logs
100644 blob 15300 cb90faba0bd4a45f2606a1e60975ed05bfacdb07 main_pacnn.py
100644 blob 2760 3c2d5ba1c81ef2770ad216c566e268f4ece17262 main_shanghaitech.py
100644 blob 2683 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 main_ucfcc50.py
100644 blob 2794 f37b3bb572c53dd942c51243bd5b0853228c6ddb model_util.py
040000 tree - 3ae76ede817d90ddfa6fe982440dfbbe193974a2 models
100644 blob 870 8f5ce4f7e0b168add5ff2a363faa973a5b56ca48 mse_l1_loss.py
100644 blob 1066 811554259182e63240d7aa9406f315377b3be1ac mse_ssim_loss.py
040000 tree - 1a8318c65dcb1ddae26e7058904e3b8848933d19 notebook
040000 tree - 06633cb1846e29f71faab849a74ac0896541c3c4 playground
040000 tree - 072abdcb8a8ad064d60f8dc7daf480cf48b3ad06 predict
040000 tree - c7c295e9e418154ae7c754dc888a77df8f50aa61 pytorch_ssim
100644 blob 1727 1cd14cbff636cb6145c8bacf013e97eb3f7ed578 sanity_check_dataloader.py
040000 tree - a1e8ea43eba8a949288a00fff12974aec8692003 saved_model_best
100644 blob 3525 27067234ad3deddd743dcab0d7b3ba4812902656 train_attn_can_adcrowdnet.py
100644 blob 3488 e47bfc7e91c46ca3c61be0c5258302de4730b06d train_attn_can_adcrowdnet_freeze_vgg.py
100644 blob 5352 3ee3269d6fcc7408901af46bed52b1d86ee9818c train_attn_can_adcrowdnet_simple.py
100644 blob 5728 90b846b68f15bdc58e3fd60b41aa4b5d82864ec4 train_attn_can_adcrowdnet_simple_lrscheduler.py
100644 blob 9081 664051f8838434c386e34e6dd6e6bca862cb3ccd train_compact_cnn.py
100644 blob 5702 fdec7cd1ee062aa4a2182a91e2fb1bd0db3ab35f train_compact_cnn_lrscheduler.py
100644 blob 5611 2a241c876015db34681d73ce534221de482b0b90 train_compact_cnn_sgd.py
100644 blob 3525 eb52f7a4462687c9b2bf1c3a887014c4afefa26d train_context_aware_network.py
100644 blob 5651 48631e36a1fdc063a6d54d9206d2fd45521d8dc8 train_custom_compact_cnn.py
100644 blob 5594 07d6c9c056db36082545b5b60b1c00d9d9f6396d train_custom_compact_cnn_lrscheduler.py
100644 blob 5281 8a92eb87b54f71ad2a799a7e05020344a22e22d3 train_custom_compact_cnn_sgd.py
040000 tree - abb0a8c0267147fdba6e467bf8f7ed3024e44594 train_script
100644 blob 6595 5b8afd4fb322dd7cbffd1a589ff5276b0e3edeb5 visualize_data_loader.py
100644 blob 1772 449bb484143443c125566907a4b862d1c283c3f3 visualize_util.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main