KL Divergence Layers
In this post, we will cover the easy way to handle KL divergence with tensorflow probability layer object. This is the summary of lecture "Probabilistic Deep Learning with Tensorflow 2" from Imperial College London.
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
plt.rcParams['figure.figsize'] = (10, 6)
print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)
Samples
latent_size=4
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
encoder = Sequential([
Dense(64, activation='relu', input_shape=(12,)),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size),
tfpl.KLDivergenceAddLoss(prior) # automatically add loss function into a model to be optimized later on
])
decoder = Sequential([
Dense(64, activation='relu', input_shape=(latent_size,)),
Dense(tfpl.IndependentNormal.params_size(12)),
tfpl.IndepedentNormal(12)
])
vae = Model(inputs=encoder.input, outputs=decoder(encoder.output))
vae.compile(loss=lambda x, pred: -pred.log_prob(x))
vae.fit(train_data, epochs=20)
Or you can implement KL Divergence that can use exact value by using use_exact_kl
keyword. Or you can also multiply weights in KL term.
encoder = Sequential([
Dense(64, activation='relu', input_shape=(12,)),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size),
tfpl.KLDivergenceAddLoss(prior, use_exact_kl=False, weight=10) # Use MC sampling for KL divergence, then weight it by 10
])
encoder = Sequential([
Dense(64, activation='relu', input_shape=(12,)),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size,
convert_to_tensor_fn=tfp.distributions.Distribution.sample),
tfpl.KLDivergenceAddLoss(prior) # automatically add loss function into a model to be optimized later on
])
In this case, the output of encoder will be the sample from multivariate normal distribution. Note that, above example is for Computing KL divergence. If you use convert_to_tensor_fn
to mean
or mode
, then it will be the tensor that would be used in the approximation.
encoder = Sequential([
Dense(64, activation='relu', input_shape=(12,)),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size),
tfpl.KLDivergenceAddLoss(prior, use_exact_kl=False, weight=10,
test_points_fn=lambda q: q.sample(10), # 10 samples for test points
test_points_reduce_axis=0) # automatically add loss function into a model to be optimized later on
])
So at that case, test point function is required to compute the estimation.
Alternative way to implement KL divergence is to use KLDivergenRegularizer
for the regularizer.
encoder = Sequential([
Dense(64, activation='relu', input_shape=(12,)),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size,
activity_regularizer=tfpl.KLDivergenceRegularizer(
prior, weight=10, use_exact_kl=False,
test_points_fn=lambda q: q.sample(10),
test_points_reduce_axis=0))
])
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Reshape
(X_train, _), (X_test, _) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.astype('float32') / 256. + 0.5 / 256
X_test = X_test.astype('float32') / 256. + 0.5 / 256
example_X = X_test[:16]
batch_size = 32
X_train = tf.data.Dataset.from_tensor_slices((X_train, X_train)).batch(batch_size)
X_test = tf.data.Dataset.from_tensor_slices((X_test, X_test)).batch(batch_size)
latent_size = 4
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
event_shape = (28, 28)
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(16, activation='relu'),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size),
tfpl.KLDivergenceAddLoss(prior) # estimate KL[ q(z|x) || p(z)]
])
# Samples z_j from q(z | x_j)
# then computes log q(z_j | x_j) - log p(z_j)
# encoder.losses before the network has received any inputs
encoder.losses
encoder(example_X)
encoder.losses
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(16, activation='relu'),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size),
tfpl.KLDivergenceAddLoss(prior,
use_exact_kl=False,
weight=1.5,
test_points_fn=lambda q: q.sample(10),
test_points_reduce_axis=0) # estimate KL[ q(z|x) || p(z)]
])
# (n_samples, batch_size, dim_z)
# z_{ij} is the ith sample for x_j (is at (i, j, :) in tensor of samples)
# is mapped to log q(z_{ij}|x_j) - log p(z_{ij})
# => tensor of KL Divergences has sape (n_samples, batch_size)
divergence_regularizer = tfpl.KLDivergenceRegularizer(prior,
use_exact_kl=False,
test_points_fn=lambda q: q.sample(5),
test_points_reduce_axis=0)
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(16, activation='relu'),
Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),
tfpl.MultivariateNormalTriL(latent_size,
activity_regularizer=divergence_regularizer),
])
decoder = Sequential([
Dense(16, activation='relu', input_shape=(latent_size,)),
Dense(32, activation='relu'),
Dense(64, activation='relu'),
Dense(128, activation='relu'),
Dense(2*event_shape[0]*event_shape[1], activation='exponential'),
Reshape((event_shape[0], event_shape[1], 2)),
tfpl.DistributionLambda(
lambda t: tfd.Independent(
tfd.Beta(concentration1=t[..., 0],
concentration0=t[..., 1])
)
)
])
1.19.x
instead of 1.20.x
. See the reference.
vae = Model(inputs=encoder.inputs, outputs=decoder(encoder.outputs))
# -E_{z ~ q(z | x)}[log p(x | z)]
def log_loss(X_true, p_x_given_z):
return -tf.reduce_sum(p_x_given_z.log_prob(X_true))
vae.compile(loss=log_loss)
vae.fit(X_train, validation_data=X_test, epochs=10)
example_reconstruction = vae(example_X).mean().numpy().squeeze()
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :].squeeze(), cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
example_reconstruction = vae(example_X).sample().numpy().squeeze()
# Plot the example reconstructions
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :].squeeze(), cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')