Skip to content

State of jax.scipy.special functions: tested by evaluation or autograd, incorrectness and missing functionality #27088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pearu opened this issue Mar 12, 2025 · 8 comments
Assignees

Comments

@pearu
Copy link
Collaborator

pearu commented Mar 12, 2025

Functions in jax.scipy.special

Note: some of the issues reported below may be resolved and the table may need an update.

Notation: R is a real line, R+ is a R subset of positive values.

Function Value/Autograd tested Issues
bernoulli 👍/NA
bessel_jn 👍/👎 mismatch of naming wrt scipy.special.jn, tests are in lax_scipy_test.py, #16996
beta 👍/👎 tested on R+, otherwise, extendable to C
betainc 👍/👎 tested on R+, has untested betainc_gradx, #21900 : #27107
betaln 👍/👎 tested on R+
digamma 👍/👍 tested on R+, #11481 - wrong grad
entr 👍/👎 tested on R, grad yields nan
erf 👍/👍 tested on small R+, #14717 - wrong result
erfc 👍/👍 tested on small R+
erfinv 👍/👍 tested on small R+
exp1 👍/👍 tested on R+ (float32 only), #13543 - slow jit
expi 👍/👍 tested on not small R (float32 only), float64 aborts on gpu (jnp.piecewise ?)
expit 👍/👍 tested on small R+
expn 👍/👍 tested on R+ (float32 only)
factorial 👍/👍 tested on R
gamma 👍/👍 tested on R
gammainc 👍/👍 tested on R+, #7922 - 2nd order derivative
gammaincc 👍/👍 tested on R+
gammaln 👍/👎 tested on R+, slightly inaccurate
hyp1f1 👍/👍 tested on [0.5, 30], inaccuracies: #21503, fixed?: #21507
i0 👍/👍 tested on R
i0e 👍/👍 tested on R
i1 👍/👍 tested on R
i1e 👍/👍 tested on R
kl_div 👍/👍 tested on R+
log_ndtr 👍/👍 tested on R
logit 👍/👍 tested on [0.05, 0.95]
logsumexp 👍/👎 tested on R+inf, lax_scipy_test.py, wrong grad #22398
lpmn 👍/👎 tested on [-0.2, 0.9], lax_scipy_test.py, #10623 - discrepancy wrt scipy
lpmn_values 👍/👎 tested on [-0.2, 0.9], lax_scipy_test.py, #19157, PR #19158 not landed, issues: #14101
multigammaln 👍/👎 tested on R+, lax_scipy_test.py
ndtr 👍/👍 tested on R
ndtri 👍/👍 tested on [0.05, 0.95]
poch 👍/👍 tested on R+
polygamma 👍/👍 tested on R+, Z+, #17738 - inaccuracies, openxla/xla#5838
rel_entr 👍/👍 tested on R+
spence 👍/👍 tested on R+, JIT
sph_harm 👍/👎 tested on a small set of discrete values
xlog1py 👍/👍 tested on R
xlogy 👍/👍 tested on R+
zeta 👍/👍 tested on R+, #17734 - float32 only impl, openxla/xla#5838
erfcx not implemented, #1987
betaincinv not implemented, #2399
hyp2f1 not implemented, #2991, #28168
gammaincinv not implemented, #5350
comb not implemented, #9709, PR #18389 not landed
jv not implemented, #11002, #12402, unlanded PR: #17038, problems: #16996, #20769
kn, kv, kve not implemented, #9956
lambertw not implemented, #13680
eval_legendre not implemented, #14101
spherical_jn not implemented, #18119
li not implemented, #23732
airy not implemented, #25244
betaincc not implemented
@eltrompetero
Copy link

eltrompetero commented Apr 7, 2025

I've an example of where expn doesn't work.

import scipy as sc
from jax.scipy.special import expn
import jax.numpy as jnp
import jax

print(sc.special.expn(2.4, 1e-1))
print(expn(jnp.array([2.4]), jnp.array([1e-1])))

This prints

0.14849550677592194
[nan]

The disagreement seems to be the case for all z<=1, where z is the second argument given to expn.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 7, 2025

I've an example of where expn doesn't work.

When I run your snippet I see the following warning:

 RuntimeWarning: floating point number truncated to an integer

It seems scipy.special.expn only accepts integers in the first entry. If you change this so that the input is an integer, you get matching results:

print(sc.special.expn(2.0, 1e-1))
print(expn(jnp.array([2.0]), jnp.array([1e-1])))
0.7225450221940204
[0.7225450221940204]

Perhaps JAX's version should round invalid inputs to integers rather than returning NaN for invalid inputs – however because of the way JAX is executed we could not raise a warning based on an invalid value. To me, that kind of silent conversion of invalid inputs seems more problematic than returning a NaN. What do you think?

@eltrompetero
Copy link

eltrompetero commented Apr 7, 2025

Ah, good catch. I suppose my question answers your point.

(Even more confusingly, mpmath.expint allows for both n and z to be complex. I was playing with old code and didn't notice this choice.)

((I would lobby for the general implementation!))

@pearu
Copy link
Collaborator Author

pearu commented Apr 10, 2025

expn is defined on the whole complex plane and it can be expressed in terms of power and gamma functions so that a straightforward general implementation would be possible, at least in a certain subset of the complex plane.
It is likely that such a straightforward implementation will be inaccurate on a considerable portion of the complex plane. I am currently developing generalized hypergeometric functions support for JAX that in addition to Bessel functions functions could be used for expn as well.

I agree that silent transformation of inputs that leads to incorrect results is worst behavior and should be replaced with an exception or return NaN (in short term).

@mdmeeker
Copy link

Hi, I'm very interested in having hyp2f1. The linked issue looks to be stale. Is anyone working on that?

@pearu
Copy link
Collaborator Author

pearu commented Apr 18, 2025

Hi, I'm very interested in having hyp2f1. The linked issue looks to be stale. Is anyone working on that?

I am working on implementing generalized hypergeometric series for JAX. @mdmeeker , I wonder if you'll need complex argument support or will the real argument support be sufficient for your use case?

@mdmeeker
Copy link

@pearu real z is all I would need. How is it coming along? Any sort of timeline?

@mvsoom
Copy link

mvsoom commented Apr 28, 2025

Driveby comment: there are TFP substrate implementations for kve which saved my day.

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants