File persia_pytorch_toolkit/model_utils.py changed (mode: 100644) (index cb5c347..fc58766) |
... |
... |
def get_model_size(model): |
27 |
27 |
tmp *= x |
tmp *= x |
28 |
28 |
params += tmp |
params += tmp |
29 |
29 |
return params |
return params |
|
30 |
|
|
|
31 |
|
|
|
32 |
|
def flatten_model_and_grad_tensor(model, verbose=False): |
|
33 |
|
""" |
|
34 |
|
caveats: |
|
35 |
|
1. can only be called after first loss.backward (so that grad variables are created) |
|
36 |
|
2. the tensors are all zero (so if you want to initialization plz reinitialize) |
|
37 |
|
""" |
|
38 |
|
total_size = 0 |
|
39 |
|
tensor_type = None |
|
40 |
|
for parameter in network.parameters(): |
|
41 |
|
if tensor_type == None: |
|
42 |
|
tensor_type = parameter.data.type() # https://github.com/pytorch/pytorch/wiki/Breaking-Changes-from-Variable-and-Tensor-merge |
|
43 |
|
total_size += parameter.nelement() |
|
44 |
|
|
|
45 |
|
if verbose: |
|
46 |
|
print("total size: ", total_size) |
|
47 |
|
print("tensor type: ", tensor_type) |
|
48 |
|
|
|
49 |
|
tensor = torch.Tensor(total_size).type(tensor_type) |
|
50 |
|
storage = tensor.storage() |
|
51 |
|
grad_tensor = torch.Tensor(total_size).type(tensor_type) |
|
52 |
|
grad_storage = grad_tensor.storage() |
|
53 |
|
|
|
54 |
|
|
|
55 |
|
if verbose: |
|
56 |
|
print("create new continuous storage") |
|
57 |
|
|
|
58 |
|
current_offset = 0 |
|
59 |
|
|
|
60 |
|
for parameter in network.parameters(): |
|
61 |
|
backup = parameter.data.clone() |
|
62 |
|
parameter.data.set_(storage, current_offset, parameter.data.size()) |
|
63 |
|
parameter.data.copy_(backup) |
|
64 |
|
parameter.grad.data.set_(grad_storage, current_offset, parameter.data.size()) |
|
65 |
|
current_offset += parameter.data.nelement() |
|
66 |
|
print("parameter storage offset: ", parameter.data.storage_offset()) |
|
67 |
|
|
|
68 |
|
if verbose: |
|
69 |
|
print("The C pointer: ", storage.data_ptr()) |
|
70 |
|
print("The C pointer for grad: ", grad_storage.data_ptr()) |
|
71 |
|
print("Is contiguous memory? ", tensor.is_contiguous()) |
|
72 |
|
print("Is grad contiguous memory? ", grad_tensor.is_contiguous()) |
|
73 |
|
|
|
74 |
|
return tensor, grad_tensor |
File setup.py changed (mode: 100644) (index b73fb55..93fd586) |
1 |
1 |
from setuptools import setup |
from setuptools import setup |
2 |
2 |
|
|
3 |
3 |
setup(name='persia_pytorch_toolkit', |
setup(name='persia_pytorch_toolkit', |
4 |
|
version='0.0.2', |
|
|
4 |
|
version='0.0.3', |
5 |
5 |
description="Xiangru Lian's toolkit with PyTorch.", |
description="Xiangru Lian's toolkit with PyTorch.", |
6 |
6 |
url='https://github.com/ikzk/persia_pytorch_toolkit', |
url='https://github.com/ikzk/persia_pytorch_toolkit', |
7 |
7 |
author="Xiangru Lian", |
author="Xiangru Lian", |