Skip to content

Commit

Permalink
Add jax dispatch for searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 28, 2024
1 parent 6d3a2a4 commit 0e03119
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pytensor/link/jax/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
Expand Down Expand Up @@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
# return filldiagonaloffset

raise NotImplementedError("flatiter not implemented in JAX")


@jax_funcify.register(SearchsortedOp)
def jax_funcify_SearchsortedOp(op, **kwargs):
side = op.side

def searchsorted(a, v, side=side, sorter=None):
return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter)

return searchsorted
8 changes: 8 additions & 0 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.sort import argsort
from pytensor.tensor.type import matrix, tensor
from tests.link.jax.test_basic import compare_jax_and_py

Expand Down Expand Up @@ -55,6 +56,13 @@ def test_extra_ops():
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)

v = ptb.as_tensor_variable(6.0)
sorted_idx = argsort(a.ravel())

out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])


@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape():
Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InterpolationMethod,
interp,
interpolate1d,
polynomial_interpolate1d,
valid_methods,
)

Expand Down Expand Up @@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
# and last should take the right.
interior_point = x[3] + 0.1
assert f(interior_point) == (y[4] if method == "last" else y[3])


def test_polynomial_interpolate1d():
x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_op = polynomial_interpolate1d(x, y)
x_hat_pt = pt.dvector("x_hat")
degree = pt.iscalar("degree")

f = pytensor.function(
[x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN"
)
x_grid = np.linspace(-2, 6, 100)
y_hat = f(x_grid, 0)

assert_allclose(y_hat, np.mean(y))

0 comments on commit 0e03119

Please sign in to comment.