/persia_pytorch_toolkit/tensor_utils.py (e40fab8ef5435c55d3c0b7b0ed780b3769ef74f9) (1944 bytes) (mode 100644) (type blob)
import torch
import numpy as np
def _iterate_over_container(inputs, func, instance_type=torch.Tensor):
"Run a lambda over tensors in the container"
def iterate(obj):
if isinstance(obj, instance_type):
return func(obj)
if (isinstance(obj, tuple) or isinstance(obj, list)) and len(obj) > 0:
return list(map(iterate, obj))
# After iterate is called, a iterate cell will exist. This cell has a
# reference to the actual function iterate, which has references to a
# closure that has a reference to the iterate cell (because the fn is
# recursive). To avoid this reference cycle, we set the function to None,
# clearing the cell
try:
return iterate(inputs)
finally:
iterate = None
def to_device(container, device, pin_memory=False, shared_memory=False):
"""The device could be cuda:0 for example."""
device = torch.device(device)
def do_to_device(x):
x = x.to(device)
if pin_memory:
x = x.pin_memory()
if shared_memory:
x.share_memory_()
return x
return _iterate_over_container(container, do_to_device)
def _numpy_dtype_to_torch_dtype(dtype: np.dtype):
t = dtype.type
if t is np.float64 or t is np.float32:
return torch.float
elif t is np.int64 or t is np.int32:
return torch.long
else:
raise Exception("Unknown numpy dtype: " + str(t))
def to_device_from_numpy(container, device, pin_memory=False, shared_memory=False):
"""The device could be cuda:0 for example."""
device = torch.device(device)
def as_tensor(x):
tensor = torch.as_tensor(x, dtype=_numpy_dtype_to_torch_dtype(x.dtype), device=device)
if pin_memory:
tensor = tensor.pin_memory()
if shared_memory:
tensor.share_memory_()
return tensor
return _iterate_over_container(container, as_tensor, np.ndarray)
Mode |
Type |
Size |
Ref |
File |
100644 |
blob |
1203 |
894a44cc066a027465cd26d634948d56d13af9af |
.gitignore |
100644 |
blob |
106 |
31e29114018d9b0254e2e8d2d46797c484362eff |
README.md |
040000 |
tree |
- |
db847373d98d5ed84adf118786b63925b7aa504d |
persia_pytorch_toolkit |
100644 |
blob |
354 |
60b2cf036b92d2c0c4be4ceec72b98c30ca04ccb |
setup.py |
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"
Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit
Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit
Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit
You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a
merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main