Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't auto-recompute attention or linear #1648

Merged
merged 6 commits into from
Jan 16, 2025
Merged

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Jan 16, 2025

Fixes: #1646

Thank you @kshitij12345 for the detailed issue.

@t-vi t-vi requested review from mruberry and lantiga as code owners January 16, 2025 10:08
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just have one question. Also, we should add a test with a simple 2 layer model to verify sdpa or linear is not recomputed in backward.

Thank you @t-vi

thunder/torch/__init__.py Show resolved Hide resolved
@t-vi t-vi enabled auto-merge (squash) January 16, 2025 12:13
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @t-vi

thunder/tests/test_networks.py Show resolved Hide resolved
@IvanYashchuk
Copy link
Collaborator

That's a good quick fix, but can we revert to the previous default behavior: don't do any recomputation except for fused operations? Current logic doesn't use the information on whether the operation will be fused. To overwrite the default rule we can propagate torch.utils.checkpoint tags or have opt-in automatic passes.

@t-vi
Copy link
Collaborator Author

t-vi commented Jan 16, 2025

That's a good quick fix, but can we revert to the previous default behavior: don't do any recomputation except for fused operations? Current logic doesn't use the information on whether the operation will be fused. To overwrite the default rule we can propagate torch.utils.checkpoint tags or have opt-in automatic passes.

Maybe, I had that as one of the options in the issue, we went for this for now.

To my mind there are multiple parts:

  • we want to be sure that we don't cause memory regressions,
  • I think the rematerialization for forward and backward eventually needs to work without creating the joint trace.

@t-vi t-vi merged commit ef06bd0 into main Jan 16, 2025
49 checks passed
@t-vi t-vi deleted the tom/dont_recompute_sdpa_linear branch January 16, 2025 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Perf Regression : SDPA is recomputed in backward
4 participants