Skip to content

Commit

Permalink
Fix a bug in the CholeskyLKJ density.
Browse files Browse the repository at this point in the history
Fixes #694

PiperOrigin-RevId: 288722702
  • Loading branch information
bloops authored and tensorflower-gardener committed Jan 8, 2020
1 parent fd7cd5f commit 205919f
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,7 @@ multi_substrate_py_test(
srcs = [
"cholesky_lkj_test.py",
],
jax_size = "large",
jax_tags = ["notap"],
deps = [
# absl/testing:parameterized dep,
# numpy dep,
Expand Down
40 changes: 31 additions & 9 deletions tensorflow_probability/python/distributions/cholesky_lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ class CholeskyLKJ(distribution.Distribution):
In other words, if If `X ~ CholeskyLKJ(c)`, then `X @ X^T ~ LKJ(c)`.
The distribution is supported on lower triangular N x N matrices which are
Cholesky factors of correlation matrices; equivalently, matrices whose rows
have Euclidean norm 1 and diagonal entries are positive. The probability
density function is given by:
pdf(X; c) = (1/Z(c)) prod_i X_ii^{n-i+2c-3} (0 <= i < n)
where there are n(n-1)/2 independent variables X_ij (0 <= i < j < n) and
Z(c) is the normalizing constant; the same one as that of LKJ(c).
For more details on the LKJ distribution, see `tfp.distributions.LKJ`.
#### Examples
Expand Down Expand Up @@ -169,11 +179,9 @@ def _log_prob(self, x):
# This log_prob comes from using a change of variables via the Cholesky
# decomposition on the LKJ's log_prob.
# The first term represents the change of variables of the LKJ's
# unnormalized log_prob, the second is the normalization term coming
# from the LKJ distribution, and the final is a normalization term
# coming from the change of variables.
return (self._log_unnorm_prob(x, concentration) -
normalizer + self.dimension * np.log(2.))
# unnormalized log_prob and the second is the normalization term coming
# from the LKJ distribution.
return self._log_unnorm_prob(x, concentration) - normalizer

def _log_unnorm_prob(self, x, concentration, name=None):
"""Returns the unnormalized log density of a CholeskyLKJ distribution.
Expand All @@ -195,12 +203,26 @@ def _log_unnorm_prob(self, x, concentration, name=None):
x = tf.convert_to_tensor(x, name='x')
logdiag = tf.math.log(tf.linalg.diag_part(x))
# We pick up a weighted sum of the log(diag) due to the jacobian
# of the cholesky decomposition. See `tfp.bijectors.CholeskyOuterProduct`
# for details.
# of the cholesky decomposition. By an argument similar to that of
# `tfp.bijectors.CholeskyOuterProduct`, the jacobian is given by:
# prod_i x_ii^{n-i-1} (0 <= i < n).
#
# To see this, observe that if x @ x^T = p, then p_ij depends only on
# those x_kl where k<=i and l<=j. Therefore, on vectorizing the strictly
# lower triangular parts of x and p, we get that the jacobian matrix
# [d/dvec(x) vec(p)] is lower triangular. The jacobian determinant is then
# the product of the n(n-1)/2 diagonal entries:
# J = prod_ij d/dx_ij p_ij (0 <= j < i < n)
# = prod_ij d/dx_ij (x_i0 * x_j0 + x_i1 * x_j1 + ... + x_ij * x_jj)
# = prod_ij x_jj
# = prod_i x_ii^{n-i-1}
#
# For more details, see `tfp.bijectors.CholeskyOuterProduct`.
dimension_range = np.linspace(
self.dimension - 1,
0.,
self.dimension,
1., self.dimension, dtype=dtype_util.as_numpy_dtype(
concentration.dtype))
dtype=dtype_util.as_numpy_dtype(concentration.dtype))
return tf.reduce_sum(
(2. * concentration[..., tf.newaxis] - 2. + dimension_range) *
logdiag, axis=-1)
Expand Down
113 changes: 100 additions & 13 deletions tensorflow_probability/python/distributions/cholesky_lkj_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,108 @@
class CholeskyLKJTest(test_util.TestCase):

def testLogProbMatchesTransformedDistribution(self, dtype):
dtype = np.float64
for dims in (3, 4, 5):

# In this test, we start with a distribution supported on N x N SPD matrices
# which factorizes as an LKJ-supported correlation matrix and a diagonal
# of exponential random variables. The total number of independent
# parameters is N(N-1)/2 + N = N(N+1)/2. Using the CholeskyOuterProduct
# bijector (which requires N(N+1)/2 independent parameters), we map it to
# the space of lower triangular Cholesky factors. We show that the resulting
# distribution factorizes as the product of CholeskyLKJ and Rayleigh
# distributions.

# Given a sample from the 'LKJ + Exponential' distribution on SPD matrices,
# and the corresponding log prob, returns the transformed sample and log
# prob in Cholesky-space; which is furthermore factored into a diagonal
# matrix and a Cholesky factor of a correlation matrix.
def _get_transformed_sample_and_log_prob(lkj_exp_sample, lkj_exp_log_prob,
dimension):
# We change variables to the space of SPD matrices parameterized by the
# lower triangular entries (including the diagonal) of \Sigma.
# The transformation is given by \Sigma = \sqrt{S} P \sqrt{S}.
#
# The jacobian matrix J of the forward transform has the block form
# [[I, 0]; [*, D]]; where I is the N x N identity matrix; and D is an
# N*(N - 1) square diagonal matrix and * need not be computed.
# Here, N is the dimension. D_ij (0<=i<j<n) equals:
# d/d(P_ij) \Sigma_ij = d/d(P_ij) \sqrt{S_i S_j} P_ij = \sqrt{S_i S_j}.
# Hence, detJ = \prod_ij (i < j) \sqrt{S_i S_j} [N(N-1)/2 terms]
# = \prod_i S_i^{.5 * (N - 1)}
# [1] Zhenxun Wang and Yunan Wu and Haitao Chu
# 'On equivalence of the LKJ distribution and the restricted Wishart
# distribution'. 2018
exp_variance, lkj_corr = lkj_exp_sample
sqrt_exp_variance = tf.math.sqrt(exp_variance[..., tf.newaxis])
sigma_sample = (tf.linalg.matrix_transpose(sqrt_exp_variance) *
lkj_corr * sqrt_exp_variance)
sigma_log_prob = lkj_exp_log_prob - .5 * (dimension - 1) * tf.reduce_sum(
tf.math.log(tf.linalg.diag_part(sigma_sample)), axis=-1)

# We change variables again to lower triangular L; where LL^T = \Sigma.
# This is just inverse of the tfb.CholeskyOuterProduct bijector.
cholesky_sigma_sample = tf.linalg.cholesky(sigma_sample)
cholesky_sigma_log_prob = sigma_log_prob + tfb.Invert(
tfb.CholeskyOuterProduct()).inverse_log_det_jacobian(
cholesky_sigma_sample, event_ndims=2)

# Change of variables to R, A; where L = RA; R is diagonal matrix
# with each dimension's standard deviation and A is a Cholesky factor of a
# correlation matrix. A is parameterized by its strictly lower triangular
# entries; i.e. N(N-1)/2 entries.
#
# The jacobian determinant is the product of each row's transformation, as
# each row is transformed independently as R_ii = ||L_i|| and
# A_ij = L_ij/R_ii. Here ||...|| denotes Euclidean norm. Direct
# computation shows that the jacobian determinant for the ith row is
# R^{i-1} / A_ii.
std_dev_sample = tf.linalg.norm(cholesky_sigma_sample, axis=-1)
cholesky_corr_sample = (
cholesky_sigma_sample / std_dev_sample[..., tf.newaxis])

cholesky_corr_std_dev_sample = (std_dev_sample, cholesky_corr_sample)
cholesky_corr_std_dev_log_prob = (
cholesky_sigma_log_prob + tf.reduce_sum(
tf.range(dimension, dtype=dtype) * tf.math.log(std_dev_sample) -
tf.math.log(tf.linalg.diag_part(cholesky_corr_sample)),
axis=-1))

return cholesky_corr_std_dev_sample, cholesky_corr_std_dev_log_prob

for dimension in (2, 3, 4, 5):
rate = np.linspace(.5, 2., 10, dtype=dtype)
concentration = np.linspace(2., 5., 10, dtype=dtype)
cholesky_lkj = tfd.CholeskyLKJ(
concentration=concentration, dimension=dims, validate_args=True)
transformed_lkj = tfd.TransformedDistribution(
bijector=tfb.Invert(tfb.CholeskyOuterProduct()),
distribution=tfd.LKJ(concentration=concentration, dimension=dims),
validate_args=True)

# Choose input that has well conditioned matrices.
x = self.evaluate(cholesky_lkj.sample(10, seed=test_util.test_seed()))

# We start with a distribution on SPD matrices given by the product of
# LKJ and Exponential random variables.
lkj_exponential_covariance_dist = tfd.JointDistributionSequential([
tfd.Sample(tfd.Exponential(rate=rate), sample_shape=dimension),
tfd.LKJ(dimension=dimension, concentration=concentration)
])
x = self.evaluate(
lkj_exponential_covariance_dist.sample(
10, seed=test_util.test_seed()))

# We transform a sample from the space of SPD matrices to the space of its
# lower triangular Cholesky factors. We decompose it into the product of a
# diagonal matrix and a Cholesky factor of a correlation matrix.
transformed_x, transformed_log_prob = (
_get_transformed_sample_and_log_prob(
x, lkj_exponential_covariance_dist.log_prob(x), dimension))

# We now show that the transformation resulted in a distribution which
# factors as the product of a rayleigh (the square root of an exponential)
# and a CholeskyLKJ distribution with the same parameters as the LKJ.
rayleigh_dist = tfd.TransformedDistribution(
bijector=tfb.Invert(tfb.Square()),
distribution=tfd.Exponential(rate=rate))
cholesky_lkj_rayleigh_dist = tfd.JointDistributionSequential([
tfd.Sample(rayleigh_dist, sample_shape=dimension),
tfd.CholeskyLKJ(dimension=dimension, concentration=concentration)
])
self.assertAllClose(
self.evaluate(cholesky_lkj.log_prob(x)),
self.evaluate(transformed_lkj.log_prob(x)))
self.evaluate(transformed_log_prob),
self.evaluate(cholesky_lkj_rayleigh_dist.log_prob(transformed_x)),
rtol=1e-3 if dtype == np.float32 else 1e-6)

def testDimensionGuard(self, dtype):
testee_lkj = tfd.CholeskyLKJ(
Expand Down

0 comments on commit 205919f

Please sign in to comment.