Skip to content

Commit eb35a36

Browse files
committed
implement hyp2f1
1 parent f346fd0 commit eb35a36

File tree

3 files changed

+272
-0
lines changed

3 files changed

+272
-0
lines changed

jax/_src/scipy/special.py

+254
Original file line numberDiff line numberDiff line change
@@ -2637,6 +2637,260 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
26372637
)
26382638

26392639

2640+
def _hyp2f1_terminal(a, b, c, x):
2641+
"""
2642+
The Taylor series representation of the 2F1 hypergeometric function
2643+
terminates when either a or b is a non-positive integer. See Eq. 4.1 and
2644+
Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
2645+
https://doi.org/10.48550/arXiv.1407.7786
2646+
"""
2647+
# Ensure that between a and b, the negative integer parameter with the greater
2648+
# absolute value - that still has a magnitude less than the absolute value of
2649+
# c if c is non-positive - is used for the upper limit in the loop.
2650+
eps = jnp.finfo(x.dtype).eps * 50
2651+
ib = jnp.round(b)
2652+
mask = jnp.logical_and(
2653+
b < a,
2654+
jnp.logical_and(
2655+
jnp.abs(b - ib) < eps,
2656+
jnp.logical_not(
2657+
jnp.logical_and(
2658+
c % 1 == 0,
2659+
jnp.logical_and(
2660+
c <= 0,
2661+
c > b
2662+
)
2663+
)
2664+
)
2665+
)
2666+
)
2667+
orig_a = a
2668+
a = jnp.where(mask, b, a)
2669+
b = jnp.where(mask, orig_a, b)
2670+
2671+
a = jnp.abs(a)
2672+
2673+
def body(i, state):
2674+
serie, term = state
2675+
2676+
term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x
2677+
serie += term
2678+
2679+
return serie, term
2680+
2681+
init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype))
2682+
2683+
return lax.fori_loop(jnp.array(1, dtype=a.dtype),
2684+
a + 1,
2685+
body,
2686+
init)[0]
2687+
2688+
2689+
def _hyp2f1_serie(a, b, c, x):
2690+
"""
2691+
Compute the 2F1 hypergeometric function using the Taylor expansion.
2692+
See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
2693+
https://doi.org/10.48550/arXiv.1407.7786
2694+
"""
2695+
rtol = jnp.finfo(x.dtype).eps
2696+
2697+
def body(state):
2698+
serie, k, term = state
2699+
2700+
serie += term
2701+
term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x
2702+
k += 1
2703+
2704+
return serie, k, term
2705+
2706+
def cond(state):
2707+
serie, k, term = state
2708+
2709+
return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie))
2710+
2711+
init = (jnp.array(0, dtype=x.dtype),
2712+
jnp.array(1, dtype=x.dtype),
2713+
jnp.array(1, dtype=x.dtype))
2714+
2715+
return lax.while_loop(cond, body, init)[0]
2716+
2717+
2718+
def _hyp2f1_terminal_or_serie(a, b, c, x):
2719+
"""
2720+
Check for recurrence relations along with whether or not the series
2721+
terminates. True recursion is not possible; however, the recurrence
2722+
relation may still be approximated.
2723+
See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
2724+
https://doi.org/10.48550/arXiv.1407.7786
2725+
"""
2726+
eps = jnp.finfo(x.dtype).eps * 50
2727+
2728+
d = c - a - b
2729+
2730+
ia = jnp.round(a)
2731+
ib = jnp.round(b)
2732+
id = jnp.round(d)
2733+
2734+
neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps)
2735+
neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps)
2736+
neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b)
2737+
not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b)
2738+
2739+
index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b),
2740+
jnp.where(jnp.abs(d - id) >= eps, 0, 1),
2741+
jnp.where(neg_int_a_or_b, 2, 0))
2742+
2743+
return lax.select_n(index,
2744+
_hyp2f1_serie(a, b, c, x),
2745+
_hyp2f1_digamma_transform(a, b, c, x),
2746+
_hyp2f1_terminal(a, b, c, x))
2747+
2748+
2749+
def _hyp2f1_digamma_transform(a, b, c, x):
2750+
"""
2751+
Digamma transformation of the 2F1 hypergeometric function.
2752+
See AMS55 #15.3.10, #15.3.11, #15.3.12
2753+
"""
2754+
rtol = jnp.finfo(x.dtype).eps
2755+
2756+
d = c - a - b
2757+
s = 1 - x
2758+
rd = jnp.round(d)
2759+
2760+
e = jnp.where(rd >= 0, d, -d)
2761+
d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype))
2762+
d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d)
2763+
ard = jnp.where(rd >= 0, rd, -rd).astype('int32')
2764+
2765+
ax = jnp.log(s)
2766+
2767+
y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax
2768+
y /= gamma(e + 1.0)
2769+
2770+
p = (a + d1) * (b + d1) * s / gamma(e + 2.0)
2771+
2772+
def cond(state):
2773+
_, _, _, _, _, _, q, _, _, t, y = state
2774+
2775+
return jnp.logical_and(
2776+
t < 250,
2777+
jnp.abs(q) >= rtol * jnp.abs(y)
2778+
)
2779+
2780+
def body(state):
2781+
a, ax, b, d1, e, p, q, r, s, t, y = state
2782+
2783+
r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \
2784+
- digamma(b + t + d1) - ax
2785+
q = p * r
2786+
y += q
2787+
p *= s * (a + t + d1) / (t + 1.0)
2788+
p *= (b + t + d1) / (t + 1.0 + e)
2789+
t += 1.0
2790+
2791+
return a, ax, b, d1, e, p, q, r, s, t, y
2792+
2793+
init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s,
2794+
jnp.array(1, dtype=x.dtype), y)
2795+
_, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init)
2796+
2797+
def compute_sum(y):
2798+
y1 = jnp.array(1, dtype=x.dtype)
2799+
t = jnp.array(0, dtype=x.dtype)
2800+
p = jnp.array(1, dtype=x.dtype)
2801+
2802+
def for_body(i, state):
2803+
a, b, d2, e, p, s, t, y1 = state
2804+
2805+
r = 1.0 - e + t
2806+
p *= s * (a + t + d2) * (b + t + d2) / r
2807+
t += 1.0
2808+
p /= t
2809+
y1 += p
2810+
2811+
return a, b, d2, e, p, s, t, y1
2812+
2813+
init_val = a, b, d2, e, p, s, t, y1
2814+
y1 = lax.fori_loop(1, ard, for_body, init_val)[-1]
2815+
2816+
p = gamma(c)
2817+
y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1))
2818+
y *= p / (gamma(a + d2) * gamma(b + d2))
2819+
2820+
y = jnp.where((ard & 1) != 0, -y, y)
2821+
q = s ** rd
2822+
2823+
return jnp.where(rd > 0, y * q + y1, y + y1 * q)
2824+
2825+
return jnp.where(
2826+
rd == 0,
2827+
y * gamma(c) / (gamma(a) * gamma(b)),
2828+
compute_sum(y)
2829+
)
2830+
2831+
2832+
@jit
2833+
@jnp.vectorize
2834+
def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array:
2835+
r"""The 2F1 hypergeometric function.
2836+
2837+
JAX implementation of :obj:`scipy.special.hyp2f1`.
2838+
2839+
.. math::
2840+
2841+
\mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}
2842+
2843+
where :math:`(\cdot)_k` is the Pochammer symbol.
2844+
2845+
The JAX version only accepts positive and real inputs. Values of
2846+
``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
2847+
lead to erroneous results; consider enabling double precision in this case.
2848+
2849+
Args:
2850+
a: arraylike, real-valued
2851+
b: arraylike, real-valued
2852+
c: arraylike, real-valued
2853+
x: arraylike, real-valued
2854+
2855+
Returns:
2856+
array of 2F1 values.
2857+
"""
2858+
# This is backed by https://doi.org/10.48550/arXiv.1407.7786
2859+
a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x)
2860+
eps = jnp.finfo(x.dtype).eps * 50
2861+
2862+
d = c - a - b
2863+
s = 1 - x
2864+
ca = c - a
2865+
cb = c - b
2866+
2867+
id = jnp.round(d)
2868+
ica = jnp.round(ca)
2869+
icb = jnp.round(cb)
2870+
2871+
neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps)
2872+
neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps)
2873+
neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb)
2874+
2875+
index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0,
2876+
jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1,
2877+
jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2,
2878+
jnp.where(jnp.logical_and(d <= 0, x == 1), 1,
2879+
jnp.where(jnp.logical_and(x < 1, b == c), 3,
2880+
jnp.where(jnp.logical_and(x < 1, a == c), 4,
2881+
jnp.where(x > 1, 1,
2882+
jnp.where(x == 1, 5, 6))))))))
2883+
2884+
return lax.select_n(index,
2885+
jnp.array(1, dtype=x.dtype),
2886+
jnp.array(jnp.inf, dtype=x.dtype),
2887+
s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x),
2888+
s ** (-a),
2889+
s ** (-b),
2890+
gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)),
2891+
_hyp2f1_terminal_or_serie(a, b, c, x))
2892+
2893+
26402894
def softmax(x: ArrayLike,
26412895
/,
26422896
*,

jax/scipy/special.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
gammaln as gammaln,
3838
gammasgn as gammasgn,
3939
hyp1f1 as hyp1f1,
40+
hyp2f1 as hyp2f1,
4041
i0 as i0,
4142
i0e as i0e,
4243
i1 as i1,

tests/lax_scipy_special_functions_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
157157
"hyp1f1", 3, float_dtypes,
158158
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
159159
),
160+
op_record(
161+
"hyp2f1", 4, float_dtypes,
162+
functools.partial(jtu.rand_uniform, low=0.5, high=30), False
163+
),
160164
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
161165
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
162166
]
@@ -354,5 +358,18 @@ def testBetaIncBoundaryValues(self):
354358
self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol)
355359
self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol)
356360

361+
def testHyp2f1SpecialCases(self):
362+
dtype = jax.dtypes.canonicalize_dtype(float)
363+
364+
a_samples = np.array([0, 1, 1, 1, 1, 5, 5, 0.245, 0.45, 0.45, 2, 0.4, 0.32, 4, 4], dtype=dtype)
365+
b_samples = np.array([1, 0, 1, 1, 1, 1, 1, 3, 0.7, 0.7, 1, 0.7, 0.76, 2, 3], dtype=dtype)
366+
c_samples = np.array([1, 1, 0, 1, -1, 3, 3, 3, 0.45, 0.45, 5, 0.3, 0.11, 7, 7], dtype=dtype)
367+
x_samples = np.array([1, 1, 1, 0, 1, 0.5, 1, 0.35, 0.35, 1.5, 1, 0.4, 0.95, 0.95, 0.95], dtype=dtype)
368+
369+
args_maker = lambda: (a_samples, b_samples, c_samples, x_samples)
370+
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5
371+
self._CheckAgainstNumpy(osp_special.hyp2f1, lsp_special.hyp2f1, args_maker, rtol=rtol)
372+
self._CompileAndCheck(lsp_special.hyp2f1, args_maker, rtol=rtol)
373+
357374
if __name__ == "__main__":
358375
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)