optax-adan

An implementation of adan optimization algorithm for optax.


Keywords
deeplearning, jax, optax, optimization-algorithms, optimization-methods
License
Apache-2.0
Install
pip install optax-adan==0.1.5

Documentation

optax-adan

An implementation of adan optimizer for optax based on Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

Collab with usage example can be found here.

How to use:

Install the package:

python3 -m pip install optax-adan

Import the optimizer:

from optax_adan import adan

Use it as you would use any other optimizer from optax:

# init
optimizer = adan(learning_rate=0.01)
optimizer_state = optimizer.init(initial_params)
# step
grad = grad_func(params)
updates, optimizer_state = optimizer.update(grad, optimizer_state, params)
params = optax.apply_updates(params, updates)