Packages

import tensorflow as tf
import tensorflow_probability as tfp

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors

plt.rcParams['figure.figsize'] = (10, 6)
print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)
Tensorflow Version:  2.5.0
Tensorflow Probability Version:  0.13.0

CelebA overview image

The Large-scale CelebFaces Attributes (CelebA) Dataset

For this assignment you will use a subset of the CelebFaces Attributes (CelebA) dataset. The full dataset contains over 200K images CelebA contains thousands of colour images of the faces of celebrities, together with tagged attributes such as 'Smiling', 'Wearing glasses', or 'Wearing lipstick'. It also contains information about bounding boxes and facial part localisation. CelebA is a popular dataset that is commonly used for face attribute recognition, face detection, landmark (or facial part) localization, and face editing & synthesis.

  • Z. Liu, P. Luo, X. Wang, and X. Tang. "Deep Learning Face Attributes in the Wild", Proceedings of International Conference on Computer Vision (ICCV), 2015.

You can read about the dataset in more detail here.

Load the dataset

The following functions will be useful for loading and preprocessing the dataset. The subset you will use for this assignment consists of 10,000 training images, 1000 validation images and 1000 test images. These examples have been chosen to respect the original training/validation/test split of the dataset.

Note: Original dataset is too large to maintain in github. If you want it, please check the official pages.
def load_dataset(split):
    train_list_ds = tf.data.Dataset.from_tensor_slices(np.load('./dataset/vae-celeba/{}.npy'.format(split)))
    train_ds = train_list_ds.map(lambda x: (x, x))
    return train_ds
train_ds = load_dataset('train')
val_ds = load_dataset('val')
test_ds = load_dataset('test')
n_examples_shown = 6
f, axs = plt.subplots(1, n_examples_shown, figsize=(16, 3))

for j, image in enumerate(train_ds.take(n_examples_shown)):
    axs[j].imshow(image[0])
    axs[j].axis('off')
batch_size = 32
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)

Mixture of Gaussians distribution

We will define a prior distribution that is a mixture of Gaussians. This is a more flexible distribution that is comprised of $K$ separate Gaussians, that are combined together with some weighting assigned to each.

Recall that the probability density function for a multivariate Gaussian distribution with mean $\mu\in\mathbb{R}^D$ and covariance matrix $\Sigma\in\mathbb{R}^{D\times D}$ is given by

$$ \mathcal{N}(\mathbf{z}; \mathbf{\mu}, \Sigma) = \frac{1}{(2\pi)^{D/2}|\Sigma|^{1/2}} \exp\left(-\frac{1}{2}(\mathbf{z}-\mathbf{\mu})^T\Sigma^{-1}(\mathbf{z}-\mathbf{\mu})\right). $$

A mixture of Gaussians with $K$ components defines $K$ Gaussians defined by means $\mathbf{\mu}_k$ and covariance matrices $\Sigma_k$, for $k=1,\ldots,K$. It also requires mixing coefficients $\pi_k$, $k=1,\ldots,K$ with $\sum_{k} \pi_k = 1$. These coefficients define a categorical distribution over the $K$ Gaussian components. To sample an event, we first sample from the categorical distribution, and then again from the corresponding Gaussian component.

The probability density function of the mixture of Gaussians is simply the weighted sum of probability density functions for each Gaussian component:

$$ p(\mathbf{z}) = \sum_{k=1}^K \pi_k \mathcal{N}(\mathbf{z}; \mathbf{\mu}_k, \Sigma_k) $$

Define the prior distribution

We will define the mixture of Gaussians distribution for the prior, for a given number of components and latent space dimension. Each Gaussian component will have a diagonal covariance matrix. This distribution will have fixed mixing coefficients, but trainable means and standard deviations.

def get_prior(num_modes, latent_dim):
    prior = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(probs=[1 / num_modes,] * num_modes),
        components_distribution=tfd.MultivariateNormalDiag(
            loc=tf.Variable(tf.random.normal(shape=[num_modes, latent_dim])),
            scale_diag=tfp.util.TransformedVariable(tf.Variable(tf.ones(shape=[num_modes, latent_dim])), bijector=tfb.Softplus())
        )
    )
    return prior
prior = get_prior(num_modes=2, latent_dim=50)
prior
<tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[] event_shape=[50] dtype=float32>

Define the encoder Network

We will now define the encoder network as part of the VAE. First, we will define the KLDivergenceRegularizer to use in the encoder network to add the KL divergence part of the loss.

def get_kl_regularizer(prior_distribution):
    divergence_regularizer = tfpl.KLDivergenceRegularizer(
        prior_distribution,
        use_exact_kl=False,
        weight=1.0,
        test_points_fn=lambda q: q.sample(3),
        test_points_reduce_axis=(0, 1)
    )
    return divergence_regularizer
kl_regularizer = get_kl_regularizer(prior)
kl_regularizer
<tensorflow_probability.python.layers.distribution_layer.KLDivergenceRegularizer at 0x7f6317ae0dd0>
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Flatten, Dense, UpSampling2D, Reshape
def get_encoder(latent_dim, kl_regularizer):
    encoder = Sequential([
        Conv2D(32, (4, 4), activation='relu', strides=2, padding='SAME', input_shape=(64, 64, 3)),
        BatchNormalization(),
        Conv2D(64, (4, 4), activation='relu', strides=2, padding='SAME'),
        BatchNormalization(),
        Conv2D(128, (4, 4), activation='relu', strides=2, padding='SAME'),
        BatchNormalization(),
        Conv2D(256, (4, 4), activation='relu', strides=2, padding='SAME'),
        BatchNormalization(),
        Flatten(),
        Dense(tfpl.MultivariateNormalTriL.params_size(latent_dim)),
        tfpl.MultivariateNormalTriL(latent_dim, activity_regularizer=kl_regularizer)
    ])
    return encoder
encoder = get_encoder(latent_dim=50, kl_regularizer=kl_regularizer)
encoder.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 32, 32, 32)        1568      
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 32)        128       
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 16, 16, 64)        32832     
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 64)        256       
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 8, 8, 128)         131200    
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 128)         512       
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 4, 4, 256)         524544    
_________________________________________________________________
batch_normalization_7 (Batch (None, 4, 4, 256)         1024      
_________________________________________________________________
flatten_1 (Flatten)          (None, 4096)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1325)              5428525   
_________________________________________________________________
multivariate_normal_tri_l_1  multiple                  200       
=================================================================
Total params: 6,120,789
Trainable params: 6,119,829
Non-trainable params: 960
_________________________________________________________________
tf.keras.utils.plot_model(encoder)

Define the decoder network

We'll define the decoder network for the VAE, which return IndependentBernoulli distribution of event_shape=(64, 64, 3)

def get_decoder(latent_dim):
    decoder = Sequential([
        Dense(4096, activation='relu', input_shape=(latent_dim, )),
        Reshape((4, 4, 256)),
        UpSampling2D(size=(2, 2)),
        Conv2D(128, (3, 3), activation='relu', padding='SAME'),
        UpSampling2D(size=(2, 2)),
        Conv2D(64, (3, 3), activation='relu', padding='SAME'),
        UpSampling2D(size=(2, 2)),
        Conv2D(32, (3, 3), activation='relu', padding='SAME'),
        UpSampling2D(size=(2, 2)),
        Conv2D(128, (3, 3), activation='relu', padding='SAME'),
        Conv2D(3, (3, 3), padding='SAME'),
        Flatten(),
        tfpl.IndependentBernoulli(event_shape=(64, 64, 3))
    ])
    return decoder
decoder = get_decoder(latent_dim=50)
decoder.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (None, 4096)              208896    
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 256)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 8, 8, 128)         295040    
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 16, 16, 64)        73792     
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 32, 32, 32)        18464     
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 64, 64, 128)       36992     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 64, 64, 3)         3459      
_________________________________________________________________
flatten_2 (Flatten)          (None, 12288)             0         
_________________________________________________________________
independent_bernoulli (Indep multiple                  0         
=================================================================
Total params: 636,643
Trainable params: 636,643
Non-trainable params: 0
_________________________________________________________________
tf.keras.utils.plot_model(decoder)

The following cell connects encoder and decoder to form the end-to-end architecture.

vae = Model(inputs=encoder.inputs, outputs=decoder(encoder.outputs))

Define the average reconstruction loss

You should now define the reconstruction loss that forms the remaining part of the negative ELBO objective. This function should take a batch of images of shape (batch_size, 64, 64, 3) in the first argument, and the output of the decoder after passing the batch of images through vae in the second argument.

The loss should be defined so that it returns $$ -\frac{1}{n}\sum_{i=1}^n \log p(x_i|z_i) $$ where $n$ is the batch size and $z_i$ is sampled from $q(z|x_i)$, the encoding distribution a.k.a. the approximate posterior. The value of this expression is always a scalar.

Expression (1) is, as you know, is an estimate of the (negative of the) batch's average expected reconstruction loss:

$$ -\frac{1}{n}\sum_{i=1}^n \mathrm{E}_{Z\sim q(z|x_i)}\big[\log p(x_i|Z)\big] $$
def reconstruction_loss(batch_of_images, decoding_dist):
    """
    The function takes batch_of_images (Tensor containing a batch of input images to
    the encoder) and decoding_dist (output distribution of decoder after passing the 
    image batch through the encoder and decoder) as arguments.
    The function should return the scalar average expected reconstruction loss.
    """
    return -tf.reduce_mean(decoding_dist.log_prob(batch_of_images), axis=0)

 Compile and fit the model

It's now time to compile and train the model. Note that, it is recommand to use Hardware accelerator while training.

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)
vae.compile(optimizer=optimizer, loss=reconstruction_loss)
vae.fit(train_ds, validation_data=val_ds, epochs=30)
Epoch 1/30
313/313 [==============================] - 6s 19ms/step - loss: 6898.9551 - val_loss: 6774.6699
Epoch 2/30
313/313 [==============================] - 5s 16ms/step - loss: 6673.9536 - val_loss: 6601.0190
Epoch 3/30
313/313 [==============================] - 5s 16ms/step - loss: 6565.9702 - val_loss: 6534.0977
Epoch 4/30
313/313 [==============================] - 5s 16ms/step - loss: 6501.4233 - val_loss: 6488.2212
Epoch 5/30
313/313 [==============================] - 5s 16ms/step - loss: 6455.9448 - val_loss: 6454.4487
Epoch 6/30
313/313 [==============================] - 5s 16ms/step - loss: 6419.7598 - val_loss: 6418.2305
Epoch 7/30
313/313 [==============================] - 5s 16ms/step - loss: 6389.2236 - val_loss: 6404.3799
Epoch 8/30
313/313 [==============================] - 5s 16ms/step - loss: 6363.1782 - val_loss: 6377.2412
Epoch 9/30
313/313 [==============================] - 5s 16ms/step - loss: 6340.8530 - val_loss: 6350.7163
Epoch 10/30
313/313 [==============================] - 5s 16ms/step - loss: 6321.7163 - val_loss: 6339.8857
Epoch 11/30
313/313 [==============================] - 5s 16ms/step - loss: 6303.2959 - val_loss: 6321.4873
Epoch 12/30
313/313 [==============================] - 5s 16ms/step - loss: 6288.5278 - val_loss: 6311.4629
Epoch 13/30
313/313 [==============================] - 5s 16ms/step - loss: 6276.3354 - val_loss: 6302.6299
Epoch 14/30
313/313 [==============================] - 5s 16ms/step - loss: 6267.0430 - val_loss: 6302.4316
Epoch 15/30
313/313 [==============================] - 5s 16ms/step - loss: 6258.6162 - val_loss: 6298.0195
Epoch 16/30
313/313 [==============================] - 5s 16ms/step - loss: 6252.2017 - val_loss: 6284.3271
Epoch 17/30
313/313 [==============================] - 5s 16ms/step - loss: 6247.2666 - val_loss: 6283.5566
Epoch 18/30
313/313 [==============================] - 5s 16ms/step - loss: 6244.1333 - val_loss: 6282.6460
Epoch 19/30
313/313 [==============================] - 5s 16ms/step - loss: 6246.8740 - val_loss: 6326.9819
Epoch 20/30
313/313 [==============================] - 5s 16ms/step - loss: 6238.1558 - val_loss: 6334.7046
Epoch 21/30
313/313 [==============================] - 5s 16ms/step - loss: 6230.5674 - val_loss: 6318.6094
Epoch 22/30
313/313 [==============================] - 5s 16ms/step - loss: 6226.9346 - val_loss: 6306.0063
Epoch 23/30
313/313 [==============================] - 5s 16ms/step - loss: 6222.0898 - val_loss: 6292.9888
Epoch 24/30
313/313 [==============================] - 5s 16ms/step - loss: 6217.1611 - val_loss: 6305.8818
Epoch 25/30
313/313 [==============================] - 5s 16ms/step - loss: 6213.1318 - val_loss: 6301.5923
Epoch 26/30
313/313 [==============================] - 5s 16ms/step - loss: 6209.0781 - val_loss: 6278.6670
Epoch 27/30
313/313 [==============================] - 5s 16ms/step - loss: 6205.8843 - val_loss: 6266.9268
Epoch 28/30
313/313 [==============================] - 5s 16ms/step - loss: 6202.8608 - val_loss: 6282.7930
Epoch 29/30
313/313 [==============================] - 5s 16ms/step - loss: 6206.2085 - val_loss: 6308.5342
Epoch 30/30
313/313 [==============================] - 5s 16ms/step - loss: 6221.4204 - val_loss: 6400.3052
<tensorflow.python.keras.callbacks.History at 0x7f6401084bd0>
test_loss = vae.evaluate(test_ds)
print("Test loss: {}".format(test_loss))
32/32 [==============================] - 0s 5ms/step - loss: 6327.6440
Test loss: 6327.64404296875

Compute reconstructions of test images

We will now take a look at some image reconstructions from the encoder-decoder architecture.

You should complete the following function, that uses encoder and decoder to reconstruct images from the test dataset. This function takes the encoder, decoder and a Tensor batch of test images as arguments. The function should be completed according to the following specification:

  • Get the mean of the encoding distributions from passing the batch of images into the encoder
  • Pass these latent vectors through the decoder to get the output distribution

Your function should then return the mean of the output distribution, which will be a Tensor of shape (batch_size, 64, 64, 3).

def reconstruct(encoder, decoder, batch_of_images):
    """
    The function takes the encoder, decoder and batch_of_images as inputs, which
    should be used to compute the reconstructions.
    The function should then return the reconstructions Tensor.
    """
    approx_posterior = encoder(batch_of_images)
    decoding_dist = decoder(approx_posterior.mean())
    return decoding_dist.mean()
n_reconstructions = 7
num_test_files = np.load('./dataset/vae-celeba/test.npy').shape[0]
test_ds_for_reconstructions = load_dataset('test')
for all_test_images, _ in test_ds_for_reconstructions.batch(num_test_files).take(1):
    all_test_images_np = all_test_images.numpy()
example_images = all_test_images_np[np.random.choice(num_test_files, n_reconstructions, replace=False)]

reconstructions = reconstruct(encoder, decoder, example_images).numpy()
f, axs = plt.subplots(2, n_reconstructions, figsize=(16, 6))
axs[0, n_reconstructions // 2].set_title("Original test images")
axs[1, n_reconstructions // 2].set_title("Reconstructed images")
for j in range(n_reconstructions):
    axs[0, j].imshow(example_images[j])
    axs[1, j].imshow(reconstructions[j])
    axs[0, j].axis('off')
    axs[1, j].axis('off')
    
plt.tight_layout();

Sample new images from the generative model

Now we will sample from the generative model; that is, first sample latent vectors from the prior, and then decode those latent vectors with the decoder.

You should complete the following function to generate new images. This function takes the prior distribution and decoder network as arguments, as well as the number of samples to generate. This function should be completed according to the following:

  • Sample a batch of n_samples images from the prior distribution, to obtain a latent vector Tensor of shape (n_samples, 50)
  • Pass this batch of latent vectors through the decoder, to obtain an Independent Bernoulli distribution with batch shape equal to [n_samples] and event shape equal to [64, 64, 3].

The function should then return the mean of the Bernoulli distribution, which will be a Tensor of shape (n_samples, 64, 64, 3).

def generate_images(prior, decoder, n_samples):
    """
    The function takes the prior distribution, decoder and number of samples as inputs, which
    should be used to generate the images.
    The function should then return the batch of generated images.
    """
    z = prior.sample(n_samples)
    return decoder(z).mean()
n_samples = 10
sampled_images = generate_images(prior, decoder, n_samples)

f, axs = plt.subplots(1, n_samples, figsize=(16, 6))

for j in range(n_samples):
    axs[j].imshow(sampled_images[j])
    axs[j].axis('off')
    
plt.tight_layout();