latentshift

A method to generate counterfactuals


License
Apache-2.0
Install
pip install latentshift==0.0.4

Documentation

Latent Shift - A Simple Autoencoder Approach to Counterfactual Generation

Open In Colab

The idea

Read the paper: https://arxiv.org/abs/2102.09475

Watch a video: https://www.youtube.com/watch?v=1fxSDP8DheI

The main diagram: latentshift.gif

Animations/GIFs

Smiling Arched Eyebrows
Mouth Slightly Open Young

Generating a transition sequence

For a predicting of smiling

gen_sequence.png

Multiple different targets

Comparison to traditional methods

For a predicting of pointy_nose

comparison.png

Getting Started

$pip install latentshift
import latentshift
# Load classifier and autoencoder
model = latentshift.classifiers.FaceAttribute()
ae = latentshift.autoencoders.Transformer(weights="celeba")

# Load image
input = torch.randn(1, 3, 1024, 1024)

# Defining Latent Shift module
attr = captum.attr.LatentShift(model, ae)

# Computes counterfactual for class 3.
output = attr.attribute(input, target=3)