From a68aee161a7f78d2780c7ceb3e67dfdb7ffaff6f Mon Sep 17 00:00:00 2001 From: drabbit17 Date: Wed, 30 Jun 2021 23:34:20 +0100 Subject: [PATCH] Add arrays test for Wald, remove useless test and simplify sampling from numpy wald --- pymc3/distributions/continuous.py | 2 +- pymc3/tests/test_distributions.py | 15 -------------- pymc3/tests/test_distributions_random.py | 25 ++++++++++++++++++++++++ 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 39f30b84119..ded45a59680 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -891,7 +891,7 @@ class WaldRV(RandomVariable): @classmethod def rng_fn(cls, rng, mu, lam, alpha, size): - return getattr(np.random.RandomState, cls.name)(rng, mu, lam, size) + alpha + return rng.wald(mu, lam, size=size) + alpha wald = WaldRV() diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index b61ccd1b717..90008cbd92f 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -2807,21 +2807,6 @@ def test_lower_bounded_broadcasted(self): assert lower_interval.value == -1 assert upper_interval is None - def test_rich_context(self): - with Model() as model: - sigma = TruncatedNormal("lower_bounded", mu=2, sigma=1.5, lower=0, upper=None) - mu = TruncatedNormal("upper_bounded", mu=0, sigma=2, lower=None, upper=3) - Normal("normal", mu=mu, sigma=sigma, observed=[1.3, -1.4, 2.0]) - ( - (_, _, lower, upper), - lower_interval, - upper_interval, - ) = self.get_dist_params_and_interval_bounds(model, "upper_bounded") - assert lower.value == -np.inf - assert upper.value == 3 - assert lower_interval is None - assert upper_interval.value == 3 - @pytest.mark.xfail(reason="LaTeX repr and str no longer applicable") class TestStrAndLatexRepr: diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 1d82886cec2..e486afc33c4 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -681,6 +681,31 @@ class TestWaldMuPhi(TestWald): ] +class TestTruncatedNormalUpperArray(BaseTestDistribution): + pymc_dist = pm.TruncatedNormal + lower, upper, mu, tau = ( + np.array([-np.inf, -np.inf]), + np.array([3, 2]), + np.array([0, 0]), + np.array( + [ + 1, + 1, + ] + ), + ) + size = (15, 2) + tau, sigma = get_tau_sigma(tau=tau, sigma=None) + pymc_dist_params = {"mu": mu, "tau": tau, "upper": upper} + expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + reference_dist_params = {"loc": mu, "scale": sigma, "a": lower, "b": (upper - mu) / sigma} + reference_dist = seeded_scipy_distribution_builder("truncnorm") + tests_to_run = [ + "check_pymc_params_match_rv_op", + "check_pymc_draws_match_reference", + ] + + class TestSkewNormal(BaseTestDistribution): pymc_dist = pm.SkewNormal pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}