latentshift

A method to generate counterfactuals


License
Apache-2.0
Install
pip install latentshift==0.0.5

Documentation

Latent Shift - A Simple Autoencoder Approach to Counterfactual Generation

Open In Colab

The idea

Read the paper about Latent Shift: https://arxiv.org/abs/2102.09475

Watch a video: https://www.youtube.com/watch?v=1fxSDP8DheI

Read the paper about Counterfactual Alignment: https://arxiv.org/abs/2312.02186

The main diagram: latentshift.gif

Animations/GIFs

Smiling Arched Eyebrows
Mouth Slightly Open Young

Generating a transition sequence

For a predicting of smiling

gen_sequence.png

Multiple different targets

Comparison to traditional methods

For a predicting of pointy_nose

comparison.png

Getting Started

$pip install latentshift
import latentshift
# Load classifier and autoencoder
model = latentshift.classifiers.FaceAttribute(download=True)
ae = latentshift.autoencoders.VQGAN(weights="faceshq", download=True)

# Load image
input = torch.randn(1, 3, 1024, 1024)

# Defining Latent Shift module
attr = captum.attr.LatentShift(model, ae)

# Computes counterfactual for class 3.
output = attr.attribute(input, target=3)