Skip to content

Commit

Permalink
Faster intermediate checkpoints with DCP async save in TorchTune (#2006)
Browse files Browse the repository at this point in the history
Co-authored-by: Saurabh Mishra <[email protected]>
  • Loading branch information
saumishr and Saurabh Mishra authored Dec 13, 2024
1 parent 096881d commit c2c6f4a
Show file tree
Hide file tree
Showing 19 changed files with 1,178 additions and 179 deletions.
157 changes: 50 additions & 107 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
TrainingProgress,
)
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm
Expand Down Expand Up @@ -138,9 +142,11 @@ def __init__(self, cfg: DictConfig) -> None:

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)

# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
Expand Down Expand Up @@ -189,21 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict

def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
Expand Down Expand Up @@ -255,7 +246,8 @@ def setup(self, cfg: DictConfig) -> None:
# log config with parameter override
self._metric_logger.log_config(cfg)

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
# Load the base model
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()

self._compile = cfg.get("compile", False)
self._model = self._setup_model(
Expand All @@ -276,11 +268,36 @@ def setup(self, cfg: DictConfig) -> None:
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
if training.OPT_KEY in checkpoint_dict
else None
),
)

if self._resume_from_checkpoint:
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
# using the DistributedCheckpointer.
# Therefore the recipe needs to load the distributed checkpoint to restore the training
# progress.
if self._enable_async_checkpointing:
try:
checkpoint_dict = (
self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
)
)
except Exception as e:
log.warning(
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
)

# Update the recipe state from the checkpoint state dict.
self._update_recipe_state(checkpoint_dict)

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)

Expand Down Expand Up @@ -547,6 +564,7 @@ def _setup_model(
log,
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)

if self._is_rank_zero:
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -661,95 +679,6 @@ def _setup_data(

return sampler, dataloader

def save_checkpoint(
self,
epoch: int,
) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Model weights with key training.MODEL_KEY
- Relevant recipe state if training is not complete
Checkpointer will save the model weights and recipe state in
different checkpoint files. To correctly resume training from an intermediate checkpoint,
the model weights and recipe state must be provided.
"""
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs

utils.log_rank_zero(
log,
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
)
start = time.perf_counter()

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)

utils.log_rank_zero(
log,
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
)

if intermediate_checkpoint:
start = time.perf_counter()
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
)
utils.log_rank_zero(
log,
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs",
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file

if self._is_rank_zero:
start = time.perf_counter()
checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
checkpoint_dict.update(
{
training.OPT_KEY: opt_state_dict,
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)

self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")

torch.distributed.barrier()

def train(self) -> None:
"""
The core training loop.
Expand Down Expand Up @@ -922,7 +851,21 @@ def train(self) -> None:
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
self._checkpoint_client.save_checkpoint(
model=self._model,
optimizer=(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
training_progress=TrainingProgress(
seed=self.seed,
epochs_run=self.epochs_run,
total_epochs=self.total_epochs,
max_steps_per_epoch=self.max_steps_per_epoch,
),
epoch=curr_epoch,
)

self._profiler.stop()

Expand Down
2 changes: 1 addition & 1 deletion recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
8 changes: 4 additions & 4 deletions recipes/ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,22 +377,22 @@ def _setup_checkpointers(

policy_checkpointer = config.instantiate(
policy_cfg,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)

ref_policy_checkpointer = config.instantiate(
ref_policy_cfg,
resume_from_checkpoint=False,
should_load_recipe_state=False,
)

value_checkpointer = config.instantiate(
value_cfg,
resume_from_checkpoint=False,
should_load_recipe_state=False,
)

reward_checkpointer = config.instantiate(
reward_cfg,
resume_from_checkpoint=False,
should_load_recipe_state=False,
)

return (
Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
Loading

0 comments on commit c2c6f4a

Please sign in to comment.