Skip to content

Commit 28cbf8d

Browse files
committed
implement hyp2f1
1 parent f346fd0 commit 28cbf8d

File tree

3 files changed

+340
-0
lines changed

3 files changed

+340
-0
lines changed

jax/_src/scipy/special.py

+335
Original file line numberDiff line numberDiff line change
@@ -2637,6 +2637,341 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
26372637
)
26382638

26392639

2640+
def _binom(n, k):
2641+
a = lax.lgamma(n + 1.0)
2642+
b = lax.lgamma(n - k + 1.0)
2643+
c = lax.lgamma(k + 1.0)
2644+
2645+
return lax.exp(a - b - c)
2646+
2647+
2648+
def _poch(q, n):
2649+
"""
2650+
`jax.scipy.special.poch` does not allow for non-positive integer q.
2651+
"""
2652+
def body(i, state):
2653+
q, prod = state
2654+
2655+
prod *= q + i
2656+
2657+
return q, prod
2658+
2659+
return lax.cond(
2660+
n == 0,
2661+
lambda: jnp.array(1, dtype=q.dtype),
2662+
lambda: lax.fori_loop(jnp.array(1, dtype=n.dtype), n, body, (q, q))[1]
2663+
)
2664+
2665+
2666+
def _hyp2f1_terminal(a, b, c, x):
2667+
"""
2668+
The Taylor series representation of the 2F1 hypergeometric function
2669+
terminates when either a or b is a non-positive integer. See Eq. 4.1 and
2670+
Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
2671+
https://doi.org/10.48550/arXiv.1407.7786
2672+
"""
2673+
# Ensure that between a and b, the negative integer parameter with the greater
2674+
# absolute value - that still has a magnitude less than the absolute value of
2675+
# c if c is non-positive - is used for the upper limit in the loop.
2676+
temp = a
2677+
a = jnp.where(
2678+
jnp.logical_and(
2679+
b < a,
2680+
jnp.logical_and(
2681+
b % 1 == 0,
2682+
jnp.logical_not(
2683+
jnp.logical_and(
2684+
c % 1 == 0,
2685+
jnp.logical_and(
2686+
c <= 0,
2687+
c > b
2688+
)
2689+
)
2690+
)
2691+
)
2692+
), b, a
2693+
)
2694+
b = jnp.where(
2695+
jnp.logical_and(
2696+
b < temp,
2697+
jnp.logical_and(
2698+
b % 1 == 0,
2699+
jnp.logical_not(
2700+
jnp.logical_and(
2701+
c % 1 == 0,
2702+
jnp.logical_and(
2703+
c <= 0,
2704+
c > b
2705+
)
2706+
)
2707+
)
2708+
)
2709+
), temp, b
2710+
)
2711+
2712+
def body(i, sum):
2713+
sum += (-1) ** i * _binom(jnp.abs(a), i) / _poch(c, i) * _poch(b, i) * x ** i
2714+
2715+
return sum
2716+
2717+
return lax.fori_loop(jnp.array(0, dtype=a.dtype),
2718+
jnp.abs(a) + 1,
2719+
body,
2720+
jnp.array(0, dtype=x.dtype))
2721+
2722+
2723+
def _hyp2f1_serie(a, b, c, x):
2724+
"""
2725+
Compute the 2F1 hypergeometric function using the Taylor expansion.
2726+
See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
2727+
https://doi.org/10.48550/arXiv.1407.7786
2728+
"""
2729+
precision = jnp.finfo(jnp.float32).eps
2730+
2731+
s = 1 - x
2732+
2733+
neg_int_a = jnp.logical_and(a <= 0, a % 1 == 0)
2734+
neg_int_b = jnp.logical_and(b <= 0, b % 1 == 0)
2735+
neg_int_c = jnp.logical_and(c <= 0, c % 1 == 0)
2736+
2737+
def body(state):
2738+
serie, k, term = state
2739+
serie += term
2740+
term = _poch(a, k) / _poch(c, k) * _poch(b, k) / factorial(k) * x ** k
2741+
k += 1
2742+
2743+
return serie, k, term
2744+
2745+
def cond(state):
2746+
serie, k, term = state
2747+
2748+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
2749+
2750+
init = (jnp.array(0, dtype=x.dtype),
2751+
jnp.array(1, dtype=x.dtype),
2752+
jnp.array(1, dtype=x.dtype))
2753+
2754+
return lax.while_loop(cond, body, init)[0]
2755+
2756+
2757+
def _hyp2f1_terminal_or_serie(a, b, c, x):
2758+
"""
2759+
Check for recurrence relations along with whether or not the series
2760+
terminates. True recursion is not possible; however, the recurrence
2761+
relation may still be approximated.
2762+
See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
2763+
https://doi.org/10.48550/arXiv.1407.7786
2764+
"""
2765+
neg_int_a = jnp.logical_and(a <= 0, a % 1 == 0)
2766+
neg_int_b = jnp.logical_and(b <= 0, b % 1 == 0)
2767+
neg_int_c = jnp.logical_and(c <= 0, c % 1 == 0)
2768+
neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b)
2769+
not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b)
2770+
2771+
s = 1 - x
2772+
d = c - a - b
2773+
2774+
index = jnp.where(
2775+
jnp.logical_and(
2776+
neg_int_c,
2777+
jnp.logical_and(
2778+
jnp.logical_not(jnp.logical_and(neg_int_a, a > c)),
2779+
jnp.logical_not(jnp.logical_and(neg_int_b, b > c))
2780+
)
2781+
), 0,
2782+
jnp.where(jnp.logical_and(x < -0.5, not_neg_int_a_or_b),
2783+
jnp.where(b > a, 1, 2),
2784+
jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b),
2785+
jnp.where(d % 1 != 0, 3, 4),
2786+
jnp.where(jnp.logical_and(jnp.logical_not(neg_int_c), neg_int_a_or_b), 5, 3))))
2787+
2788+
return lax.select_n(index,
2789+
jnp.array(jnp.inf, dtype=x.dtype),
2790+
s ** (-a) * _hyp2f1_serie(a, c - b, c, -x / s),
2791+
s ** (-b) * _hyp2f1_serie(c - a, b, c, -x / s),
2792+
_hyp2f1_serie(a, b, c, x),
2793+
_hyp2f1_digamma_transform(a, b, c, x),
2794+
_hyp2f1_terminal(a, b, c, x))
2795+
2796+
2797+
def _hyp2f1_gamma_transform(a, b, c, x):
2798+
"""
2799+
Gamma transformations of the 2F1 hypergeometric function.
2800+
"""
2801+
2802+
def transform_1():
2803+
"""
2804+
See Eq. 4.10 and Analytic Continuation Formulas from PEARSON, OLVER & PORTER 2014
2805+
https://doi.org/10.48550/arXiv.1407.7786
2806+
"""
2807+
p = _hyp2f1_serie(a, 1 - c + a, 1 - b + a, 1 / x)
2808+
q = _hyp2f1_serie(b, 1 - c + b, 1 - a + b, 1 / x)
2809+
p *= (-x) ** (-a)
2810+
q *= (-x) ** (-b)
2811+
t1 = gamma(c)
2812+
s = t1 * gamma(b - a) / (gamma(b) * gamma(c - a))
2813+
y = t1 * gamma(a - b) / (gamma(a) * gamma(c - b))
2814+
2815+
return s * p + y * q
2816+
2817+
def transform_2():
2818+
"""
2819+
See 4.1 Properties of F from PEARSON, OLVER & PORTER 2014
2820+
https://doi.org/10.48550/arXiv.1407.7786
2821+
"""
2822+
return gamma(c) * gamma(c - a - b) / (gamma(c - a) * gamma(c - b))
2823+
2824+
return jnp.where(
2825+
x < -2,
2826+
transform_1(),
2827+
transform_2()
2828+
)
2829+
2830+
2831+
def _hyp2f1_digamma_transform(a, b, c, x):
2832+
"""
2833+
Digamma transformation of the 2F1 hypergeometric function.
2834+
See AMS55 #15.3.10, #15.3.11, #15.3.12
2835+
"""
2836+
precision = jnp.finfo(jnp.float32).eps
2837+
2838+
d = c - a - b
2839+
s = 1 - x
2840+
id = jnp.round(d)
2841+
2842+
e = jnp.where(id >= 0, d, -d)
2843+
d1 = jnp.where(id >= 0, d, jnp.array(0, dtype=d.dtype))
2844+
d2 = jnp.where(id >= 0, jnp.array(0, dtype=d.dtype), d)
2845+
aid = jnp.where(id >= 0, id, -id).astype('int32')
2846+
2847+
ax = jnp.log(s)
2848+
2849+
y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax
2850+
y /= gamma(e + 1.0)
2851+
2852+
p = (a + d1) * (b + d1) * s / gamma(e + 2.0)
2853+
2854+
def cond(state):
2855+
_, _, _, _, _, _, q, _, _, t, y = state
2856+
2857+
return jnp.logical_and(
2858+
t < 250,
2859+
jnp.logical_or(y == 0, jnp.abs(q / y) > precision)
2860+
)
2861+
2862+
def body(state):
2863+
a, ax, b, d1, e, p, q, r, s, t, y = state
2864+
2865+
r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \
2866+
- digamma(b + t + d1) - ax
2867+
q = p * r
2868+
y += q
2869+
p *= s * (a + t + d1) / (t + 1.0)
2870+
p *= (b + t + d1) / (t + 1.0 + e)
2871+
t += 1.0
2872+
2873+
return a, ax, b, d1, e, p, q, r, s, t, y
2874+
2875+
init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s,
2876+
jnp.array(1, dtype=x.dtype), y)
2877+
_, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init)
2878+
2879+
def compute_sum(y):
2880+
y1 = jnp.array(1, dtype=x.dtype)
2881+
t = jnp.array(0, dtype=x.dtype)
2882+
p = jnp.array(1, dtype=x.dtype)
2883+
2884+
def for_body(i, state):
2885+
a, b, d2, e, p, s, t, y1 = state
2886+
2887+
r = 1.0 - e + t
2888+
p *= s * (a + t + d2) * (b + t + d2) / r
2889+
t += 1.0
2890+
p /= t
2891+
y1 += p
2892+
2893+
return a, b, d2, e, p, s, t, y1
2894+
2895+
init_val = a, b, d2, e, p, s, t, y1
2896+
y1 = lax.fori_loop(1, aid, for_body, init_val)[-1]
2897+
2898+
p = gamma(c)
2899+
y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1))
2900+
y *= p / (gamma(a + d2) * gamma(b + d2))
2901+
2902+
y = jnp.where((aid & 1) != 0, -y, y)
2903+
q = s ** id
2904+
2905+
return jnp.where(id > 0, y * q + y1, y + y1 * q)
2906+
2907+
return jnp.where(
2908+
id == 0,
2909+
y * gamma(c) / (gamma(a) * gamma(b)),
2910+
compute_sum(y)
2911+
)
2912+
2913+
2914+
@jit
2915+
@jnp.vectorize
2916+
def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array:
2917+
r"""The 2F1 hypergeometric function.
2918+
2919+
JAX implementation of :obj:`scipy.special.hyp2f1`.
2920+
2921+
.. math::
2922+
2923+
\mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}
2924+
2925+
where :math:`(\cdot)_k` is the Pochammer symbol.
2926+
2927+
The JAX version only accepts positive and real inputs. Values of
2928+
``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
2929+
lead to erroneous results; consider enabling double precision in this case.
2930+
2931+
Args:
2932+
a: arraylike, real-valued
2933+
b: arraylike, real-valued
2934+
c: arraylike, real-valued
2935+
x: arraylike, real-valued
2936+
2937+
Returns:
2938+
array of 2F1 values.
2939+
"""
2940+
# This is backed by https://doi.org/10.48550/arXiv.1407.7786
2941+
a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x)
2942+
2943+
d = c - a - b
2944+
s = 1 - x
2945+
2946+
neg_int_ca = jnp.logical_and(c - a <= 0, (c - a) % 1 == 0)
2947+
neg_int_cb = jnp.logical_and(c - b <= 0, (c - b) % 1 == 0)
2948+
neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb)
2949+
2950+
index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0,
2951+
jnp.where(c == 0, 2,
2952+
jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(d % 1 != 0, s < 0))), 1,
2953+
jnp.where(jnp.logical_and(d <= 0, x == 1), 2,
2954+
jnp.where(jnp.logical_and(x < 1, b == c), 3,
2955+
jnp.where(jnp.logical_and(x < 1, a == c), 4,
2956+
jnp.where(x > 1, 2,
2957+
jnp.where(x == 1,
2958+
jnp.where(neg_int_ca_or_cb,
2959+
jnp.where(d >= 0, 5, 2),
2960+
jnp.where(d <= 0, 2, 6)),
2961+
jnp.where(d < 0, 7,
2962+
jnp.where(neg_int_ca_or_cb, 5, 7))))))))))
2963+
2964+
return lax.select_n(index,
2965+
jnp.array(1, dtype=x.dtype),
2966+
s ** d * _hyp2f1_terminal_or_serie(c - a, c - b, c, x),
2967+
jnp.array(jnp.inf, dtype=x.dtype),
2968+
s ** (-a),
2969+
s ** (-b),
2970+
s ** d * _hyp2f1_serie(c - a, c - b, c, x),
2971+
_hyp2f1_gamma_transform(a, b, c, x),
2972+
_hyp2f1_terminal_or_serie(a, b, c, x))
2973+
2974+
26402975
def softmax(x: ArrayLike,
26412976
/,
26422977
*,

jax/scipy/special.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
gammaln as gammaln,
3838
gammasgn as gammasgn,
3939
hyp1f1 as hyp1f1,
40+
hyp2f1 as hyp2f1,
4041
i0 as i0,
4142
i0e as i0e,
4243
i1 as i1,

tests/lax_scipy_special_functions_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
157157
"hyp1f1", 3, float_dtypes,
158158
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
159159
),
160+
op_record(
161+
"hyp2f1", 4, float_dtypes,
162+
functools.partial(jtu.rand_uniform, low=0.5, high=30), False
163+
),
160164
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
161165
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
162166
]

0 commit comments

Comments
 (0)