A pure functional machine learning library build on top of Google JAX.


Keywords
deep-learning, jax, machine-learning, neural-networks
License
MIT
Install
pip install mlax-nn==0.2.0

Documentation

MLAX: Functional NN library built on top of Google JAX

Overview | Installation | Quickstart | Examples | Documentation

What is MLAX?

MLAX is a purely functional neural network library built on top of Google JAX.

MLAX follows object-oriented semantics like Keras and PyTorch but remains fully compatible with native JAX transformations.

Learn more about MLAX on Read the Docs.

Installation

Install JAX first if you have not already.

pip install mlax-nn

Quickstart

This is a simple lazy linear layer defined in MLAX.

import jax
from jax import (
    numpy as jnp,
    nn,
    random
)
from mlax import Module, Parameter, Variable

class Linear(Module):
    def __init__(self, rng, out_features):
        super().__init__()
        self.rng = Variable(data=rng)
        self.out_features = out_features
        
        self.kernel_weight = Parameter()
        self.bias_weight = Parameter()
    
    # Define a ``set_up`` method for lazy initialziation of parameters
    def set_up(self, x):
        rng1, rng2 = random.split(self.rng.data)
        self.kernel_weight.data = nn.initializers.lecun_normal()(
            rng1, (x.shape[-1], self.out_features)
        )
        self.bias_weight.data=nn.initializers.zeros(rng2, (self.out_features,))

    # Define an ``forward`` method for the forward pass
    def forward(
        self, x, rng = None, inference_mode = False, batch_axis_name = ()
    ):
        return x @ self.kernel_weight.data + self.bias_weight.data

It is fully compatible with native JAX transformations:

def loss_fn(x, y, model):
    pred, model = model(x, rng=None, inference_mode=True)
    return jnp.mean(y - pred) ** 2, model

x = jnp.ones((4, 3), dtype=jnp.float32)
y = jnp.ones((4, 4), dtype=jnp.float32)
model = Linear(random.PRNGKey(0), 4)

loss, updated_model = loss_fn(x, y, model)
print(loss)

# Now let's apply `jax.jit` and `jax.value_and_grad`
(loss, updated_model), grads = jax.jit(
    jax.value_and_grad(
        loss_fn,
        has_aux=True
    )
)(x, y, model)

print(loss)
print(grads)

For end-to-end examples with reference PyTorch implementations, visit MLAX's GitHub.

View the full documentation on Read the Docs.

Bugs and Feature Requests

Please create an issue on MLAX's Github repository.

Contribution

If you wish to contribute, thank you and please contact me by email: y22zong@uwaterloo.ca.