From 37ef829bb8c4109f11b5aedc13866bc4db790d87 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 26 Nov 2024 14:07:15 +0100 Subject: [PATCH] Support consecutive vector indices in Numba backend --- pytensor/link/numba/dispatch/subtensor.py | 149 +++++++++++++++++- pytensor/tensor/subtensor.py | 25 ++++ tests/link/numba/test_basic.py | 17 ++- tests/link/numba/test_subtensor.py | 174 +++++++++++++++------- 4 files changed, 298 insertions(+), 67 deletions(-) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index c784321bf3..e821fd7da7 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.tensor import TensorType +from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -13,6 +14,7 @@ IncSubtensor, Subtensor, ) +from pytensor.tensor.type_other import NoneTypeT, SliceType @numba_funcify.register(Subtensor) @@ -104,18 +106,72 @@ def {function_name}({", ".join(input_names)}): @numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): - idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:] - adv_idxs_dims = [ - idx.type.ndim + if isinstance(op, AdvancedSubtensor): + x, y, idxs = node.inputs[0], None, node.inputs[1:] + else: + x, y, *idxs = node.inputs + + basic_idxs = [ + idx for idx in idxs - if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) + if ( + isinstance(idx.type, NoneTypeT) + or (isinstance(idx.type, SliceType) and not is_full_slice(idx)) + ) ] + adv_idxs = [ + { + "axis": i, + "dtype": idx.type.dtype, + "bcast": idx.type.broadcastable, + "ndim": idx.type.ndim, + } + for i, idx in enumerate(idxs) + if isinstance(idx.type, TensorType) + ] + + # Special case for consecutive consecutive vector indices + def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): + # Check that x is not broadcasted to y based on broadcastable info + if len(x_bcast) < len(to_bcast): + return True + for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True): + if x_bcast_dim and not to_bcast_dim: + return True + return False + + if ( + not basic_idxs + and len(adv_idxs) >= 2 + # Must be integer vectors + # Todo: we could allow shape=(1,) if this is the shape of x + and all( + (adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool") + for adv_idx in adv_idxs + ) + # Must be consecutive + and not op.non_contiguous_adv_indexing(node) + # y in set/inc_subtensor cannot be broadcasted + and ( + y is None + or not broadcasted_to( + y.type.broadcastable, + ( + x.type.broadcastable[: adv_idxs[0]["axis"]] + + x.type.broadcastable[adv_idxs[-1]["axis"] :] + ), + ) + ) + ): + return numba_funcify_multiple_vector_indexing(op, node, **kwargs) + # Cases natively supported by Numba if ( # Numba does not support indexes with more than one dimension + any(idx["ndim"] > 1 for idx in adv_idxs) # Nor multiple vector indexes - (len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1) - # The default index implementation does not handle duplicate indices correctly + or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1 + # The default PyTensor implementation does not handle duplicate indices correctly or ( isinstance(op, AdvancedIncSubtensor) and not op.set_instead_of_inc @@ -127,6 +183,87 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): return numba_funcify_default_subtensor(op, node, **kwargs) +def numba_funcify_multiple_vector_indexing( + op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs +): + # Special-case implementation for multiple consecutive vector indices (and set/incsubtensor) + if isinstance(op, AdvancedSubtensor): + y, idxs = None, node.inputs[1:] + else: + y, *idxs = node.inputs[1:] + + first_axis = next( + i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType) + ) + try: + after_last_axis = next( + i + for i, idx in enumerate(idxs[first_axis:], start=first_axis) + if not isinstance(idx.type, TensorType) + ) + except StopIteration: + after_last_axis = len(idxs) + + if isinstance(op, AdvancedSubtensor): + + @numba_njit + def advanced_subtensor_multiple_vector(x, *idxs): + none_slices = idxs[:first_axis] + vec_idxs = idxs[first_axis:after_last_axis] + + x_shape = x.shape + idx_shape = vec_idxs[0].shape + shape_bef = x_shape[:first_axis] + shape_aft = x_shape[after_last_axis:] + out_shape = (*shape_bef, *idx_shape, *shape_aft) + out_buffer = np.empty(out_shape, dtype=x.dtype) + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] + return out_buffer + + return advanced_subtensor_multiple_vector + + elif op.set_instead_of_inc: + inplace = op.inplace + + @numba_njit + def advanced_set_subtensor_multiple_vector(x, y, *idxs): + vec_idxs = idxs[first_axis:after_last_axis] + x_shape = x.shape + + if inplace: + out = x + else: + out = x.copy() + + for outer in np.ndindex(x_shape[:first_axis]): + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out[(*outer, *scalar_idxs)] = y[(*outer, i)] + return out + + return advanced_set_subtensor_multiple_vector + + else: + inplace = op.inplace + + @numba_njit + def advanced_inc_subtensor_multiple_vector(x, y, *idxs): + vec_idxs = idxs[first_axis:after_last_axis] + x_shape = x.shape + + if inplace: + out = x + else: + out = x.copy() + + for outer in np.ndindex(x_shape[:first_axis]): + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out[(*outer, *scalar_idxs)] += y[(*outer, i)] + return out + + return advanced_inc_subtensor_multiple_vector + + @numba_funcify.register(AdvancedIncSubtensor1) def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): inplace = op.inplace diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index f37641b557..109c40ee6e 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2937,6 +2937,31 @@ def grad(self, inpt, output_gradients): gy = _sum_grad_over_bcasted_dims(y, gy) return [gx, gy] + [DisconnectedType()() for _ in idxs] + @staticmethod + def non_contiguous_adv_indexing(node: Apply) -> bool: + """ + Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing). + + This function checks if the advanced indexing is non-contiguous, + in which case the advanced index dimensions are placed on the left of the + output array, regardless of their opriginal position. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + + + Parameters + ---------- + node : Apply + The node of the AdvancedSubtensor operation. + + Returns + ------- + bool + True if the advanced indexing is non-contiguous, False otherwise. + """ + _, _, *idxs = node.inputs + return _non_contiguous_adv_indexing(idxs) + advanced_inc_subtensor = AdvancedIncSubtensor() advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index dfadc58a69..dd4c5b4967 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -228,9 +228,11 @@ def compare_numba_and_py( fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]], inputs: Sequence["TensorLike"], assert_fn: Callable | None = None, + *, numba_mode=numba_mode, py_mode=py_mode, updates=None, + inplace: bool = False, eval_obj_mode: bool = True, ) -> tuple[Callable, Any]: """Function to compare python graph output and Numba compiled output for testing equality @@ -276,7 +278,14 @@ def assert_fn(x, y): pytensor_py_fn = function( fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates ) - py_res = pytensor_py_fn(*inputs) + + test_inputs = (inp.copy() for inp in inputs) if inplace else inputs + py_res = pytensor_py_fn(*test_inputs) + + # Get some coverage (and catch errors in python mode before unreadable numba ones) + if eval_obj_mode: + test_inputs = (inp.copy() for inp in inputs) if inplace else inputs + eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode) pytensor_numba_fn = function( fn_inputs, @@ -285,11 +294,9 @@ def assert_fn(x, y): accept_inplace=True, updates=updates, ) - numba_res = pytensor_numba_fn(*inputs) - # Get some coverage - if eval_obj_mode: - eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) + test_inputs = (inp.copy() for inp in inputs) if inplace else inputs + numba_res = pytensor_numba_fn(*test_inputs) if len(fn_outputs) > 1: for j, p in zip(numba_res, py_res, strict=True): diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index ff335e30dc..ea3095408b 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -85,7 +85,11 @@ def test_AdvancedSubtensor1_out_of_bounds(): (np.array([True, False, False])), False, ), - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), + ( + pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([1, 2], [2, 3]), + False, + ), # Single multidimensional indexing (supported after specialization rewrites) ( as_tensor(np.arange(3 * 3).reshape((3, 3))), @@ -117,17 +121,23 @@ def test_AdvancedSubtensor1_out_of_bounds(): (slice(2, None), np.eye(3).astype(bool)), False, ), - # Multiple advanced indexing, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (slice(None), [1, 2], [3, 4]), - True, + False, + ), + ( + as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))), + ([1, 2], [3, 4], [5, 6]), + False, ), + # Non-contiguous vector indexing, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), True, ), + # >1d vector indexing, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([[1, 2], [2, 1]], [0, 0]), @@ -135,7 +145,7 @@ def test_AdvancedSubtensor1_out_of_bounds(): ), ], ) -@pytest.mark.filterwarnings("error") +@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed def test_AdvancedSubtensor(x, indices, objmode_needed): """Test NumPy's advanced indexing in more than one dimension.""" x_pt = x.type() @@ -268,94 +278,151 @@ def test_AdvancedIncSubtensor1(x, y, indices): "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", [ ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(3 * 5).reshape(3, 5), - (slice(None, None, 2), [1, 2, 3]), + (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index False, False, False, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - -99, - (slice(None, None, 2), [1, 2, 3], -1), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array(-99), # Broadcasted value + ( + slice(None, None, 2), + [1, 2, 3], + -1, + ), # Mixed basic and broadcasted vector idx False, False, False, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - -99, # Broadcasted value - (slice(None, None, 2), [1, 2, 3]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array(-99), # Broadcasted value + (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx False, False, False, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(4 * 5).reshape(4, 5), - (0, [1, 2, 2, 3]), + (0, [1, 2, 2, 3]), # Broadcasted vector index True, False, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - [-99], # Broadcsasted value - (0, [1, 2, 2, 3]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array([-99]), # Broadcasted value + (0, [1, 2, 2, 3]), # Broadcasted vector index True, False, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(1 * 4 * 5).reshape(1, 4, 5), - (np.array([True, False, False])), + (np.array([True, False, False])), # Broadcasted boolean index False, False, False, ), ( - as_tensor(np.arange(3 * 3).reshape((3, 3))), + np.arange(3 * 3).reshape((3, 3)), -np.arange(3), - (np.eye(3).astype(bool)), + (np.eye(3).astype(bool)), # Boolean index False, True, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - as_tensor(rng.poisson(size=(2, 5))), - ([1, 2], [2, 3]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 5)), + ([1, 2], [2, 3]), # 2 vector indices + False, + False, + False, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(3, 2)), + (slice(None), [1, 2], [2, 3]), # 2 vector indices + False, + False, + False, + ), + ( + np.arange(3 * 4 * 6).reshape((3, 4, 6)), + rng.poisson(size=(2,)), + ([1, 2], [2, 3], [4, 5]), # 3 vector indices + False, + False, + False, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array(-99), # Broadcasted value + ([1, 2], [2, 3]), # 2 vector indices False, True, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - as_tensor(rng.poisson(size=(2, 4))), - ([1, 2], slice(None), [3, 4]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 4)), + ([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices False, True, True, ), - pytest.param( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - as_tensor(rng.poisson(size=(2, 5))), - ([1, 1], [2, 2]), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 2)), + ( + slice(1, None), + [1, 2], + [3, 4], + ), # Mixed double vector index and basic index + False, + True, + True, + ), + ( + np.arange(5), + rng.poisson(size=(2, 2)), + ([[1, 2], [2, 3]]), # matrix indices False, True, True, ), + pytest.param( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 5)), + ([1, 1], [2, 2]), # Repeated indices + True, + False, + False, + ), ], ) -@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("inplace", (False, True)) +@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed def test_AdvancedIncSubtensor( - x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode + x, + y, + indices, + duplicate_indices, + set_requires_objmode, + inc_requires_objmode, + inplace, ): - out_pt = set_subtensor(x[indices], y) + x_pt = pt.as_tensor(x).type("x") + y_pt = pt.as_tensor(y).type("y") + + out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) with ( pytest.warns( @@ -365,11 +432,18 @@ def test_AdvancedIncSubtensor( if set_requires_objmode else contextlib.nullcontext() ): - compare_numba_and_py(out_fg, []) + fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y]) + + if inplace: + # Test updates inplace + x_orig = x.copy() + fn(x, y + 1) + assert not np.all(x == x_orig) - out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices) + out_pt = inc_subtensor( + x_pt[indices], y_pt, ignore_duplicates=not duplicate_indices, inplace=inplace + ) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) with ( pytest.warns( UserWarning, @@ -378,21 +452,9 @@ def test_AdvancedIncSubtensor( if inc_requires_objmode else contextlib.nullcontext() ): - compare_numba_and_py(out_fg, []) - - x_pt = x.type() - out_pt = set_subtensor(x_pt[indices], y) - # Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just - # hack it on here - out_pt.owner.op.inplace = True - assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - with ( - pytest.warns( - UserWarning, - match="Numba will use object mode to run AdvancedSetSubtensor's perform method", - ) - if set_requires_objmode - else contextlib.nullcontext() - ): - compare_numba_and_py(out_fg, [x.data]) + fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y]) + if inplace: + # Test updates inplace + x_orig = x.copy() + fn(x, y) + assert not np.all(x == x_orig)