-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Miscellanous fixes for FP8 DPA module #804
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
…old checkpoints Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
I looked at this further post our sync, looks like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment
We do use |
Signed-off-by: Charlene Yang <[email protected]>
…with fp8_group Signed-off-by: Charlene Yang <[email protected]>
/te-ci pytorch |
With #575, the amax reduction is handled in the |
@cyanguwa Regarding checkpoints compatibility: not requiring |
…xtra_state Signed-off-by: Charlene Yang <[email protected]>
@mikolajblaz I've moved |
Signed-off-by: cyanguwa <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
…re_attention; keep the test Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@ksivaman could you please help take another look? |
/te-ci pytorch |
I had some discussion with mikolajblaz offline and we decided to not pursue the move from |
Signed-off-by: Charlene Yang <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <[email protected]> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <[email protected]> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <[email protected]> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <[email protected]> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <[email protected]> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <[email protected]> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <[email protected]> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <[email protected]> * add temporary test Signed-off-by: Charlene Yang <[email protected]> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <[email protected]> * remove the test Signed-off-by: Charlene Yang <[email protected]> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]>
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <[email protected]> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <[email protected]> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <[email protected]> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <[email protected]> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <[email protected]> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <[email protected]> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <[email protected]> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <[email protected]> * add temporary test Signed-off-by: Charlene Yang <[email protected]> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <[email protected]> * remove the test Signed-off-by: Charlene Yang <[email protected]> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <[email protected]> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <[email protected]> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <[email protected]> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <[email protected]> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <[email protected]> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <[email protected]> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <[email protected]> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <[email protected]> * add temporary test Signed-off-by: Charlene Yang <[email protected]> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <[email protected]> * remove the test Signed-off-by: Charlene Yang <[email protected]> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
This PR
FusedAttention
has been subclassed withTEBaseModule
, and an_extra_state
has been added to the module'sstate_dict
._extra_state
contains FP8 meta data, but due to the subclassing, the addition of_extra_state
tostate_dict
happens regardless of FP8 training or F16 training. This PR allows users to load older checkpoints (which do not have_extra_state
forFusedAttention
), as well as save and load new checkpoints as usual (which will contain_extra_state
forFusedAttention
).