/persia_pytorch_toolkit/model_utils.py (5b26d191a40fc3fb884f9394d7050a0386419e27) (2859 bytes) (mode 100644) (type blob)
import torch
import os
from persia_pytorch_toolkit import utils
def _ensure_module(model):
"handle DataParallel etc"
if hasattr(model, 'module'):
model = model.module
return model
def checkpoint(model, directory, filename, add_time_str=True, extension=".pt"):
model = _ensure_module(model)
os.makedirs(directory, exist_ok=True)
state_dict = model.state_dict()
time_str = utils.current_time_str()
torch.save(state_dict, os.path.join(directory, filename + "." + time_str + extension))
def get_model_size(model):
"""TODO: consider using nelements()"""
model = _ensure_module(model)
params = 0
for p in model.parameters():
tmp = 1
for x in p.size():
tmp *= x
params += tmp
return params
def flatten_model_and_grad_tensor(network, verbose=False):
"""
caveats:
1. can only be called after first loss.backward (so that grad variables are created)
2. the tensors are all zero (so if you want to initialization plz reinitialize)
"""
total_size = 0
tensor_type = None
for parameter in network.parameters():
if tensor_type == None:
tensor_type = parameter.data.type() # https://github.com/pytorch/pytorch/wiki/Breaking-Changes-from-Variable-and-Tensor-merge
total_size += parameter.nelement()
if verbose:
print("total size: ", total_size)
print("tensor type: ", tensor_type)
tensor = torch.Tensor(total_size).type(tensor_type)
storage = tensor.storage()
grad_tensor = torch.Tensor(total_size).type(tensor_type)
grad_storage = grad_tensor.storage()
if verbose:
print("create new continuous storage")
current_offset = 0
for parameter in network.parameters():
backup = parameter.data.clone()
parameter.data.set_(storage, current_offset, parameter.data.size())
parameter.data.copy_(backup)
parameter.grad.data.set_(grad_storage, current_offset, parameter.data.size())
current_offset += parameter.data.nelement()
print("parameter storage offset: ", parameter.data.storage_offset())
if verbose:
print("The C pointer: ", storage.data_ptr())
print("The C pointer for grad: ", grad_storage.data_ptr())
print("Is contiguous memory? ", tensor.is_contiguous())
print("Is grad contiguous memory? ", grad_tensor.is_contiguous())
return tensor, grad_tensor
def model_to_flatten_parameters(network):
return torch._utils._flatten_dense_tensors(list(network.parameters()))
def model_to_flatten_gradients(network):
return torch._utils._flatten_dense_tensors(list(map(lambda x: x.grad.data, network.parameters())))
def flatten_parameters_to_model(flatten_parameters, model):
return torch._utils._unflatten_dense_tensors(flatten_parameters, model.parameters())
Mode |
Type |
Size |
Ref |
File |
100644 |
blob |
1203 |
894a44cc066a027465cd26d634948d56d13af9af |
.gitignore |
100644 |
blob |
106 |
31e29114018d9b0254e2e8d2d46797c484362eff |
README.md |
040000 |
tree |
- |
db847373d98d5ed84adf118786b63925b7aa504d |
persia_pytorch_toolkit |
100644 |
blob |
354 |
60b2cf036b92d2c0c4be4ceec72b98c30ca04ccb |
setup.py |
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"
Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit
Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit
Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit
You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a
merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main