diff --git a/test/test_distributions.py b/test/test_distributions.py index a3adae7df..6d23aa905 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -53,6 +53,8 @@ ) from numpyro.nn import AutoregressiveNN +from .utils import get_python_version_specific_seed + TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests. @@ -1653,7 +1655,7 @@ def test_gof(jax_dist, sp_dist, params): num_samples = 10000 if "BetaProportion" in jax_dist.__name__: num_samples = 20000 - rng_key = random.PRNGKey(19470715) + rng_key = random.PRNGKey(get_python_version_specific_seed(0, 19470715)) d = jax_dist(*params) samples = d.sample(key=rng_key, sample_shape=(num_samples,)) probs = np.exp(d.log_prob(samples)) @@ -1935,7 +1937,7 @@ def test_mean_var(jax_dist, sp_dist, params): else 200000 ) d_jax = jax_dist(*params) - k = random.PRNGKey(19470715) + k = random.PRNGKey(get_python_version_specific_seed(0, 19470715)) samples = d_jax.sample(k, sample_shape=(n,)).astype(np.float32) # check with suitable scipy implementation if available # XXX: VonMises is already tested below @@ -2433,7 +2435,7 @@ def test_biject_to(constraint, shape): assert transform.codomain.upper_bound == constraint.upper_bound if len(shape) < event_dim: return - rng_key = random.PRNGKey(19470715) + rng_key = random.PRNGKey(get_python_version_specific_seed(0, 19470715)) x = random.normal(rng_key, shape) y = transform(x) @@ -2558,7 +2560,7 @@ def inv_vec_transform(y): ) def test_bijective_transforms(transform, event_shape, batch_shape): shape = batch_shape + event_shape - rng_key = random.PRNGKey(20020626) + rng_key = random.PRNGKey(get_python_version_specific_seed(0, 20020626)) x = biject_to(transform.domain)(random.normal(rng_key, shape)) y = transform(x) diff --git a/test/test_distributions_util.py b/test/test_distributions_util.py index ef434201a..14be5d47c 100644 --- a/test/test_distributions_util.py +++ b/test/test_distributions_util.py @@ -26,6 +26,8 @@ von_mises_centered, ) +from .utils import get_python_version_specific_seed + @pytest.mark.parametrize("x, y", [(0.2, 10.0), (0.6, -10.0)]) def test_binary_cross_entropy_with_logits(x, y): @@ -133,7 +135,9 @@ def test_vec_to_tril_matrix(shape, diagonal): @pytest.mark.parametrize("dim", [1, 4]) @pytest.mark.parametrize("coef", [1, -1]) def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef): - key1, key2 = random.split(random.PRNGKey(19470715)) + key1, key2 = random.split( + random.PRNGKey(get_python_version_specific_seed(0, 19470715)) + ) A = random.normal(key1, chol_batch_shape + (dim, dim)) A = A @ jnp.swapaxes(A, -2, -1) + jnp.eye(dim) x = random.normal(key2, vec_batch_shape + (dim,)) * 0.1 diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 000000000..d1ffb910c --- /dev/null +++ b/test/utils.py @@ -0,0 +1,21 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + + +import sys + + +def get_python_version_specific_seed( + seed_for_py_3_9: int, seed_not_for_py_3_9: int +) -> int: + """After release of `jax==0.5.0`, we need different seeds for tests in Python 3.9 + and other versions. This function returns the seed based on the Python version. + + :param seed_for_py_3_9: Seed for Python 3.9 + :param seed_not_for_py_3_9: Seed for other versions of Python + :return: Seed based on the Python version + """ + if sys.version_info.minor == 9: + return seed_for_py_3_9 + else: + return seed_not_for_py_3_9