jaxformers

'Attention is all you need' in JAX (Flax)


Install
pip install jaxformers==0.0.1.dev1