-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove old-style fsdp/ddp #1640
Conversation
d19a76d
to
a477a7c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm speculating that the failure of fsdp no_sync parity tests is caused by the mismatch between the parameters that _sync_grads
collects and ones that _stash_grad_for_fsdp_prim_impl
attaches the unsharded grads.
lightning-thunder/thunder/distributed/__init__.py
Lines 157 to 159 in 5e18c2e
params_with_grad = tuple(filter(lambda p: hasattr(p, "_thunder_fsdp_unsharded_grad"), module.parameters())) if not params_with_grad: return lightning-thunder/thunder/executors/torchex.py
Lines 2116 to 2128 in 5e18c2e
def _stash_grad_for_fsdp_prim_impl( grad: torch.Tensor, param_fqn: str, compile_data: CompileData, ) -> None: grad_name = "_thunder_fsdp_unsharded_grad" param = compile_data.fn.get_parameter(param_fqn) if torch.is_tensor(unsharded_grad := getattr(param, grad_name, None)): unsharded_grad += grad else: setattr(param, grad_name, grad) return grad
…ter sharded and sync'ed grads (#1643)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamped!
Co-authored-by: Masaki Kozuki <[email protected]>
cc: @crcrpar
Plan:
Replace with transform-aware module functions:
Removed as not (currently) applicable
--> We are not re-entrant.
Fixed