The TransformedDistribution class
In this post, we are going to take a look at transform distribution objects as a module. This is the summary of lecture "Probabilistic Deep Learning with Tensorflow 2" from Imperial College London.
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
plt.rcParams['figure.figsize'] = (10, 6)
print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)
Overview
The transformedDistribution is sort of distribution that can be defined by another base distribution and a bijector object. Tensorflow Probability offers transformed distribution object with consistent API that can use same methods and properties of other distribution.
normal = tfd.Normal(loc=0., scale=1.)
z = normal.sample(3)
z
scale_and_shift = tfb.Chain([tfb.Shift(1.), tfb.Scale(2.)])
x = scale_and_shift.forward(z)
x
log_prob_z = normal.log_prob(z)
log_prob_z
log_prob_x = (log_prob_z - scale_and_shift.forward_log_det_jacobian(z, event_ndims=0))
log_prob_x
Note that, the event_ndims
argument means the number of rightmost dimensions of z make up the event shape. So in the above case, the log of the jacobian determinant is calculated for each element of the tensor z.
Or we express it with the inverse of the bijective transformation.
log_prob_x = (log_prob_z + scale_and_shift.inverse_log_det_jacobian(x, event_ndims=0))
log_prob_x
The result is the same as while using inverse of x.
log_prob_x = (normal.log_prob(scale_and_shift.inverse(x)) + scale_and_shift.inverse_log_det_jacobian(x, event_ndims=0))
log_prob_x
You may notice that log probability of x can be calculated with only using z or x. In practice, most of cases uses second expression. The reason is that the z is from base distriubtion. So in terms of analysis, it is the latent variable. But x is from the data distribution, and it is the output from transformed distribution. While using mentioned approach, we can express transform object with bijector or invertible, it can be learned with best parameters for maximum likelihood.
# Base distribution Transformation Data distribution
# z ~ P0 <=> x = f(z) <=> x ~ P1
log_prob_x = (base_dist.log_prob(bijector.inverse(x)) + bijector.inverse_log_det_jacobian(x, event_ndims=0))
### Training
x_sample = bijector.forward(base_dist.sample())
normal = tfd.Normal(loc=0., scale=1.)
z = normal.sample(3)
z
exp = tfb.Exp()
x = exp.forward(z)
x
log_normal = tfd.TransformedDistribution(normal, exp)
log_normal
Above expression is same with like this,
log_normal = exp(normal)
log_normal
log_normal.sample()
log_normal.log_prob(x)
We can also define specific event_shape
and batch_shape
for transformedDistribtion.
normal = tfd.Normal(loc=0., scale=1.)
scale_tril = [[1., 0.], [1., 1.]]
scale = tfb.ScaleMatvecTriL(scale_tril=scale_tril)
mvn = tfd.TransformedDistribution(tfd.Sample(normal, sample_shape=[2]), scale)
mvn
scale_tril = [[[1., 0.], [1., 1.]], [[0.5, 0.], [-1., 0.5]]]
scale = tfb.ScaleMatvecTriL(scale_tril=scale_tril)
mvn = tfd.TransformedDistribution(tfd.Sample(tfd.Normal(loc=[0., 0.], scale=1.), sample_shape=[2], ), scale)
mvn
n = 10000
loc = 0
scale = 0.5
normal = tfd.Normal(loc=loc, scale=scale)
print('batch shape: ', normal.batch_shape)
print('event shape: ', normal.event_shape)
exp = tfb.Exp()
log_normal_td = exp(normal)
print('batch shape: ', log_normal_td.batch_shape)
print('event shape: ', log_normal_td.event_shape)
z = normal.sample(n)
plt.hist(z.numpy(), bins=100, density=True)
plt.show()
x = log_normal_td.sample(n)
plt.hist(x.numpy(), bins=100, density=True)
plt.show()
log_normal = tfd.LogNormal(loc=loc, scale=scale)
l = log_normal.sample(n)
plt.hist(l.numpy(), bins=100, density=True)
plt.show()
log_prob = log_normal.log_prob(x)
log_prob_td = log_normal_td.log_prob(x)
tf.norm(log_prob - log_prob_td)
tril = tf.random.normal((2, 4, 4))
scale_low_tri = tf.linalg.LinearOperatorLowerTriangular(tril)
scale_low_tri.to_dense()
scale_lin_op = tfb.ScaleMatvecLinearOperator(scale_low_tri)
mvn = tfd.TransformedDistribution(tfd.Sample(tfd.Normal(loc=[0., 0.], scale=1.), sample_shape=[4]), scale_lin_op)
print('batch shape: ', mvn.batch_shape)
print('event shape: ', mvn.event_shape)
y1 = mvn.sample(sample_shape=(n,))
print(y1.shape)
mvn2 = tfd.MultivariateNormalLinearOperator(loc=0, scale=scale_low_tri)
mvn2
y2 = mvn2.sample(sample_shape=(n, ))
y2.shape
xn = normal.sample((n, 2, 4))
tf.norm(mvn.log_prob(xn) - mvn2.log_prob(xn)) / tf.norm(mvn.log_prob(xn))