torch-fn

A decorator for seamless PyTorch calculations (primarily on CUDA) from numpy.ndarray and pd.DataFrame.


Keywords
torch, GPU, numpy, pandas
License
MIT
Install
pip install torch-fn==1.0.0

Documentation

CI CI

Installation

$ pip install torch_fn

Usage

from torch_fn import torch_fn

import numpy as np
import torch.nn.functional as F

@torch_fn
def torch_softmax(*args, **kwargs):
    return F.softmax(*args, **kwargs)

def custom_print(x):
    print(type(x), x)

# Test the decorator with different input types
x = [1, 2, 3]
x_list = x
x_tensor = torch.tensor(x).float()
x_tensor_cuda = torch.tensor(x).float().cuda()
x_array = np.array(x)
x_df = pd.DataFrame({"col1": x})

custom_print(torch_softmax(x_list, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:57: UserWarning: Converted from  <class 'list'> to <class 'torch.Tensor'> (cuda:0)
#   warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]

custom_print(torch_softmax(x_array, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:57: UserWarning: Converted from  <class 'numpy.ndarray'> to <class 'torch.Tensor'> (cuda:0)
#  warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]

custom_print(torch_softmax(x_df, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:49: UserWarning: Converted from  <class 'pandas.core.frame.DataFrame'> to <class 'torch.Tensor'> (cuda:0)
#   warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]

custom_print(torch_softmax(x_tensor, dim=-1))
# <class 'torch.Tensor'> tensor([0.0900, 0.2447, 0.6652])

custom_print(torch_softmax(x_tensor_cuda, dim=-1))
# <class 'torch.Tensor'> tensor([0.0900, 0.2447, 0.6652], device='cuda:0')