Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add information about unpacked element from prims.unpack_sequence #1672

Merged
merged 2 commits into from
Jan 22, 2025

Conversation

kshitij12345
Copy link
Collaborator

Fixes #1635

Example

import thunder
import torch

def fn(x):
    return x.sin()

jfn = thunder.jit(fn)
jfn(torch.randn(3, requires_grad=True))

backward_trc = thunder.last_backward_traces(jfn)[-1]
print(backward_trc)

Backward before PR

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  x, = C0
  clear_mutable_collection(C0)
  del C0
  bw_t6 = torch.cos(x)  # bw_t6: "cpu f32[3]"
    # bw_t6 = ltorch.cos(x)  # bw_t6: "cpu f32[3]"
      # bw_t6 = prims.cos(x)  # bw_t6: "cpu f32[3]"
  del x
  bw_t7 = torch.mul(t2, bw_t6)  # bw_t7: "cpu f32[3]"
    # bw_t7 = ltorch.mul(t2, bw_t6)  # bw_t7: "cpu f32[3]"
      # bw_t7 = prims.mul(t2, bw_t6)  # bw_t7: "cpu f32[3]"
  del t2, bw_t6
  return (bw_t7,)

Backward after PR

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  # C0: "Collection"
  # None
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  # t2: "cpu f32[3]"
  clear_mutable_collection(cotangents)
  del cotangents
  x, = C0
  # x: "cpu f32[3]"
  clear_mutable_collection(C0)
  del C0
  bw_t6 = torch.cos(x)  # bw_t6: "cpu f32[3]"
    # bw_t6 = ltorch.cos(x)  # bw_t6: "cpu f32[3]"
      # bw_t6 = prims.cos(x)  # bw_t6: "cpu f32[3]"
  del x
  bw_t7 = torch.mul(t2, bw_t6)  # bw_t7: "cpu f32[3]"
    # bw_t7 = ltorch.mul(t2, bw_t6)  # bw_t7: "cpu f32[3]"
      # bw_t7 = prims.mul(t2, bw_t6)  # bw_t7: "cpu f32[3]"
  del t2, bw_t6
  return (bw_t7,)

@kshitij12345 kshitij12345 changed the title add information about unpacked element from prims.unpack_sequence Add information about unpacked element from prims.unpack_sequence Jan 21, 2025
@kshitij12345 kshitij12345 requested a review from crcrpar January 21, 2025 18:13
@kshitij12345 kshitij12345 marked this pull request as ready for review January 21, 2025 18:14
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you @kshitij12345

@t-vi t-vi enabled auto-merge (squash) January 21, 2025 18:43
@riccardofelluga
Copy link
Collaborator

Just a curiosity, why the info is displayed in the next row instead of on the side?

@t-vi
Copy link
Collaborator

t-vi commented Jan 22, 2025

Printing is a bit tricky for the sequence unpack, as it can be 100s of results.

@t-vi t-vi merged commit 0d8b6c9 into Lightning-AI:main Jan 22, 2025
49 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants