Skip to content

Implement Gauss hypergeometric function 2F1 #28168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mattbahr
Copy link
Contributor

Addresses #2991

Implementation of the Gauss hypergeometric 2F1 function to match scipy.special.hyp2f1.

Key differences are that some precision is lost on the recurrence relationships, compared with scipy, as recursion is not possible in the jitted function. Precision also takes a hit because 64-bit floating point values are disabled by default.

To simplify the selection logic, and to keep things consistent with jax.scipy.special.hyp1f1, I decided to only include support for positive input values. If it is desired, I can add support for negative inputs easily enough. The sub-functions can technically already handle them.

Logic should be consistent with 1407.7786.

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @mattbahr for the PR!

Here follows a review of the implemented feature.

I tested the hyp2f1 implementation against hyp2f1 in scipy and mpmath for

a, b, c = 2, 1, 4
0 <= x <= 1

while confirming that scipy and mpmath results match more or less well for float64:

float64:
x range          | dULP=0   | dULP=1   | dULP=2   | dULP=3..9   | dULP=10..1000   | dULP=1001..1000000   | dULP > 1000000  
-----------------+----------+----------+----------+-------------+-----------------+----------------------+-----------------
0..tiny          | 2        | -        | -        | -           | -               | -                    | -               
tiny..eps**2     | 10000    | -        | -        | -           | -               | -                    | -               
eps**2..eps      | 9999     | 1        | -        | -           | -               | -                    | -               
eps..sqrt(eps)   | 9911     | 89       | -        | -           | -               | -                    | -               
sqrt(eps)..1     | 5800     | 3494     | 442      | 251         | 13              | -                    | -               

(scipy hyp2f1 on float32 inputs appears to use float64 implementation as max dULP is 0).

Comparing JAX hyp2f1 vs mpath hyp2f1, I noticed that there exists inputs that lead to large errors when x > sqrt(eps):

float64:
x range          | dULP=0   | dULP=1   | dULP=2   | dULP=3..9   | dULP=10..1000   | dULP=1001..1000000   | dULP > 1000000  
-----------------+----------+----------+----------+-------------+-----------------+----------------------+-----------------
0..tiny          | 2        | -        | -        | -           | -               | -                    | -               
tiny..eps**2     | 10000    | -        | -        | -           | -               | -                    | -               
eps**2..eps      | 9999     | 1        | -        | -           | -               | -                    | -               
eps..sqrt(eps)   | -        | 577      | 289      | 745         | 2602            | 3828                 | 1959            
sqrt(eps)..1     | -        | -        | -        | -           | 696             | 2234                 | 7070            

float32:
x range          | dULP=0   | dULP=1   | dULP=2   | dULP=3..9   | dULP=10..1000   | dULP=1001..1000000   | dULP > 1000000  
-----------------+----------+----------+----------+-------------+-----------------+----------------------+-----------------
0..tiny          | 2        | -        | -        | -           | -               | -                    | -               
tiny..eps**2     | 10000    | -        | -        | -           | -               | -                    | -               
eps**2..eps      | 9999     | 1        | -        | -           | -               | -                    | -               
eps..sqrt(eps)   | 8942     | 1058     | -        | -           | -               | -                    | -               
sqrt(eps)..1     | 4870     | 3664     | 546      | 464         | 380             | 76                   | -               

The plots of the differences between JAX and reference (mpmath) implementations confirm that the largest error is around x = 0.9:
jax_hyp2f1
jax_hyp2f1

Can you explain these discrepancies between JAX and scipy implementations of hyp2f1?

Also notice that when XF64 is enabled, calling JAX hyp2f1 with float32 inputs will fail:

>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update("jax_enable_x64", True)
>>> jax.scipy.special.hyp2f1(2, 1, 4, jnp.float64(0.9))
Array(2.17894023, dtype=float64)
>>> jax.scipy.special.hyp2f1(2, 1, 4, jnp.float32(0.9))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/pearu/git/mattbahr/jax/jax/_src/numpy/vectorize.py", line 346, in wrapped
    result = vectorized_func(*squeezed_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pearu/git/mattbahr/jax/jax/_src/numpy/vectorize.py", line 144, in wrapped
    out = func(*args)
          ^^^^^^^^^^^
  File "/home/pearu/git/mattbahr/jax/jax/_src/scipy/special.py", line 2960, in hyp2f1
    s ** d * _hyp2f1_terminal_or_serie(c - a, c - b, c, x),
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pearu/git/mattbahr/jax/jax/_src/scipy/special.py", line 2785, in _hyp2f1_terminal_or_serie
    s ** (-a) * _hyp2f1_serie(a, c - b, c, -x / s),
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pearu/git/mattbahr/jax/jax/_src/scipy/special.py", line 2749, in _hyp2f1_serie
    return lax.while_loop(cond, body, init)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: while_loop body function carry input and carry output must have equal types, but they differ:

The input carry component state[0] has type float64[] but the corresponding output carry component has type float32[], so the dtypes do not match.

Revise the function so that all output types match the corresponding input types.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@mattbahr
Copy link
Contributor Author

Hi @pearu thanks for the review! It looks like I needed to be a little stricter with the typing around the loops. The issue you found with x64 enabled should be resolved now!

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the float32/float64 conflict!

Here follows a review of the implementation. I have number of nits and a suggestion to improve the accuracy of hyp2f1 to the level of scipy hyp2f1 accuracy.

@mattbahr
Copy link
Contributor Author

@pearu Thanks again for the review! I'll try to get these suggestions committed today

@mattbahr
Copy link
Contributor Author

Updated the loop for the terminal case as discussed, added specific test cases for each code branch, and simplified index selection logic as there were some branches should have never hit given that we only support positive inputs

@mattbahr mattbahr requested a review from pearu April 24, 2025 03:47
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this PR looks good. Thanks, @mattbahr!

I have a LOC reducing nit and a concern how integer values are detected from floating point values with possible rounding errors.

@mattbahr
Copy link
Contributor Author

Overall, this PR looks good. Thanks, @mattbahr!

I have a LOC reducing nit and a concern how integer values are detected from floating point values with possible rounding errors.

Sounds good! I'll get these changes in later today. I appreciate the thorough review!

@mattbahr
Copy link
Contributor Author

@pearu Scaled eps by a factor of 50

@mattbahr
Copy link
Contributor Author

@pearu Will this PR be requiring another review from a maintainer? I am still kind of new to JAX and not totally sure of the process.

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @mattbahr!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 26, 2025
@pearu
Copy link
Collaborator

pearu commented Apr 26, 2025

@jakevdp could you help landing this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants