Tools for JAX.

pip install tjax==0.14.2


Tools for JAX

This repository implements a variety of tools for the differential programming library JAX.

Major components

Tjax's major components are:

  • A dataclass decorator dataclasss that facilitates defining JAX trees, and has a MyPy plugin. (See dataclass and mypy_plugin.)
  • A fixed point finding library heavily based on fax. Our library (fixed_point):
    • supports stochastic iterated functions,
    • uses dataclasses instead of closures to avoid leaking JAX tracers, and
    • supports higher-order differentiation.

Minor components

Tjax also includes:

  • An object-oriented wrapper on top of optax. (See gradient.)
  • A pretty printer print_generic for aggregate and vector types, including dataclasses. (See display.)
  • Versions of custom_vjp and custom_jvp that support being used on methods. (See shims.)
  • Tools for working with cotangents. (See cotangent_tools.)
  • A random number generator class Generator. (See generator.)
  • JAX tree registration for NetworkX graph types. (See graph.)
  • Leaky integration leaky_integrate and Ornstein-Uhlenbeck process iteration diffused_leaky_integrate. (See leaky_integral.)
  • An improved version of jax.tree_util.Partial. (See partial.)
  • A Matplotlib trajectory plotter PlottableTrajectory. (See plottable_trajectory.)
  • A testing function assert_tree_allclose that automatically produces testing code. And, a related function tree_allclose. (See testing.)
  • Basic tools like divide_where. (See tools.)

Also, see the documentation.

Contribution guidelines

  • Conventions: PEP8.
  • How to run tests: pytest .
  • How to clean the source:
    • isort tjax
    • pylint tjax
    • mypy tjax
    • flake8 tjax