Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned #33932

Merged
merged 38 commits into from
Oct 24, 2024

Conversation

Abhishek-TAMU
Copy link
Contributor

@Abhishek-TAMU Abhishek-TAMU commented Oct 3, 2024

What does this PR do?

This PR removes the function call prepare_fa2_from_position_ids in flash_attention_forward as it causes graph break when torch_compile flag is turned on in Training arguments to use in SFTTrainer to perform padding free tuning of Llama model. This is because code in prepare_fa2_from_position_ids incur a cpu-gpu sync that is unavoidable.
Hence cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q is now taken from the batch in DataCollatorForCompletionOnlyLM with this PR to avoid call to prepare_fa2_from_position_ids in flash_attention_forward.

CC: @ani300 @ArthurZucker

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very nice PR and very much welcome!
Let's add general kwargs, #31446 has some commits with that
image

Comment on lines 982 to 985
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually something we had planned 😅 cc @gante on generate unpadding the input!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez as well if you want to have fun IMO can be quite impactfull!

Comment on lines 1181 to 1184
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are are FlashAttention specific. IMO it would make sense to just add them as fa2_kwargs for example. We can use something like this:

class TextKwargs(TypedDict, total=False):

@ArthurZucker
Copy link
Collaborator

This way we can potentially add more kwargs without changing the forward!

@Abhishek-TAMU
Copy link
Contributor Author

Thanks for the review and welcoming this PR.
The changes suggested by you have been made @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,
**kwargs: Unpack[Fa2Kwargs],

cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
**fa2_kwargs: Fa2Kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**fa2_kwargs: Fa2Kwargs,
**fa2_kwargs: Unpack[Fa2Kwargs],

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking that we can call then flash_attn_kwargs to not depend on versioning !

@Abhishek-TAMU
Copy link
Contributor Author

@ArthurZucker Made the necessary changes.
Feel free to suggest changes if required any. Thanks!

@Abhishek-TAMU
Copy link
Contributor Author

@ArthurZucker Would you mind facilitating in moving ahead with the related PR in TRL which supports this PR: huggingface/trl#2158 ?

@ArthurZucker
Copy link
Collaborator

Okay! Overall looks good.

  1. we need to protect the import of Unpack :
  2. let's just add an example in the documentation of how to use this! a small python snippet!

@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Unpack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import List, Optional, Tuple, Union, Unpack
from typing import List, Optional, Tuple, Union

Comment on lines 51 to 53
from ...processing_utils import (
FlashAttentionKwargs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...processing_utils import (
FlashAttentionKwargs,
)
from ...processing_utils import (
FlashAttentionKwargs, Unpack
)

@ArthurZucker
Copy link
Collaborator

The CI's should mostly go green with this!
Then you will have make fix-copies that is gonna propagate the changes!

@ArthurZucker
Copy link
Collaborator

Once we merge I'll ping TRL team to make sure they don't miss it!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! This LGTM!
You just need to run the make fix-copies to make sure CIs go green 🚀

@ArthurZucker
Copy link
Collaborator

If you have issue with just make fix copies I can take it over if you want!

@Abhishek-TAMU
Copy link
Contributor Author

Abhishek-TAMU commented Oct 22, 2024

Sure, That would be helpful. Thank you! There seems to be some mismatch in src/transformers/models/glm/modeling_glm.py.

@ArthurZucker
Copy link
Collaborator

On it!

@ArthurZucker
Copy link
Collaborator

I am just waiting on #34283 to be merged!

@ArthurZucker
Copy link
Collaborator

The new helper will be this:
image

(with loss kwargs!)

@Abhishek-TAMU
Copy link
Contributor Author

Thanks @ArthurZucker for the code change to accomodate LossKwargs.

@ArthurZucker
Copy link
Collaborator

Thanks @Abhishek-TAMU for your contribution! 🚀

@ArthurZucker ArthurZucker merged commit 65753d6 into huggingface:main Oct 24, 2024
21 of 25 checks passed
@Abhishek-TAMU
Copy link
Contributor Author

Abhishek-TAMU commented Oct 31, 2024

Once we merge I'll ping TRL team to make sure they don't miss it

Hi @ArthurZucker, do you mind facilitating this ? PR: huggingface/trl#2158

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…en Lllama Model is padding free tuned (huggingface#33932)

* fix: fixes for graph breaks

Signed-off-by: Abhishek <[email protected]>

* fix: formatting

Signed-off-by: Abhishek <[email protected]>

* fix: import error

Signed-off-by: Abhishek <[email protected]>

* fix: Add Fa2Kwargs

Signed-off-by: Abhishek <[email protected]>

* fix: PR Changes

Signed-off-by: Abhishek <[email protected]>

* PR changes

Signed-off-by: Abhishek <[email protected]>

* PR changes

Signed-off-by: Abhishek <[email protected]>

* PR changes

Signed-off-by: Abhishek <[email protected]>

* PR changes

Signed-off-by: Abhishek <[email protected]>

* Revert "PR changes"

This reverts commit 39d2868.

* PR changes

Signed-off-by: Abhishek <[email protected]>

* fix: FlashAttentionKwarg

Signed-off-by: Abhishek <[email protected]>

* fix: FlashAttentionKwarg

Signed-off-by: Abhishek <[email protected]>

* PR Changes

Signed-off-by: Abhishek <[email protected]>

* PR Changes

Signed-off-by: Abhishek <[email protected]>

* PR Changes

Signed-off-by: Abhishek <[email protected]>

* PR Changes

Signed-off-by: Abhishek <[email protected]>

* PR Changes

Signed-off-by: Abhishek <[email protected]>

* addition of documentation

Signed-off-by: Abhishek <[email protected]>

* change in _flash_attention_forward

Signed-off-by: Abhishek <[email protected]>

* make fix-copies

Signed-off-by: Abhishek <[email protected]>

* revert make fix-copies

Signed-off-by: Abhishek <[email protected]>

* fix copies

* style

* loss kwargs typing

* style and pull latest changes

---------

Signed-off-by: Abhishek <[email protected]>
Co-authored-by: Arthur Zucker <[email protected]>
@ma787639046
Copy link
Contributor

Hi @Abhishek-TAMU @ArthurZucker , very nice PR for adding FlashAttentionKwargs.
In modeling_llama.py#L959, I noticed that **flash_attn_kwargs is added to the inputs of decoder_layer when gradient checkpointing is not used. Could you please also add flash_attn_kwargs when using gradient checkpointing in the above if branch at line 938? If checkpointing function does not accept kwargs, can we make all FlashAttentionKwargs as optional input fields of LlamaDecoderLayer.forward?

@ArthurZucker
Copy link
Collaborator

kwargs are by default optional!
Checkpoint indeed does not support kwargs, at least the way we have formulated it. Will be fixed by #34987

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants