@@ -2637,6 +2637,260 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
2637
2637
)
2638
2638
2639
2639
2640
+ def _hyp2f1_terminal (a , b , c , x ):
2641
+ """
2642
+ The Taylor series representation of the 2F1 hypergeometric function
2643
+ terminates when either a or b is a non-positive integer. See Eq. 4.1 and
2644
+ Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
2645
+ https://doi.org/10.48550/arXiv.1407.7786
2646
+ """
2647
+ # Ensure that between a and b, the negative integer parameter with the greater
2648
+ # absolute value - that still has a magnitude less than the absolute value of
2649
+ # c if c is non-positive - is used for the upper limit in the loop.
2650
+ eps = jnp .finfo (x .dtype ).eps * 50
2651
+ ib = jnp .round (b )
2652
+ mask = jnp .logical_and (
2653
+ b < a ,
2654
+ jnp .logical_and (
2655
+ jnp .abs (b - ib ) < eps ,
2656
+ jnp .logical_not (
2657
+ jnp .logical_and (
2658
+ c % 1 == 0 ,
2659
+ jnp .logical_and (
2660
+ c <= 0 ,
2661
+ c > b
2662
+ )
2663
+ )
2664
+ )
2665
+ )
2666
+ )
2667
+ orig_a = a
2668
+ a = jnp .where (mask , b , a )
2669
+ b = jnp .where (mask , orig_a , b )
2670
+
2671
+ a = jnp .abs (a )
2672
+
2673
+ def body (i , state ):
2674
+ serie , term = state
2675
+
2676
+ term *= - (a - i + 1 ) / (c + i - 1 ) * (b + i - 1 ) / i * x
2677
+ serie += term
2678
+
2679
+ return serie , term
2680
+
2681
+ init = (jnp .array (1 , dtype = x .dtype ), jnp .array (1 , dtype = x .dtype ))
2682
+
2683
+ return lax .fori_loop (jnp .array (1 , dtype = a .dtype ),
2684
+ a + 1 ,
2685
+ body ,
2686
+ init )[0 ]
2687
+
2688
+
2689
+ def _hyp2f1_serie (a , b , c , x ):
2690
+ """
2691
+ Compute the 2F1 hypergeometric function using the Taylor expansion.
2692
+ See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
2693
+ https://doi.org/10.48550/arXiv.1407.7786
2694
+ """
2695
+ rtol = jnp .finfo (x .dtype ).eps
2696
+
2697
+ def body (state ):
2698
+ serie , k , term = state
2699
+
2700
+ serie += term
2701
+ term *= (a + k - 1 ) * (b + k - 1 ) / (c + k - 1 ) / k * x
2702
+ k += 1
2703
+
2704
+ return serie , k , term
2705
+
2706
+ def cond (state ):
2707
+ serie , k , term = state
2708
+
2709
+ return (k < 250 ) & (lax .abs (term ) > rtol * lax .abs (serie ))
2710
+
2711
+ init = (jnp .array (0 , dtype = x .dtype ),
2712
+ jnp .array (1 , dtype = x .dtype ),
2713
+ jnp .array (1 , dtype = x .dtype ))
2714
+
2715
+ return lax .while_loop (cond , body , init )[0 ]
2716
+
2717
+
2718
+ def _hyp2f1_terminal_or_serie (a , b , c , x ):
2719
+ """
2720
+ Check for recurrence relations along with whether or not the series
2721
+ terminates. True recursion is not possible; however, the recurrence
2722
+ relation may still be approximated.
2723
+ See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
2724
+ https://doi.org/10.48550/arXiv.1407.7786
2725
+ """
2726
+ eps = jnp .finfo (x .dtype ).eps * 50
2727
+
2728
+ d = c - a - b
2729
+
2730
+ ia = jnp .round (a )
2731
+ ib = jnp .round (b )
2732
+ id = jnp .round (d )
2733
+
2734
+ neg_int_a = jnp .logical_and (a <= 0 , jnp .abs (a - ia ) < eps )
2735
+ neg_int_b = jnp .logical_and (b <= 0 , jnp .abs (b - ib ) < eps )
2736
+ neg_int_a_or_b = jnp .logical_or (neg_int_a , neg_int_b )
2737
+ not_neg_int_a_or_b = jnp .logical_not (neg_int_a_or_b )
2738
+
2739
+ index = jnp .where (jnp .logical_and (x > 0.9 , not_neg_int_a_or_b ),
2740
+ jnp .where (jnp .abs (d - id ) >= eps , 0 , 1 ),
2741
+ jnp .where (neg_int_a_or_b , 2 , 0 ))
2742
+
2743
+ return lax .select_n (index ,
2744
+ _hyp2f1_serie (a , b , c , x ),
2745
+ _hyp2f1_digamma_transform (a , b , c , x ),
2746
+ _hyp2f1_terminal (a , b , c , x ))
2747
+
2748
+
2749
+ def _hyp2f1_digamma_transform (a , b , c , x ):
2750
+ """
2751
+ Digamma transformation of the 2F1 hypergeometric function.
2752
+ See AMS55 #15.3.10, #15.3.11, #15.3.12
2753
+ """
2754
+ rtol = jnp .finfo (x .dtype ).eps
2755
+
2756
+ d = c - a - b
2757
+ s = 1 - x
2758
+ rd = jnp .round (d )
2759
+
2760
+ e = jnp .where (rd >= 0 , d , - d )
2761
+ d1 = jnp .where (rd >= 0 , d , jnp .array (0 , dtype = d .dtype ))
2762
+ d2 = jnp .where (rd >= 0 , jnp .array (0 , dtype = d .dtype ), d )
2763
+ ard = jnp .where (rd >= 0 , rd , - rd ).astype ('int32' )
2764
+
2765
+ ax = jnp .log (s )
2766
+
2767
+ y = digamma (1.0 ) + digamma (1.0 + e ) - digamma (a + d1 ) - digamma (b + d1 ) - ax
2768
+ y /= gamma (e + 1.0 )
2769
+
2770
+ p = (a + d1 ) * (b + d1 ) * s / gamma (e + 2.0 )
2771
+
2772
+ def cond (state ):
2773
+ _ , _ , _ , _ , _ , _ , q , _ , _ , t , y = state
2774
+
2775
+ return jnp .logical_and (
2776
+ t < 250 ,
2777
+ jnp .abs (q ) >= rtol * jnp .abs (y )
2778
+ )
2779
+
2780
+ def body (state ):
2781
+ a , ax , b , d1 , e , p , q , r , s , t , y = state
2782
+
2783
+ r = digamma (1.0 + t ) + digamma (1.0 + t + e ) - digamma (a + t + d1 ) \
2784
+ - digamma (b + t + d1 ) - ax
2785
+ q = p * r
2786
+ y += q
2787
+ p *= s * (a + t + d1 ) / (t + 1.0 )
2788
+ p *= (b + t + d1 ) / (t + 1.0 + e )
2789
+ t += 1.0
2790
+
2791
+ return a , ax , b , d1 , e , p , q , r , s , t , y
2792
+
2793
+ init = (a , ax , b , d1 , e , p , y , jnp .array (0 , dtype = x .dtype ), s ,
2794
+ jnp .array (1 , dtype = x .dtype ), y )
2795
+ _ , _ , _ , _ , _ , _ , q , r , _ , _ , y = lax .while_loop (cond , body , init )
2796
+
2797
+ def compute_sum (y ):
2798
+ y1 = jnp .array (1 , dtype = x .dtype )
2799
+ t = jnp .array (0 , dtype = x .dtype )
2800
+ p = jnp .array (1 , dtype = x .dtype )
2801
+
2802
+ def for_body (i , state ):
2803
+ a , b , d2 , e , p , s , t , y1 = state
2804
+
2805
+ r = 1.0 - e + t
2806
+ p *= s * (a + t + d2 ) * (b + t + d2 ) / r
2807
+ t += 1.0
2808
+ p /= t
2809
+ y1 += p
2810
+
2811
+ return a , b , d2 , e , p , s , t , y1
2812
+
2813
+ init_val = a , b , d2 , e , p , s , t , y1
2814
+ y1 = lax .fori_loop (1 , ard , for_body , init_val )[- 1 ]
2815
+
2816
+ p = gamma (c )
2817
+ y1 *= gamma (e ) * p / (gamma (a + d1 ) * gamma (b + d1 ))
2818
+ y *= p / (gamma (a + d2 ) * gamma (b + d2 ))
2819
+
2820
+ y = jnp .where ((ard & 1 ) != 0 , - y , y )
2821
+ q = s ** rd
2822
+
2823
+ return jnp .where (rd > 0 , y * q + y1 , y + y1 * q )
2824
+
2825
+ return jnp .where (
2826
+ rd == 0 ,
2827
+ y * gamma (c ) / (gamma (a ) * gamma (b )),
2828
+ compute_sum (y )
2829
+ )
2830
+
2831
+
2832
+ @jit
2833
+ @jnp .vectorize
2834
+ def hyp2f1 (a : ArrayLike , b : ArrayLike , c : ArrayLike , x : ArrayLike ) -> Array :
2835
+ r"""The 2F1 hypergeometric function.
2836
+
2837
+ JAX implementation of :obj:`scipy.special.hyp2f1`.
2838
+
2839
+ .. math::
2840
+
2841
+ \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}
2842
+
2843
+ where :math:`(\cdot)_k` is the Pochammer symbol.
2844
+
2845
+ The JAX version only accepts positive and real inputs. Values of
2846
+ ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
2847
+ lead to erroneous results; consider enabling double precision in this case.
2848
+
2849
+ Args:
2850
+ a: arraylike, real-valued
2851
+ b: arraylike, real-valued
2852
+ c: arraylike, real-valued
2853
+ x: arraylike, real-valued
2854
+
2855
+ Returns:
2856
+ array of 2F1 values.
2857
+ """
2858
+ # This is backed by https://doi.org/10.48550/arXiv.1407.7786
2859
+ a , b , c , x = promote_args_inexact ('hyp2f1' , a , b , c , x )
2860
+ eps = jnp .finfo (x .dtype ).eps * 50
2861
+
2862
+ d = c - a - b
2863
+ s = 1 - x
2864
+ ca = c - a
2865
+ cb = c - b
2866
+
2867
+ id = jnp .round (d )
2868
+ ica = jnp .round (ca )
2869
+ icb = jnp .round (cb )
2870
+
2871
+ neg_int_ca = jnp .logical_and (ca <= 0 , jnp .abs (ca - ica ) < eps )
2872
+ neg_int_cb = jnp .logical_and (cb <= 0 , jnp .abs (cb - icb ) < eps )
2873
+ neg_int_ca_or_cb = jnp .logical_or (neg_int_ca , neg_int_cb )
2874
+
2875
+ index = jnp .where (jnp .logical_or (x == 0 , jnp .logical_and (jnp .logical_or (a == 0 , b == 0 ), c != 0 )), 0 ,
2876
+ jnp .where (jnp .logical_or (c == 0 , jnp .logical_and (c < 0 , c % 1 == 0 )), 1 ,
2877
+ jnp .where (jnp .logical_and (d <= - 1 , jnp .logical_not (jnp .logical_and (jnp .abs (d - id ) >= eps , s < 0 ))), 2 ,
2878
+ jnp .where (jnp .logical_and (d <= 0 , x == 1 ), 1 ,
2879
+ jnp .where (jnp .logical_and (x < 1 , b == c ), 3 ,
2880
+ jnp .where (jnp .logical_and (x < 1 , a == c ), 4 ,
2881
+ jnp .where (x > 1 , 1 ,
2882
+ jnp .where (x == 1 , 5 , 6 ))))))))
2883
+
2884
+ return lax .select_n (index ,
2885
+ jnp .array (1 , dtype = x .dtype ),
2886
+ jnp .array (jnp .inf , dtype = x .dtype ),
2887
+ s ** d * _hyp2f1_terminal_or_serie (ca , cb , c , x ),
2888
+ s ** (- a ),
2889
+ s ** (- b ),
2890
+ gamma (c ) * gamma (d ) / (gamma (ca ) * gamma (cb )),
2891
+ _hyp2f1_terminal_or_serie (a , b , c , x ))
2892
+
2893
+
2640
2894
def softmax (x : ArrayLike ,
2641
2895
/ ,
2642
2896
* ,
0 commit comments