From 4fa9bb878b94703063b89b434a20b9dcb72d9472 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Mon, 10 Feb 2025 02:05:23 +0100 Subject: [PATCH] PyTorch inline constants in dispatch to avoid graph breaks (#1118) * Split and inverse * PyTorch inline constants in dispatch to avoid graph breaks --- pytensor/link/pytorch/dispatch/basic.py | 44 +++++++++++++++--- pytensor/link/pytorch/dispatch/scalar.py | 6 +++ pytensor/link/pytorch/dispatch/shape.py | 19 ++++++-- pytensor/link/pytorch/dispatch/subtensor.py | 15 +++++++ pytensor/link/pytorch/linker.py | 3 ++ tests/link/pytorch/test_basic.py | 50 +++++++++++++++++++++ 6 files changed, 127 insertions(+), 10 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..ef4bf10637 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -8,6 +8,7 @@ from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python @@ -19,6 +20,7 @@ Eye, Join, MakeVector, + Split, TensorFromScalar, ) @@ -120,14 +122,23 @@ def arange(start, stop, step): @pytorch_funcify.register(Join) -def pytorch_funcify_Join(op, **kwargs): - def join(axis, *tensors): - # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] +def pytorch_funcify_Join(op, node, **kwargs): + axis = node.inputs[0] - return torch.cat(tensors, dim=axis) + if isinstance(axis, Constant): + axis = int(axis.data) - return join + def join_constant_axis(_, *tensors): + return torch.cat(tensors, dim=axis) + + return join_constant_axis + + else: + + def join(axis, *tensors): + return torch.cat(tensors, dim=axis) + + return join @pytorch_funcify.register(Eye) @@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs): @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) - # Apply inner rewrites PYTORCH.optimizer(op.fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) @@ -185,3 +195,23 @@ def tensorfromscalar(x): return torch.as_tensor(x) return tensorfromscalar + + +@pytorch_funcify.register(Split) +def pytorch_funcify_Split(op, node, **kwargs): + x, dim, split_sizes = node.inputs + if isinstance(dim, Constant) and isinstance(split_sizes, Constant): + dim = int(dim.data) + split_sizes = tuple(int(size) for size in split_sizes.data) + + def split_constant_axis_and_sizes(x, *_): + return x.split(split_sizes, dim=dim) + + return split_constant_axis_and_sizes + + else: + + def inner_fn(x, dim, split_amounts): + return x.split(split_amounts.tolist(), dim=dim.item()) + + return inner_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 65170b1f53..6a1c6b235e 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,12 +5,18 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Invert, ScalarOp, ) from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.math import Softplus +@pytorch_funcify.register(Invert) +def pytorch_funcify_invert(op, node, **kwargs): + return torch.bitwise_not + + @pytorch_funcify.register(ScalarOp) def pytorch_funcify_ScalarOp(op, node, **kwargs): """Return pytorch function that implements the same computation as the Scalar Op. diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index f771ac7211..c15b3a3779 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -1,15 +1,28 @@ import torch +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast @pytorch_funcify.register(Reshape) def pytorch_funcify_Reshape(op, node, **kwargs): - def reshape(x, shape): - return torch.reshape(x, tuple(shape)) + _, shape = node.inputs - return reshape + if isinstance(shape, Constant): + constant_shape = tuple(int(dim) for dim in shape.data) + + def reshape_constant_shape(x, *_): + return torch.reshape(x, constant_shape) + + return reshape_constant_shape + + else: + + def reshape(x, shape): + return torch.reshape(x, tuple(shape)) + + return reshape @pytorch_funcify.register(Shape) diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 75e7ec0776..34358797fb 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -1,3 +1,4 @@ +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -23,7 +24,21 @@ def check_negative_steps(indices): @pytorch_funcify.register(Subtensor) def pytorch_funcify_Subtensor(op, node, **kwargs): idx_list = op.idx_list + x, *idxs = node.inputs + if all(isinstance(idx, Constant) for idx in idxs): + # Use constant indices to avoid graph break + constant_indices = indices_from_subtensor( + [int(idx.data) for idx in idxs], idx_list + ) + check_negative_steps(constant_indices) + + def constant_index_subtensor(x, *_): + return x[constant_indices] + + return constant_index_subtensor + + # Fallback that will introduce a graph break def subtensor(x, *flattened_indices): indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..b8475e3157 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs): def jit_compile(self, fn): import torch + # flag that tend to help our graphs + torch._dynamo.config.capture_dynamic_output_shape_ops = True + from pytensor.link.pytorch.dispatch import pytorch_typify class wrapper: diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 2ac8ee7c3b..d5c23c83e4 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries(): compare_pytorch_and_py( f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) ) + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "n_splits, axis, values, sizes", + [ + ( + 0, + 0, + rng.normal(size=20).astype(config.floatX), + [], + ), + ( + 5, + 0, + rng.normal(size=5).astype(config.floatX), + rng.multinomial(5, np.ones(5) / 5), + ), + ( + 5, + 0, + rng.normal(size=10).astype(config.floatX), + rng.multinomial(10, np.ones(5) / 5), + ), + ( + 5, + -1, + rng.normal(size=(11, 7)).astype(config.floatX), + rng.multinomial(7, np.ones(5) / 5), + ), + ( + 5, + -2, + rng.normal(size=(11, 7)).astype(config.floatX), + rng.multinomial(11, np.ones(5) / 5), + ), + ], +) +def test_Split(n_splits, axis, values, sizes): + i = pt.tensor("i", shape=values.shape, dtype=config.floatX) + s = pt.vector("s", dtype="int64") + g = pt.split(i, s, n_splits, axis=axis) + assert len(g) == n_splits + if n_splits == 0: + return + g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g) + + compare_pytorch_and_py(g_fg, [values, sizes])