From 0e03119cd334937a0a08feaf811d84d8b52e43bc Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 28 Dec 2024 16:50:13 -0500 Subject: [PATCH] Add jax dispatch for `searchsorted` --- pytensor/link/jax/dispatch/extra_ops.py | 11 +++++++++++ tests/link/jax/test_extra_ops.py | 8 ++++++++ tests/tensor/test_interpolate.py | 18 ++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/pytensor/link/jax/dispatch/extra_ops.py b/pytensor/link/jax/dispatch/extra_ops.py index a9e36667ef..87e55f1007 100644 --- a/pytensor/link/jax/dispatch/extra_ops.py +++ b/pytensor/link/jax/dispatch/extra_ops.py @@ -10,6 +10,7 @@ FillDiagonalOffset, RavelMultiIndex, Repeat, + SearchsortedOp, Unique, UnravelIndex, ) @@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs): # return filldiagonaloffset raise NotImplementedError("flatiter not implemented in JAX") + + +@jax_funcify.register(SearchsortedOp) +def jax_funcify_SearchsortedOp(op, **kwargs): + side = op.side + + def searchsorted(a, v, side=side, sorter=None): + return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter) + + return searchsorted diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 1427413379..0c8fb92810 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -6,6 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as pt_extra_ops +from pytensor.tensor.sort import argsort from pytensor.tensor.type import matrix, tensor from tests.link.jax.test_basic import compare_jax_and_py @@ -55,6 +56,13 @@ def test_extra_ops(): fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False ) + v = ptb.as_tensor_variable(6.0) + sorted_idx = argsort(a.ravel()) + + out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [a_test]) + @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") def test_bartlett_dynamic_shape(): diff --git a/tests/tensor/test_interpolate.py b/tests/tensor/test_interpolate.py index 95ebae10e2..b98e0ce371 100644 --- a/tests/tensor/test_interpolate.py +++ b/tests/tensor/test_interpolate.py @@ -8,6 +8,7 @@ InterpolationMethod, interp, interpolate1d, + polynomial_interpolate1d, valid_methods, ) @@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod): # and last should take the right. interior_point = x[3] + 0.1 assert f(interior_point) == (y[4] if method == "last" else y[3]) + + +def test_polynomial_interpolate1d(): + x = np.linspace(-2, 6, 10) + y = np.sin(x) + + f_op = polynomial_interpolate1d(x, y) + x_hat_pt = pt.dvector("x_hat") + degree = pt.iscalar("degree") + + f = pytensor.function( + [x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN" + ) + x_grid = np.linspace(-2, 6, 100) + y_hat = f(x_grid, 0) + + assert_allclose(y_hat, np.mean(y))