Skip to content

Commit

Permalink
clean up accumulated fp32 grads between data batches
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed May 2, 2024
1 parent 4b5abe2 commit c97bfd9
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# else:
# param.unsharded_main_grad.add_(param.grad.data)
param.unsharded_main_grad = new_unsharded_main_grad_in_fp32
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = []
param.grad = None

if not self._require_backward_grad_sync:
Expand Down

0 comments on commit c97bfd9

Please sign in to comment.