jax-toolkit

A collection of jax functions to help with common machine/deep learning related functionality.


License
Apache-2.0
Install
pip install jax-toolkit==0.2.0

Documentation

jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.

Documentation, PyPi

This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.

Installation

pip install jax_toolkit

Or for additional loss function utils:

pip install jax_toolkit[losses_utils]