diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 27c39f66a..48eac72ee 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1729,6 +1729,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # else: # param.unsharded_main_grad.add_(param.grad.data) param.unsharded_main_grad = new_unsharded_main_grad_in_fp32 + # Clean up accumulated grads between data batches + self._fsdp_wrapped_module.fp32_grads = [] param.grad = None if not self._require_backward_grad_sync: