torchpipe

Simple ETL Pipeline for PyTorch


Keywords
pytorch
License
MIT
Install
pip install torchpipe==0.0.3

Documentation

PyTorch Pipeline: Simple ETL Pipeline for PyTorch

PyTorch Pipeline is a simple ETL framework for PyTorch. It is an alternative to tf.data in TensorFlow

Requirements

  • Python 3.6+
  • PyTorch 1.2+

Installation

To install PyTorch Pipeline:

pip install pytorch_pipeline

Basic Usage

import pytorch_pipeline as pp

d = pp.TextDataset('/path/to/your/text')
d.shuffle(buffer_size=100).batch(batch_size=10).first()

Usage with PyTorch

from torch.utils.data import DataLoader
import pytorch_pipeline as pp


d = pp.Dataset(range(1_000)).parallel().shuffle(100).batch(10)
loader = DataLoader(d, num_workers=4, collate_fn=lambda x: x)
for x in loader:
    ...

Usage with LineFlow

You can use PyTorch Pipeline with pre-defined datasets in LineFlow:

from torch.utils.data import DataLoader
from lineflow.datasets.wikitext import cached_get_wikitext
import pytorch_pipeline as pp

dataset = cached_get_wikitext('wikitext-2')
# Preprocessing dataset
train_data = pp.Dataset(dataset['train']) \
    .flat_map(lambda x: x.split() + ['<eos>']) \
    .window(35) \
    .parallel() \
    .shuffle(64 * 100) \
    .batch(64)

# Iterating dataset
loader = DataLoader(train_data, num_workers=4, collate_fn=lambda x: x)
for x in loader:
    ...