jax-enums

JAX-compatible Enumerations.


License
Other
Install
pip install jax-enums==0.1.2

Documentation

JAX_ENUMS: Jax-compatible enumerables

CI CD PyPI version

Installation | Examples | Cite

Installation

pip install jax_enums

Example

class Foo(Enumerable):
    BAR = 0
    BAZ = 1

def f(array: jax.Array, enumerable: Enum) -> jax.Array:
    return array[enumerable.value]

array = jnp.zeros((2, 2))
enumerable = Foo.BAR

f(array, enumerable)
jax.jit(f)(array, enumerable)

Cite

@misc{pignatelli2023jax_enums,
  author = {Pignatelli, Eduardo},
  title = {JAX_ENUMS: JAX-compatible enumerations},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/epignatelli/jax_enums}}
  }