Maximizing the ELBO
In this post, we will cover the complete implementation of Variational AutoEncoder, which can optimize the ELBO objective function. This is the summary of lecture "Probabilistic Deep Learning with Tensorflow 2" from Imperial College London.
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from IPython.display import HTML, Image
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['animation.embed_limit'] = 2**128
print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)
Approximating True Posterior distribution
$ \text{encoder }(x) = q(z \vert x) \simeq p(z \vert x) \\ \begin{aligned} \log p(x) & \ge \mathbb{E}_{z \sim q(z \vert x)}[-\log q(z \vert x) + \log p(x \vert z)] \quad \leftarrow \text{maximizing this lower bound} \\ &= - \mathrm{KL} (q(z \vert x) \vert \vert p(z)) + \mathbb{E}_{z \sim q(z \vert x)}[\log p(x \vert z)] \quad \leftarrow \text{Evidence Lower Bound (ELBO)} \end{aligned}$
latent_size = 2
event_shape = (28, 28, 1)
encoder = Sequential([
Conv2D(8, (5, 5), strides=2, activation='tanh', input_shape=event_shape),
Conv2D(8, (5, 5), strides=2, activatoin='tanh'),
Flatten(),
Dense(64, activation='tanh'),
Dense(2 * latent_size),
tfpl.DistributionLambda(lambda t: tfd.MultivariateNormalDiag(
loc=t[..., :latent_size], scale_diag=tf.math.exp(t[..., latent_size:]))),
], name='encoder')
encoder(X_train[:16])
Almose reverse order of Encoder.
decoder = Sequential([
Dense(64, activation='tanh', input_shape=(latent_size, )),
Dense(128, activation='tanh'),
Reshape((4, 4, 8)), # In order to put it in the form required by Conv2D layer
Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),
Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),
Conv2D(1, (3, 3), padding='SAME'),
Flatten(),
tfpl.IndependentBernoulli(event_shape)
], name='decoder')
decoder(tf.random.normal([16, latent_size])
ELBO objective function
One way to implement ELBO function is to use Analytical computation of KL divergence.
def loss_fn(X_true, approx_posterior, X_pred, prior_dist):
"""
X_true: batch of data examples
approx_posterior: the output of encoder
X_pred: output of decoder
prior_dist: Prior distribution
"""
return tf.reduce_mean(tfd.kl_divergence(approx_posterior, prior_dist) - X_pred.log_prob(X_true))
The other way is using Monte Carlo Sampling instead of analyticall with the KL Divergence.
def loss_fn(X_true, approx_posterior, X_pred, prior_dist):
reconstruction_loss = -X_pred.log_prob(X_true)
approx_posterior_sample = approx_posterior.sample()
kl_approx = (approx_posterior.log_prob(approx_posterior_sample) - prior_dist.log_prob(approx_posterior_sample))
return tf.reduce_mean(kl_approx + reconstruction_loss)
Calculating Gradient of Loss function
@tf.function
def get_loss_and_grads(x):
with tf.GradientTape() as tape:
approx_posterior = encoder(x)
approx_posterior_sample = approx_posterior.sample()
X_pred = decoder(approx_posterior_sample)
current_loss = loss_fn(x, approx_posterior, X_pred, prior)
grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
return current_loss, grads
Review of terminology:
- $p(z)$ = prior
- $q(z|x)$ = encoding distribution
- $p(x|z)$ = decoding distribution
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Reshape
(X_train, _), (X_test, _) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
example_X = X_test[:16]
batch_size = 64
X_train = tf.data.Dataset.from_tensor_slices(X_train).batch(batch_size)
latent_size = 2
event_shape = (28, 28)
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(256, activation='relu'),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(2 * latent_size),
tfpl.DistributionLambda(
lambda t: tfd.MultivariateNormalDiag(
loc=t[..., :latent_size],
scale_diag=tf.math.exp(t[..., latent_size:])
)
)
])
encoder(example_X)
decoder = Sequential([
Dense(32, activation='relu'),
Dense(64, activation='relu'),
Dense(128, activation='relu'),
Dense(256, activation='relu'),
Dense(tfpl.IndependentBernoulli.params_size(event_shape)),
tfpl.IndependentBernoulli(event_shape)
])
decoder(tf.random.normal([16, latent_size]))
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
The loss function we need to estimate is $$ -\mathrm{ELBO} = \mathrm{KL}[ \ q(z|x) \ || \ p(z) \ ] - \mathrm{E}_{Z \sim q(z|x)}[\log p(x|Z)]\\ $$ where $x = (x_1, x_2, \ldots, x_n)$ refers to all observations, $z = (z_1, z_2, \ldots, z_n)$ refers to corresponding latent variables.
Assumed independence of examples implies that we can write this as $$ \sum_j \mathrm{KL}[ \ q(z_j|x_j) \ || \ p(z_j) \ ] - \mathrm{E}_{Z_j \sim q(z_j|x_j)}[\log p(x_j|Z_j)] $$
def loss(x, encoding_dist, sampled_decoding_dist, prior):
return tf.reduce_sum(
tfd.kl_divergence(encoding_dist, prior) - sampled_decoding_dist.log_prob(x)
)
@tf.function
def get_loss_and_grads(x):
with tf.GradientTape() as tape:
encoding_dist = encoder(x)
sampled_z = encoding_dist.sample()
sampled_decoding_dist = decoder(sampled_z)
current_loss = loss(x, encoding_dist, sampled_decoding_dist, prior)
grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
return current_loss, grads
num_epochs = 10
optimizer = tf.keras.optimizers.Adam()
for i in range(num_epochs):
for train_batch in X_train:
current_loss, grads = get_loss_and_grads(train_batch)
optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
print('-ELBO after epoch {}: {:.0f}'.format(i + 1, current_loss.numpy()))
def vae(inputs):
approx_posterior = encoder(inputs)
decoding_dist = decoder(approx_posterior.sample())
return decoding_dist.sample()
example_reconstruction = vae(example_X).numpy().squeeze()
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
Since the model has lack of reconstruction from grayscale image, So using mean for reconstruction gets more satisfied results.
def vae_mean(inputs):
approx_posterior = encoder(inputs)
decoding_dist = decoder(approx_posterior.sample())
return decoding_dist.mean()
example_reconstruction = vae_mean(example_X).numpy().squeeze()
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
z = prior.sample(6)
generated_x = decoder(z).sample()
f, axs = plt.subplots(1, 6, figsize=(16, 5))
for j in range(6):
axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')
axs[j].axis('off')
z = prior.sample(6)
generated_x = decoder(z).mean()
f, axs = plt.subplots(1, 6, figsize=(16, 5))
for j in range(6):
axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')
axs[j].axis('off')
What if we use Monte Carlo Sampling for kl divergence?
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(256, activation='relu'),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(2 * latent_size),
tfpl.DistributionLambda(
lambda t: tfd.MultivariateNormalDiag(
loc=t[..., :latent_size],
scale_diag=tf.math.exp(t[..., latent_size:])
)
)
])
decoder = Sequential([
Dense(32, activation='relu'),
Dense(64, activation='relu'),
Dense(128, activation='relu'),
Dense(256, activation='relu'),
Dense(tfpl.IndependentBernoulli.params_size(event_shape)),
tfpl.IndependentBernoulli(event_shape)
])
# Define the prior, p(z) - a standard bivariate Gaussian
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
def loss(x, encoding_dist, sampled_decoding_dist, prior, sampled_z):
reconstruction_loss = -sampled_decoding_dist.log_prob(x)
kl_approx = (encoding_dist.log_prob(sampled_z) - prior.log_prob(sampled_z))
return tf.reduce_sum(kl_approx + reconstruction_loss)
@tf.function
def get_loss_and_grads(x):
with tf.GradientTape() as tape:
encoding_dist = encoder(x)
sampled_z = encoding_dist.sample()
sampled_decoding_dist = decoder(sampled_z)
current_loss = loss(x, encoding_dist, sampled_decoding_dist, prior, sampled_z)
grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
return current_loss, grads
num_epochs = 10
optimizer = tf.keras.optimizers.Adam()
for i in range(num_epochs):
for train_batch in X_train:
current_loss, grads = get_loss_and_grads(train_batch)
optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
print('-ELBO after epoch {}: {:.0f}'.format(i + 1, current_loss.numpy()))
def vae_mean(inputs):
approx_posterior = encoder(inputs)
decoding_dist = decoder(approx_posterior.sample())
return decoding_dist.mean()
example_reconstruction = vae_mean(example_X).numpy().squeeze()
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
z = prior.sample(6)
generated_x = decoder(z).mean()
# Display generated_x
f, axs = plt.subplots(1, 6, figsize=(16, 5))
for j in range(6):
axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')
axs[j].axis('off')