diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index a24736ccfec0..d5c99b6e3b4f 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -2637,6 +2637,260 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: ) +def _hyp2f1_terminal(a, b, c, x): + """ + The Taylor series representation of the 2F1 hypergeometric function + terminates when either a or b is a non-positive integer. See Eq. 4.1 and + Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + # Ensure that between a and b, the negative integer parameter with the greater + # absolute value - that still has a magnitude less than the absolute value of + # c if c is non-positive - is used for the upper limit in the loop. + eps = jnp.finfo(x.dtype).eps * 50 + ib = jnp.round(b) + mask = jnp.logical_and( + b < a, + jnp.logical_and( + jnp.abs(b - ib) < eps, + jnp.logical_not( + jnp.logical_and( + c % 1 == 0, + jnp.logical_and( + c <= 0, + c > b + ) + ) + ) + ) + ) + orig_a = a + a = jnp.where(mask, b, a) + b = jnp.where(mask, orig_a, b) + + a = jnp.abs(a) + + def body(i, state): + serie, term = state + + term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x + serie += term + + return serie, term + + init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype)) + + return lax.fori_loop(jnp.array(1, dtype=a.dtype), + a + 1, + body, + init)[0] + + +def _hyp2f1_serie(a, b, c, x): + """ + Compute the 2F1 hypergeometric function using the Taylor expansion. + See Eq. 4.1 from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + rtol = jnp.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + + serie += term + term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie)) + + init = (jnp.array(0, dtype=x.dtype), + jnp.array(1, dtype=x.dtype), + jnp.array(1, dtype=x.dtype)) + + return lax.while_loop(cond, body, init)[0] + + +def _hyp2f1_terminal_or_serie(a, b, c, x): + """ + Check for recurrence relations along with whether or not the series + terminates. True recursion is not possible; however, the recurrence + relation may still be approximated. + See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + eps = jnp.finfo(x.dtype).eps * 50 + + d = c - a - b + + ia = jnp.round(a) + ib = jnp.round(b) + id = jnp.round(d) + + neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps) + neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps) + neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b) + not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b) + + index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b), + jnp.where(jnp.abs(d - id) >= eps, 0, 1), + jnp.where(neg_int_a_or_b, 2, 0)) + + return lax.select_n(index, + _hyp2f1_serie(a, b, c, x), + _hyp2f1_digamma_transform(a, b, c, x), + _hyp2f1_terminal(a, b, c, x)) + + +def _hyp2f1_digamma_transform(a, b, c, x): + """ + Digamma transformation of the 2F1 hypergeometric function. + See AMS55 #15.3.10, #15.3.11, #15.3.12 + """ + rtol = jnp.finfo(x.dtype).eps + + d = c - a - b + s = 1 - x + rd = jnp.round(d) + + e = jnp.where(rd >= 0, d, -d) + d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype)) + d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d) + ard = jnp.where(rd >= 0, rd, -rd).astype('int32') + + ax = jnp.log(s) + + y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax + y /= gamma(e + 1.0) + + p = (a + d1) * (b + d1) * s / gamma(e + 2.0) + + def cond(state): + _, _, _, _, _, _, q, _, _, t, y = state + + return jnp.logical_and( + t < 250, + jnp.abs(q) >= rtol * jnp.abs(y) + ) + + def body(state): + a, ax, b, d1, e, p, q, r, s, t, y = state + + r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \ + - digamma(b + t + d1) - ax + q = p * r + y += q + p *= s * (a + t + d1) / (t + 1.0) + p *= (b + t + d1) / (t + 1.0 + e) + t += 1.0 + + return a, ax, b, d1, e, p, q, r, s, t, y + + init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s, + jnp.array(1, dtype=x.dtype), y) + _, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init) + + def compute_sum(y): + y1 = jnp.array(1, dtype=x.dtype) + t = jnp.array(0, dtype=x.dtype) + p = jnp.array(1, dtype=x.dtype) + + def for_body(i, state): + a, b, d2, e, p, s, t, y1 = state + + r = 1.0 - e + t + p *= s * (a + t + d2) * (b + t + d2) / r + t += 1.0 + p /= t + y1 += p + + return a, b, d2, e, p, s, t, y1 + + init_val = a, b, d2, e, p, s, t, y1 + y1 = lax.fori_loop(1, ard, for_body, init_val)[-1] + + p = gamma(c) + y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1)) + y *= p / (gamma(a + d2) * gamma(b + d2)) + + y = jnp.where((ard & 1) != 0, -y, y) + q = s ** rd + + return jnp.where(rd > 0, y * q + y1, y + y1 * q) + + return jnp.where( + rd == 0, + y * gamma(c) / (gamma(a) * gamma(b)), + compute_sum(y) + ) + + +@jit +@jnp.vectorize +def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array: + r"""The 2F1 hypergeometric function. + + JAX implementation of :obj:`scipy.special.hyp2f1`. + + .. math:: + + \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!} + + where :math:`(\cdot)_k` is the Pochammer symbol. + + The JAX version only accepts positive and real inputs. Values of + ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may + lead to erroneous results; consider enabling double precision in this case. + + Args: + a: arraylike, real-valued + b: arraylike, real-valued + c: arraylike, real-valued + x: arraylike, real-valued + + Returns: + array of 2F1 values. + """ + # This is backed by https://doi.org/10.48550/arXiv.1407.7786 + a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x) + eps = jnp.finfo(x.dtype).eps * 50 + + d = c - a - b + s = 1 - x + ca = c - a + cb = c - b + + id = jnp.round(d) + ica = jnp.round(ca) + icb = jnp.round(cb) + + neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps) + neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps) + neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb) + + index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0, + jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1, + jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2, + jnp.where(jnp.logical_and(d <= 0, x == 1), 1, + jnp.where(jnp.logical_and(x < 1, b == c), 3, + jnp.where(jnp.logical_and(x < 1, a == c), 4, + jnp.where(x > 1, 1, + jnp.where(x == 1, 5, 6)))))))) + + return lax.select_n(index, + jnp.array(1, dtype=x.dtype), + jnp.array(jnp.inf, dtype=x.dtype), + s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x), + s ** (-a), + s ** (-b), + gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)), + _hyp2f1_terminal_or_serie(a, b, c, x)) + + def softmax(x: ArrayLike, /, *, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 2ffc65a1abe1..e1330d4b6cf3 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -37,6 +37,7 @@ gammaln as gammaln, gammasgn as gammasgn, hyp1f1 as hyp1f1, + hyp2f1 as hyp2f1, i0 as i0, i0e as i0e, i1 as i1, diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 4b3945a84453..e02581626a53 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t "hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True ), + op_record( + "hyp2f1", 4, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.5, high=30), False + ), op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] @@ -354,5 +358,18 @@ def testBetaIncBoundaryValues(self): self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) + def testHyp2f1SpecialCases(self): + dtype = jax.dtypes.canonicalize_dtype(float) + + a_samples = np.array([0, 1, 1, 1, 1, 5, 5, 0.245, 0.45, 0.45, 2, 0.4, 0.32, 4, 4], dtype=dtype) + b_samples = np.array([1, 0, 1, 1, 1, 1, 1, 3, 0.7, 0.7, 1, 0.7, 0.76, 2, 3], dtype=dtype) + c_samples = np.array([1, 1, 0, 1, -1, 3, 3, 3, 0.45, 0.45, 5, 0.3, 0.11, 7, 7], dtype=dtype) + x_samples = np.array([1, 1, 1, 0, 1, 0.5, 1, 0.35, 0.35, 1.5, 1, 0.4, 0.95, 0.95, 0.95], dtype=dtype) + + args_maker = lambda: (a_samples, b_samples, c_samples, x_samples) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.hyp2f1, lsp_special.hyp2f1, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.hyp2f1, args_maker, rtol=rtol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())