File persia_pytorch_toolkit/tensor_utils.py changed (mode: 100644) (index 5d6460c..6740458) |
1 |
1 |
import torch |
import torch |
|
2 |
|
import numpy as np |
2 |
3 |
|
|
3 |
|
|
|
4 |
|
def _iterate_over_container(inputs, func): |
|
|
4 |
|
def _iterate_over_container(inputs, func, instance_type=torch.Tensor): |
5 |
5 |
"Run a lambda over tensors in the container" |
"Run a lambda over tensors in the container" |
6 |
6 |
def iterate(obj): |
def iterate(obj): |
7 |
|
if isinstance(obj, torch.Tensor): |
|
|
7 |
|
if isinstance(obj, instance_type): |
8 |
8 |
return func(obj) |
return func(obj) |
9 |
|
if isinstance(obj, tuple) and len(obj) > 0: |
|
10 |
|
return list(zip(*map(iterate, obj))) |
|
11 |
|
if isinstance(obj, list) and len(obj) > 0: |
|
12 |
|
return list(map(list, zip(*map(iterate, obj)))) |
|
13 |
|
if isinstance(obj, dict) and len(obj) > 0: |
|
14 |
|
return list(map(type(obj), zip(*map(iterate, obj.items())))) |
|
|
9 |
|
if (isinstance(obj, tuple) or isinstance(obj, list)) and len(obj) > 0: |
|
10 |
|
return list(map(iterate, obj)) |
15 |
11 |
# After iterate is called, a iterate cell will exist. This cell has a |
# After iterate is called, a iterate cell will exist. This cell has a |
16 |
12 |
# reference to the actual function iterate, which has references to a |
# reference to the actual function iterate, which has references to a |
17 |
13 |
# closure that has a reference to the iterate cell (because the fn is |
# closure that has a reference to the iterate cell (because the fn is |
|
... |
... |
def to_device(container, device): |
27 |
23 |
"""The device could be cuda:0 for example.""" |
"""The device could be cuda:0 for example.""" |
28 |
24 |
device = torch.device(device) |
device = torch.device(device) |
29 |
25 |
return _iterate_over_container(container, lambda x: x.to(device)) |
return _iterate_over_container(container, lambda x: x.to(device)) |
|
26 |
|
|
|
27 |
|
|
|
28 |
|
def _numpy_dtype_to_torch_dtype(dtype: np.dtype): |
|
29 |
|
t = dtype.type |
|
30 |
|
if t is np.float64: |
|
31 |
|
return torch.float |
|
32 |
|
elif t is np.int64: |
|
33 |
|
return torch.long |
|
34 |
|
else: |
|
35 |
|
raise Exception("Unknown numpy dtype: " + str(t)) |
|
36 |
|
|
|
37 |
|
|
|
38 |
|
def to_device_from_numpy(container, device): |
|
39 |
|
"""The device could be cuda:0 for example.""" |
|
40 |
|
device = torch.device(device) |
|
41 |
|
return _iterate_over_container(container, |
|
42 |
|
lambda x: torch.as_tensor(x, dtype=_numpy_dtype_to_torch_dtype(x.dtype), device=device) |
|
43 |
|
np.ndarray |
|
44 |
|
) |