Skip to content

Commit

Permalink
Add arrays test for Wald, remove useless test and simplify sampling f…
Browse files Browse the repository at this point in the history
…rom numpy wald
  • Loading branch information
matteo-pallini committed Jun 30, 2021
1 parent 44343e3 commit a68aee1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ class WaldRV(RandomVariable):

@classmethod
def rng_fn(cls, rng, mu, lam, alpha, size):
return getattr(np.random.RandomState, cls.name)(rng, mu, lam, size) + alpha
return rng.wald(mu, lam, size=size) + alpha


wald = WaldRV()
Expand Down
15 changes: 0 additions & 15 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2807,21 +2807,6 @@ def test_lower_bounded_broadcasted(self):
assert lower_interval.value == -1
assert upper_interval is None

def test_rich_context(self):
with Model() as model:
sigma = TruncatedNormal("lower_bounded", mu=2, sigma=1.5, lower=0, upper=None)
mu = TruncatedNormal("upper_bounded", mu=0, sigma=2, lower=None, upper=3)
Normal("normal", mu=mu, sigma=sigma, observed=[1.3, -1.4, 2.0])
(
(_, _, lower, upper),
lower_interval,
upper_interval,
) = self.get_dist_params_and_interval_bounds(model, "upper_bounded")
assert lower.value == -np.inf
assert upper.value == 3
assert lower_interval is None
assert upper_interval.value == 3


@pytest.mark.xfail(reason="LaTeX repr and str no longer applicable")
class TestStrAndLatexRepr:
Expand Down
25 changes: 25 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,31 @@ class TestWaldMuPhi(TestWald):
]


class TestTruncatedNormalUpperArray(BaseTestDistribution):
pymc_dist = pm.TruncatedNormal
lower, upper, mu, tau = (
np.array([-np.inf, -np.inf]),
np.array([3, 2]),
np.array([0, 0]),
np.array(
[
1,
1,
]
),
)
size = (15, 2)
tau, sigma = get_tau_sigma(tau=tau, sigma=None)
pymc_dist_params = {"mu": mu, "tau": tau, "upper": upper}
expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper}
reference_dist_params = {"loc": mu, "scale": sigma, "a": lower, "b": (upper - mu) / sigma}
reference_dist = seeded_scipy_distribution_builder("truncnorm")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
]


class TestSkewNormal(BaseTestDistribution):
pymc_dist = pm.SkewNormal
pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
Expand Down

0 comments on commit a68aee1

Please sign in to comment.