Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new dynamo.last_traces entry point. #1614

Merged
merged 14 commits into from
Jan 21, 2025
54 changes: 52 additions & 2 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,55 @@ def thunderfx(fn: Callable, /, **kwargs) -> Callable:

backend = ThunderCompiler(**thunder_options)
compiled = torch.compile(fn, backend=backend, **torch_compile_options)
compiled._backend = backend
return compiled

# We return this object instead of just the raw `compiled` Callable so that
# we have a place to hang the `last_*traces` properties.
class CompiledObject:
def __init__(self, be, func: Callable):
self._backend = backend
self._func = func

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)

@property
def last_traces(self) -> [Trace]:
tfogal marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the Thunder traces for all the forward subgraphs of a ThunderFX
callable.

.. note:: The object must have been invoked before calling this
tfogal marked this conversation as resolved.
Show resolved Hide resolved
function.
"""
rv: [Trace] = []
if not self._backend.subgraph_infos:
warnings.warn("Must invoke the function before using last_traces")
for sinfo in self._backend.subgraph_infos:
for th_fqn in sinfo.thunder_compiled_fns:
trcs = thunder.last_traces(th_fqn)
if trcs != []:
rv.append(trcs[-1])
del trcs
return rv

@property
def last_backward_traces(self) -> [Trace]:
"""
Get the Thunder traces for all the backward subgraphs of a
ThunderFX callable.

.. note:: The object must have been invoked before calling this
tfogal marked this conversation as resolved.
Show resolved Hide resolved
function.
"""
rv: [Trace] = []
if not self._backend.subgraph_infos:
warnings.warn("last_backward_traces used before function invoked")
for sinfo in self._backend.subgraph_infos:
for th_fqn in sinfo.thunder_compiled_fns:
trcs_bw = thunder.last_backward_traces(th_fqn)
if trcs_bw != []:
rv.append(trcs_bw[-1])
return rv

c = CompiledObject(backend, compiled)
return c
20 changes: 20 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,26 @@ def foo(x):
assert any(bsym.sym.id == "nvtx_range_push" for bsym in trc.bound_symbols)


def test_thunderfx_last_traces():
def foo(x):
return torch.sin(x) + torch.cos(x)

x = torch.randn((4, 4), requires_grad=True)
cfoo = thunderfx(foo)
cfoo(x)
assert cfoo.last_traces != []
assert cfoo.last_backward_traces != []

# Call it w/o invoking the function first.
dfoo = thunderfx(foo)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert dfoo.last_traces == []
assert "Must invoke" in str(w[0].message)
assert dfoo.last_backward_traces == []
assert "before function invoked" in str(w[1].message)


def test_get_example_input_tensor_metadata():
from thunder.dynamo.utils import _get_example_input_tensor_metadata
from torch._subclasses.fake_tensor import FakeTensorMode
Expand Down
Loading