Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LR Scheduler to full finetune distributed #2017

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
def setup(self, cfg: DictConfig) -> None:
"""
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, sampler, and dataloader.
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
"""
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)
Expand Down Expand Up @@ -332,6 +332,13 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand All @@ -341,6 +348,55 @@ def setup(self, cfg: DictConfig) -> None:
(cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
)

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
"""
Set up the learning rate scheduler based on the provided configuration.
It supports both standard optimization and optimizer-in-backward cases.

Args:
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
num_training_steps (int): The total number of training steps.
last_epoch (int): The index of the last epoch.

Returns:
lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
"""
if cfg_lr_scheduler is None:
if self._is_rank_zero:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

if self._is_rank_zero:
log.info("Learning rate scheduler is initialized.")

return lr_scheduler

def _setup_profiler(
self, cfg_profiler: Optional[DictConfig] = None
) -> Union[torch.profiler.profile, DummyProfiler]:
Expand Down Expand Up @@ -818,6 +874,10 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

# Step the learning rate scheduler
if self._lr_scheduler is not None:
self._lr_scheduler.step()

loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
Expand Down
Loading