From c81733f1032a56a817b594c8971a738108ded7d0 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 1 May 2024 20:41:59 -0700 Subject: [PATCH] [PyTorch] Miscellanous fixes for FP8 DPA module (#804) * initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add temporary test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 3 ++- transformer_engine/pytorch/attention.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 40cfdd34b7..caba385d46 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -70,7 +70,8 @@ def reset_global_fp8_state(): def _cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" encoded_version = ext.get_cudnn_version() - major, encoded_version = divmod(encoded_version, 1000) + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3bf4598fc1..af6c151cab 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2711,6 +2711,17 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument + """ + Temporarily remove fused_attention._extra_state as a missing key + when loading older TransformerEngine checkpoints. Will phase out + this hook in TransformerEngine 2.0. + """ + for key in incompatible_keys.missing_keys: + if 'fused_attention._extra_state' in key: + incompatible_keys.missing_keys.remove(key) + self.register_load_state_dict_post_hook(remove_extra_states_check) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], @@ -3063,6 +3074,7 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs) + self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number)