From dbaff7e138d5e3362cf8668f04009c579fc14115 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 2 Dec 2021 02:51:09 +0300 Subject: [PATCH] Hotfix: offload_optimizer in load_state_from_peers --- hivemind/optim/experimental/state_averager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/experimental/state_averager.py index fb204de82..a75f6006b 100644 --- a/hivemind/optim/experimental/state_averager.py +++ b/hivemind/optim/experimental/state_averager.py @@ -631,7 +631,8 @@ def load_state_from_peers(self, **kwargs): Attempt to download the latest optimizer state from peers and update trainer parameters/statistics. :returns: whether or the averager succeeded in loading parameters """ - main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors)) + opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"]) + main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors)) num_parameters_and_extras = len(main_parameters_and_extras) loaded_state = super().load_state_from_peers(**kwargs) @@ -661,6 +662,8 @@ def load_state_from_peers(self, **kwargs): if self.offload_optimizer: self._apply_optimizer_parameters_() + if not self.reuse_tensors: + self._load_local_tensors_into_averager_() self.local_epoch = metadata["epoch"] self._update_scheduler()