diff --git a/docs/source/recipes/dpo.rst b/docs/source/recipes/dpo.rst index c4854ef81e..efe2a0e126 100644 --- a/docs/source/recipes/dpo.rst +++ b/docs/source/recipes/dpo.rst @@ -13,7 +13,7 @@ To see the best results when using this recipe, it may be helpful to first fine- on-distribution for the domain you're interested in. To do this, check out our other fine-tuning recipes in the :ref:`recipe overview ` which support a variety of SFT paradigms. -After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B: +After supervised fine-tuning, here is an example of using either LoRA-based finetuning, or full-finetuning Llama 3.1 8B with DPO: .. note:: @@ -27,12 +27,15 @@ After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B: --ignore-patterns "original/consolidated.00.pth" --HF_TOKEN - # run on a single device + # run lora dpo on a single device tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device - # run on two gpus + # run lora dpo on two gpus tune run --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo + # run full dpo on four gpus + tune run --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo + It's easy to get started with this recipe with your dataset of choice, including custom local datasets, and datasets from Hugging Face. Check out our primer on :ref:`preference datasets ` to see how to do this. diff --git a/recipes/configs/llama3_1/8B_full_dpo.yaml b/recipes/configs/llama3_1/8B_full_dpo.yaml new file mode 100644 index 0000000000..3dcdfdd46b --- /dev/null +++ b/recipes/configs/llama3_1/8B_full_dpo.yaml @@ -0,0 +1,98 @@ +# Config for multi-device full DPO alignment in full_dpo_distributed.py +# using a Llama3.1 8B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo checkpointer.checkpoint_dir= +# + +output_dir: /tmp/torchtune/llama3_1_8B/full_dpo # /tmp may be deleted by your system. Change it to your preference. + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: 1024 # higher increases memory + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 +resume_from_checkpoint: False + +# The ref_checkpointer should always point to the original weights. +ref_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchange_paired_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.05 + lr: 2e-5 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 20 + +loss: + _component_: torchtune.rlhf.loss.DPOLoss + beta: 0.05 + label_smoothing: 0 + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 8 # Use to increase effective batch size +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py new file mode 100644 index 0000000000..5d462aa473 --- /dev/null +++ b/recipes/full_dpo_distributed.py @@ -0,0 +1,1094 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.distributed import destroy_process_group, init_process_group +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, rlhf, training, utils +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.lr_schedulers import get_lr +from torchtune.utils import get_world_size_and_rank +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class FullDPORecipeDistributed(FTRecipeInterface): + """ + Full DPO finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + The following losses are supported in this recipe: + - :class:`~torchtune.modules.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). + - :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). + + + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + _, rank = get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def _load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + should_load_recipe_state=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _load_ref_checkpoint(self, cfg_ref_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the reference model checkpoint state from file. + """ + _ref_checkpointer = config.instantiate( + cfg_ref_checkpointer, should_load_recipe_state=False + ) + checkpoint_dict = _ref_checkpointer.load_checkpoint() + return checkpoint_dict[training.MODEL_KEY] + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + # Load the base model + checkpoint_dict = self._load_checkpoint(cfg.checkpointer) + ref_checkpoint_dict = self._load_ref_checkpoint(cfg.ref_checkpointer) + + self._compile = cfg.get("compile", False) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ) + + # TODO (@SalmanMohammadi) investigate TP for ref model + self._ref_model = self._setup_reference_model( + cfg_model=cfg.model, + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=ref_checkpoint_dict, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) + + if self._is_rank_zero: + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs + # between ref policy and current policy + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_reference_model( + self, + cfg_model: DictConfig, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ) -> nn.Module: + """ + Similar to `self._setup_model`: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + + Additionally, since the reference model is inference-only, we omit some training-specific + optimizations. + """ + + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating reference model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + utils.log_rank_zero( + log, + f"Instantiating reference model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) + + if self._is_rank_zero: + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs + # between ref policy and current policy + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + + for p in model.parameters(): + p.requires_grad = False + + model.eval() + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._model, + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + utils.log_rank_zero(log, "In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + self._model, + optimizer, + opt_state_dict, + self._device, + ) + + utils.log_rank_zero(log, "Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=partial( + padded_collate_dpo, + padding_idx=self._tokenizer.pad_id, + ignore_idx=CROSS_ENTROPY_IGNORE_IDX, + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.gather_cpu_state_dict( + self._model, + self._is_rank_zero, + device=self._device, + ) + + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + start = time.perf_counter() + utils.log_rank_zero(log, "Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._model, + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + self._model, opt, self._is_rank_zero, device=self._device + ) + utils.log_rank_zero( + log, + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + if self._is_rank_zero: + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def concatenated_forward( + self, + model: nn.Module, + batch: Tuple[torch.Tensor, torch.Tensor], + activations_handling: Optional[bool] = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward pass of the model with chosen and rejected samples concatenated. + + Args: + model (nn.Module): The model to be used for the forward pass. + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. + + Returns: + Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + """ + concatenated_input_ids, concatenated_labels = batch + concatenated_input_ids = concatenated_input_ids.to(self._device) + concatenated_labels = concatenated_labels.to(self._device) + + # formed by concatenating an equal number of "chosen" and "rejected". + len_chosen = concatenated_input_ids.shape[0] // 2 + + if activations_handling: + with self.activations_handling_ctx: + all_logits = model(concatenated_input_ids) + else: + all_logits = model(concatenated_input_ids) + + chosen_log_probs = rlhf.get_batch_log_probs( + all_logits[:len_chosen], + concatenated_labels[:len_chosen], + return_average_logprobs=False, + ) + + rejected_log_probs = rlhf.get_batch_log_probs( + all_logits[len_chosen:], + concatenated_labels[len_chosen:], + return_average_logprobs=False, + ) + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + # clean up before training begins + training.cleanup_before_training() + + world_size, rank = get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + + # Running metrics + running_loss = 0 + running_metrics = { + "rewards/chosen": 0, + "rewards/rejected": 0, + "rewards/accuracies": 0, + "log_probs/chosen": 0, + "log_probs/rejected": 0, + "logits/chosen": 0, + "logits/rejected": 0, + } + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # batch is input_ids, labels + num_tokens += torch.tensor(batch[0].numel()) + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self._model, batch) + + policy_chosen_logits_mean = policy_chosen_logits.detach().mean() + policy_rejected_logits_mean = policy_rejected_logits.detach().mean() + + # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging + del policy_chosen_logits, policy_rejected_logits + + with torch.no_grad(): + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + reference_chosen_logits, + reference_rejected_logits, + ) = self.concatenated_forward( + self._ref_model, batch, activations_handling=False + ) + + del reference_chosen_logits, reference_rejected_logits + + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + reference_chosen_log_probs, + reference_rejected_log_probs, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + loss = loss.mean() + + loss = loss / self._gradient_accumulation_steps + + # Update running metrics + running_loss += loss + scaling_factor = ( + 1 / self._gradient_accumulation_steps + ) # to average out between grad_acc steps + running_metrics["rewards/chosen"] += ( + scaling_factor * chosen_rewards.mean() + ) + running_metrics["rewards/rejected"] += ( + scaling_factor * rejected_rewards.mean() + ) + running_metrics["rewards/accuracies"] += ( + scaling_factor * reward_accuracies.mean() + ) + running_metrics["log_probs/chosen"] += ( + scaling_factor * policy_chosen_log_probs.detach().mean() + ) + running_metrics["log_probs/rejected"] += ( + scaling_factor * policy_rejected_log_probs.detach().mean() + ) + running_metrics["logits/chosen"] += ( + scaling_factor * policy_chosen_logits_mean + ) + running_metrics["logits/rejected"] += ( + scaling_factor * policy_rejected_logits_mean + ) + + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + # Accumulate running metrics across all devices + torch.distributed.all_reduce( + running_loss, op=torch.distributed.ReduceOp.AVG + ) + torch.distributed.all_reduce(num_tokens) + + for key in running_metrics: + torch.distributed.all_reduce( + running_metrics[key], op=torch.distributed.ReduceOp.AVG + ) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ).full_tensor() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + # Step the learning rate scheduler + if self._lr_scheduler is not None: + self._lr_scheduler.step() + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + "rewards/chosen": running_metrics["rewards/chosen"].cpu(), + "rewards/rejected": running_metrics[ + "rewards/rejected" + ].cpu(), + "rewards/accuracies": running_metrics[ + "rewards/accuracies" + ].cpu(), + "rewards/margins": ( + running_metrics["rewards/chosen"] + - running_metrics["rewards/rejected"] + ).cpu(), + "log_probs/chosen": running_metrics[ + "log_probs/chosen" + ].cpu(), + "log_probs/rejected": running_metrics[ + "log_probs/rejected" + ].cpu(), + "logits/chosen": running_metrics["logits/chosen"].cpu(), + "logits/rejected": running_metrics["logits/rejected"].cpu(), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + running_metrics = {key: 0 for key in running_metrics} + num_tokens = 0 + + t0 = time.perf_counter() + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + init_process_group("cuda:nccl,cpu:gloo") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + + config.log_config(recipe_name="FullDPORecipeDistributed", cfg=cfg) + + recipe = FullDPORecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index d54adc2cf4..396b599067 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -651,14 +651,25 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = utils.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() + + # Running metrics running_loss = 0 + running_metrics = { + "rewards/chosen": 0, + "rewards/rejected": 0, + "rewards/accuracies": 0, + "log_probs/chosen": 0, + "log_probs/rejected": 0, + "logits/chosen": 0, + "logits/rejected": 0, + } num_tokens = 0 # self.epochs_run should be non-zero when we're resuming from a checkpoint @@ -678,7 +689,7 @@ def train(self) -> None: break # batch is input_ids, labels - num_tokens += batch[0].numel() + num_tokens += torch.tensor(batch[0].numel()) ( policy_chosen_log_probs, @@ -706,16 +717,52 @@ def train(self) -> None: reference_chosen_log_probs, reference_rejected_log_probs, ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() loss = loss.mean() - reward_accuracies = (chosen_rewards > rejected_rewards).float() loss = loss / self._gradient_accumulation_steps + + # Update running metrics running_loss += loss + scaling_factor = ( + 1 / self._gradient_accumulation_steps + ) # to average out between grad_acc steps + running_metrics["rewards/chosen"] += ( + scaling_factor * chosen_rewards.mean() + ) + running_metrics["rewards/rejected"] += ( + scaling_factor * rejected_rewards.mean() + ) + running_metrics["rewards/accuracies"] += ( + scaling_factor * reward_accuracies.mean() + ) + running_metrics["log_probs/chosen"] += ( + scaling_factor * policy_chosen_log_probs.detach().mean() + ) + running_metrics["log_probs/rejected"] += ( + scaling_factor * policy_rejected_log_probs.detach().mean() + ) + running_metrics["logits/chosen"] += ( + scaling_factor * policy_chosen_logits_mean + ) + running_metrics["logits/rejected"] += ( + scaling_factor * policy_rejected_logits_mean + ) + loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + # Accumulate running metrics across all devices + torch.distributed.all_reduce(running_loss) + torch.distributed.all_reduce(num_tokens) + + for key in running_metrics: + torch.distributed.all_reduce( + running_metrics[key], op=torch.distributed.ReduceOp.AVG + ) + self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -738,21 +785,27 @@ def train(self) -> None: log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, - "rewards/chosen": chosen_rewards.mean().cpu(), - "rewards/rejected": rejected_rewards.mean().cpu(), - "rewards/accuracies": reward_accuracies.mean().cpu(), - "rewards/margins": (chosen_rewards - rejected_rewards) - .mean() - .cpu(), - "log_probs/rejected": policy_rejected_log_probs.detach() - .mean() - .cpu(), - "log_probs/chosen": policy_chosen_log_probs.detach() - .mean() - .cpu(), - "logits/rejected": policy_rejected_logits_mean.cpu(), - "logits/chosen": policy_chosen_logits_mean.cpu(), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + "rewards/chosen": running_metrics["rewards/chosen"].cpu(), + "rewards/rejected": running_metrics[ + "rewards/rejected" + ].cpu(), + "rewards/accuracies": running_metrics[ + "rewards/accuracies" + ].cpu(), + "rewards/margins": ( + running_metrics["rewards/chosen"] + - running_metrics["rewards/rejected"] + ).cpu(), + "log_probs/chosen": running_metrics[ + "log_probs/chosen" + ].cpu(), + "log_probs/rejected": running_metrics[ + "log_probs/rejected" + ].cpu(), + "logits/chosen": running_metrics["logits/chosen"].cpu(), + "logits/rejected": running_metrics["logits/rejected"].cpu(), } if self._log_peak_memory_stats: log_dict.update( @@ -765,7 +818,9 @@ def train(self) -> None: # Reset running stats for the next step running_loss = 0 + running_metrics = {key: 0 for key in running_metrics} num_tokens = 0 + t0 = time.perf_counter() self.epochs_run += 1 diff --git a/tests/recipes/test_full_dpo_distributed.py b/tests/recipes/test_full_dpo_distributed.py new file mode 100644 index 0000000000..c3efc5e3fc --- /dev/null +++ b/tests/recipes/test_full_dpo_distributed.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + dummy_stack_exchange_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) + + +class TestFullDPODistributedRecipe: + def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): + return [ + "batch_size=1", + "device=cuda", + "enable_activation_checkpointing=True", + "enable_activation_offloading=True", + f"dtype={dtype_str}", + "dataset.train_on_input=False", + "seed=9", + f"epochs={epochs}", + "max_steps_per_epoch=2", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-6", + "log_every_n_steps=1", + "gradient_accumulation_steps=4", + "clip_grad_norm=100", + "tokenizer.max_seq_len=256", + ] + dummy_stack_exchange_dataset_config() + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + Unlike `tests.recipes.test_lora_finetune_single_device`, this test does not use pre-computed loss + values to benchmark against. This test just ensures the loss values are identical when resuming. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed \ + --config llama3_1/8B_full_dpo \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + ref_checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + ref_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_checkpointer.checkpoint_files=[{ckpt_path}]\ + ref_checkpointer.output_dir={tmpdir} \ + ref_checkpointer.model_type=LLAMA3 \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + tokenizer.max_seq_len=256 \ + metric_logger.filename={log_file} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=True \ + batch_size=1 \ + gradient_accumulation_steps=4 + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3"] + + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_1) + # with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = get_loss_values_from_metric_logger(log_file) + + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed \ + --config llama3_1/8B_full_dpo \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + ref_checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + ref_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_checkpointer.checkpoint_files=[{ckpt_path}]\ + ref_checkpointer.output_dir={tmpdir} \ + ref_checkpointer.model_type=LLAMA3 \ + resume_from_checkpoint=True \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + tokenizer.max_seq_len=256 \ + metric_logger.filename={resumed_log_file} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=True \ + batch_size=1 \ + gradient_accumulation_steps=4 + """.split() + cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Second epoch only + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + + torch.testing.assert_close( + resumed_loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 1c41519712..4536b7c33c 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -325,6 +325,17 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="full_dpo_distributed", + file_path="full_dpo_distributed.py", + configs=[ + Config( + name="llama3_1/8B_full_dpo", + file_path="llama3_1/8B_full_dpo.yaml", + ), + ], + supports_distributed=True, + ), Recipe( name="ppo_full_finetune_single_device", file_path="ppo_full_finetune_single_device.py",