-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned #33932
Conversation
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A very nice PR and very much welcome!
Let's add general kwargs, #31446 has some commits with that
cu_seq_lens_q=cu_seq_lens_q, | ||
cu_seq_lens_k=cu_seq_lens_k, | ||
max_length_q=max_length_q, | ||
max_length_k=max_length_k, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually something we had planned 😅 cc @gante on generate unpadding the input!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez as well if you want to have fun IMO can be quite impactfull!
cu_seq_lens_q: Optional[torch.LongTensor] = None, | ||
cu_seq_lens_k: Optional[torch.LongTensor] = None, | ||
max_length_q: int = 0, | ||
max_length_k: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are are FlashAttention specific. IMO it would make sense to just add them as fa2_kwargs
for example. We can use something like this:
class TextKwargs(TypedDict, total=False): |
This way we can potentially add more kwargs without changing the forward! |
Signed-off-by: Abhishek <[email protected]>
Thanks for the review and welcoming this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
cu_seq_lens_k: Optional[torch.LongTensor] = None, | ||
max_length_q: int = 0, | ||
max_length_k: int = 0, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs, | |
**kwargs: Unpack[Fa2Kwargs], |
cu_seq_lens_k: Optional[torch.LongTensor] = None, | ||
max_length_q: int = 0, | ||
max_length_k: int = 0, | ||
**fa2_kwargs: Fa2Kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**fa2_kwargs: Fa2Kwargs, | |
**fa2_kwargs: Unpack[Fa2Kwargs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking that we can call then flash_attn_kwargs
to not depend on versioning !
Signed-off-by: Abhishek <[email protected]>
@ArthurZucker Made the necessary changes. |
@ArthurZucker Would you mind facilitating in moving ahead with the related PR in TRL which supports this PR: huggingface/trl#2158 ? |
Okay! Overall looks good.
|
@@ -18,7 +18,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import math | |||
from typing import List, Optional, Tuple, Union | |||
from typing import List, Optional, Tuple, Union, Unpack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import List, Optional, Tuple, Union, Unpack | |
from typing import List, Optional, Tuple, Union |
from ...processing_utils import ( | ||
FlashAttentionKwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from ...processing_utils import ( | |
FlashAttentionKwargs, | |
) | |
from ...processing_utils import ( | |
FlashAttentionKwargs, Unpack | |
) |
The CI's should mostly go green with this! |
Once we merge I'll ping TRL team to make sure they don't miss it! |
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
This reverts commit 39d2868.
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! This LGTM!
You just need to run the make fix-copies
to make sure CIs go green 🚀
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
If you have issue with just make fix copies I can take it over if you want! |
Sure, That would be helpful. Thank you! There seems to be some mismatch in |
On it! |
I am just waiting on #34283 to be merged! |
Thanks @ArthurZucker for the code change to accomodate LossKwargs. |
Thanks @Abhishek-TAMU for your contribution! 🚀 |
Hi @ArthurZucker, do you mind facilitating this ? PR: huggingface/trl#2158 |
…en Lllama Model is padding free tuned (huggingface#33932) * fix: fixes for graph breaks Signed-off-by: Abhishek <[email protected]> * fix: formatting Signed-off-by: Abhishek <[email protected]> * fix: import error Signed-off-by: Abhishek <[email protected]> * fix: Add Fa2Kwargs Signed-off-by: Abhishek <[email protected]> * fix: PR Changes Signed-off-by: Abhishek <[email protected]> * PR changes Signed-off-by: Abhishek <[email protected]> * PR changes Signed-off-by: Abhishek <[email protected]> * PR changes Signed-off-by: Abhishek <[email protected]> * PR changes Signed-off-by: Abhishek <[email protected]> * Revert "PR changes" This reverts commit 39d2868. * PR changes Signed-off-by: Abhishek <[email protected]> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <[email protected]> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <[email protected]> * PR Changes Signed-off-by: Abhishek <[email protected]> * PR Changes Signed-off-by: Abhishek <[email protected]> * PR Changes Signed-off-by: Abhishek <[email protected]> * PR Changes Signed-off-by: Abhishek <[email protected]> * PR Changes Signed-off-by: Abhishek <[email protected]> * addition of documentation Signed-off-by: Abhishek <[email protected]> * change in _flash_attention_forward Signed-off-by: Abhishek <[email protected]> * make fix-copies Signed-off-by: Abhishek <[email protected]> * revert make fix-copies Signed-off-by: Abhishek <[email protected]> * fix copies * style * loss kwargs typing * style and pull latest changes --------- Signed-off-by: Abhishek <[email protected]> Co-authored-by: Arthur Zucker <[email protected]>
Hi @Abhishek-TAMU @ArthurZucker , very nice PR for adding FlashAttentionKwargs. |
kwargs are by default optional! |
What does this PR do?
This PR removes the function call
prepare_fa2_from_position_ids
inflash_attention_forward
as it causes graph break whentorch_compile
flag is turned on in Training arguments to use in SFTTrainer to perform padding free tuning of Llama model. This is because code inprepare_fa2_from_position_ids
incur a cpu-gpu sync that is unavoidable.Hence
cu_seq_lens_q
,cu_seq_lens_k
,max_length_k
,max_length_q
is now taken from the batch inDataCollatorForCompletionOnlyLM
with this PR to avoid call toprepare_fa2_from_position_ids
inflash_attention_forward
.CC: @ani300 @ArthurZucker
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.