You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
pytensor.function() returns a class with a complicated __call__ method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).
There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.
A simple JAX/Numba PyTensor function with global variables and updates looks like this:
I think with #1101 we don't need this anymore. Function with trust_input=True, is now just calling the jitted function + setting updates which was the goal. We can add a mention in the docs that fn.vm.jit_fn gives direct access to the jitted function
Description
pytensor.function()
returns a class with a complicated__call__
method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).We could provide a simpler
jax_function
andnumba_function
that do just that. In PyMC we implemented something like that for JAX: https://github.com/pymc-devs/pymc/blob/31c30dc1beea26e4bff52a93037540923feaaa84/pymc/sampling/jax.py#L108-L132There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.
A simple JAX/Numba PyTensor function with global variables and updates looks like this:
And roughly translates to the following pseudo code:
I don't think neither JAX nor Numba support stateful jitted functions, so users would need to work with the
inner_fn
directly.https://numba.pydata.org/numba-doc/dev/user/faq.html#numba-doesn-t-seem-to-care-when-i-modify-a-global-variable
The proposal here is to give users easy access to the compiled (jitted or not)
inner_fn
The text was updated successfully, but these errors were encountered: