FSDP full-precision param_dtype
training with PyTorch < 2.0 triggers FSDP assertion error
#18277
Labels
Milestone
Bug description
When FSDP training with full precision
param_dtype
s (16-mixed
,bf16-mixed
and32-true
configurations) and PyTorch < 2.0, FSDP training will encounter this assertion error.This is because FSDP uses the noneness of
param_dtype
as a proxy for the_uses_param_mixed_precision
property andFSDPPrecisionPlugin
currently sets the defaultparam_dtype
totorch.float32
when training in full precision.I'll be submitting a PR shortly that sets
MixedPrecision
param_dtype
toNone
when FSDP training with full precisionparam_dtype
s and PyTorch < 2.0. Because there is substantial overlap with #18230, I'll be including a fix to that including thelightning_module_state_dict
patch as well.What version are you seeing the problem on?
master
How to reproduce the bug
To reproduce an example of the issue, run
./tests/tests_pytorch/strategies/test_fsdp.py::test_configure_model[32-true-expected_dtype0]
, withoutfast_dev_run
enabled and after patchinglightning_module_state_dict
to allow the FSDP 1.x test to proceed:https://github.com/Lightning-AI/lightning/blob/c83774a1093fab53fef02ae2b824dd85ee21af0a/src/lightning/pytorch/strategies/fsdp.py#L171-L179
Patch the above with:
Error messages and logs
An example of the produced errors:
Environment
Current environment
More info
No response
cc @awaelchli @carmocca
The text was updated successfully, but these errors were encountered: