/model_util.py (f37b3bb572c53dd942c51243bd5b0853228c6ddb) (2794 bytes) (mode 100644) (type blob)

import h5py
import torch
import shutil
import numpy as np
import os


def save_net(fname, net):
    with h5py.File(fname, 'w') as h5f:
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())


def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)


def save_checkpoint(state, is_best, task_id, filename='checkpoint.pth.tar'):
    if not os.path.exists("saved_model"):
        os.makedirs("saved_model")
    full_file_name = os.path.join("saved_model", task_id + filename)
    torch.save(state, full_file_name)
    if is_best:
        shutil.copyfile(task_id + filename, task_id + 'model_best.pth.tar')
    return full_file_name


def calculate_padding(kernel_size, dilation):
    """
    https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338

    o = output
    p = padding
    k = kernel_size
    s = stride
    d = dilation

    :return:
    """
    k = kernel_size
    d = dilation
    p = -1 + k + (k-1)*(d-1)
    p = p/2
    return p


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


class BestMetrics:
    def __init__(self, best_metric="mae"):
        """
        :param best_metric: whether mae or mse will be use to determine best metric
        """
        self.best = 9999
        self.cur_mae = 9999
        self.cur_mse = 9999
        self.best_metric = "mae"

    def checkAndRecord(self, mae, mse):
        # check best
        flag_best = False
        if self.best_metric == "mae":
            if mae < self.best:
                flag_best = True
                self.best = mae
        elif self.best_metric == "mse":
            if mse < self.best:
                flag_best = True
                self.best = mse
        if flag_best:
            self.cur_mae = mae
            self.cur_mse = mse
        return flag_best


if __name__ == "__main__":
    print(calculate_padding(kernel_size=3, dilation=4))
    print(calculate_padding(kernel_size=5, dilation=1))
    print(calculate_padding(kernel_size=7, dilation=1))
    print(calculate_padding(kernel_size=9, dilation=1))
    print(calculate_padding(kernel_size=3, dilation=1))

    print("-----compact dilated cnn -----------------")
    print(calculate_padding(kernel_size=5, dilation=3))
    print(calculate_padding(kernel_size=5, dilation=2))
    print(calculate_padding(kernel_size=5, dilation=1))
    print("---dilated 3x3 with dilated 2 4 6")
    print(calculate_padding(kernel_size=3, dilation=2)) # 2
    print(calculate_padding(kernel_size=3, dilation=4)) # 4
    print(calculate_padding(kernel_size=3, dilation=6)) # 6

Mode Type Size Ref File
100644 blob 61 169fe2b7d512a59cfedf86ddb7ed040173c7434d .gitignore
100644 blob 1255 1dfa426237bc174a2ba2186240191a6b7041bc86 README.md
100644 blob 8019 b07453bc637967304c072283d44cef1a6a6ef2ac args_util.py
040000 tree - 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 backup_notebook
040000 tree - 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 bug
100644 blob 3591 7b4c18e8cf2c0417cd13d3f77ea0571c9e0e493f crowd_counting_error_metrics.py
100644 blob 38997 acedc960a0a0b5b4115e7a47bb881cdc9e7e5288 data_flow.py
040000 tree - 17c9c74641b7acc37008a7f940a62323dd5b2b6b data_util
040000 tree - 98585ee75deafce7149ed6db25a4068a6bb2777a dataset_script
040000 tree - f4655ae7590cf976350887782cf9929c9e313f20 debug
040000 tree - 9862b9cbc6e7a1d43565f12d85d9b17d1bf1814e env_file
100644 blob 4460 9b254c348a3453f4df2c3ccbf21fb175a16852de eval_context_aware_network.py
100644 blob 428 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 evaluator.py
100644 blob 12605 b6b690d5fa639867bfa892cd1218be83e804c9bb experiment_main.py
100644 blob 8876 049432d6bde50245a4acba4e116d59605b5b6315 experiment_meow_main.py
100644 blob 1916 1d228fa4fa2887927db069f0c93c61a920279d1f explore_model_summary.py
100644 blob 2718 b09b84e8b761137654ba6904669799c4866554b3 hard_code_variable.py
040000 tree - b3aa858a157f5e1e22c00fdb6f9dd071f4c6c163 local_train_script
040000 tree - 927d159228536a86499de8a294700f8599b8a60b logs
100644 blob 15300 cb90faba0bd4a45f2606a1e60975ed05bfacdb07 main_pacnn.py
100644 blob 2760 3c2d5ba1c81ef2770ad216c566e268f4ece17262 main_shanghaitech.py
100644 blob 2683 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 main_ucfcc50.py
100644 blob 2794 f37b3bb572c53dd942c51243bd5b0853228c6ddb model_util.py
040000 tree - ff265ef837660d8534d7bd4af6d53819ee13ed3b models
100644 blob 1066 811554259182e63240d7aa9406f315377b3be1ac mse_ssim_loss.py
040000 tree - 2cc497edce5da8793879cc5e82718d1562ef17e8 playground
040000 tree - c7c295e9e418154ae7c754dc888a77df8f50aa61 pytorch_ssim
100644 blob 1727 1cd14cbff636cb6145c8bacf013e97eb3f7ed578 sanity_check_dataloader.py
040000 tree - a1e8ea43eba8a949288a00fff12974aec8692003 saved_model_best
100644 blob 3525 27067234ad3deddd743dcab0d7b3ba4812902656 train_attn_can_adcrowdnet.py
100644 blob 3488 e47bfc7e91c46ca3c61be0c5258302de4730b06d train_attn_can_adcrowdnet_freeze_vgg.py
100644 blob 5352 3ee3269d6fcc7408901af46bed52b1d86ee9818c train_attn_can_adcrowdnet_simple.py
100644 blob 5728 90b846b68f15bdc58e3fd60b41aa4b5d82864ec4 train_attn_can_adcrowdnet_simple_lrscheduler.py
100644 blob 9081 664051f8838434c386e34e6dd6e6bca862cb3ccd train_compact_cnn.py
100644 blob 5702 fdec7cd1ee062aa4a2182a91e2fb1bd0db3ab35f train_compact_cnn_lrscheduler.py
100644 blob 5611 2a241c876015db34681d73ce534221de482b0b90 train_compact_cnn_sgd.py
100644 blob 3525 eb52f7a4462687c9b2bf1c3a887014c4afefa26d train_context_aware_network.py
100644 blob 5651 48631e36a1fdc063a6d54d9206d2fd45521d8dc8 train_custom_compact_cnn.py
100644 blob 5594 07d6c9c056db36082545b5b60b1c00d9d9f6396d train_custom_compact_cnn_lrscheduler.py
100644 blob 5281 8a92eb87b54f71ad2a799a7e05020344a22e22d3 train_custom_compact_cnn_sgd.py
040000 tree - d26184f6d67fdafc7459082ea19e0e01d93b4586 train_script
100644 blob 5392 03c78fe177520b309ee21e5c2b7ca67598fad99a visualize_data_loader.py
100644 blob 1146 1b0f845587f0f37166d44fa0c74b51f89cf8b349 visualize_util.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main