From c0b0731d819d621698673e10d33ae585de56f1df Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 23 Aug 2024 02:59:58 +0900 Subject: [PATCH] Specify output dtype to `torch.float32` in `_foreach_norm` (#727) one less kernel Signed-off-by: Masaki Kozuki --- torchao/float8/fsdp_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index bbb954eca1..81859de4bb 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: return # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce @@ -67,7 +67,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate if amax_tensor.dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - local_scale_tensor = scale_tensor.to_local().to(dtype=torch.float32) + local_scale_tensor = scale_tensor.to_local() for i, float8_linear in enumerate(float8_linears): float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]