Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Miscellanous fixes for FP8 DPA module #804

Merged
merged 23 commits into from
May 2, 2024

Conversation

cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Apr 24, 2024

This PR

  • fixes cuDNN version extraction in unit tests due to versioning difference between cuDNN pre-9.0 and post-9.0,
  • allows for better compatibility with older checkpoints (pre-TE 1.6). Since TE 1.6, FusedAttention has been subclassed with TEBaseModule, and an _extra_state has been added to the module's state_dict. _extra_state contains FP8 meta data, but due to the subclassing, the addition of _extra_state to state_dict happens regardless of FP8 training or F16 training. This PR allows users to load older checkpoints (which do not have _extra_state for FusedAttention), as well as save and load new checkpoints as usual (which will contain _extra_state for FusedAttention).

@ksivaman
Copy link
Member

I looked at this further post our sync, looks like tp_size/tp_group aren't used at all at the DPA/FusedAttention level. Can we simply remove/deprecate them? @cyanguwa

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Apr 24, 2024

I looked at this further post our sync, looks like tp_size/tp_group aren't used at all at the DPA/FusedAttention level. Can we simply remove/deprecate them? @cyanguwa

We do use tp_group_initialized in prepare_forward. Also, if we don't keep track of tp_size/tp_group, how do we manage amax reduction for TP groups, or does fp8.py already handle that? @ksivaman

@cyanguwa cyanguwa requested a review from ksivaman April 24, 2024 20:13
@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@ksivaman
Copy link
Member

With #575, the amax reduction is handled in the reduce_and_update_fp8_tensors function using the fp8_group passed into the autocast. So we don't store the tensor parallel group to handle it separately.

@mikolajblaz
Copy link

@cyanguwa Regarding checkpoints compatibility: not requiring _extra_state in state dict is good, although I believe it doesn't solve the problem on the application side in case of switching from one attention implementation to another. Would such interoperability of attention layers be possible? It would require _extra_state to live on the same level as the default attention implementation.

@cyanguwa
Copy link
Collaborator Author

@mikolajblaz I've moved core_attention.fused_attention._extra_state to core_attention._extra_state in b94a1ee. Let me know if you have any thoughts/comments. Thanks.

@cyanguwa cyanguwa requested a review from ptrendx April 30, 2024 21:45
@cyanguwa
Copy link
Collaborator Author

@ksivaman could you please help take another look?

@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

I had some discussion with mikolajblaz offline and we decided to not pursue the move from core_attention.fused_attention._extra_state to core_attention._extra_state. The possible solutions all look unclean and may not even guarantee the loading of checkpoints is correct. This is because PyTorch relies heavily on the module structure and FusedAttention is currently a submodule of DotProductAttention, so it's very hard to manipulate state_dict()/load_state_dict() calls to get around this hierarchical structure. Will consider this another time.

https://github.com/pytorch/pytorch/blob/74b7c56517f97c5d813620da9a479417a564e8b4/torch/nn/modules/module.py#L2164

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented May 1, 2024

/te-ci pytorch

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented May 1, 2024

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit 6459fd8 into NVIDIA:main May 2, 2024
18 of 20 checks passed
ksivaman pushed a commit that referenced this pull request May 2, 2024
* initialize tp_group for FP8 DPA

Signed-off-by: Charlene Yang <[email protected]>

* fix cuDNN version in unit tests for cuDNN v9

Signed-off-by: Charlene Yang <[email protected]>

* add hook to ignore missing fused_attn._extra_states if training from old checkpoints

Signed-off-by: Charlene Yang <[email protected]>

* remove test and redundant implementation from last commit

Signed-off-by: Charlene Yang <[email protected]>

* remove warning message and replace with docstring

Signed-off-by: Charlene Yang <[email protected]>

* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group

Signed-off-by: Charlene Yang <[email protected]>

* move core_attention.fused_attention._extra_state to core_attention._extra_state

Signed-off-by: Charlene Yang <[email protected]>

* simplify post_state_dict_hooks between FU and DPA

Signed-off-by: Charlene Yang <[email protected]>

* add temporary test

Signed-off-by: Charlene Yang <[email protected]>

* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test

Signed-off-by: Charlene Yang <[email protected]>

* remove the test

Signed-off-by: Charlene Yang <[email protected]>

* disable pylint self arg for hook which is required by hook

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
@ksivaman ksivaman added the 1.6.0 label May 2, 2024
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 16, 2024
* initialize tp_group for FP8 DPA

Signed-off-by: Charlene Yang <[email protected]>

* fix cuDNN version in unit tests for cuDNN v9

Signed-off-by: Charlene Yang <[email protected]>

* add hook to ignore missing fused_attn._extra_states if training from old checkpoints

Signed-off-by: Charlene Yang <[email protected]>

* remove test and redundant implementation from last commit

Signed-off-by: Charlene Yang <[email protected]>

* remove warning message and replace with docstring

Signed-off-by: Charlene Yang <[email protected]>

* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group

Signed-off-by: Charlene Yang <[email protected]>

* move core_attention.fused_attention._extra_state to core_attention._extra_state

Signed-off-by: Charlene Yang <[email protected]>

* simplify post_state_dict_hooks between FU and DPA

Signed-off-by: Charlene Yang <[email protected]>

* add temporary test

Signed-off-by: Charlene Yang <[email protected]>

* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test

Signed-off-by: Charlene Yang <[email protected]>

* remove the test

Signed-off-by: Charlene Yang <[email protected]>

* disable pylint self arg for hook which is required by hook

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
* initialize tp_group for FP8 DPA

Signed-off-by: Charlene Yang <[email protected]>

* fix cuDNN version in unit tests for cuDNN v9

Signed-off-by: Charlene Yang <[email protected]>

* add hook to ignore missing fused_attn._extra_states if training from old checkpoints

Signed-off-by: Charlene Yang <[email protected]>

* remove test and redundant implementation from last commit

Signed-off-by: Charlene Yang <[email protected]>

* remove warning message and replace with docstring

Signed-off-by: Charlene Yang <[email protected]>

* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group

Signed-off-by: Charlene Yang <[email protected]>

* move core_attention.fused_attention._extra_state to core_attention._extra_state

Signed-off-by: Charlene Yang <[email protected]>

* simplify post_state_dict_hooks between FU and DPA

Signed-off-by: Charlene Yang <[email protected]>

* add temporary test

Signed-off-by: Charlene Yang <[email protected]>

* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test

Signed-off-by: Charlene Yang <[email protected]>

* remove the test

Signed-off-by: Charlene Yang <[email protected]>

* disable pylint self arg for hook which is required by hook

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants