-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] cast scale to float32 in precompute #835
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/835
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7ae2e9e with merge base e2dad4a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -69,7 +69,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |||
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | |||
local_scale_tensor = scale_tensor.to_local() | |||
for i, float8_linear in enumerate(float8_linears): | |||
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] | |||
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precompute should align with Float8Linear
ao/torchao/float8/float8_utils.py
Line 55 in 144445a
return res.to(torch.float32) |
cc @crcrpar that doing |
@@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |||
return | |||
|
|||
# inf-norm is equivalent to max(abs(w)) | |||
max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial | |||
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering why remove dtype=torch.float32
here / why does it matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is mimicing float8 compute when we do abs(max) without upscaling
https://github.com/pytorch/ao/blob/144445a7a8e988059555421caafa17e0c1678053/torchao/float8/float8_utils.py#L101-L102C28
not sure if it improves groudtruth numerics, but at least this brings numeric on par with float8 compute
ao/test/float8/test_fsdp2/test_fsdp2.py
Lines 124 to 132 in 144445a
module = self.init_transformer(weight_tying=weight_tying, dtype=dtype) | |
ref_module = copy.deepcopy(module) | |
float8_linear_config1 = Float8LinearConfig( | |
cast_config_weight=CastConfig(scaling_type=scaling_type_weight), | |
) | |
convert_to_float8_training( | |
ref_module, | |
config=float8_linear_config1, | |
) |
It might be helpful if you can provide some explanation of which tensors' dtypes changed before/after revert. |
I would mainly suggest to not phrase it as regressing numerics. It broke tests, so there was a test regression. However, this is just because there was a numeric mismatch now where the precompute code path does some computations in fp32 where the no-precompute code path does them in bf16. The current change in the PR makes the two the same, but it makes the precompute path less accurate than before. |
good point. I modifed the PR description to focus on unit test fixing. totally get your point that true numerics might be regressing |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
revert a recent PR that breaks unit tests #727
we can revisit if we should apply fp32 upcasting consistently across float8 compute and precompute
It failed unit test at my devgpu but not sure why our CI did not catch it. maybe because of no H100 in CI?
pytest -s test/float8/test_fsdp2/test_fsdp2.py -k test_transformer_parity