Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tests): update tolerance levels and PRNGKey usage for improved test stability #1959

Merged
merged 18 commits into from
Jan 31, 2025

Conversation

Qazalbash
Copy link
Contributor

@Qazalbash Qazalbash commented Jan 26, 2025

With the release of jax==0.5.0, the jax_threefry_partitionable mode is now enabled by default. This change has caused some test cases to fail because the random numbers generated with the same seed differ depending on the status of jax_threefry_partitionable. For example,

Example

>>> python -c "import jax; jax.print_environment_info()"
jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='beepboop', release='6.1.0-22-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.1.94-1 (2024-06-21)', machine='x86_64')

>>> JAX_THREEFRY_PARTITIONABLE=False python -c "from jax import random as jrd; key=jrd.PRNGKey(0); print(jrd.uniform(key))"
0.41845703

>>> python -c "from jax import random as jrd; key=jrd.PRNGKey(0); print(jrd.uniform(key))"
0.947667

Tolerance levels are updated for failing test cases.

@martinjankowiak
Copy link
Collaborator

@Qazalbash thanks for taking a close look at this !!

After modifying the test cases to use appropriate seed values, most of the tests passed. However, in certain cases where the appropriate seed could not be identified, using higher numerical precision resolved the issue.

can you please explain why it's not possible (or why it's difficult) to get rid of more "magic numbers" like 19470715? wherever possible it's probably preferably to increase the tolerance, provided the tolerance isn't being increased by too much...

@Qazalbash
Copy link
Contributor Author

Qazalbash commented Jan 26, 2025

wherever possible it's probably preferably to increase the tolerance, provided the tolerance isn't being increased by too much...

While addressing a similar issue in a personal project, I found that the most practical solution was to modify the random seed. Since I am not an expert in probability or statistics, I assumed that the tolerances, as well as the relative and absolute error thresholds, are inherently tied to the algorithm's precision. To avoid altering these parameters, I opted to adjust the input data by changing the random keys.

can you please explain why it's not possible (or why it's difficult) to get rid of more "magic numbers" like 19470715?

Initially, I experimented with various integers, including prime numbers and other two-digit values. However, after some time, I decided to explore more meaningful options, such as historically significant dates or personal milestones, including my own birthday. Surprisingly, these choices proved effective in resolving the issue :).

I can not argue 'to get rid of more "magic numbers"' you can explore if running tests on higher precision works! Because almost all the failing ones were passing on higher precision without a change of seed value.

My suggestion would be, if you don't want to encounter more magic numbers, shifting to higher precision will be a better choice.

@fehiepsi
Copy link
Member

I think we should avoid magic numbers and address the numerical issues directly. It is better to have tests which are rarely failing (rather than sometimes passing).

@fehiepsi
Copy link
Member

I can take a stab at it.

@Qazalbash
Copy link
Contributor Author

I think we should avoid magic numbers and address the numerical issues directly. It is better to have tests which are rarely failing (rather than sometimes passing).

I second that! If you like I can test it.

@fehiepsi
Copy link
Member

Sure, thanks!!

@Qazalbash Qazalbash changed the title fix(tests): using different PRNGKey or high precision for failing tests fix(tests): update tolerance levels and PRNGKey usage for improved test stability Jan 27, 2025
@Qazalbash
Copy link
Contributor Author

@fehiepsi I have fixed all the tests and merged the master branch, too, except for test/infer/test_mcmc.py::test_change_point_x64, which is passing in 85d4982 and also on my machine. Its failure does not make sense to me, so I left this one to you.

test/infer/test_mcmc.py Outdated Show resolved Hide resolved
test/infer/test_mcmc.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions_util.py Outdated Show resolved Hide resolved
test/test_transforms.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member

Maybe the init values of the change_point test make MCMC stuck at local minimal. How about setting

from numpyro.infer import init_to_value
...

kernel = NUTS(model=model, init_strategy=init_to_value(values={"lambda1": 1, "lambda2": 72}))

@fehiepsi
Copy link
Member

fehiepsi commented Jan 31, 2025

There are a couple of failing tests:

  • test/infer/test_gradient.py: you can simply increase atol I think
  • test/infer/test_autoguide.py::test_dais_vae: increase svi.run steps to e.g. 5000
  • test_discrete_gibbs_multiple_sites_chain: you can increase atol to 0.1

Edit: turns out that you already did

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge thanks for fixing this annoying issue, @Qazalbash!

@Qazalbash
Copy link
Contributor Author

It is pretty suspicious these test cases were passing in 85d4982 even though py3.10 had jax==0.5.0.

@fehiepsi
Copy link
Member

Could you increase atol there? I feel the test is incorrect

@Qazalbash
Copy link
Contributor Author

Some more test cases have failed!

Can you check test/contrib/einstein/test_stein_loss.py::test_stein_particle_loss? It has a relatively large error.

>           assert_allclose(act_loss, exp_loss)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=0
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference among violations: 14.192255
E           Max relative difference among violations: 5.6824265
E            ACTUAL: array(-16.689825, dtype=float32)
E            DESIRED: array(-2.49757, dtype=float32)

test/contrib/einstein/test_stein_loss.py:95: AssertionError

@fehiepsi fehiepsi merged commit 7041846 into pyro-ppl:master Jan 31, 2025
10 checks passed
@Qazalbash
Copy link
Contributor Author

@fehiepsi, I enabled continue on error in 295136f, to get all the failed test cases at once, oblivious to the fact the test suite passed on that commit. Do you like to revert it?

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2025

Sure. I dont think it is needed for other PRs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants