/persia_pytorch_toolkit/tensor_utils.py (e40fab8ef5435c55d3c0b7b0ed780b3769ef74f9) (1944 bytes) (mode 100644) (type blob)

import torch
import numpy as np

def _iterate_over_container(inputs, func, instance_type=torch.Tensor):
    "Run a lambda over tensors in the container"
    def iterate(obj):
        if isinstance(obj, instance_type):
            return func(obj)
        if (isinstance(obj, tuple) or isinstance(obj, list)) and len(obj) > 0:
            return list(map(iterate, obj))
    # After iterate is called, a iterate cell will exist. This cell has a
    # reference to the actual function iterate, which has references to a
    # closure that has a reference to the iterate cell (because the fn is
    # recursive). To avoid this reference cycle, we set the function to None,
    # clearing the cell
    try:
        return iterate(inputs)
    finally:
        iterate = None


def to_device(container, device, pin_memory=False, shared_memory=False):
    """The device could be cuda:0 for example."""
    device = torch.device(device)
    def do_to_device(x):
        x = x.to(device)
        if pin_memory:
            x = x.pin_memory()
        if shared_memory:
            x.share_memory_()
        return x
    return _iterate_over_container(container, do_to_device)


def _numpy_dtype_to_torch_dtype(dtype: np.dtype):
    t = dtype.type
    if t is np.float64 or t is np.float32:
        return torch.float
    elif t is np.int64 or t is np.int32:
        return torch.long
    else:
        raise Exception("Unknown numpy dtype: " + str(t))


def to_device_from_numpy(container, device, pin_memory=False, shared_memory=False):
    """The device could be cuda:0 for example."""
    device = torch.device(device)
    def as_tensor(x):
        tensor = torch.as_tensor(x, dtype=_numpy_dtype_to_torch_dtype(x.dtype), device=device)
        if pin_memory:
            tensor = tensor.pin_memory()
        if shared_memory:
            tensor.share_memory_()
        return tensor
    return _iterate_over_container(container, as_tensor, np.ndarray)


Mode Type Size Ref File
100644 blob 1203 894a44cc066a027465cd26d634948d56d13af9af .gitignore
100644 blob 106 31e29114018d9b0254e2e8d2d46797c484362eff README.md
040000 tree - db847373d98d5ed84adf118786b63925b7aa504d persia_pytorch_toolkit
100644 blob 354 60b2cf036b92d2c0c4be4ceec72b98c30ca04ccb setup.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main