Skip to content

Commit

Permalink
Remove compiler disable
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer committed Nov 18, 2024
1 parent 9b4059d commit d03bfc7
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 6 deletions.
4 changes: 1 addition & 3 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
return fgraph_fn


@pytorch_funcify.register(TensorFromScalar)
Expand Down
2 changes: 0 additions & 2 deletions pytensor/link/pytorch/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.compiler

from pytensor.graph import FunctionGraph
from pytensor.link.pytorch.dispatch import pytorch_funcify
Expand All @@ -16,7 +15,6 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
for _ in range(batched_dims):
inner_func = torch.vmap(inner_func)

@torch.compiler.disable(recursive=False)
def batcher(*inputs):
op._check_runtime_broadcast(node, inputs)
# broadcast on batched_dims
Expand Down
1 change: 0 additions & 1 deletion tests/link/pytorch/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def perform(self, *_):

@basic.pytorch_funcify.register(TestOp)
def evaluate_test_op(op, **_):
@torch.compiler.disable(recursive=False)
def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
return a @ b
Expand Down

0 comments on commit d03bfc7

Please sign in to comment.