Skip to content

Commit

Permalink
PyTorch inline constants in dispatch to avoid graph breaks (#1118)
Browse files Browse the repository at this point in the history
* Split and inverse

* PyTorch inline constants in dispatch to avoid graph breaks
  • Loading branch information
ricardoV94 authored Feb 10, 2025
1 parent 17748b7 commit 4fa9bb8
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 10 deletions.
44 changes: 37 additions & 7 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
Expand All @@ -19,6 +20,7 @@
Eye,
Join,
MakeVector,
Split,
TensorFromScalar,
)

Expand Down Expand Up @@ -120,14 +122,23 @@ def arange(start, stop, step):


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]

return torch.cat(tensors, dim=axis)
if isinstance(axis, Constant):
axis = int(axis.data)

return join
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)

return join_constant_axis

else:

def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

return join


@pytorch_funcify.register(Eye)
Expand Down Expand Up @@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
Expand All @@ -185,3 +195,23 @@ def tensorfromscalar(x):
return torch.as_tensor(x)

return tensorfromscalar


@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)

def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)

return split_constant_axis_and_sizes

else:

def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())

return inner_fn
6 changes: 6 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Invert,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import torch

from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
_, shape = node.inputs

return reshape
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)

def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)

return reshape_constant_shape

else:

def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
Expand Down
15 changes: 15 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -23,7 +24,21 @@ def check_negative_steps(indices):
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
x, *idxs = node.inputs

if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)

def constant_index_subtensor(x, *_):
return x[constant_indices]

return constant_index_subtensor

# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
Expand Down
3 changes: 3 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True

from pytensor.link.pytorch.dispatch import pytorch_typify

class wrapper:
Expand Down
50 changes: 50 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries():
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)


rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
rng.normal(size=20).astype(config.floatX),
[],
),
(
5,
0,
rng.normal(size=5).astype(config.floatX),
rng.multinomial(5, np.ones(5) / 5),
),
(
5,
0,
rng.normal(size=10).astype(config.floatX),
rng.multinomial(10, np.ones(5) / 5),
),
(
5,
-1,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(7, np.ones(5) / 5),
),
(
5,
-2,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(11, np.ones(5) / 5),
),
],
)
def test_Split(n_splits, axis, values, sizes):
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
s = pt.vector("s", dtype="int64")
g = pt.split(i, s, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)

compare_pytorch_and_py(g_fg, [values, sizes])

0 comments on commit 4fa9bb8

Please sign in to comment.