Skip to content

Commit

Permalink
Support consecutive vector indices in Numba backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 27, 2024
1 parent ae66e82 commit 37ef829
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 67 deletions.
149 changes: 143 additions & 6 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -13,6 +14,7 @@
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import NoneTypeT, SliceType


@numba_funcify.register(Subtensor)
Expand Down Expand Up @@ -104,18 +106,72 @@ def {function_name}({", ".join(input_names)}):
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
adv_idxs_dims = [
idx.type.ndim
if isinstance(op, AdvancedSubtensor):
x, y, idxs = node.inputs[0], None, node.inputs[1:]
else:
x, y, *idxs = node.inputs

basic_idxs = [
idx
for idx in idxs
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]

# Special case for consecutive consecutive vector indices
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False

if (
not basic_idxs
and len(adv_idxs) >= 2
# Must be integer vectors
# Todo: we could allow shape=(1,) if this is the shape of x
and all(
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_contiguous_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
or not broadcasted_to(
y.type.broadcastable,
(
x.type.broadcastable[: adv_idxs[0]["axis"]]
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
),
)
)
):
return numba_funcify_multiple_vector_indexing(op, node, **kwargs)

# Cases natively supported by Numba
if (
# Numba does not support indexes with more than one dimension
any(idx["ndim"] > 1 for idx in adv_idxs)
# Nor multiple vector indexes
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
# The default index implementation does not handle duplicate indices correctly
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
# The default PyTensor implementation does not handle duplicate indices correctly
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
Expand All @@ -127,6 +183,87 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
return numba_funcify_default_subtensor(op, node, **kwargs)


def numba_funcify_multiple_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
):
# Special-case implementation for multiple consecutive vector indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor):
y, idxs = None, node.inputs[1:]
else:
y, *idxs = node.inputs[1:]

first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
try:
after_last_axis = next(
i
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
if not isinstance(idx.type, TensorType)
)
except StopIteration:
after_last_axis = len(idxs)

if isinstance(op, AdvancedSubtensor):

@numba_njit
def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis]
vec_idxs = idxs[first_axis:after_last_axis]

x_shape = x.shape
idx_shape = vec_idxs[0].shape
shape_bef = x_shape[:first_axis]
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer

return advanced_subtensor_multiple_vector

elif op.set_instead_of_inc:
inplace = op.inplace

@numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape

if inplace:
out = x
else:
out = x.copy()

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out

return advanced_set_subtensor_multiple_vector

else:
inplace = op.inplace

@numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape

if inplace:
out = x
else:
out = x.copy()

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out

return advanced_inc_subtensor_multiple_vector


@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
Expand Down
25 changes: 25 additions & 0 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2937,6 +2937,31 @@ def grad(self, inpt, output_gradients):
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()() for _ in idxs]

@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)


advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
Expand Down
17 changes: 12 additions & 5 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,11 @@ def compare_numba_and_py(
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
inputs: Sequence["TensorLike"],
assert_fn: Callable | None = None,
*,
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
inplace: bool = False,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality
Expand Down Expand Up @@ -276,7 +278,14 @@ def assert_fn(x, y):
pytensor_py_fn = function(
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
)
py_res = pytensor_py_fn(*inputs)

test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
py_res = pytensor_py_fn(*test_inputs)

# Get some coverage (and catch errors in python mode before unreadable numba ones)
if eval_obj_mode:
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)

pytensor_numba_fn = function(
fn_inputs,
Expand All @@ -285,11 +294,9 @@ def assert_fn(x, y):
accept_inplace=True,
updates=updates,
)
numba_res = pytensor_numba_fn(*inputs)

# Get some coverage
if eval_obj_mode:
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
numba_res = pytensor_numba_fn(*test_inputs)

if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res, strict=True):
Expand Down
Loading

0 comments on commit 37ef829

Please sign in to comment.