Skip to content

Commit

Permalink
Reenable old MatrixNormal test
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 25, 2022
1 parent dd06623 commit db0b762
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
55 changes: 29 additions & 26 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,7 +1710,12 @@ class TestMatrixNormal(BaseTestDistributionRandom):
pymc_dist_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov}
expected_rv_op_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov}

tests_to_run = ["check_pymc_params_match_rv_op", "test_matrix_normal", "test_errors"]
tests_to_run = [
"check_pymc_params_match_rv_op",
"test_matrix_normal",
"check_errors",
"check_matrix_normal_random_with_random_variables",
]

def test_matrix_normal(self):
delta = 0.05 # limit for KS p-value
Expand Down Expand Up @@ -1746,7 +1751,7 @@ def ref_rand(mu, rowcov, colcov):

assert p > delta

def test_errors(self):
def check_errors(self):
msg = "MatrixNormal doesn't support size argument"
with pm.Model():
with pytest.raises(NotImplementedError, match=msg):
Expand Down Expand Up @@ -1778,6 +1783,28 @@ def test_errors(self):
shape=15,
)

def check_matrix_normal_random_with_random_variables(self):
"""
This test checks for shape correctness when using MatrixNormal distribution
with parameters as random variables.
Originally reported - https://github.com/pymc-devs/pymc/issues/3585
"""
K = 3
D = 15
mu_0 = np.zeros((D, K))
lambd = 1.0
with pm.Model() as model:
sd_dist = pm.HalfCauchy.dist(beta=2.5, size=D)
packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist, compute_corr=False)
L = pm.expand_packed_triangular(D, packedL, lower=True)
Sigma = pm.Deterministic("Sigma", L.dot(L.T)) # D x D covariance
mu = pm.MatrixNormal(
"mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
)
prior = pm.sample_prior_predictive(2, return_inferencedata=False)

assert prior["mu"].shape == (2, D, K)


class TestInterpolated(BaseTestDistributionRandom):
def interpolated_rng_fn(self, size, mu, sigma, rng):
Expand Down Expand Up @@ -2386,30 +2413,6 @@ def generate_shapes(include_params=False):
return data


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_matrix_normal_random_with_random_variables():
"""
This test checks for shape correctness when using MatrixNormal distribution
with parameters as random variables.
Originally reported - https://github.com/pymc-devs/pymc/issues/3585
"""
K = 3
D = 15
mu_0 = np.zeros((D, K))
lambd = 1.0
with pm.Model() as model:
sd_dist = pm.HalfCauchy.dist(beta=2.5)
packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist)
L = pm.expand_packed_triangular(D, packedL, lower=True)
Sigma = pm.Deterministic("Sigma", L.dot(L.T)) # D x D covariance
mu = pm.MatrixNormal(
"mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
)
prior = pm.sample_prior_predictive(2)

assert prior["mu"].shape == (2, D, K)


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestMvGaussianRandomWalk(SeededTest):
@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,14 @@ def test_missing_data_model(self):
# See https://github.com/pymc-devs/pymc/issues/5255
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)

@pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
@pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4")
def test_mv_missing_data_model(self):
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)

model = pm.Model()
with model:
mu = pm.Normal("mu", 0, 1, size=2)
sd_dist = pm.HalfNormal.dist(1.0)
sd_dist = pm.HalfNormal.dist(1.0, size=2)
# pylint: disable=unpacking-non-sequence
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
# pylint: enable=unpacking-non-sequence
Expand Down

0 comments on commit db0b762

Please sign in to comment.