From 4431a28659723f19b947f5f98efbb9a2729d8774 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sat, 16 Nov 2024 20:55:56 -0800 Subject: [PATCH 1/2] Track generated torch files for torch compiler --- pytensor/link/pytorch/dispatch/basic.py | 11 ++-- pytensor/link/pytorch/dispatch/blockwise.py | 6 +- pytensor/link/pytorch/linker.py | 64 ++++++++++++++++++++- pytensor/link/utils.py | 9 ++- tests/link/pytorch/test_basic.py | 33 ++++++++++- tests/link/pytorch/test_blockwise.py | 1 - 6 files changed, 110 insertions(+), 14 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index e0aa80e18b..11e1d6c63a 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph( fgraph, node=None, fgraph_name="pytorch_funcified_fgraph", + conversion_func=pytorch_funcify, **kwargs, ): + built_kwargs = {"conversion_func": conversion_func, **kwargs} return fgraph_to_python( fgraph, - pytorch_funcify, + conversion_func, type_conversion_fn=pytorch_typify, fgraph_name=fgraph_name, - **kwargs, + **built_kwargs, ) @@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): # Apply inner rewrites PYTORCH.optimizer(op.fgraph) - fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) - # Disable one step inlining to prevent torch from trying to import local functions - # defined in `pytorch_funcify` - return torch.compiler.disable(fgraph_fn, recursive=False) + return fgraph_fn @pytorch_funcify.register(TensorFromScalar) diff --git a/pytensor/link/pytorch/dispatch/blockwise.py b/pytensor/link/pytorch/dispatch/blockwise.py index 524e706633..0681d32a8e 100644 --- a/pytensor/link/pytorch/dispatch/blockwise.py +++ b/pytensor/link/pytorch/dispatch/blockwise.py @@ -1,5 +1,4 @@ import torch -import torch.compiler from pytensor.graph import FunctionGraph from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): batched_dims = op.batch_ndim(node) core_node = op._create_dummy_core_node(node.inputs) core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) - inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1) + inner_func = pytorch_funcify( + core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs + ) for _ in range(batched_dims): inner_func = torch.vmap(inner_func) - @torch.compiler.disable(recursive=False) def batcher(*inputs): op._check_runtime_broadcast(node, inputs) # broadcast on batched_dims diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 035d654c83..ec26fd252f 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,12 +1,18 @@ +import copy from typing import Any from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker +from pytensor.link.utils import unique_name_generator class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gen_functors = [] + def input_filter(self, inp: Any) -> Any: from pytensor.link.pytorch.dispatch import pytorch_typify @@ -18,14 +24,68 @@ def output_filter(self, var: Variable, out: Any) -> Any: def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify + # We want to have globally unique names + # across the entire pytensor graph, not + # just the subgraph + generator = unique_name_generator(["torch_linker"]) + + # Ensure that torch is aware of the generated + # code so we can compile without graph breaks + def conversion_func_register(*args, **kwargs): + functor = pytorch_funcify(*args, **kwargs) + name = kwargs["unique_name"](functor) + self.gen_functors.append((f"_{name}", functor)) + return functor + + built_kwargs = { + "unique_name": generator, + "conversion_func": conversion_func_register, + **kwargs, + } return pytorch_funcify( - fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs + fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs ) def jit_compile(self, fn): import torch - return torch.compile(fn) + class wrapper: + """ + Pytorch would fail compiling our method when trying + to resolve some of the methods returned from dispatch + calls. We want to be careful to not leak the methods, + so this class just holds them and provisions the expected + location accordingly + + https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319 + """ + + def __init__(self, fn, gen_functors): + self.fn = torch.compile(fn) + self.gen_functors = copy.copy(gen_functors) + + def __call__(self, *args, **kwargs): + import pytensor.link.utils + + # set attrs + for n, fn in self.gen_functors: + setattr(pytensor.link.utils, n[1:], fn) + + res = self.fn(*args, **kwargs) + + # unset attrs + for n, _ in self.gen_functors: + if getattr(pytensor.link.utils, n[1:], False): + delattr(pytensor.link.utils, n[1:]) + + return res + + def __del__(self): + del self.gen_functors + + res = wrapper(fn, self.gen_functors) + self.gen_functors = [] + return res def create_thunk_inputs(self, storage_map): thunk_inputs = [] diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 69c36f160d..9cbc3838dd 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -675,6 +675,7 @@ def fgraph_to_python( local_env: dict[Any, Any] | None = None, get_name_for_object: Callable[[Any], str] = get_name_for_object, squeeze_output: bool = False, + unique_name: Callable | None = None, **kwargs, ) -> Callable: """Convert a `FunctionGraph` into a regular Python function. @@ -706,6 +707,8 @@ def fgraph_to_python( get_name_for_object A function used to provide names for the objects referenced within the generated function. + unique_name + A function to make random function names for generated code squeeze_output If the `FunctionGraph` has only one output and this option is ``True``, return the single output instead of a tuple with the output. @@ -719,7 +722,11 @@ def fgraph_to_python( if storage_map is None: storage_map = {} - unique_name = unique_name_generator([fgraph_name]) + if not unique_name: + unique_name = unique_name_generator([fgraph_name]) + + # make sure we plumb this through + kwargs["unique_name"] = unique_name if global_env is None: global_env = {} diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index b67c3d9377..25827d23f9 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -22,6 +22,7 @@ torch = pytest.importorskip("torch") +torch_dispatch = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") optimizer = RewriteDatabaseQuery( @@ -335,7 +336,7 @@ def test_pytorch_OpFromGraph(): ofg_2 = OpFromGraph([x, y], [x * y, x - y]) o1, o2 = ofg_2(y, z) - out = ofg_1(x, o1) + o2 + out = ofg_1(x, o1) / o2 xv = np.ones((2, 2), dtype=config.floatX) yv = np.ones((2, 2), dtype=config.floatX) * 3 @@ -343,3 +344,33 @@ def test_pytorch_OpFromGraph(): f = FunctionGraph([x, y, z], [out]) compare_pytorch_and_py(f, [xv, yv, zv]) + + +def test_pytorch_link_references(): + import pytensor.link.utils as m + + class BasicOp(Op): + def __init__(self): + super().__init__() + + def make_node(self, *x): + return Apply(self, list(x), [xi.type() for xi in x]) + + def perform(self, *_): + raise RuntimeError("In perform") + + @torch_dispatch.pytorch_funcify.register(BasicOp) + def fn(op, node, **kwargs): + def inner_fn(x): + assert "inner_fn" in dir(m), "not available during dispatch" + return x + + return inner_fn + + x = vector("x") + op = BasicOp() + out = op(x) + + f = function([x], out, mode="PYTORCH") + f(torch.ones(3)) + assert "inner_fn" not in dir(m), "function call reference leaked" diff --git a/tests/link/pytorch/test_blockwise.py b/tests/link/pytorch/test_blockwise.py index 75f207e544..762f9b985e 100644 --- a/tests/link/pytorch/test_blockwise.py +++ b/tests/link/pytorch/test_blockwise.py @@ -29,7 +29,6 @@ def perform(self, *_): @basic.pytorch_funcify.register(TestOp) def evaluate_test_op(op, **_): - @torch.compiler.disable(recursive=False) def func(a, b): op.call_shapes.extend(map(torch.Tensor.size, [a, b])) return a @ b From aa6aac269015fb3fadfa3458bee892e5ea71d196 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 17:23:55 -0800 Subject: [PATCH 2/2] Fix test warning --- tests/link/pytorch/test_blockwise.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/pytorch/test_blockwise.py b/tests/link/pytorch/test_blockwise.py index 762f9b985e..d0678fd2c4 100644 --- a/tests/link/pytorch/test_blockwise.py +++ b/tests/link/pytorch/test_blockwise.py @@ -12,7 +12,7 @@ basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") -class TestOp(Op): +class BatchedTestOp(Op): gufunc_signature = "(m,n),(n,p)->(m,p)" def __init__(self, final_shape): @@ -27,7 +27,7 @@ def perform(self, *_): raise RuntimeError("In perform") -@basic.pytorch_funcify.register(TestOp) +@basic.pytorch_funcify.register(BatchedTestOp) def evaluate_test_op(op, **_): def func(a, b): op.call_shapes.extend(map(torch.Tensor.size, [a, b])) @@ -42,7 +42,7 @@ def test_blockwise_broadcast(): x = pt.tensor4("x", shape=(5, 1, 2, 3)) y = pt.tensor3("y", shape=(3, 3, 2)) - op = TestOp((2, 2)) + op = BatchedTestOp((2, 2)) z = Blockwise(op)(x, y) f = pytensor.function([x, y], z, mode="PYTORCH")