@@ -2637,6 +2637,341 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
2637
2637
)
2638
2638
2639
2639
2640
+ def _binom (n , k ):
2641
+ a = lax .lgamma (n + 1.0 )
2642
+ b = lax .lgamma (n - k + 1.0 )
2643
+ c = lax .lgamma (k + 1.0 )
2644
+
2645
+ return lax .exp (a - b - c )
2646
+
2647
+
2648
+ def _poch (q , n ):
2649
+ """
2650
+ `jax.scipy.special.poch` does not allow for non-positive integer q.
2651
+ """
2652
+ def body (i , state ):
2653
+ q , prod = state
2654
+
2655
+ prod *= q + i
2656
+
2657
+ return q , prod
2658
+
2659
+ return lax .cond (
2660
+ n == 0 ,
2661
+ lambda : jnp .array (1 , dtype = q .dtype ),
2662
+ lambda : lax .fori_loop (jnp .array (1 , dtype = n .dtype ), n , body , (q , q ))[1 ]
2663
+ )
2664
+
2665
+
2666
+ def _hyp2f1_terminal (a , b , c , x ):
2667
+ """
2668
+ The Taylor series representation of the 2F1 hypergeometric function
2669
+ terminates when either a or b is a non-positive integer. See Eq. 4.1 and
2670
+ Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
2671
+ https://doi.org/10.48550/arXiv.1407.7786
2672
+ """
2673
+ # Ensure that between a and b, the negative integer parameter with the greater
2674
+ # absolute value - that still has a magnitude less than the absolute value of
2675
+ # c if c is non-positive - is used for the upper limit in the loop.
2676
+ temp = a
2677
+ a = jnp .where (
2678
+ jnp .logical_and (
2679
+ b < a ,
2680
+ jnp .logical_and (
2681
+ b % 1 == 0 ,
2682
+ jnp .logical_not (
2683
+ jnp .logical_and (
2684
+ c % 1 == 0 ,
2685
+ jnp .logical_and (
2686
+ c <= 0 ,
2687
+ c > b
2688
+ )
2689
+ )
2690
+ )
2691
+ )
2692
+ ), b , a
2693
+ )
2694
+ b = jnp .where (
2695
+ jnp .logical_and (
2696
+ b < temp ,
2697
+ jnp .logical_and (
2698
+ b % 1 == 0 ,
2699
+ jnp .logical_not (
2700
+ jnp .logical_and (
2701
+ c % 1 == 0 ,
2702
+ jnp .logical_and (
2703
+ c <= 0 ,
2704
+ c > b
2705
+ )
2706
+ )
2707
+ )
2708
+ )
2709
+ ), temp , b
2710
+ )
2711
+
2712
+ def body (i , sum ):
2713
+ sum += (- 1 ) ** i * _binom (jnp .abs (a ), i ) / _poch (c , i ) * _poch (b , i ) * x ** i
2714
+
2715
+ return sum
2716
+
2717
+ return lax .fori_loop (jnp .array (0 , dtype = a .dtype ),
2718
+ jnp .abs (a ) + 1 ,
2719
+ body ,
2720
+ jnp .array (0 , dtype = x .dtype ))
2721
+
2722
+
2723
+ def _hyp2f1_serie (a , b , c , x ):
2724
+ """
2725
+ Compute the 2F1 hypergeometric function using the Taylor expansion.
2726
+ See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
2727
+ https://doi.org/10.48550/arXiv.1407.7786
2728
+ """
2729
+ precision = jnp .finfo (jnp .float32 ).eps
2730
+
2731
+ s = 1 - x
2732
+
2733
+ neg_int_a = jnp .logical_and (a <= 0 , a % 1 == 0 )
2734
+ neg_int_b = jnp .logical_and (b <= 0 , b % 1 == 0 )
2735
+ neg_int_c = jnp .logical_and (c <= 0 , c % 1 == 0 )
2736
+
2737
+ def body (state ):
2738
+ serie , k , term = state
2739
+ serie += term
2740
+ term = _poch (a , k ) / _poch (c , k ) * _poch (b , k ) / factorial (k ) * x ** k
2741
+ k += 1
2742
+
2743
+ return serie , k , term
2744
+
2745
+ def cond (state ):
2746
+ serie , k , term = state
2747
+
2748
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > precision )
2749
+
2750
+ init = (jnp .array (0 , dtype = x .dtype ),
2751
+ jnp .array (1 , dtype = x .dtype ),
2752
+ jnp .array (1 , dtype = x .dtype ))
2753
+
2754
+ return lax .while_loop (cond , body , init )[0 ]
2755
+
2756
+
2757
+ def _hyp2f1_terminal_or_serie (a , b , c , x ):
2758
+ """
2759
+ Check for recurrence relations along with whether or not the series
2760
+ terminates. True recursion is not possible; however, the recurrence
2761
+ relation may still be approximated.
2762
+ See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
2763
+ https://doi.org/10.48550/arXiv.1407.7786
2764
+ """
2765
+ neg_int_a = jnp .logical_and (a <= 0 , a % 1 == 0 )
2766
+ neg_int_b = jnp .logical_and (b <= 0 , b % 1 == 0 )
2767
+ neg_int_c = jnp .logical_and (c <= 0 , c % 1 == 0 )
2768
+ neg_int_a_or_b = jnp .logical_or (neg_int_a , neg_int_b )
2769
+ not_neg_int_a_or_b = jnp .logical_not (neg_int_a_or_b )
2770
+
2771
+ s = 1 - x
2772
+ d = c - a - b
2773
+
2774
+ index = jnp .where (
2775
+ jnp .logical_and (
2776
+ neg_int_c ,
2777
+ jnp .logical_and (
2778
+ jnp .logical_not (jnp .logical_and (neg_int_a , a > c )),
2779
+ jnp .logical_not (jnp .logical_and (neg_int_b , b > c ))
2780
+ )
2781
+ ), 0 ,
2782
+ jnp .where (jnp .logical_and (x < - 0.5 , not_neg_int_a_or_b ),
2783
+ jnp .where (b > a , 1 , 2 ),
2784
+ jnp .where (jnp .logical_and (x > 0.9 , not_neg_int_a_or_b ),
2785
+ jnp .where (d % 1 != 0 , 3 , 4 ),
2786
+ jnp .where (jnp .logical_and (jnp .logical_not (neg_int_c ), neg_int_a_or_b ), 5 , 3 ))))
2787
+
2788
+ return lax .select_n (index ,
2789
+ jnp .array (jnp .inf , dtype = x .dtype ),
2790
+ s ** (- a ) * _hyp2f1_serie (a , c - b , c , - x / s ),
2791
+ s ** (- b ) * _hyp2f1_serie (c - a , b , c , - x / s ),
2792
+ _hyp2f1_serie (a , b , c , x ),
2793
+ _hyp2f1_digamma_transform (a , b , c , x ),
2794
+ _hyp2f1_terminal (a , b , c , x ))
2795
+
2796
+
2797
+ def _hyp2f1_gamma_transform (a , b , c , x ):
2798
+ """
2799
+ Gamma transformations of the 2F1 hypergeometric function.
2800
+ """
2801
+
2802
+ def transform_1 ():
2803
+ """
2804
+ See Eq. 4.10 and Analytic Continuation Formulas from PEARSON, OLVER & PORTER 2014
2805
+ https://doi.org/10.48550/arXiv.1407.7786
2806
+ """
2807
+ p = _hyp2f1_serie (a , 1 - c + a , 1 - b + a , 1 / x )
2808
+ q = _hyp2f1_serie (b , 1 - c + b , 1 - a + b , 1 / x )
2809
+ p *= (- x ) ** (- a )
2810
+ q *= (- x ) ** (- b )
2811
+ t1 = gamma (c )
2812
+ s = t1 * gamma (b - a ) / (gamma (b ) * gamma (c - a ))
2813
+ y = t1 * gamma (a - b ) / (gamma (a ) * gamma (c - b ))
2814
+
2815
+ return s * p + y * q
2816
+
2817
+ def transform_2 ():
2818
+ """
2819
+ See 4.1 Properties of F from PEARSON, OLVER & PORTER 2014
2820
+ https://doi.org/10.48550/arXiv.1407.7786
2821
+ """
2822
+ return gamma (c ) * gamma (c - a - b ) / (gamma (c - a ) * gamma (c - b ))
2823
+
2824
+ return jnp .where (
2825
+ x < - 2 ,
2826
+ transform_1 (),
2827
+ transform_2 ()
2828
+ )
2829
+
2830
+
2831
+ def _hyp2f1_digamma_transform (a , b , c , x ):
2832
+ """
2833
+ Digamma transformation of the 2F1 hypergeometric function.
2834
+ See AMS55 #15.3.10, #15.3.11, #15.3.12
2835
+ """
2836
+ precision = jnp .finfo (jnp .float32 ).eps
2837
+
2838
+ d = c - a - b
2839
+ s = 1 - x
2840
+ id = jnp .round (d )
2841
+
2842
+ e = jnp .where (id >= 0 , d , - d )
2843
+ d1 = jnp .where (id >= 0 , d , jnp .array (0 , dtype = d .dtype ))
2844
+ d2 = jnp .where (id >= 0 , jnp .array (0 , dtype = d .dtype ), d )
2845
+ aid = jnp .where (id >= 0 , id , - id ).astype ('int32' )
2846
+
2847
+ ax = jnp .log (s )
2848
+
2849
+ y = digamma (1.0 ) + digamma (1.0 + e ) - digamma (a + d1 ) - digamma (b + d1 ) - ax
2850
+ y /= gamma (e + 1.0 )
2851
+
2852
+ p = (a + d1 ) * (b + d1 ) * s / gamma (e + 2.0 )
2853
+
2854
+ def cond (state ):
2855
+ _ , _ , _ , _ , _ , _ , q , _ , _ , t , y = state
2856
+
2857
+ return jnp .logical_and (
2858
+ t < 250 ,
2859
+ jnp .logical_or (y == 0 , jnp .abs (q / y ) > precision )
2860
+ )
2861
+
2862
+ def body (state ):
2863
+ a , ax , b , d1 , e , p , q , r , s , t , y = state
2864
+
2865
+ r = digamma (1.0 + t ) + digamma (1.0 + t + e ) - digamma (a + t + d1 ) \
2866
+ - digamma (b + t + d1 ) - ax
2867
+ q = p * r
2868
+ y += q
2869
+ p *= s * (a + t + d1 ) / (t + 1.0 )
2870
+ p *= (b + t + d1 ) / (t + 1.0 + e )
2871
+ t += 1.0
2872
+
2873
+ return a , ax , b , d1 , e , p , q , r , s , t , y
2874
+
2875
+ init = (a , ax , b , d1 , e , p , y , jnp .array (0 , dtype = x .dtype ), s ,
2876
+ jnp .array (1 , dtype = x .dtype ), y )
2877
+ _ , _ , _ , _ , _ , _ , q , r , _ , _ , y = lax .while_loop (cond , body , init )
2878
+
2879
+ def compute_sum (y ):
2880
+ y1 = jnp .array (1 , dtype = x .dtype )
2881
+ t = jnp .array (0 , dtype = x .dtype )
2882
+ p = jnp .array (1 , dtype = x .dtype )
2883
+
2884
+ def for_body (i , state ):
2885
+ a , b , d2 , e , p , s , t , y1 = state
2886
+
2887
+ r = 1.0 - e + t
2888
+ p *= s * (a + t + d2 ) * (b + t + d2 ) / r
2889
+ t += 1.0
2890
+ p /= t
2891
+ y1 += p
2892
+
2893
+ return a , b , d2 , e , p , s , t , y1
2894
+
2895
+ init_val = a , b , d2 , e , p , s , t , y1
2896
+ y1 = lax .fori_loop (1 , aid , for_body , init_val )[- 1 ]
2897
+
2898
+ p = gamma (c )
2899
+ y1 *= gamma (e ) * p / (gamma (a + d1 ) * gamma (b + d1 ))
2900
+ y *= p / (gamma (a + d2 ) * gamma (b + d2 ))
2901
+
2902
+ y = jnp .where ((aid & 1 ) != 0 , - y , y )
2903
+ q = s ** id
2904
+
2905
+ return jnp .where (id > 0 , y * q + y1 , y + y1 * q )
2906
+
2907
+ return jnp .where (
2908
+ id == 0 ,
2909
+ y * gamma (c ) / (gamma (a ) * gamma (b )),
2910
+ compute_sum (y )
2911
+ )
2912
+
2913
+
2914
+ @jit
2915
+ @jnp .vectorize
2916
+ def hyp2f1 (a : ArrayLike , b : ArrayLike , c : ArrayLike , x : ArrayLike ) -> Array :
2917
+ r"""The 2F1 hypergeometric function.
2918
+
2919
+ JAX implementation of :obj:`scipy.special.hyp2f1`.
2920
+
2921
+ .. math::
2922
+
2923
+ \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}
2924
+
2925
+ where :math:`(\cdot)_k` is the Pochammer symbol.
2926
+
2927
+ The JAX version only accepts positive and real inputs. Values of
2928
+ ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
2929
+ lead to erroneous results; consider enabling double precision in this case.
2930
+
2931
+ Args:
2932
+ a: arraylike, real-valued
2933
+ b: arraylike, real-valued
2934
+ c: arraylike, real-valued
2935
+ x: arraylike, real-valued
2936
+
2937
+ Returns:
2938
+ array of 2F1 values.
2939
+ """
2940
+ # This is backed by https://doi.org/10.48550/arXiv.1407.7786
2941
+ a , b , c , x = promote_args_inexact ('hyp2f1' , a , b , c , x )
2942
+
2943
+ d = c - a - b
2944
+ s = 1 - x
2945
+
2946
+ neg_int_ca = jnp .logical_and (c - a <= 0 , (c - a ) % 1 == 0 )
2947
+ neg_int_cb = jnp .logical_and (c - b <= 0 , (c - b ) % 1 == 0 )
2948
+ neg_int_ca_or_cb = jnp .logical_or (neg_int_ca , neg_int_cb )
2949
+
2950
+ index = jnp .where (jnp .logical_or (x == 0 , jnp .logical_and (jnp .logical_or (a == 0 , b == 0 ), c != 0 )), 0 ,
2951
+ jnp .where (c == 0 , 2 ,
2952
+ jnp .where (jnp .logical_and (d <= - 1 , jnp .logical_not (jnp .logical_and (d % 1 != 0 , s < 0 ))), 1 ,
2953
+ jnp .where (jnp .logical_and (d <= 0 , x == 1 ), 2 ,
2954
+ jnp .where (jnp .logical_and (x < 1 , b == c ), 3 ,
2955
+ jnp .where (jnp .logical_and (x < 1 , a == c ), 4 ,
2956
+ jnp .where (x > 1 , 2 ,
2957
+ jnp .where (x == 1 ,
2958
+ jnp .where (neg_int_ca_or_cb ,
2959
+ jnp .where (d >= 0 , 5 , 2 ),
2960
+ jnp .where (d <= 0 , 2 , 6 )),
2961
+ jnp .where (d < 0 , 7 ,
2962
+ jnp .where (neg_int_ca_or_cb , 5 , 7 ))))))))))
2963
+
2964
+ return lax .select_n (index ,
2965
+ jnp .array (1 , dtype = x .dtype ),
2966
+ s ** d * _hyp2f1_terminal_or_serie (c - a , c - b , c , x ),
2967
+ jnp .array (jnp .inf , dtype = x .dtype ),
2968
+ s ** (- a ),
2969
+ s ** (- b ),
2970
+ s ** d * _hyp2f1_serie (c - a , c - b , c , x ),
2971
+ _hyp2f1_gamma_transform (a , b , c , x ),
2972
+ _hyp2f1_terminal_or_serie (a , b , c , x ))
2973
+
2974
+
2640
2975
def softmax (x : ArrayLike ,
2641
2976
/ ,
2642
2977
* ,
0 commit comments