Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide lower level Numba and Jax functions #222

Closed
ricardoV94 opened this issue Feb 17, 2023 · 3 comments
Closed

Provide lower level Numba and Jax functions #222

ricardoV94 opened this issue Feb 17, 2023 · 3 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 17, 2023

Description

pytensor.function() returns a class with a complicated __call__ method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).

We could provide a simpler jax_function and numba_function that do just that. In PyMC we implemented something like that for JAX: https://github.com/pymc-devs/pymc/blob/31c30dc1beea26e4bff52a93037540923feaaa84/pymc/sampling/jax.py#L108-L132

There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.

A simple JAX/Numba PyTensor function with global variables and updates looks like this:

import pytensor
import pytensor.tensor as pt
import numpy as np

shared_y = pytensor.shared(np.ones((5,)))
x = pt.vector("x")
fn = pytensor.function([x], x + y, updates={y: y + 1}, mode="JAX")

And roughly translates to the following pseudo code:

global shared_y = np.ones((5,))

def fn(x):
  @jax.jit
  def inner_fn(x, y):
    return x + y, y + 1

  global shared_y
  out, update_y = inner_fn(x, shared_y)
  shared_y[:] = update_y
  return out

I don't think neither JAX nor Numba support stateful jitted functions, so users would need to work with the inner_fn directly.

https://numba.pydata.org/numba-doc/dev/user/faq.html#numba-doesn-t-seem-to-care-when-i-modify-a-global-variable

The proposal here is to give users easy access to the compiled (jitted or not) inner_fn

@ammar-s847
Copy link

Hey, is this still available to contribute to? Would love to get started!

@twiecki
Copy link
Member

twiecki commented May 6, 2023

Yes!

@ricardoV94
Copy link
Member Author

I think with #1101 we don't need this anymore. Function with trust_input=True, is now just calling the jitted function + setting updates which was the goal. We can add a mention in the docs that fn.vm.jit_fn gives direct access to the jitted function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants