Skip to content

Commit

Permalink
Canonicalize Subtensor slices (pymc-devs#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi authored and Ch0ronomato committed Nov 2, 2024
1 parent 1c7136c commit d3b2217
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 29 deletions.
79 changes: 51 additions & 28 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,49 +337,72 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@register_stabilize
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form:
1. X[0, :] -> X[0]
2. X[:] -> X
Also, rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]

if not idxs:
return [node.inputs[0]]

last_useless_slice = len(idxs)
for s in idxs[::-1]:
# check if slice and then check slice indices
new_idxs = list(idxs)
change_flag = False
last_useful_idx = -1
for dim, s in enumerate(new_idxs):
if not isinstance(s, slice):
last_useful_idx = dim
continue

if s == slice(None):
continue

start = s.start
stop = s.stop
step = s.step
if (
isinstance(s, slice)
and s.start is None
and s.stop is None
and (
s.step is None
or extract_constant(s.step, only_process_constants=True) == 1
)
start is not None
and extract_constant(start, only_process_constants=True) == 0
):
last_useless_slice -= 1
else:
break
# check if we removed something
if last_useless_slice < len(idxs):
new_idxs = idxs[:last_useless_slice]
if new_idxs:
new_subtensor = Subtensor(new_idxs)
new_subtensor_inputs = get_slice_elements(
new_idxs, lambda x: isinstance(x, Variable)
)
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]
else:
# Subtensor is not needed at all
return [node.inputs[0]]
change_flag = True
start = None

if (
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
):
change_flag = True
stop = None

if (
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None

if not (start is None and stop is None and step is None):
last_useful_idx = dim

new_idxs[dim] = slice(start, stop, step)

if change_flag or ((last_useful_idx + 1) < len(idxs)):
out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)

return [out]


# fast_compile to allow opt subtensor(cast{float32}(make_vector))
Expand Down
43 changes: 42 additions & 1 deletion tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
Expand Down Expand Up @@ -2402,3 +2402,44 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)


def test_slice_canonicalize():
rng = np.random.default_rng(43)
x = tensor(shape=(3, 5, None, 9))
test_x = rng.normal(size=(3, 5, 8, 9))
# Test case 1
y = x[0:None, 0:5, 0:7, 0:9:1]
f = pytensor.function([x], y, allow_input_downcast=True)

# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)

expected_y = x[None:None:None, None:None:None, None:7:None]

assert equal_computations([test_y], [expected_y])

np.testing.assert_allclose(
f(test_x),
test_x[
0:None, 0:5, 0:7, 0:9:1
], # Use the unoptimized slice to make sure our rewrite logic is correct
)

# Test case 2
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
f1 = pytensor.function([x], y1, allow_input_downcast=True)

# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)

expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]

assert equal_computations([test_y1], [expected_y1])

np.testing.assert_allclose(
f1(test_x),
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
)

0 comments on commit d3b2217

Please sign in to comment.