From efd9bb94fdbee8dba55e818f3d51949ea5f81f11 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:59:19 -0700 Subject: [PATCH] [float8] all-reduce amax on dp mesh instead of global pg (#933) * [float8] all-reduce amax on dp mesh instead of global pg Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * liner Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * improve comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * move hp tensor inside if Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_fsdp2/test_fsdp2.py | 32 +++++++++++++++++++++++++- torchao/float8/float8_scaling_utils.py | 3 ++- torchao/float8/float8_utils.py | 14 +++++++---- torchao/float8/fsdp_utils.py | 1 + 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index ecde051e36..1ad5586513 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -17,10 +17,12 @@ import torch.nn as nn from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import DTensor, init_device_mesh +from torchao.float8.float8_tensor import GemmInputRole from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -293,6 +295,34 @@ def _get_curr_active_memory_mb(self) -> int: return round(mem_stats["active_bytes.all.current"] / 1e6) +class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common): + @property + def world_size(self) -> int: + return 4 + + def test_amax_allreduce_device_mesh(self): + dp_size = 2 + pp_size = self.world_size // dp_size + global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp")) + dp_mesh = global_mesh["dp"] + pp_mesh = global_mesh["pp"] + + if self.rank in [0, 1]: + # rank 0 and 1 are the 1st stage in the pipeline + # rank 2 and 4 are doing nothing but waiting for the 1st stage + torch.manual_seed(42 + self.rank) + hp_tensor = torch.randn(768, 32, device="cuda") + float8_tensor = hp_tensor_to_float8_dynamic( + hp_tensor, + torch.float8_e4m3fn, + Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + ), + gemm_input_role=GemmInputRole.WEIGHT, + reduce_amax=True, + device_mesh=dp_mesh + ) + class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): @property def world_size(self) -> int: diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index d2ae896320..e9e1951763 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + device_mesh = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -52,7 +53,7 @@ def hp_tensor_to_float8_dynamic( """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 535c870890..d8ad315f16 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -100,23 +100,29 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: +def tensor_to_amax( + x: torch.Tensor, reduce_amax: bool = False, device_mesh=None +) -> torch.Tensor: amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will # happen elsewhere. if reduce_amax and dist.is_initialized(): - dist.all_reduce(amax, op=dist.ReduceOp.MAX) + pg = device_mesh.get_group() if device_mesh is not None else None + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg) return amax @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + device_mesh=None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 19386d932b..d3c0f73c6c 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -216,6 +216,7 @@ def fsdp_pre_all_gather(self, mesh): self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, + device_mesh=mesh, ) return (float8_tensor._data,), (float8_tensor._scale,)