fasttrain

Framework for building training loops easier and faster


Keywords
python, torch, pytorch
License
Apache-2.0
Install
pip install fasttrain==0.0.7

Documentation

fasttrain

fasttrain is a lightweight framework for building training loops for neural nets as fast as possible. It's designed to remove all boring details about making up training loops in PyTorch, so you don't have to concentrate on how to pretty print a loss or metrics or bother about how to calculate them right.

Installation

$ pip install fasttrain

How do we start?

Let's use a neural network to classify images in the FashionMNIST dataset:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

learning_rate = 1e-3
batch_size = 64
epochs = 5

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

Then we make up a trainer:

from fasttrain import Trainer
from fasttrain.metrics import accuracy

class MyTrainer(Trainer):

    # Define how we compute the loss
    def compute_loss(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return nn.CrossEntropyLoss()(output_batch, y_batch)

    # Define how we compute metrics
    def eval_metrics(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return {
            "accuracy": accuracy(output_batch, y_batch, task="multiclass")
        }

Finally, let's train our model:

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
trainer = MyTrainer(model, optimizer)
history = trainer.train(train_dataloader, val_data=test_dataloader, num_epochs=epochs)

fasttrain offers some useful callbacks - one of them is Tqdm which shows a pretty-looking progress bar: training_loop

Trainer.train() returns the history of training - it contains a dict which stores metrics over epochs and can plot them:

history.plot("loss", with_val=True)

loss

history.plot("accuracy", with_val=True)

accuracy