torch2jax

Run PyTorch in JAX. 🤝


Install
pip install torch2jax==0.0.1