/model_util.py (eeb34682e8425776dc9e70326a77c2333ae10e5c) (815 bytes) (mode 100644) (type blob)
import h5py
import torch
import shutil
import numpy as np
import os
def save_net(fname, net):
with h5py.File(fname, 'w') as h5f:
for k, v in net.state_dict().items():
h5f.create_dataset(k, data=v.cpu().numpy())
def load_net(fname, net):
with h5py.File(fname, 'r') as h5f:
for k, v in net.state_dict().items():
param = torch.from_numpy(np.asarray(h5f[k]))
v.copy_(param)
def save_checkpoint(state, is_best, task_id, filename='checkpoint.pth.tar'):
if not os.path.exists("saved_model"):
os.makedirs("saved_model")
full_file_name = os.path.join("saved_model", task_id + filename)
torch.save(state, full_file_name)
if is_best:
shutil.copyfile(task_id + filename, task_id + 'model_best.pth.tar')
return full_file_name
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"
Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework
Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework
Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework
You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a
merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main