From 6de3151370371576d061ae2bb0a45f92dcc73eeb Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sat, 12 Oct 2024 15:21:10 -0700 Subject: [PATCH] Improve torch elemwise operator --- pytensor/link/pytorch/dispatch/elemwise.py | 18 ++++++++++-- tests/link/pytorch/test_elemwise.py | 33 ++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index b1ad5582c5..72f97af1fa 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -11,9 +11,21 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) - def elemwise_fn(*inputs): - Elemwise._check_runtime_broadcast(node, inputs) - return base_fn(*inputs) + if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]): + # torch can handle this scalar + # broadcast, we'll let it. + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + return base_fn(*inputs) + else: + + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + broadcast_inputs = torch.broadcast_tensors(*inputs) + ufunc = base_fn + for _ in range(broadcast_inputs[0].dim()): + ufunc = torch.vmap(ufunc) + return ufunc(*broadcast_inputs) return elemwise_fn diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 86089cc921..20c98094c1 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,10 +1,13 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt import pytensor.tensor.math as ptm from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph +from pytensor.scalar.basic import ScalarOp, get_scalar_type +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -150,3 +153,33 @@ def test_cast(): fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] ) assert res.dtype == torch.int32 + + +def test_vmap_elemwise(): + from pytensor.link.pytorch.dispatch.basic import pytorch_funcify + + class TestOp(ScalarOp): + def __init__(self): + super().__init__( + output_types_preference=lambda *_: [get_scalar_type("float32")] + ) + self.call_shapes = [] + self.nin = 1 + + def perform(self, *_): + raise RuntimeError("In perform") + + @pytorch_funcify.register(TestOp) + def relu(op, node, **kwargs): + def relu(row): + op.call_shapes.append(row.size()) + return torch.max(torch.zeros_like(row), row) + + return relu + + x = matrix("x", shape=(2, 3)) + op = TestOp() + f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH") + vals = torch.zeros(2, 3).normal_() + np.testing.assert_allclose(f(vals), torch.relu(vals)) + assert op.call_shapes == [torch.Size([])], op.call_shapes