List of commits:
Subject Hash Author Date (UTC)
add numpy support 2e10283edfc427ea4da6a76c8c3a6f85a44bace4 Xiangru Lian 2019-01-07 22:24:12
update fcfc8a26e275395688fb3fa50538f890c1f612e4 Xiangru Lian 2019-01-07 21:55:04
use markdown for readme f46e30761b4f361d5494484d9bdb20dc04a20d83 Xiangru Lian 2019-01-07 21:53:46
add readme a9a0b91ea0e0485039ea76b6b69c5f983815e8a3 Xiangru Lian 2019-01-07 21:52:31
fix 16a965862eaefa554a540fd2eeee26f73907d5dd Xiangru Lian 2019-01-07 21:37:12
fix setup.py 24e85045cc859ffa7d6f514c80257fde94c0794a Xiangru Lian 2019-01-07 21:35:15
compatible with setup.py 657782ac87ea9251f2e3726b483089eca5384caf Xiangru Lian 2019-01-07 21:33:48
initial commit 27cb0b3b4f5f20240af6d14ead4f373a1aaa5343 Xiangru Lian 2019-01-07 21:26:20
Update .gitignore 52642556fea807b773617938eecd2d9d17f34bc6 ikzk 2019-01-07 20:16:37
Commit 2e10283edfc427ea4da6a76c8c3a6f85a44bace4 - add numpy support
Author: Xiangru Lian
Author date (UTC): 2019-01-07 22:24
Committer name: Xiangru Lian
Committer date (UTC): 2019-01-07 22:24
Parent(s): fcfc8a26e275395688fb3fa50538f890c1f612e4
Signing key:
Tree: 697d264d80ef6f8fae7ef0d840f2a7cad52dcf72
File Lines added Lines deleted
persia_pytorch_toolkit/tensor_utils.py 24 9
File persia_pytorch_toolkit/tensor_utils.py changed (mode: 100644) (index 5d6460c..6740458)
1 1 import torch import torch
2 import numpy as np
2 3
3
4 def _iterate_over_container(inputs, func):
4 def _iterate_over_container(inputs, func, instance_type=torch.Tensor):
5 5 "Run a lambda over tensors in the container" "Run a lambda over tensors in the container"
6 6 def iterate(obj): def iterate(obj):
7 if isinstance(obj, torch.Tensor):
7 if isinstance(obj, instance_type):
8 8 return func(obj) return func(obj)
9 if isinstance(obj, tuple) and len(obj) > 0:
10 return list(zip(*map(iterate, obj)))
11 if isinstance(obj, list) and len(obj) > 0:
12 return list(map(list, zip(*map(iterate, obj))))
13 if isinstance(obj, dict) and len(obj) > 0:
14 return list(map(type(obj), zip(*map(iterate, obj.items()))))
9 if (isinstance(obj, tuple) or isinstance(obj, list)) and len(obj) > 0:
10 return list(map(iterate, obj))
15 11 # After iterate is called, a iterate cell will exist. This cell has a # After iterate is called, a iterate cell will exist. This cell has a
16 12 # reference to the actual function iterate, which has references to a # reference to the actual function iterate, which has references to a
17 13 # closure that has a reference to the iterate cell (because the fn is # closure that has a reference to the iterate cell (because the fn is
 
... ... def to_device(container, device):
27 23 """The device could be cuda:0 for example.""" """The device could be cuda:0 for example."""
28 24 device = torch.device(device) device = torch.device(device)
29 25 return _iterate_over_container(container, lambda x: x.to(device)) return _iterate_over_container(container, lambda x: x.to(device))
26
27
28 def _numpy_dtype_to_torch_dtype(dtype: np.dtype):
29 t = dtype.type
30 if t is np.float64:
31 return torch.float
32 elif t is np.int64:
33 return torch.long
34 else:
35 raise Exception("Unknown numpy dtype: " + str(t))
36
37
38 def to_device_from_numpy(container, device):
39 """The device could be cuda:0 for example."""
40 device = torch.device(device)
41 return _iterate_over_container(container,
42 lambda x: torch.as_tensor(x, dtype=_numpy_dtype_to_torch_dtype(x.dtype), device=device)
43 np.ndarray
44 )
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main