Skip to content

jax.scipy.special.exp1 is slow when applied to a vector #13543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
DyeKuu opened this issue Dec 7, 2022 · 1 comment
Open

jax.scipy.special.exp1 is slow when applied to a vector #13543

DyeKuu opened this issue Dec 7, 2022 · 1 comment
Assignees
Labels
bug Something isn't working performance make things lean and fast

Comments

@DyeKuu
Copy link

DyeKuu commented Dec 7, 2022

Description

The code takes so long to jit when x64 is enabled.

Below is a mwe for reproducing in CPU environment.

import jax
import jax.numpy as jnp
from jax.config import config

config.update("jax_enable_x64", True)

def E1_scaled(x):
  return jax.scipy.special.exp1(x) * jnp.exp(x)

if __name__ == "__main__":
  fn = jax.jit(jax.vmap(E1_scaled))
  x = jnp.array(
    [
      1.00475018e-07, 1.83070789e-06, 4.33937425e-05, 5.15584770e-02,
      1.28789450e-01, 1.23364297e-04, 3.65711501e-05, 2.96945880e-02,
      3.88449715e-04, 3.41324223e-02
    ],
    dtype=jnp.float64
  )
  print(fn(x))

What jax/jaxlib version are you using?

jax v0.2.24, jaxlib v0.1.71+cuda111

Which accelerator(s) are you using?

CPU

Additional system info

Python version 3.8, Ubuntu 20.04.3 LTS (Focal Fossa)

NVIDIA GPU info

No response

@DyeKuu DyeKuu added the bug Something isn't working label Dec 7, 2022
@DyeKuu DyeKuu changed the title JIT takes long time and errors for jax.scipy.special.exp1 multiplying jnp.exp JIT takes long time for jax.scipy.special.exp1 multiplying jnp.exp Dec 7, 2022
@hawkinsp hawkinsp added the performance make things lean and fast label Dec 7, 2022
@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 7, 2022

Huh, this is an interesting bug.

For a workaround on CPU use lax.map instead of vmap, i.e., write:

  fn = jax.jit(lambda xs: jax.lax.map(E1_scaled, xs))

I think the issue is the expn is a function defined piecewise over its domain. When vmap-ed, we actually evaluate all of the pieces for each argument, even those that correspond to conditions that are not taken. Some of these are very slow to converge!

One idea is that perhaps vmap(cond(..., while(...), ...)) could do something to short-circuit the evaluation of the inner loop if we are in a branch that is predicated out?

Smaller repro:

import jax
import jax.numpy as jnp
from jax.config import config

config.update("jax_enable_x64", True)

x = jnp.array(
  [
    1.00475018e-07, 1.83070789e-06, 4.33937425e-05, 5.15584770e-02,
    1.28789450e-01, 1.23364297e-04, 3.65711501e-05, 2.96945880e-02,
    3.88449715e-04, 3.41324223e-02
  ],
  dtype=jnp.float64
)
print(jax.scipy.special.exp1(x))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working performance make things lean and fast
Projects
None yet
Development

No branches or pull requests

4 participants