From 3cdcfde4155b5aa8d4ac29fd7f07fe21f776a712 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 18 Feb 2025 17:49:46 +0100 Subject: [PATCH] Fix Blockwise and RandomVariable in Numba with repeated arguments --- pytensor/tensor/blockwise.py | 7 +++++++ tests/link/numba/test_blockwise.py | 16 ++++++++++++++-- tests/link/numba/test_random.py | 12 ++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index b3366f21af..be5e048c77 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -443,6 +443,13 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: class OpWithCoreShape(OpFromGraph): """Generalizes an `Op` to include core shape as an additional input.""" + def __init__(self, *args, on_unused_input="ignore", **kwargs): + # We set on_unused_inputs="ignore" so that we can easily wrap nodes with repeated inputs + # In this case the subsequent appearance of repeated inputs get disconnected in the inner graph + # I can't think of a scenario where this will backfire, but if there's one + # I bet on inplacing operations (time will tell) + return super().__init__(*args, on_unused_input=on_unused_input, **kwargs) + class BlockwiseWithCoreShape(OpWithCoreShape): """Generalizes a Blockwise `Op` to include a core shape parameter.""" diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py index 43056f9f56..702efe6ed9 100644 --- a/tests/link/numba/test_blockwise.py +++ b/tests/link/numba/test_blockwise.py @@ -2,9 +2,9 @@ import pytest from pytensor import function -from pytensor.tensor import tensor +from pytensor.tensor import tensor, tensor3 from pytensor.tensor.basic import ARange -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.slinalg import Cholesky, cholesky from tests.link.numba.test_basic import compare_numba_and_py, numba_mode @@ -58,3 +58,15 @@ def test_blockwise_benchmark(benchmark): x_test = np.eye(3) * np.arange(1, 6)[:, None, None] fn(x_test) # JIT compile benchmark(fn, x_test) + + +def test_repeated_args(): + x = tensor3("x") + x_test = np.full((1, 1, 1), 2.0, dtype=x.type.dtype) + out = x @ x + fn, _ = compare_numba_and_py([x], [out], [x_test], eval_obj_mode=False) + + # Confirm we are testing a Blockwise with repeated inputs + final_node = fn.maker.fgraph.outputs[0].owner + assert isinstance(final_node.op, BlockwiseWithCoreShape) + assert final_node.inputs[0] is final_node.inputs[1] diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 9443775a39..d2301a54cb 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -10,6 +10,7 @@ from pytensor import shared from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function +from pytensor.tensor.random.op import RandomVariableWithCoreShape from tests.link.numba.test_basic import ( compare_numba_and_py, numba_mode, @@ -693,3 +694,14 @@ def test_rv_inside_ofg(): def test_unnatural_batched_dims(batch_dims_tester): """Tests for RVs that don't have natural batch dims in Numba API.""" batch_dims_tester(mode="NUMBA") + + +def test_repeated_args(): + v = pt.scalar() + x = ptr.beta(v, v) + fn, _ = compare_numba_and_py([v], [x], [0.5 * 1e6], eval_obj_mode=False) + + # Confirm we are testing a RandomVariable with repeated inputs + final_node = fn.maker.fgraph.outputs[0].owner + assert isinstance(final_node.op, RandomVariableWithCoreShape) + assert final_node.inputs[-2] is final_node.inputs[-1]