ADVI class is an object but ADVI.objective_fun is a pure function that can be optimized with optax or jaxopt or any other jax supported optimizers.
variational distribution parameters can be initialized with ADVI.init using a distrax or tfp distribution as an initializer (or any jax distribution that implements .sample() method in a similar way).
Users can pass the suitable bijectors of class distrax.Bijector to the variational distribution.
Transformation is directly applied to posterior and thus prior and likelihood stay untouched during the entire process. This way, after the training, the variational distribution is ready for sampling without any additional transformations. Also, this gives freedom to variational distribution to be constructed in more complex way as it is separated from the other parts of the model (see the example below).
If we do not change the key during the training, the method is called the deterministic ADVI.
Users can implement their own likelihood_log_prob_fun because likelihood does not necessarily have to be a
distribution.
A Coin Toss Example
importjaximportjax.numpyasjnpfromadvi_jaximportADVIfromadvi_jax.variational_distributionsimportMeanFieldfromadvi_jax.initimportinitializeimporttensorflow_probability.substrates.jaxastfpdist=tfp.distributions# Datatosses=jnp.array([0, 1, 0, 0, 1, 0])
# Prior and likelihoodprior_dist=dist.Beta(2.0, 3.0)
likelihood_log_prob_fun=lambdatheta: dist.Bernoulli(probs=theta).log_prob(tosses).sum()
# ADVI modelmodel=ADVI(prior_dist, likelihood_log_prob_fun, tosses)
# Variational distribution and bijectorbijector=distrax.Sigmoid()
variational_dist=MeanField(u_mean=jnp.array(0.0), u_scale=jnp.array(0.0), bijector=bijector)
# Initialize the parameters of variational distributionkey=jax.random.PRNGKey(0)
variational_dist=initialize(key, variational_dist, initializer=dist.Normal(0.0, 1.0))
# Define the value and grad functionvalue_and_grad_fun=jax.jit(jax.value_and_grad(model.objective_fun, argnums=1), static_argnums=2)
# Do gradient descent!learning_rate=0.01foriinrange(100):
key=jax.random.PRNGKey(i) # If this is constant, this becomes deterministic ADVIloss_value, grads=value_and_grad_fun(key, variational_dist, n_samples=10)
variational_dist=variational_dist-learning_rate*grad# Get the posterior sampleskey=jax.random.PRNGKey(2)
posterior_samples=variational_dist.sample(seed=key, sample_shape=(100,))
The Tidelift Subscription provides access to a continuously curated stream of human-researched and maintainer-verified data on open source packages and their licenses, releases, vulnerabilities, and development practices.