LPIPS-J
This is a minimal JAX/Flax port of lpips
, as implemented in:
Only the essential features have been implemented. Our motivation is to support VQGAN training for DALL•E Mini.
It currently supports the vgg16
backend, leveraging the implementation in flaxmodels
.
Pre-trained weights for the network and the linear layers are downloaded from the
Installation
- Install JAX for CUDA or TPU following the instructions at https://github.com/google/jax#installation.
- Install this package from the repository:
pip install --upgrade git+https://github.com/pcuenca/lpips-j.git