Skip to content

Implement Gauss hypergeometric function 2F1 #28168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 254 additions & 0 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -2637,6 +2637,260 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
)


def _hyp2f1_terminal(a, b, c, x):
"""
The Taylor series representation of the 2F1 hypergeometric function
terminates when either a or b is a non-positive integer. See Eq. 4.1 and
Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
https://doi.org/10.48550/arXiv.1407.7786
"""
# Ensure that between a and b, the negative integer parameter with the greater
# absolute value - that still has a magnitude less than the absolute value of
# c if c is non-positive - is used for the upper limit in the loop.
eps = jnp.finfo(x.dtype).eps * 50
ib = jnp.round(b)
mask = jnp.logical_and(
b < a,
jnp.logical_and(
jnp.abs(b - ib) < eps,
jnp.logical_not(
jnp.logical_and(
c % 1 == 0,
jnp.logical_and(
c <= 0,
c > b
)
)
)
)
)
orig_a = a
a = jnp.where(mask, b, a)
b = jnp.where(mask, orig_a, b)

a = jnp.abs(a)

def body(i, state):
serie, term = state

term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x
serie += term

return serie, term

init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype))

return lax.fori_loop(jnp.array(1, dtype=a.dtype),
a + 1,
body,
init)[0]


def _hyp2f1_serie(a, b, c, x):
"""
Compute the 2F1 hypergeometric function using the Taylor expansion.
See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
https://doi.org/10.48550/arXiv.1407.7786
"""
rtol = jnp.finfo(x.dtype).eps

def body(state):
serie, k, term = state

serie += term
term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x
k += 1

return serie, k, term

def cond(state):
serie, k, term = state

return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie))

init = (jnp.array(0, dtype=x.dtype),
jnp.array(1, dtype=x.dtype),
jnp.array(1, dtype=x.dtype))

return lax.while_loop(cond, body, init)[0]


def _hyp2f1_terminal_or_serie(a, b, c, x):
"""
Check for recurrence relations along with whether or not the series
terminates. True recursion is not possible; however, the recurrence
relation may still be approximated.
See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
https://doi.org/10.48550/arXiv.1407.7786
"""
eps = jnp.finfo(x.dtype).eps * 50

d = c - a - b

ia = jnp.round(a)
ib = jnp.round(b)
id = jnp.round(d)

neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps)
neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps)
neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b)
not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b)

index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b),
jnp.where(jnp.abs(d - id) >= eps, 0, 1),
jnp.where(neg_int_a_or_b, 2, 0))

return lax.select_n(index,
_hyp2f1_serie(a, b, c, x),
_hyp2f1_digamma_transform(a, b, c, x),
_hyp2f1_terminal(a, b, c, x))


def _hyp2f1_digamma_transform(a, b, c, x):
"""
Digamma transformation of the 2F1 hypergeometric function.
See AMS55 #15.3.10, #15.3.11, #15.3.12
"""
rtol = jnp.finfo(x.dtype).eps

d = c - a - b
s = 1 - x
rd = jnp.round(d)

e = jnp.where(rd >= 0, d, -d)
d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype))
d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d)
ard = jnp.where(rd >= 0, rd, -rd).astype('int32')

ax = jnp.log(s)

y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax
y /= gamma(e + 1.0)

p = (a + d1) * (b + d1) * s / gamma(e + 2.0)

def cond(state):
_, _, _, _, _, _, q, _, _, t, y = state

return jnp.logical_and(
t < 250,
jnp.abs(q) >= rtol * jnp.abs(y)
)

def body(state):
a, ax, b, d1, e, p, q, r, s, t, y = state

r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \
- digamma(b + t + d1) - ax
q = p * r
y += q
p *= s * (a + t + d1) / (t + 1.0)
p *= (b + t + d1) / (t + 1.0 + e)
t += 1.0

return a, ax, b, d1, e, p, q, r, s, t, y

init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s,
jnp.array(1, dtype=x.dtype), y)
_, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init)

def compute_sum(y):
y1 = jnp.array(1, dtype=x.dtype)
t = jnp.array(0, dtype=x.dtype)
p = jnp.array(1, dtype=x.dtype)

def for_body(i, state):
a, b, d2, e, p, s, t, y1 = state

r = 1.0 - e + t
p *= s * (a + t + d2) * (b + t + d2) / r
t += 1.0
p /= t
y1 += p

return a, b, d2, e, p, s, t, y1

init_val = a, b, d2, e, p, s, t, y1
y1 = lax.fori_loop(1, ard, for_body, init_val)[-1]

p = gamma(c)
y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1))
y *= p / (gamma(a + d2) * gamma(b + d2))

y = jnp.where((ard & 1) != 0, -y, y)
q = s ** rd

return jnp.where(rd > 0, y * q + y1, y + y1 * q)

return jnp.where(
rd == 0,
y * gamma(c) / (gamma(a) * gamma(b)),
compute_sum(y)
)


@jit
@jnp.vectorize
def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array:
r"""The 2F1 hypergeometric function.

JAX implementation of :obj:`scipy.special.hyp2f1`.

.. math::

\mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}

where :math:`(\cdot)_k` is the Pochammer symbol.

The JAX version only accepts positive and real inputs. Values of
``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
lead to erroneous results; consider enabling double precision in this case.

Args:
a: arraylike, real-valued
b: arraylike, real-valued
c: arraylike, real-valued
x: arraylike, real-valued

Returns:
array of 2F1 values.
"""
# This is backed by https://doi.org/10.48550/arXiv.1407.7786
a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x)
eps = jnp.finfo(x.dtype).eps * 50

d = c - a - b
s = 1 - x
ca = c - a
cb = c - b

id = jnp.round(d)
ica = jnp.round(ca)
icb = jnp.round(cb)

neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps)
neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps)
neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb)

index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0,
jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1,
jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2,
jnp.where(jnp.logical_and(d <= 0, x == 1), 1,
jnp.where(jnp.logical_and(x < 1, b == c), 3,
jnp.where(jnp.logical_and(x < 1, a == c), 4,
jnp.where(x > 1, 1,
jnp.where(x == 1, 5, 6))))))))

return lax.select_n(index,
jnp.array(1, dtype=x.dtype),
jnp.array(jnp.inf, dtype=x.dtype),
s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x),
s ** (-a),
s ** (-b),
gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)),
_hyp2f1_terminal_or_serie(a, b, c, x))


def softmax(x: ArrayLike,
/,
*,
Expand Down
1 change: 1 addition & 0 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
gammaln as gammaln,
gammasgn as gammasgn,
hyp1f1 as hyp1f1,
hyp2f1 as hyp2f1,
i0 as i0,
i0e as i0e,
i1 as i1,
Expand Down
17 changes: 17 additions & 0 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
"hyp1f1", 3, float_dtypes,
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
),
op_record(
"hyp2f1", 4, float_dtypes,
functools.partial(jtu.rand_uniform, low=0.5, high=30), False
),
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
]
Expand Down Expand Up @@ -354,5 +358,18 @@ def testBetaIncBoundaryValues(self):
self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol)

def testHyp2f1SpecialCases(self):
dtype = jax.dtypes.canonicalize_dtype(float)

a_samples = np.array([0, 1, 1, 1, 1, 5, 5, 0.245, 0.45, 0.45, 2, 0.4, 0.32, 4, 4], dtype=dtype)
b_samples = np.array([1, 0, 1, 1, 1, 1, 1, 3, 0.7, 0.7, 1, 0.7, 0.76, 2, 3], dtype=dtype)
c_samples = np.array([1, 1, 0, 1, -1, 3, 3, 3, 0.45, 0.45, 5, 0.3, 0.11, 7, 7], dtype=dtype)
x_samples = np.array([1, 1, 1, 0, 1, 0.5, 1, 0.35, 0.35, 1.5, 1, 0.4, 0.95, 0.95, 0.95], dtype=dtype)

args_maker = lambda: (a_samples, b_samples, c_samples, x_samples)
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5
self._CheckAgainstNumpy(osp_special.hyp2f1, lsp_special.hyp2f1, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.hyp2f1, args_maker, rtol=rtol)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())