diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index bb58ca57fef..cd74b68e24d 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -674,7 +674,15 @@ def prepare(self, *args): return result if len(result) > 1 else result[0] - def prepare_model(self, model): + def prepare_model(self, model: torch.nn.Module): + """ + Prepares a PyTorch model for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + model (`torch.nn.Module`): + A PyTorch model to prepare + """ self._models.append(model) if self.device_placement and self.distributed_type != DistributedType.FSDP: model = model.to(self.device) @@ -886,7 +894,15 @@ def _prepare_deepspeed(self, *args): ) return tuple(result) - def prepare_data_loader(self, data_loader): + def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader): + """ + Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + data_loader (`torch.utils.data.DataLoader`): + A vanilla PyTorch DataLoader to prepare + """ return prepare_data_loader( data_loader, self.device, @@ -898,12 +914,28 @@ def prepare_data_loader(self, data_loader): dispatch_batches=self.dispatch_batches, ) - def prepare_optimizer(self, optimizer): + def prepare_optimizer(self, optimizer: torch.optim.Optimizer): + """ + Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + optimizer (`torch.optim.Optimizer`): + A vanilla PyTorch optimizer to prepare + """ optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler) self._optimizers.append(optimizer) return optimizer - def prepare_scheduler(self, scheduler): + def prepare_scheduler(self, scheduler: torch.optim.lr_scheduler._LRScheduler): + """ + Prepares a PyTorch Scheduler for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + scheduler (`torch.optim.lr_scheduler._LRScheduler`): + A vanilla PyTorch scheduler to prepare + """ # We try to find the optimizer associated with `scheduler`, the default is the full list. optimizer = self._optimizers for opt in self._optimizers: @@ -1133,7 +1165,7 @@ def init_trackers(self, project_name: str, config: Optional[dict] = None, init_k Optional starting configuration to be logged. init_kwargs (`dict`, *optional*): A nested dictionary of kwargs to be passed to a specific tracker's `__init__` function. Should be - formatted like this: + formatted like so: ```python {"wandb": {"tags": ["tag_a", "tag_b"]}} ``` @@ -1182,7 +1214,7 @@ def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dic The run step. If included, the log will be affiliated with this step. log_kwargs (`dict`, *optional*): A nested dictionary of kwargs to be passed to a specific tracker's `log` function. Should be formatted - like this: + like so: ```python {"wandb": {"tags": ["tag_a", "tag_b"]}} ``` @@ -1193,7 +1225,8 @@ def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dic @on_main_process def end_training(self): """ - Runs any special end training behaviors, such as stopping trackers on the main process only. + Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be + called at the end of your script if using experiment tracking. """ for tracker in self.trackers: tracker.finish() @@ -1383,6 +1416,15 @@ def _get_devices(self, *args): return (model_device, optimizer_device) def get_state_dict(self, model, unwrap=True): + """ + Returns the state dictionary of a model sent through [`Accelerator.prepare`] in full precision + + Args: + model (`torch.nn.Module`): + A PyTorch model sent through [`Accelerator.prepare`] + unwrap (`bool`, *optional*, defaults to True): + Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict + """ is_zero_3 = False if self.distributed_type == DistributedType.DEEPSPEED: is_zero_3 = self.deepspeed_config["zero_optimization"]["stage"] == 3