From 85d03de43160328eaf350e7ec3877d3d7b57da50 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Tue, 10 Sep 2024 22:52:16 -0700 Subject: [PATCH] [FSDP2] cast scale to float32 in precompute (#835) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/fsdp_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 45e83b7cf5..7ec60c795b 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: return # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce @@ -69,7 +69,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local() for i, float8_linear in enumerate(float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32) # FSDP pads its local tensor on dim-0. The subclass should be preserved such