diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 98d34b5f94..29303063f6 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -258,7 +258,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def setup(self, cfg: DictConfig) -> None: """ Setup the recipe. This includes training state (if resume_from_checkpoint is True), - model, tokenizer, loss, optimizer, sampler, and dataloader. + model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. """ if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) @@ -332,6 +332,13 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch + # Setup lr scheduler + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.get("lr_scheduler", None), + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) @@ -341,6 +348,55 @@ def setup(self, cfg: DictConfig) -> None: (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device ) + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: Optional[DictConfig], + num_training_steps: int, + last_epoch: int, + ) -> Optional[Optimizer]: + """ + Set up the learning rate scheduler based on the provided configuration. + It supports both standard optimization and optimizer-in-backward cases. + + Args: + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + num_training_steps (int): The total number of training steps. + last_epoch (int): The index of the last epoch. + + Returns: + lr_scheduler (Optional[Optimizer]): The learning rate scheduler. + """ + if cfg_lr_scheduler is None: + if self._is_rank_zero: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + + if self._optimizer_in_bwd: + # Use the first optimizer from the wrapper to represent the learning rate + optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) + else: + # Standard case: use the single optimizer + optimizer = self._optimizer + + # Instantiate the learning rate scheduler + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + if self._optimizer_in_bwd: + # Modify the scheduler for optimizer_in_bwd case + self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) + + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + + return lr_scheduler + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -818,6 +874,10 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 + # Step the learning rate scheduler + if self._lr_scheduler is not None: + self._lr_scheduler.step() + loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description(