Subclassing Bijectors
In this post, we are going to make customized transformation with our own bijectors for fexibility. This is the summary of lecture "Probabilistic Deep Learning with Tensorflow 2" from Imperial College London.
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
plt.rcParams['figure.figsize'] = (10, 6)
print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)
class MySigmoid(tfb.Bijector):
def __init__(self, validate_args=False, name='sigmoid'):
super(MySigmoid, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name)
def _forward(self, x):
return tf.math.sigmoid(x)
def _inverse(self, y):
return tf.math.log(y) - tf.math.log(1 - y)
def _inverse_log_det_jacobian(self, y):
return -tf.math.log(y) - tf.math.log(1 - y)
def _forward_log_det_jacobian(self, x):
return -self.inverse_log_det_jacobian(self._forward(x))
Note that, while implementing our bijector, the name of method should contain underscore in front of the name itself.
class MySigmoid(tfb.Bijector):
def __init__(self, validate_args=False, name='sigmoid'):
super(MySigmoid, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name)
def _forward(self, x):
return tf.math.sigmoid(x)
def _inverse(self, y):
return tf.math.log(y) - tf.math.log(1 - y)
def _inverse_log_det_jacobian(self, y):
return -_forward_log_det_jacobian(self._inverse(y))
def _forward_log_det_jacobian(self, x):
return -tf.math.softplus(-x) - tf.math.softplus(x)
class MyShift(tfb.Bijector):
def __init__(self, shift, validate_args=False, name='shift'):
self.shift = shift
super(MyShift, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name, is_constant_jacobian=True)
def _forward(self, x):
return x + self.shift
def _inverse(self, y):
return y - self.shift
def _forward_log_det_jacobian(self, x):
return tf.constant(0., x.dtype)
class MyTanh(tfb.Bijector):
def __init__(self, validate_args=False, name='cube'):
super(MyTanh, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name)
def _forward(self, x):
return tf.math.tanh(x)
def _inverse(self, y):
return tf.math.atanh(y)
def _forward_log_det_jacobian(self, x):
return tf.math.log1p(-tf.square(tf.tanh(x)))
class Cubic(tfb.Bijector):
def __init__(self, a, b, validate_args=False, name='Cubic'):
self.a = tf.cast(a, tf.float32)
self.b = tf.cast(b, tf.float32)
if validate_args:
assert tf.reduce_mean(tf.cast(tf.math.greater_equal(tf.abs(self.a), 1e-5), tf.float32)) == 1.0
assert tf.reduce_mean(tf.cast(tf.math.greater_equal(tf.abs(self.b), 1e-5), tf.float32)) == 1.0
super(Cubic, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name)
def _forward(self, x):
x = tf.cast(x, tf.float32)
return tf.squeeze(tf.pow(self.a * x + self.b, 3))
def _inverse(self, y):
y = tf.cast(y, tf.float32)
return (tf.math.sign(y) * tf.pow(tf.abs(y), 1/3) - self.b) / self.a
def _forward_log_det_jacobian(self, x):
x = tf.cast(x, tf.float32)
return tf.math.log(3. * tf.abs(self.a)) + 2. * tf.math.log(tf.abs(self.a * x + self.b))
cubic = Cubic([1.0, -2.0], [-1.0, 0.4], validate_args=True)
x = tf.constant([[1, 2], [3, 4]])
y = cubic.forward(x)
y
np.linalg.norm(x - cubic.inverse(y))
x = np.linspace(-10, 10, 500).reshape(-1, 1)
plt.plot(x, cubic.forward(x))
plt.show()
cubic.forward(x).shape
plt.plot(x, cubic.inverse(x))
plt.show()
plt.plot(x, cubic.forward_log_det_jacobian(x, event_ndims=0))
plt.show()
plt.plot(x, cubic.inverse_log_det_jacobian(x, event_ndims=0))
plt.show()
normal = tfd.Normal(loc=0., scale=1.)
cubed_normal = tfd.TransformedDistribution(tfd.Sample(normal, sample_shape=[2]), cubic)
n = 1000
g = cubed_normal.sample(n)
g.shape
plt.subplot(1, 2, 1)
plt.hist(g[..., 0].numpy(), bins=50, density=True)
plt.subplot(1, 2, 2)
plt.hist(g[..., 1].numpy(), bins=50, density=True)
plt.show()
xx = np.linspace(-0.5, 0.5, 100)
yy = np.linspace(-0.5, 0.5, 100)
X, y = np.meshgrid(xx, yy)
fig, ax = plt.subplots(1, 1)
Z = cubed_normal.prob(np.dstack((X, y)))
cp = ax.contourf(X, y, Z)
fig.colorbar(cp)
ax.set_title('Filled Contours Plot')
ax.set_xlabel('X')
ax.set_ylabel('y')
plt.show()
inverse_cubic = tfb.Invert(cubic)
inv_cubed_normal = inverse_cubic(tfd.Sample(normal, sample_shape=[2]))
g = inv_cubed_normal.sample(n)
g.shape
xx = np.linspace(-3.0, 3.0, 100)
yy = np.linspace(-2.0, 2.0, 100)
X, y = np.meshgrid(xx, yy)
fig, ax = plt.subplots(1, 1)
Z = inv_cubed_normal.prob(np.dstack((X, y)))
cp = ax.contourf(X, y, Z)
fig.colorbar(cp)
ax.set_title('Filled Contours Plot')
ax.set_xlabel('X')
ax.set_ylabel('y')
plt.show()
plt.subplot(1, 2, 1)
plt.hist(g[..., 0].numpy(), bins=50, density=True)
plt.subplot(1, 2, 2)
plt.hist(g[..., 1].numpy(), bins=50, density=True)
plt.show()
probs = [0.45, 0.55]
mix_gauss = tfd.Mixture(
cat=tfd.Categorical(probs=probs),
components=[
tfd.Normal(loc=2.3, scale=0.4),
tfd.Normal(loc=-0.8, scale=0.4)
]
)
X_train = mix_gauss.sample(10000)
X_train = tf.data.Dataset.from_tensor_slices(X_train).batch(128)
X_valid = mix_gauss.sample(10000)
X_valid = tf.data.Dataset.from_tensor_slices(X_valid).batch(128)
print(X_train.element_spec)
print(X_valid.element_spec)
x = np.linspace(-5.0, 5.0, 100)
plt.plot(x, mix_gauss.prob(x))
plt.title('Data Distribution')
plt.show()
trainable_inv_cubic = tfb.Invert(Cubic(tf.Variable(0.25), tf.Variable(-0.1)))
trainable_inv_cubic.trainable_variables
trainable_dist = tfd.TransformedDistribution(normal, trainable_inv_cubic)
x = np.linspace(-5.0, 5.0, 100)
plt.plot(x, mix_gauss.prob(x), label='data')
plt.plot(x, trainable_dist.prob(x), label='trainable')
plt.title('Data and trainable distribution')
plt.legend(loc='best')
plt.show()
num_epochs = 10
opt = tf.keras.optimizers.Adam()
train_losses = []
valid_losses = []
for epoch in range(num_epochs):
print("Epoch {}...".format(epoch))
train_loss = tf.keras.metrics.Mean()
val_loss = tf.keras.metrics.Mean()
for train_batch in X_train:
with tf.GradientTape() as tape:
tape.watch(trainable_inv_cubic.trainable_variables)
loss = -trainable_dist.log_prob(train_batch)
train_loss(loss)
grads = tape.gradient(loss, trainable_inv_cubic.trainable_variables)
opt.apply_gradients(zip(grads, trainable_inv_cubic.trainable_variables))
train_losses.append(train_loss.result().numpy())
# Validation
for valid_batch in X_valid:
loss = -trainable_dist.log_prob(valid_batch)
val_loss(loss)
valid_losses.append(val_loss.result().numpy())
plt.plot(train_losses, label='train')
plt.plot(valid_losses, label='valid')
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Negative log likelihood")
plt.title("Training and validation loss curves")
plt.show()
x = np.linspace(-5.0, 5.0, 100)
plt.plot(x, mix_gauss.prob(x), label='data')
plt.plot(x, trainable_dist.prob(x), label='learned')
plt.title('Data and learned distribution')
plt.legend(loc='best')
plt.show()
trainable_inv_cubic.trainable_variables