Flexible Modules for JAX


Keywords
deep-learning, jax
License
MIT
Install
pip install ninjax==3.1.0

Documentation

PyPI

🥷 Ninjax: Flexible Modules for JAX

Ninjax is a general and practical module system for JAX. It gives users full and transparent control over updating the state of each module, bringing flexibility to JAX and enabling new use cases.

Overview

Ninjax provides a simple and general nj.Module class.

  • Modules can store state for things like model parameters, Adam momentum buffer, BatchNorm statistics, recurrent state, etc.
  • Modules can read and write their state entries. For example, this allows modules to have train methods, because they can update their parameters from the inside.
  • Any method can initialize, read, and write state entries. This avoids the need for a special build() method or @compact decorator used in Flax.
  • Ninjax makes it easy to mix and match modules from different libraries, such as Flax and Haiku.
  • Instead of PyTrees, Ninjax state is a flat dict that maps string keys like /net/layer1/weights to jnp.arrays. This makes it easy to iterate over, modify, and save or load state.
  • Modules can specify typed hyperparameters using the dataclass syntax.

Installation

Ninjax is a single file, so you can just copy it to your project directory. Or you can install the package:

pip install ninjax

Quickstart

import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax

Linear = nj.FromFlax(flax.linen.Dense)


class MyModel(nj.Module):

  lr: float = 1e-3

  def __init__(self, size):
    self.size = size
    # Define submodules upfront
    self.h1 = Linear(128, name='h1')
    self.h2 = Linear(128, name='h2')
    self.opt = optax.adam(self.lr)

  def predict(self, x):
    x = jax.nn.relu(self.h1(x))
    x = jax.nn.relu(self.h2(x))
    # Define submodules inline
    x = self.sub('h3', Linear, self.size, use_bias=False)(x)
    # Create state entries inline
    x += self.value('bias', jnp.zeros, self.size)
    # Update state entries inline
    self.write('bias', self.read('bias') + 0.1)
    return x

  def loss(self, x, y):
    return ((self.predict(x) - y) ** 2).mean()

  def train(self, x, y):
    # Take grads wrt. to submodules or state keys
    wrt = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
    loss, params, grads = nj.grad(self.loss, wrt)(x, y)
    # Update weights
    state = self.sub('optstate', nj.Tree, self.opt.init, params)
    updates, new_state = self.opt.update(grads, state.read(), params)
    params = optax.apply_updates(params, updates)
    nj.context().update(params)  # Store the new params
    state.write(new_state)       # Store new optimizer state
    return loss


# Create model and example data
model = MyModel(3, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)

# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias'])

# Purify for JAX transformations
train = jax.jit(nj.pure(model.train))

# Training loop
for x, y in [(x, y)] * 10:
  state, loss = train(state, x, y)
  print('Loss:', float(loss))

# Look at the parameters
print(state['model/bias'])

Questions

If you have a question, please file an issue.