/model_util.py (eeb34682e8425776dc9e70326a77c2333ae10e5c) (815 bytes) (mode 100644) (type blob)

import h5py
import torch
import shutil
import numpy as np
import os


def save_net(fname, net):
    with h5py.File(fname, 'w') as h5f:
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())


def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)


def save_checkpoint(state, is_best, task_id, filename='checkpoint.pth.tar'):
    if not os.path.exists("saved_model"):
        os.makedirs("saved_model")
    full_file_name = os.path.join("saved_model", task_id + filename)
    torch.save(state, full_file_name)
    if is_best:
        shutil.copyfile(task_id + filename, task_id + 'model_best.pth.tar')
    return full_file_name


Mode Type Size Ref File
100644 blob 61 169fe2b7d512a59cfedf86ddb7ed040173c7434d .gitignore
100644 blob 426 1ac57815ba8c20b0d7fb0946a607a760004c603c README.md
100644 blob 4226 33e540f2c232cf5ae5c6618d26d84840a005deb8 args_util.py
040000 tree - 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 backup_notebook
040000 tree - 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 bug
100644 blob 1775 1165f1aba0814b448a3595a32bd74f1967509509 crowd_counting_error_metrics.py
100644 blob 15424 356acaf91b046ea6d7e7a624bd56077c9f8756fa data_flow.py
040000 tree - 17c9c74641b7acc37008a7f940a62323dd5b2b6b data_util
040000 tree - 2a46ff24b8b8997b4ca07c18e2326cb3c35dc5a0 dataset_script
100644 blob 4460 9b254c348a3453f4df2c3ccbf21fb175a16852de eval_context_aware_network.py
100644 blob 428 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 evaluator.py
100644 blob 2718 b09b84e8b761137654ba6904669799c4866554b3 hard_code_variable.py
100644 blob 15300 cb90faba0bd4a45f2606a1e60975ed05bfacdb07 main_pacnn.py
100644 blob 2760 3c2d5ba1c81ef2770ad216c566e268f4ece17262 main_shanghaitech.py
100644 blob 2683 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 main_ucfcc50.py
100644 blob 815 eeb34682e8425776dc9e70326a77c2333ae10e5c model_util.py
040000 tree - af056613d4ec202129a222b7f7e7a839af466c9d models
040000 tree - d1c13a0fa59c995bbc5c766ea807108aabbc41a8 playground
040000 tree - 970ac54d8293aed6667e016f2245547f3a5449c3 pytorch_ssim
100644 blob 3525 27067234ad3deddd743dcab0d7b3ba4812902656 train_attn_can_adcrowdnet.py
100644 blob 3488 e47bfc7e91c46ca3c61be0c5258302de4730b06d train_attn_can_adcrowdnet_freeze_vgg.py
100644 blob 3458 3d5c65643671e8866f293bf6c5457ee289da70cf train_context_aware_network.py
040000 tree - fc69bfdcecfe10ff3c03c9280f023b67edfb55da train_script
100644 blob 5392 03c78fe177520b309ee21e5c2b7ca67598fad99a visualize_data_loader.py
100644 blob 1146 1b0f845587f0f37166d44fa0c74b51f89cf8b349 visualize_util.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main