Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show return values type_str in printing unpack_sequence (to have the information in backward trace like forward trace) #1635

Closed
crcrpar opened this issue Jan 10, 2025 · 3 comments · Fixed by #1672
Assignees
Labels
enhancement New feature or request

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Jan 10, 2025

🚀 Feature

Currently, a backward trace would look like this:

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t3, = cotangents
  t194, input_fp8, = C0
  # C1 (empty sequence)
  ...

It would make them a bit clearer if each values has their type_str somewhere close, e.g.

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # t3: cuda:0 f32[...]
  t3, = cotangents
  # t194: cuda:0 f32[...], ...
  t194, input_fp8, = C0
  # C1 (empty sequence)
  ...

Motivation

While every forward trace (and computation trace if autograd is disabled) has type information of its args by giving unpack prim to each of them, backward ones doesn't seem to do so.
When looking at only backward trace without its corresponding forward trace juxtaposed, I feel like such metadata would be helpful to create a executable script from that backward trace.

Pitch

Alternatives

Additional context

@crcrpar crcrpar added the enhancement New feature or request label Jan 10, 2025
@mruberry
Copy link
Collaborator

fyi @t-vi, @IvanYashchuk if you're looking at fwd+bwd

@t-vi
Copy link
Collaborator

t-vi commented Jan 10, 2025

Yeah, I actually started adding some additional header printing in #1634 and it would be cool to look at what bits of the default printer are useful.

Note that this probably is "just" the unpack_sequence that needs doing it, maybe I'll just stick it on #1634.

@t-vi t-vi changed the title Show arguments' type_str in backward trace like forward trace Show return values type_str in backward trace like forward trace Jan 10, 2025
@t-vi t-vi changed the title Show return values type_str in backward trace like forward trace Show return values type_str in printing unpack_sequence (to have the information in backward trace like forward trace) Jan 10, 2025
@t-vi t-vi self-assigned this Jan 10, 2025
@kshitij12345
Copy link
Collaborator

I was looking at this and this patch updating unpack_sequence seems to do the trick. Hope this is helpful.

diff --git a/thunder/core/prims.py b/thunder/core/prims.py
index 359a83a2..4d3d1472 100644
--- a/thunder/core/prims.py
+++ b/thunder/core/prims.py
@@ -872,6 +872,10 @@ def unpack_sequence_printer(
     parts.append(f"= {call_str}")
 
     lines = _make_parts_into_line_or_lines(parts)
+    for out in out_printables:
+        details = _make_parts_into_line_or_lines([f"# {codeutils.prettyprint(out, with_type=True)}"])
+        lines.extend(details)
+
     return lines
 

Example Backward Trace

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # C0: "Collection"
  # C1: "Collection"
  t10, = cotangents
  # t10: "cpu f32[3]"
  t0, t1, t3, = C0
  # t0: "cpu f32[3]"
  # t1: "cpu f32[3]"
  # t3: "cpu f32[3]"
  # C1 (empty sequence)
  t24 = ltorch.true_divide(t10, t3)  # t24: "cpu f32[3]"
    # t24 = prims.div(t10, t3)  # t24: "cpu f32[3]"
  ...
  return (t30,)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants