TorchABC
is an abstract class for training and inference in PyTorch that helps you keep your code well organized. It is a minimalist version of pytorch-lightning, it depends on torch only, and it consists of a simple self-contained file.
The TorchABC
class implements the workflow illustrated above. The workflow begins with raw data
, which undergoes a preprocess
step. This step transforms the raw data
into input
samples and their corresponding target
labels.
Next, the individual input
samples are grouped into batches called inputs
using a collate
function. Similarly, the target
labels are batched into targets
. The inputs
are then fed into the network
, which produces outputs
.
The outputs
are compared to the targets
using a loss
function which quantifies the error between the two. The optimizer
updates the parameters of the network
to minimize the loss
. The scheduler
can dynamically change the learning rate of the optimizer
during training.
Finally, the raw outputs
from the network
undergo a postprocess
step to generate the final predictions
. This could involve converting probabilities to class labels, applying thresholds, or other task-specific transformations.
The core logic blocks are abstract. You define their specific behavior with maximum flexibility.
Install the package.
pip install torchabc
Generate a template using the command line interface.
torchabc --create template.py
The template is structured as follows.
class ClassName(TorchABC):
@cached_property
def dataloaders(self):
raise NotImplementedError
def preprocess(self, data: Any, flag: str = '') -> Any:
return data
@cached_property
def network(self):
raise NotImplementedError
@cached_property
def optimizer(self):
raise NotImplementedError
@cached_property
def scheduler(self):
return None
def loss(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def metrics(self, outputs: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
return {}
def postprocess(self, outputs: torch.Tensor) -> Any:
return outputs
Fill out the template with the dataloaders, preprocessing and postprocessing steps, the neural network, optimizer, scheduler, loss, and optional evaluation metrics.
This method defines and returns a dictionary containing the DataLoader
instances for the training, validation, and testing datasets. The dictionary's keys should correspond to the names of the datasets (e.g., 'train', 'val', 'test'), and the values should be their respective DataLoader
objects. Any transformation of the raw input data for each dataset should be implemented within the preprocess
method of this class. The preprocess
method should then be passed as the transform
argument of the Dataset
instances.
This method processes the data differently depending on a flag
. When flag
is empty (the default), the data are assumed to represent the model's input used for inference. When flag
has a specific value, the method may perform different preprocessing steps, such as transforming the target or augmenting the input for training.
Returns a Module
whose input and output tensors assume the batch size is the first dimension: (batch_size, ...).
Returns an Optimizer
configured for the network
.
Returns a LRScheduler
or ReduceLROnPlateau
configured for the optimizer
.
This method defines the loss function that quantifies the discrepancy between the neural network outputs
and the corresponding targets
. The loss function should be differentiable to enable backpropagation.
This method calculates various metrics that quantify the discrepancy between the neural network outputs
and the corresponding targets
. Unlike loss
, which is primarily used for training, these metrics are only used for evaluation and do not need to be differentiable.
This method transforms the neural network outputs to generate the final predictions.
After filling out the template above, you can use your class as follows.
Initialize the class with
model = ClassName(
device: Union[str, torch.device] = None,
logger: Callable = print,
**hparams
)
The device
is the torch.device
to use. Defaults to None
, which will try CUDA, then MPS, and finally fall back to CPU.
A logging function that takes a dictionary in input. The default prints to standard output. You can can easily log with wandb
import wandb
model = ClassName(logger=wandb.log)
or with any other custom logger.
You will typically use several parameters to control the behavior of ClassName
, such as the learning rate or batch size. These parameters should be passed during the initialization
model = ClassName(lr=0.001, batch_size=64)
and are stored in the attribute hparams
of the model. For instance, use hparams.lr
to access the lr
value.
Train the model with
model.train(
epochs: int,
on: str = 'train',
val: str = 'val',
gas: int = 1,
callback: Callable = None
)
where
-
epochs
is the number of training epochs to perform. -
on
is the name of the training dataloader. Defaults to 'train'. -
val
is the name of the validation dataloader. Defaults to 'val'. -
gas
is the number of gradient accumulation steps. Defaults to 1 (no gradient accumulation). -
callback
is a function that is called after each epoch. It should accept two arguments: the instance itself and a list of dictionaries containing the loss and evaluation metrics. When this function returnsTrue
, training stops.
This method returns a list of dictionaries containing the loss and evaluation metrics.
Save the model to a checkpoint.
model.save("checkpoint.pth")
Load the model from a checkpoint.
model.load("checkpoint.pth")
You can also use the callback
function to implement a custom checkpointing strategy. For instance, the following example saves a checkpoint after each training epoch.
callback = lambda self, logs: self.save(f"epoch_{logs[-1]['val/epoch']}.pth")
model.train(epochs=10, val='val', callback=callback)
Evaluate the model with
model.eval(on='test')
where on
is the name of the dataloader to evaluate on. This should be one of the keys in the dataloaders
. This method returns a dictionary containing the evaluation metrics.
Predict raw data with
model.predict(data)
where data
is the raw input data. This method returns the postprocessed prediction.
Get started with simple self-contained examples:
Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC
class itself.