-
Notifications
You must be signed in to change notification settings - Fork 3k
State of jax.scipy.special functions: tested by evaluation or autograd, incorrectness and missing functionality #27088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I've an example of where expn doesn't work. import scipy as sc
from jax.scipy.special import expn
import jax.numpy as jnp
import jax
print(sc.special.expn(2.4, 1e-1))
print(expn(jnp.array([2.4]), jnp.array([1e-1]))) This prints
The disagreement seems to be the case for all z<=1, where z is the second argument given to expn. |
When I run your snippet I see the following warning:
It seems print(sc.special.expn(2.0, 1e-1))
print(expn(jnp.array([2.0]), jnp.array([1e-1])))
Perhaps JAX's version should round invalid inputs to integers rather than returning NaN for invalid inputs – however because of the way JAX is executed we could not raise a warning based on an invalid value. To me, that kind of silent conversion of invalid inputs seems more problematic than returning a NaN. What do you think? |
Ah, good catch. I suppose my question answers your point. (Even more confusingly, mpmath.expint allows for both n and z to be complex. I was playing with old code and didn't notice this choice.) ((I would lobby for the general implementation!)) |
expn is defined on the whole complex plane and it can be expressed in terms of power and gamma functions so that a straightforward general implementation would be possible, at least in a certain subset of the complex plane. I agree that silent transformation of inputs that leads to incorrect results is worst behavior and should be replaced with an exception or return NaN (in short term). |
Hi, I'm very interested in having |
I am working on implementing generalized hypergeometric series for JAX. @mdmeeker , I wonder if you'll need complex argument support or will the real argument support be sufficient for your use case? |
@pearu real z is all I would need. How is it coming along? Any sort of timeline? |
Functions in jax.scipy.special
Note: some of the issues reported below may be resolved and the table may need an update.
Notation: R is a real line, R+ is a R subset of positive values.
bernoulli
bessel_jn
beta
betainc
betainc_gradx
, #21900 : #27107betaln
digamma
entr
erf
erfc
erfinv
exp1
expi
jnp.piecewise
?)expit
expn
factorial
gamma
gammainc
gammaincc
gammaln
hyp1f1
[0.5, 30]
, inaccuracies: #21503, fixed?: #21507i0
i0e
i1
i1e
kl_div
log_ndtr
logit
[0.05, 0.95]
logsumexp
lpmn
[-0.2, 0.9]
, lax_scipy_test.py, #10623 - discrepancy wrt scipylpmn_values
[-0.2, 0.9]
, lax_scipy_test.py, #19157, PR #19158 not landed, issues: #14101multigammaln
ndtr
ndtri
[0.05, 0.95]
poch
polygamma
rel_entr
spence
sph_harm
xlog1py
xlogy
zeta
erfcx
betaincinv
hyp2f1
gammaincinv
comb
jv
kn
,kv
,kve
lambertw
eval_legendre
spherical_jn
li
airy
betaincc
The text was updated successfully, but these errors were encountered: