javiche

A JAX wrapper around ceviche to make interoperability easier. In the future it might make sense to update ceviche itself to use JAX internally.


Keywords
nbdev, jupyter, notebook, python, ceviche, fdtd, inverse-design
License
Apache-2.0
Install
pip install javiche==0.0.6

Documentation

javiche

Small package to enable using ceviche with a JAX optimizer easily.

Install

This package is not yet published. As soon as it is install with:

pip install javiche

or

conda install javiche

How to use

Import the decorator

from javiche import jaxit

decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)

@jaxit()
def square(A):
  """squares number/array"""
  return A**2

Now you can use jax as usual:

grad_fn = jax.grad(square)
grad_fn(2.0)
Array(4., dtype=float32, weak_type=True)

In this toy example that was already possible without the jaxit() decorator. However jaxit() decorated functions can contain autograd operators (but no jax operators):

import autograd.numpy as npa
def sin(A):
  """computes sin of number/array using autograds numpy"""
  return npa.sin(A)
grad_sin = jax.grad(sin)
try:
  print(grad_sin(0.0))
except Exception as e:
  print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
  primal = 0.0
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[], weak_type=True), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
  """computes sin of number/array using autograds numpy"""
  return npa.cos(A)

grad_cos = jax.grad(cos)
try:
  print(grad_cos(0.0))
except Exception as e:
  print(e)
-0.0

Usecase

This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example