Subject | Hash | Author | Date (UTC) |
---|---|---|---|
initial commit | 27cb0b3b4f5f20240af6d14ead4f373a1aaa5343 | Xiangru Lian | 2019-01-07 21:26:20 |
Update .gitignore | 52642556fea807b773617938eecd2d9d17f34bc6 | ikzk | 2019-01-07 20:16:37 |
File | Lines added | Lines deleted |
---|---|---|
model_utils.py | 30 | 0 |
tensor_utils.py | 29 | 0 |
utils.py | 4 | 0 |
File model_utils.py added (mode: 100644) (index 0000000..f513af6) | |||
1 | import torch | ||
2 | import os | ||
3 | import utils | ||
4 | |||
5 | |||
6 | def _ensure_module(model): | ||
7 | "handle DataParallel etc" | ||
8 | if hasattr(model, 'module'): | ||
9 | model = model.module | ||
10 | return model | ||
11 | |||
12 | |||
13 | def checkpoint(model, directory, filename, add_time_str=True, extension=".pt"): | ||
14 | model = _ensure_module(model) | ||
15 | os.makedirs(directory, exist_ok=True) | ||
16 | state_dict = model.state_dict() | ||
17 | time_str = utils.current_time_str() | ||
18 | torch.save(state_dict, os.path.join(directory, filename + "." + time_str + extension)) | ||
19 | |||
20 | |||
21 | def get_model_size(model): | ||
22 | """TODO: consider using nelements()""" | ||
23 | model = _ensure_module(model) | ||
24 | params = 0 | ||
25 | for p in model.parameters(): | ||
26 | tmp = 1 | ||
27 | for x in p.size(): | ||
28 | tmp *= x | ||
29 | params += tmp | ||
30 | return params |
File tensor_utils.py added (mode: 100644) (index 0000000..5d6460c) | |||
1 | import torch | ||
2 | |||
3 | |||
4 | def _iterate_over_container(inputs, func): | ||
5 | "Run a lambda over tensors in the container" | ||
6 | def iterate(obj): | ||
7 | if isinstance(obj, torch.Tensor): | ||
8 | return func(obj) | ||
9 | if isinstance(obj, tuple) and len(obj) > 0: | ||
10 | return list(zip(*map(iterate, obj))) | ||
11 | if isinstance(obj, list) and len(obj) > 0: | ||
12 | return list(map(list, zip(*map(iterate, obj)))) | ||
13 | if isinstance(obj, dict) and len(obj) > 0: | ||
14 | return list(map(type(obj), zip(*map(iterate, obj.items())))) | ||
15 | # After iterate is called, a iterate cell will exist. This cell has a | ||
16 | # reference to the actual function iterate, which has references to a | ||
17 | # closure that has a reference to the iterate cell (because the fn is | ||
18 | # recursive). To avoid this reference cycle, we set the function to None, | ||
19 | # clearing the cell | ||
20 | try: | ||
21 | return iterate(inputs) | ||
22 | finally: | ||
23 | iterate = None | ||
24 | |||
25 | |||
26 | def to_device(container, device): | ||
27 | """The device could be cuda:0 for example.""" | ||
28 | device = torch.device(device) | ||
29 | return _iterate_over_container(container, lambda x: x.to(device)) |
File utils.py added (mode: 100644) (index 0000000..9595556) | |||
1 | import datetime | ||
2 | |||
3 | def current_time_str(): | ||
4 | return datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") |