Skip to content

Commit

Permalink
Specify output dtype to torch.float32 in _foreach_norm (#727)
Browse files Browse the repository at this point in the history
one less kernel

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Aug 22, 2024
1 parent 8002099 commit c0b0731
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
return

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial
amax_tensor = torch.stack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
local_scale_tensor = scale_tensor.to_local().to(dtype=torch.float32)
local_scale_tensor = scale_tensor.to_local()
for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]

Expand Down

0 comments on commit c0b0731

Please sign in to comment.