Skip to content

Commit

Permalink
fix(tests): use version-specific PRNGKey seeds for improved test reli…
Browse files Browse the repository at this point in the history
…ability
  • Loading branch information
Qazalbash committed Jan 26, 2025
1 parent a90e0d2 commit 7822ace
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
10 changes: 6 additions & 4 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
)
from numpyro.nn import AutoregressiveNN

from .utils import get_python_version_specific_seed

TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests.


Expand Down Expand Up @@ -1653,7 +1655,7 @@ def test_gof(jax_dist, sp_dist, params):
num_samples = 10000
if "BetaProportion" in jax_dist.__name__:
num_samples = 20000
rng_key = random.PRNGKey(19470715)
rng_key = random.PRNGKey(get_python_version_specific_seed(0, 19470715))
d = jax_dist(*params)
samples = d.sample(key=rng_key, sample_shape=(num_samples,))
probs = np.exp(d.log_prob(samples))
Expand Down Expand Up @@ -1935,7 +1937,7 @@ def test_mean_var(jax_dist, sp_dist, params):
else 200000
)
d_jax = jax_dist(*params)
k = random.PRNGKey(19470715)
k = random.PRNGKey(get_python_version_specific_seed(0, 19470715))
samples = d_jax.sample(k, sample_shape=(n,)).astype(np.float32)
# check with suitable scipy implementation if available
# XXX: VonMises is already tested below
Expand Down Expand Up @@ -2433,7 +2435,7 @@ def test_biject_to(constraint, shape):
assert transform.codomain.upper_bound == constraint.upper_bound
if len(shape) < event_dim:
return
rng_key = random.PRNGKey(19470715)
rng_key = random.PRNGKey(get_python_version_specific_seed(0, 19470715))
x = random.normal(rng_key, shape)
y = transform(x)

Expand Down Expand Up @@ -2558,7 +2560,7 @@ def inv_vec_transform(y):
)
def test_bijective_transforms(transform, event_shape, batch_shape):
shape = batch_shape + event_shape
rng_key = random.PRNGKey(20020626)
rng_key = random.PRNGKey(get_python_version_specific_seed(0, 20020626))
x = biject_to(transform.domain)(random.normal(rng_key, shape))
y = transform(x)

Expand Down
6 changes: 5 additions & 1 deletion test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
von_mises_centered,
)

from .utils import get_python_version_specific_seed


@pytest.mark.parametrize("x, y", [(0.2, 10.0), (0.6, -10.0)])
def test_binary_cross_entropy_with_logits(x, y):
Expand Down Expand Up @@ -133,7 +135,9 @@ def test_vec_to_tril_matrix(shape, diagonal):
@pytest.mark.parametrize("dim", [1, 4])
@pytest.mark.parametrize("coef", [1, -1])
def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef):
key1, key2 = random.split(random.PRNGKey(19470715))
key1, key2 = random.split(
random.PRNGKey(get_python_version_specific_seed(0, 19470715))
)
A = random.normal(key1, chol_batch_shape + (dim, dim))
A = A @ jnp.swapaxes(A, -2, -1) + jnp.eye(dim)
x = random.normal(key2, vec_batch_shape + (dim,)) * 0.1
Expand Down
21 changes: 21 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


import sys


def get_python_version_specific_seed(
seed_for_py_3_9: int, seed_not_for_py_3_9: int
) -> int:
"""After release of `jax==0.5.0`, we need different seeds for tests in Python 3.9
and other versions. This function returns the seed based on the Python version.
:param seed_for_py_3_9: Seed for Python 3.9
:param seed_not_for_py_3_9: Seed for other versions of Python
:return: Seed based on the Python version
"""
if sys.version_info.minor == 9:
return seed_for_py_3_9
else:
return seed_not_for_py_3_9

0 comments on commit 7822ace

Please sign in to comment.