Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track generated functions for torch compile #1094

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
conversion_func=pytorch_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
pytorch_funcify,
conversion_func,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
**built_kwargs,
)


Expand Down Expand Up @@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)

fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
return fgraph_fn


@pytorch_funcify.register(TensorFromScalar)
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/pytorch/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.compiler

from pytensor.graph import FunctionGraph
from pytensor.link.pytorch.dispatch import pytorch_funcify
Expand All @@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
batched_dims = op.batch_ndim(node)
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1)
inner_func = pytorch_funcify(
core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs
)

for _ in range(batched_dims):
inner_func = torch.vmap(inner_func)

@torch.compiler.disable(recursive=False)
def batcher(*inputs):
op._check_runtime_broadcast(node, inputs)
# broadcast on batched_dims
Expand Down
64 changes: 62 additions & 2 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import copy
from typing import Any

from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator


class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []

def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify

Expand All @@ -18,14 +24,68 @@ def output_filter(self, var: Variable, out: Any) -> Any:
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
generator = unique_name_generator(["torch_linker"])

# Ensure that torch is aware of the generated
# code so we can compile without graph breaks
def conversion_func_register(*args, **kwargs):
functor = pytorch_funcify(*args, **kwargs)
name = kwargs["unique_name"](functor)
self.gen_functors.append((f"_{name}", functor))
return functor

built_kwargs = {
"unique_name": generator,
"conversion_func": conversion_func_register,
**kwargs,
}
return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
)

def jit_compile(self, fn):
import torch

return torch.compile(fn)
class wrapper:
"""
Pytorch would fail compiling our method when trying
to resolve some of the methods returned from dispatch
calls. We want to be careful to not leak the methods,
so this class just holds them and provisions the expected
location accordingly

https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
"""

def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = copy.copy(gen_functors)

def __call__(self, *args, **kwargs):
import pytensor.link.utils
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, there is no way this is threadsafe.


# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)

res = self.fn(*args, **kwargs)

# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])

return res

def __del__(self):
del self.gen_functors

res = wrapper(fn, self.gen_functors)
self.gen_functors = []
return res

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
Expand Down
9 changes: 8 additions & 1 deletion pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def fgraph_to_python(
local_env: dict[Any, Any] | None = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object,
squeeze_output: bool = False,
unique_name: Callable | None = None,
**kwargs,
) -> Callable:
"""Convert a `FunctionGraph` into a regular Python function.
Expand Down Expand Up @@ -706,6 +707,8 @@ def fgraph_to_python(
get_name_for_object
A function used to provide names for the objects referenced within the
generated function.
unique_name
A function to make random function names for generated code
squeeze_output
If the `FunctionGraph` has only one output and this option is
``True``, return the single output instead of a tuple with the output.
Expand All @@ -719,7 +722,11 @@ def fgraph_to_python(
if storage_map is None:
storage_map = {}

unique_name = unique_name_generator([fgraph_name])
if not unique_name:
unique_name = unique_name_generator([fgraph_name])

# make sure we plumb this through
kwargs["unique_name"] = unique_name

if global_env is None:
global_env = {}
Expand Down
33 changes: 32 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


torch = pytest.importorskip("torch")
torch_dispatch = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")


optimizer = RewriteDatabaseQuery(
Expand Down Expand Up @@ -335,11 +336,41 @@ def test_pytorch_OpFromGraph():
ofg_2 = OpFromGraph([x, y], [x * y, x - y])

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out = ofg_1(x, o1) / o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])


def test_pytorch_link_references():
import pytensor.link.utils as m

class BasicOp(Op):
def __init__(self):
super().__init__()

def make_node(self, *x):
return Apply(self, list(x), [xi.type() for xi in x])

def perform(self, *_):
raise RuntimeError("In perform")

@torch_dispatch.pytorch_funcify.register(BasicOp)
def fn(op, node, **kwargs):
def inner_fn(x):
assert "inner_fn" in dir(m), "not available during dispatch"
return x

return inner_fn

x = vector("x")
op = BasicOp()
out = op(x)

f = function([x], out, mode="PYTORCH")
f(torch.ones(3))
assert "inner_fn" not in dir(m), "function call reference leaked"
7 changes: 3 additions & 4 deletions tests/link/pytorch/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")


class TestOp(Op):
class BatchedTestOp(Op):
gufunc_signature = "(m,n),(n,p)->(m,p)"

def __init__(self, final_shape):
Expand All @@ -27,9 +27,8 @@ def perform(self, *_):
raise RuntimeError("In perform")


@basic.pytorch_funcify.register(TestOp)
@basic.pytorch_funcify.register(BatchedTestOp)
def evaluate_test_op(op, **_):
@torch.compiler.disable(recursive=False)
def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
return a @ b
Expand All @@ -43,7 +42,7 @@ def test_blockwise_broadcast():

x = pt.tensor4("x", shape=(5, 1, 2, 3))
y = pt.tensor3("y", shape=(3, 3, 2))
op = TestOp((2, 2))
op = BatchedTestOp((2, 2))
z = Blockwise(op)(x, y)

f = pytensor.function([x, y], z, mode="PYTORCH")
Expand Down
Loading