From d03bfc7044e484e1254d9b60c0be7019097a99b9 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Mon, 18 Nov 2024 12:30:34 -0800 Subject: [PATCH] Remove compiler disable --- pytensor/link/pytorch/dispatch/basic.py | 4 +--- pytensor/link/pytorch/dispatch/blockwise.py | 2 -- tests/link/pytorch/test_blockwise.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index fe00975a8f..e9c7ad4a56 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -181,9 +181,7 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): # Apply inner rewrites PYTORCH.optimizer(op.fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) - # Disable one step inlining to prevent torch from trying to import local functions - # defined in `pytorch_funcify` - return torch.compiler.disable(fgraph_fn, recursive=False) + return fgraph_fn @pytorch_funcify.register(TensorFromScalar) diff --git a/pytensor/link/pytorch/dispatch/blockwise.py b/pytensor/link/pytorch/dispatch/blockwise.py index 524e706633..26568f5836 100644 --- a/pytensor/link/pytorch/dispatch/blockwise.py +++ b/pytensor/link/pytorch/dispatch/blockwise.py @@ -1,5 +1,4 @@ import torch -import torch.compiler from pytensor.graph import FunctionGraph from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -16,7 +15,6 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): for _ in range(batched_dims): inner_func = torch.vmap(inner_func) - @torch.compiler.disable(recursive=False) def batcher(*inputs): op._check_runtime_broadcast(node, inputs) # broadcast on batched_dims diff --git a/tests/link/pytorch/test_blockwise.py b/tests/link/pytorch/test_blockwise.py index 75f207e544..762f9b985e 100644 --- a/tests/link/pytorch/test_blockwise.py +++ b/tests/link/pytorch/test_blockwise.py @@ -29,7 +29,6 @@ def perform(self, *_): @basic.pytorch_funcify.register(TestOp) def evaluate_test_op(op, **_): - @torch.compiler.disable(recursive=False) def func(a, b): op.call_shapes.extend(map(torch.Tensor.size, [a, b])) return a @ b