List of commits:
Subject Hash Author Date (UTC)
initial commit 27cb0b3b4f5f20240af6d14ead4f373a1aaa5343 Xiangru Lian 2019-01-07 21:26:20
Update .gitignore 52642556fea807b773617938eecd2d9d17f34bc6 ikzk 2019-01-07 20:16:37
Commit 27cb0b3b4f5f20240af6d14ead4f373a1aaa5343 - initial commit
Author: Xiangru Lian
Author date (UTC): 2019-01-07 21:26
Committer name: Xiangru Lian
Committer date (UTC): 2019-01-07 21:26
Parent(s): 52642556fea807b773617938eecd2d9d17f34bc6
Signing key:
Tree: bdd5250279a237885153325a69aea0592642824d
File Lines added Lines deleted
model_utils.py 30 0
tensor_utils.py 29 0
utils.py 4 0
File model_utils.py added (mode: 100644) (index 0000000..f513af6)
1 import torch
2 import os
3 import utils
4
5
6 def _ensure_module(model):
7 "handle DataParallel etc"
8 if hasattr(model, 'module'):
9 model = model.module
10 return model
11
12
13 def checkpoint(model, directory, filename, add_time_str=True, extension=".pt"):
14 model = _ensure_module(model)
15 os.makedirs(directory, exist_ok=True)
16 state_dict = model.state_dict()
17 time_str = utils.current_time_str()
18 torch.save(state_dict, os.path.join(directory, filename + "." + time_str + extension))
19
20
21 def get_model_size(model):
22 """TODO: consider using nelements()"""
23 model = _ensure_module(model)
24 params = 0
25 for p in model.parameters():
26 tmp = 1
27 for x in p.size():
28 tmp *= x
29 params += tmp
30 return params
File tensor_utils.py added (mode: 100644) (index 0000000..5d6460c)
1 import torch
2
3
4 def _iterate_over_container(inputs, func):
5 "Run a lambda over tensors in the container"
6 def iterate(obj):
7 if isinstance(obj, torch.Tensor):
8 return func(obj)
9 if isinstance(obj, tuple) and len(obj) > 0:
10 return list(zip(*map(iterate, obj)))
11 if isinstance(obj, list) and len(obj) > 0:
12 return list(map(list, zip(*map(iterate, obj))))
13 if isinstance(obj, dict) and len(obj) > 0:
14 return list(map(type(obj), zip(*map(iterate, obj.items()))))
15 # After iterate is called, a iterate cell will exist. This cell has a
16 # reference to the actual function iterate, which has references to a
17 # closure that has a reference to the iterate cell (because the fn is
18 # recursive). To avoid this reference cycle, we set the function to None,
19 # clearing the cell
20 try:
21 return iterate(inputs)
22 finally:
23 iterate = None
24
25
26 def to_device(container, device):
27 """The device could be cuda:0 for example."""
28 device = torch.device(device)
29 return _iterate_over_container(container, lambda x: x.to(device))
File utils.py added (mode: 100644) (index 0000000..9595556)
1 import datetime
2
3 def current_time_str():
4 return datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main