-
Notifications
You must be signed in to change notification settings - Fork 3k
Implement Gauss hypergeometric function 2F1 #28168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @mattbahr for the PR!
Here follows a review of the implemented feature.
I tested the hyp2f1 implementation against hyp2f1 in scipy and mpmath for
a, b, c = 2, 1, 4
0 <= x <= 1
while confirming that scipy and mpmath results match more or less well for float64:
float64:
x range | dULP=0 | dULP=1 | dULP=2 | dULP=3..9 | dULP=10..1000 | dULP=1001..1000000 | dULP > 1000000
-----------------+----------+----------+----------+-------------+-----------------+----------------------+-----------------
0..tiny | 2 | - | - | - | - | - | -
tiny..eps**2 | 10000 | - | - | - | - | - | -
eps**2..eps | 9999 | 1 | - | - | - | - | -
eps..sqrt(eps) | 9911 | 89 | - | - | - | - | -
sqrt(eps)..1 | 5800 | 3494 | 442 | 251 | 13 | - | -
(scipy hyp2f1 on float32 inputs appears to use float64 implementation as max dULP is 0).
Comparing JAX hyp2f1 vs mpath hyp2f1, I noticed that there exists inputs that lead to large errors when x > sqrt(eps)
:
float64:
x range | dULP=0 | dULP=1 | dULP=2 | dULP=3..9 | dULP=10..1000 | dULP=1001..1000000 | dULP > 1000000
-----------------+----------+----------+----------+-------------+-----------------+----------------------+-----------------
0..tiny | 2 | - | - | - | - | - | -
tiny..eps**2 | 10000 | - | - | - | - | - | -
eps**2..eps | 9999 | 1 | - | - | - | - | -
eps..sqrt(eps) | - | 577 | 289 | 745 | 2602 | 3828 | 1959
sqrt(eps)..1 | - | - | - | - | 696 | 2234 | 7070
float32:
x range | dULP=0 | dULP=1 | dULP=2 | dULP=3..9 | dULP=10..1000 | dULP=1001..1000000 | dULP > 1000000
-----------------+----------+----------+----------+-------------+-----------------+----------------------+-----------------
0..tiny | 2 | - | - | - | - | - | -
tiny..eps**2 | 10000 | - | - | - | - | - | -
eps**2..eps | 9999 | 1 | - | - | - | - | -
eps..sqrt(eps) | 8942 | 1058 | - | - | - | - | -
sqrt(eps)..1 | 4870 | 3664 | 546 | 464 | 380 | 76 | -
The plots of the differences between JAX and reference (mpmath) implementations confirm that the largest error is around x = 0.9
:
Can you explain these discrepancies between JAX and scipy implementations of hyp2f1?
Also notice that when XF64 is enabled, calling JAX hyp2f1 with float32 inputs will fail:
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update("jax_enable_x64", True)
>>> jax.scipy.special.hyp2f1(2, 1, 4, jnp.float64(0.9))
Array(2.17894023, dtype=float64)
>>> jax.scipy.special.hyp2f1(2, 1, 4, jnp.float32(0.9))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/pearu/git/mattbahr/jax/jax/_src/numpy/vectorize.py", line 346, in wrapped
result = vectorized_func(*squeezed_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pearu/git/mattbahr/jax/jax/_src/numpy/vectorize.py", line 144, in wrapped
out = func(*args)
^^^^^^^^^^^
File "/home/pearu/git/mattbahr/jax/jax/_src/scipy/special.py", line 2960, in hyp2f1
s ** d * _hyp2f1_terminal_or_serie(c - a, c - b, c, x),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pearu/git/mattbahr/jax/jax/_src/scipy/special.py", line 2785, in _hyp2f1_terminal_or_serie
s ** (-a) * _hyp2f1_serie(a, c - b, c, -x / s),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pearu/git/mattbahr/jax/jax/_src/scipy/special.py", line 2749, in _hyp2f1_serie
return lax.while_loop(cond, body, init)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: while_loop body function carry input and carry output must have equal types, but they differ:
The input carry component state[0] has type float64[] but the corresponding output carry component has type float32[], so the dtypes do not match.
Revise the function so that all output types match the corresponding input types.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
47a16ec
to
28cbf8d
Compare
Hi @pearu thanks for the review! It looks like I needed to be a little stricter with the typing around the loops. The issue you found with x64 enabled should be resolved now! |
28cbf8d
to
aa1a521
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the float32/float64 conflict!
Here follows a review of the implementation. I have number of nits and a suggestion to improve the accuracy of hyp2f1 to the level of scipy hyp2f1 accuracy.
@pearu Thanks again for the review! I'll try to get these suggestions committed today |
aa1a521
to
d826395
Compare
d826395
to
3e94f1b
Compare
Updated the loop for the terminal case as discussed, added specific test cases for each code branch, and simplified index selection logic as there were some branches should have never hit given that we only support positive inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this PR looks good. Thanks, @mattbahr!
I have a LOC reducing nit and a concern how integer values are detected from floating point values with possible rounding errors.
Sounds good! I'll get these changes in later today. I appreciate the thorough review! |
3e94f1b
to
0ee3b35
Compare
0ee3b35
to
cec2661
Compare
cec2661
to
eb35a36
Compare
@pearu Scaled |
@pearu Will this PR be requiring another review from a maintainer? I am still kind of new to JAX and not totally sure of the process. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @mattbahr!
@jakevdp could you help landing this? |
Addresses #2991
Implementation of the Gauss hypergeometric 2F1 function to match scipy.special.hyp2f1.
Key differences are that some precision is lost on the recurrence relationships, compared with scipy, as recursion is not possible in the jitted function. Precision also takes a hit because 64-bit floating point values are disabled by default.
To simplify the selection logic, and to keep things consistent with jax.scipy.special.hyp1f1, I decided to only include support for positive input values. If it is desired, I can add support for negative inputs easily enough. The sub-functions can technically already handle them.
Logic should be consistent with 1407.7786.