Skip to content

Commit

Permalink
Add LR Scheduler to full finetune distributed (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
parthsarthi03 authored Nov 20, 2024
1 parent a4a74a0 commit fcd400f
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
def setup(self, cfg: DictConfig) -> None:
"""
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, sampler, and dataloader.
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
"""
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)
Expand Down Expand Up @@ -329,6 +329,13 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand All @@ -338,6 +345,55 @@ def setup(self, cfg: DictConfig) -> None:
(cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
)

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
"""
Set up the learning rate scheduler based on the provided configuration.
It supports both standard optimization and optimizer-in-backward cases.
Args:
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
num_training_steps (int): The total number of training steps.
last_epoch (int): The index of the last epoch.
Returns:
lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
"""
if cfg_lr_scheduler is None:
if self._is_rank_zero:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

if self._is_rank_zero:
log.info("Learning rate scheduler is initialized.")

return lr_scheduler

def _setup_profiler(
self, cfg_profiler: Optional[DictConfig] = None
) -> Union[torch.profiler.profile, DummyProfiler]:
Expand Down Expand Up @@ -813,6 +869,10 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

# Step the learning rate scheduler
if self._lr_scheduler is not None:
self._lr_scheduler.step()

loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
Expand Down

0 comments on commit fcd400f

Please sign in to comment.