diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 35afd6c5d9a..5aea696b492 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -446,16 +446,13 @@ def _prepare_one(self, obj, first_pass=False): if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj) elif isinstance(obj, torch.nn.Module): - self._models.append(obj) return self.prepare_model(obj) elif isinstance(obj, torch.optim.Optimizer): optimizer = self.prepare_optimizer(obj) - self._optimizers.append(optimizer) return optimizer # Second pass of preparation: LR scheduler (which need the full list of optimizers) elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler): scheduler = self.prepare_scheduler(obj) - self._schedulers.append(scheduler) return scheduler # Return the unprocessed object if previous criteria was not met return obj @@ -570,6 +567,7 @@ def prepare(self, *args): return result if len(result) > 1 else result[0] def prepare_model(self, model): + self._models.append(model) if self.device_placement and self.distributed_type != DistributedType.FSDP: model = model.to(self.device) if self.distributed_type == DistributedType.MULTI_GPU: @@ -782,7 +780,9 @@ def prepare_data_loader(self, data_loader): ) def prepare_optimizer(self, optimizer): - return AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler) + optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler) + self._optimizers.append(optimizer) + return optimizer def prepare_scheduler(self, scheduler): # We try to find the optimizer associated with `scheduler`, the default is the full list. @@ -791,13 +791,14 @@ def prepare_scheduler(self, scheduler): if getattr(scheduler, "optimizer", None) == opt.optimizer: optimizer = opt break - - return AcceleratedScheduler( + scheduler = AcceleratedScheduler( scheduler, optimizer, step_with_optimizer=self.step_scheduler_with_optimizer, split_batches=self.split_batches, ) + self._schedulers.append(scheduler) + return scheduler def backward(self, loss, **kwargs): """