Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] all-reduce amax on dp mesh instead of global pg #933

Merged
merged 9 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import torch.nn as nn
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import DTensor, init_device_mesh
from torchao.float8.float8_tensor import GemmInputRole
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
Expand Down Expand Up @@ -293,6 +295,34 @@ def _get_curr_active_memory_mb(self) -> int:
return round(mem_stats["active_bytes.all.current"] / 1e6)


class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common):
@property
def world_size(self) -> int:
return 4

def test_amax_allreduce_device_mesh(self):
dp_size = 2
pp_size = self.world_size // dp_size
global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp"))
dp_mesh = global_mesh["dp"]
pp_mesh = global_mesh["pp"]

if self.rank in [0, 1]:
# rank 0 and 1 are the 1st stage in the pipeline
# rank 2 and 4 are doing nothing but waiting for the 1st stage
torch.manual_seed(42 + self.rank)
hp_tensor = torch.randn(768, 32, device="cuda")
float8_tensor = hp_tensor_to_float8_dynamic(
hp_tensor,
torch.float8_e4m3fn,
Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
),
gemm_input_role=GemmInputRole.WEIGHT,
reduce_amax=True,
device_mesh=dp_mesh
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the PR, the unit test will throw NCCL timeout error because rank 2 and 4 did not participate in amax all-reduce

)

class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
@property
def world_size(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
device_mesh = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeviceMesh only exists in newer pytorch (2.4+). I tried to annotate device_mesh with string "torch.distributed.DeviceMesh". but some CI run torch older pytorch 2.2 and complain about not being able to find DeviceMesh. I gave up on type hint. hopefully device_mesh itself is a strong indicator of DeviceMesh

Copy link
Contributor

@vkuzo vkuzo Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, we are intending to only test float8 code on the latest nightly to move fast and minimize maintenance load. Do you know which CI is giving you issues? We have disabled CI on older PT versions per-test-file using version checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I belive it's the following one. My guess - the function signature is checked at import time. For torch==2.2.2, if one of the test imported float8_utils.py or float8_scaling_utils.py, it will throw error

[test (CUDA 2.2.2, linux.g5.12xlarge.nvidia.gpu, torch==2.2.2 "numpy<2" , cuda, 12.1) / linux-job](https://github.com/pytorch/ao/actions/runs/11023026757/job/30613550360?pr=933#logs)

) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -52,7 +53,7 @@ def hp_tensor_to_float8_dynamic(
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh)
return hp_tensor_and_scale_to_float8(
hp_tensor,
scale,
Expand Down
14 changes: 10 additions & 4 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,29 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
def tensor_to_amax(
x: torch.Tensor, reduce_amax: bool = False, device_mesh=None
) -> torch.Tensor:
amax = torch.max(torch.abs(x))

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
# happen elsewhere.
if reduce_amax and dist.is_initialized():
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
pg = device_mesh.get_group() if device_mesh is not None else None
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)

return amax


@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
1 change: 1 addition & 0 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def fsdp_pre_all_gather(self, mesh):
self._linear_mm_config,
reduce_amax=True,
gemm_input_role=GemmInputRole.WEIGHT,
device_mesh=mesh,
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand Down
Loading