diff --git a/src/olmo_core/train/train_module/__init__.py b/src/olmo_core/train/train_module/__init__.py index 20ad8b13..f45fa6ec 100644 --- a/src/olmo_core/train/train_module/__init__.py +++ b/src/olmo_core/train/train_module/__init__.py @@ -9,11 +9,15 @@ TransformerActivationCheckpointingMode, TransformerDataParallelConfig, TransformerDataParallelWrappingStrategy, - TransformerPipelineParallelConfig, TransformerTensorParallelConfig, TransformerTrainModule, TransformerTrainModuleConfig, ) +from .transformer_pipeline import ( + TransformerPipelineParallelConfig, + TransformerPipelineTrainModule, + TransformerPipelineTrainModuleConfig, +) __all__ = [ "TrainModule", @@ -22,6 +26,8 @@ "BasicTrainModule", "TransformerTrainModule", "TransformerTrainModuleConfig", + "TransformerPipelineTrainModule", + "TransformerPipelineTrainModuleConfig", "TransformerActivationCheckpointingConfig", "TransformerActivationCheckpointingMode", "TransformerDataParallelConfig", diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 9b8d7f7b..1d0e6b66 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -1,19 +1,14 @@ import contextlib -import copy import logging -import math from dataclasses import dataclass, replace -from functools import partial from typing import Any, Dict, Generator, List, Optional, Tuple, cast import torch import torch.distributed as dist import torch.distributed.checkpoint.state_dict as dist_cp_sd import torch.nn as nn -from torch.distributed import DeviceMesh from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.pipelining import PipelineStage from torch.distributed.tensor import DTensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -24,13 +19,10 @@ from olmo_core.distributed.parallel import ( DataParallelConfig, DataParallelType, - PipelineParallelConfig, - PipelineSchedule, TensorParallelConfig, build_device_mesh, get_dp_mesh, get_dp_process_group, - get_pp_mesh, get_tp_mesh, ) from olmo_core.distributed.utils import get_local_tensor, get_world_size @@ -53,7 +45,7 @@ from olmo_core.utils import gc_cuda, get_default_device, mark_dynamic, move_to_device from ..common import ReduceType, reshape_inputs_for_loss -from .train_module import EvalBatchSizeUnit, EvalBatchSpec, TrainModule +from .train_module import EvalBatchSpec, TrainModule log = logging.getLogger(__name__) @@ -79,118 +71,6 @@ class TransformerTensorParallelConfig(TensorParallelConfig): """ -@dataclass -class TransformerPipelineParallelConfig(PipelineParallelConfig): - """ - Transformer-specific pipeline parallel config. - """ - - split_points: Optional[List[int]] = None - """ - A list of unique, increasing block indices that define how to split the model into stages. - - For example, ``split_points = [0, 2]`` with a 4-layer model means the model will be split into - 3 stages, with the first containing just the embedding, the second containing blocks 0 and 1, - and the third containing blocks 2 and 3 and the language modeling head. - - If not specified the split points are determined automatically based on the schedule type. - """ - - def get_split_points(self, n_layers: int) -> List[int]: - if self.split_points is not None: - return self.split_points - - # Multi-stage schedules support more than 2 stages per rank, but this is the default if - # no pipeline split is specified. - num_stages_per_rank = 1 if self.schedule.is_single_stage else 2 - total_stages = self.degree * num_stages_per_rank - num_layers = n_layers - if total_stages > num_layers: - raise OLMoConfigurationError("Total stages cannot be greater than the number of layers") - - base_interval = num_layers // total_stages - extra_layers = num_layers % total_stages - - splits: List[int] = [] - current_layer = 0 - for i in range(total_stages - 1): - if i == 0: - current_layer += base_interval - else: - # Middle stages get an extra layer if there are any remaining - if extra_layers > 0: - current_layer += base_interval + 1 - extra_layers -= 1 - else: - current_layer += base_interval - splits.append(current_layer) - log.info(f"Auto generated pipeline split points will be {splits}") - return splits - - def split_model( - self, model: Transformer, *, pp_mesh: DeviceMesh, device: torch.device - ) -> Tuple[List[PipelineStage], List[Transformer]]: - split_points = self.get_split_points(model.n_layers) - pp_rank = pp_mesh.get_local_rank() - - def build_stage( - stage_idx: int, - start_layer: Optional[int], - stop_layer: Optional[int], - is_first: bool = False, - is_last: bool = False, - ) -> Tuple[PipelineStage, Transformer]: - model_chunk = copy.deepcopy(model) - if not is_first: - model_chunk.embeddings = None # type: ignore - - drop_layers = start_layer is not None - for block_idx in range(model.n_layers): - # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) - if block_idx == start_layer: - drop_layers = False - if block_idx == stop_layer: - drop_layers = True - if drop_layers: - del model_chunk.blocks[str(block_idx)] - - if not is_last: - model_chunk.lm_head = None # type: ignore - - stage = PipelineStage( - model_chunk, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model_chunk - - num_stages = len(split_points) + 1 - stage_idx = pp_rank - - stages = [] - models = [] - for stage_idx in self.stage_ids_this_rank(pp_rank, num_stages, style="loop"): - start_layer = split_points[stage_idx - 1] if stage_idx > 0 else None - stop_layer = split_points[stage_idx] if stage_idx < num_stages - 1 else None - stage, model_chunk = build_stage( - stage_idx, - start_layer, - stop_layer, - is_first=stage_idx == 0, - is_last=stage_idx == num_stages - 1, - ) - log.info( - f"PP rank {pp_rank} is building stage {stage_idx} with start layer " - f"{start_layer}, stop layer {stop_layer}: {model_chunk}" - ) - stages.append(stage) - models.append(model_chunk) - - return stages, models - - @beta_feature @dataclass class TransformerActivationCheckpointingConfig(Config): @@ -255,7 +135,6 @@ class TransformerTrainModuleConfig(Config): float8_config: Optional[Float8Config] = None dp_config: Optional[TransformerDataParallelConfig] = None tp_config: Optional[TransformerTensorParallelConfig] = None - pp_config: Optional[TransformerPipelineParallelConfig] = None ac_config: Optional[TransformerActivationCheckpointingConfig] = None # Loss function settings. @@ -322,7 +201,6 @@ class TransformerTrainModule(TrainModule): :param float8_config: Float8 configuration for the model. :param dp_config: Data parallel configuration for the model. :param tp_config: Tensor parallel configuration for the model. - :param pp_config: Pipeline parallel configuration for the model. :param ac_config: Activation checkpointing configuration for the model. :param compile_loss: Compile the loss function. This can provide a small speedup while also reducing GPU memory usage, especially when using Z-loss. @@ -358,7 +236,6 @@ def __init__( float8_config: Optional[Float8Config] = None, dp_config: Optional[TransformerDataParallelConfig] = None, tp_config: Optional[TransformerTensorParallelConfig] = None, - pp_config: Optional[TransformerPipelineParallelConfig] = None, ac_config: Optional[TransformerActivationCheckpointingConfig] = None, compile_loss: bool = False, fused_loss: bool = False, @@ -385,7 +262,7 @@ def __init__( self.device = device or get_default_device() self.world_mesh = build_device_mesh( - dp=dp_config, tp=tp_config, pp=pp_config, device_type=self.device.type + dp=dp_config, tp=tp_config, device_type=self.device.type ) log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") @@ -403,41 +280,24 @@ def __init__( float8_config.compile = compile_model self.float8_handler = float8_config.build() - self._pp_config = pp_config - # We'll initialize this lazily when the trainer is attached, since we need to know - # the global batch size in order to determine the number of pipeline micro-batches. - self._train_pp_schedule: Optional[PipelineSchedule] = None - self._eval_pp_schedule: Optional[PipelineSchedule] = None - self._pp_stages: Optional[List[PipelineStage]] = None - - self.model_parts: List[Transformer] = [] - if pp_config is not None: - pp_mesh = get_pp_mesh(self.world_mesh) - assert pp_mesh is not None - stages, model_parts = pp_config.split_model(model, pp_mesh=pp_mesh, device=self.device) - self._pp_stages = stages - self.model_parts = model_parts - else: - self.model_parts = [model] + self.model = model # Maybe convert linear layers to FP8 linear. if self.float8_handler is not None and self.float8_handler.enabled: - for model in self.model_parts: - self.float8_handler.convert_to_float8_training( - model, modules_to_ignore={"lm_head.w_out"} - ) + self.float8_handler.convert_to_float8_training( + self.model, modules_to_ignore={"lm_head.w_out"} + ) log.info("Swapped linear layers to Float8 linear layers") # Maybe apply tensor parallelism. if tp_config is not None: tp_mesh = get_tp_mesh(self.world_mesh) assert tp_mesh is not None - for model in self.model_parts: - model.apply_tp( - tp_mesh, - float8_enabled=float8_enabled, - loss_parallel=False, # TODO (epwalsh): figure out if this will work w/ z-loss - ) + self.model.apply_tp( + tp_mesh, + float8_enabled=float8_enabled, + loss_parallel=False, # TODO (epwalsh): figure out if this will work w/ z-loss + ) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" @@ -445,19 +305,17 @@ def __init__( # Maybe apply activation checkpointing. if ac_config is not None: - for model in self.model_parts: - model.apply_activation_checkpointing( - ac_config.mode, - block_interval=ac_config.block_interval, - modules=ac_config.modules, - ) + self.model.apply_activation_checkpointing( + ac_config.mode, + block_interval=ac_config.block_interval, + modules=ac_config.modules, + ) log.info(f"Applied '{ac_config.mode}' activation checkpointing to the model") # Maybe compile. if compile_model: if torch.cuda.is_available(): - for model in self.model_parts: - model.apply_compile() + self.model.apply_compile() log.info("Applied torch.compile() to the model") else: log.warning("Skipping model compilation since CUDA is not available") @@ -467,38 +325,29 @@ def __init__( if dp_config is not None: dp_mesh = get_dp_mesh(self.world_mesh) if dp_config.name in (DataParallelType.fsdp, DataParallelType.hsdp): - for model in self.model_parts: - model.apply_fsdp( - dp_mesh=dp_mesh, - param_dtype=dp_config.param_dtype.as_pt() - if dp_config.param_dtype is not None - else None, - reduce_dtype=dp_config.reduce_dtype.as_pt(), - wrapping_strategy=dp_config.wrapping_strategy, - pp_enabled=self.pp_enabled, - ) + self.model.apply_fsdp( + dp_mesh=dp_mesh, + param_dtype=dp_config.param_dtype.as_pt() + if dp_config.param_dtype is not None + else None, + reduce_dtype=dp_config.reduce_dtype.as_pt(), + wrapping_strategy=dp_config.wrapping_strategy, + pp_enabled=False, + ) log.info("Applied FSDP to the model") elif dp_config.name == DataParallelType.ddp: - for model in self.model_parts: - model.apply_ddp(dp_mesh=dp_mesh, compile_enabled=compile_model) + self.model.apply_ddp(dp_mesh=dp_mesh, compile_enabled=compile_model) log.info("Applied DDP to the model") else: raise NotImplementedError(dp_config.name) # Materialize and init parameters. - for model in self.model_parts: - log.info("Initializing model weights...") - model.init_weights(max_seq_len=max_sequence_length, device=self.device) + log.info("Initializing model weights...") + self.model.init_weights(max_seq_len=max_sequence_length, device=self.device) # Build optimizer(s). - log.info("Building optimizer(s)...") - self.optimizers: List[Optimizer] = [ - optim.build(model, strict=not self.pp_enabled) for model in self.model_parts - ] - if self.pp_enabled and isinstance(self.optimizers[0], SkipStepOptimizer): - raise NotImplementedError( - "Pipeline parallelism with a SkipStepOptimizer is currently not supported" - ) + log.info("Building optimizer...") + self.optim: Optimizer = optim.build(self.model, strict=True) self.rank_microbatch_size = rank_microbatch_size self.max_sequence_length = max_sequence_length @@ -510,25 +359,14 @@ def __init__( flatten_optimizer_state_dict=True, cpu_offload=True ) self.state_dict_load_opts = state_dict_load_opts or dist_cp_sd.StateDictOptions( - flatten_optimizer_state_dict=True, strict=not self.pp_enabled + flatten_optimizer_state_dict=True, strict=True ) self.load_key_mapping = load_key_mapping self.label_ignore_index = label_ignore_index self.moe_handler: Optional[MoEHandler] = None - for model in self.model_parts: - if MoEHandler.has_moe(model): - self.moe_handler = MoEHandler(model=model) - if self.pp_enabled: - # TODO (epwalsh): need to figure out how to handle the internal MoE losses correctly. - raise NotImplementedError( - "Pipeline parallelism with MoE's is currently not supported" - ) - break - - self._batch_num_tokens_for_loss: Optional[torch.Tensor] = None - self._ce_batch_loss: Optional[torch.Tensor] = None - self._z_batch_loss: Optional[torch.Tensor] = None + if MoEHandler.has_moe(self.model): + self.moe_handler = MoEHandler(model=self.model) @property def dp_process_group(self) -> Optional[dist.ProcessGroup]: @@ -536,51 +374,13 @@ def dp_process_group(self) -> Optional[dist.ProcessGroup]: @property def eval_batch_spec(self) -> EvalBatchSpec: - if not self.pp_enabled: - return EvalBatchSpec( - self.rank_microbatch_size, max_sequence_length=self.max_sequence_length - ) - else: - # Determine the number of micro-batches. - rank_batch_size = self.trainer.global_batch_size // get_world_size( - self.trainer.dp_process_group - ) - rank_batch_size_instances = rank_batch_size // self.max_sequence_length - return EvalBatchSpec( - rank_batch_size=rank_batch_size_instances, - batch_size_unit=EvalBatchSizeUnit.instances, - max_sequence_length=self.max_sequence_length, - fixed_sequence_length=True, - ) - - @property - def logits_dtype(self) -> torch.dtype: - if self.autocast_precision is not None: - return self.autocast_precision - elif self._dp_config is not None and self._dp_config.param_dtype is not None: - return self._dp_config.param_dtype.as_pt() - else: - for param in self.model_parts[0].parameters(): - return param.dtype - raise RuntimeError("Should not get here") - - @property - def pp_enabled(self) -> bool: - return self._pp_config is not None - - @property - def train_pp_schedule(self) -> Optional[PipelineSchedule]: - self.trainer # make sure trainer has been attached before trying to access this - return self._train_pp_schedule - - @property - def eval_pp_schedule(self) -> Optional[PipelineSchedule]: - self.trainer # make sure trainer has been attached before trying to access this - return self._eval_pp_schedule - - def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - assert self._batch_num_tokens_for_loss is not None + return EvalBatchSpec( + self.rank_microbatch_size, max_sequence_length=self.max_sequence_length + ) + def loss_fn( + self, logits: torch.Tensor, labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' @@ -595,27 +395,20 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: z_loss_multiplier=self.z_loss_multiplier or 1e-4, ) - ce_loss.div_(self._batch_num_tokens_for_loss) + ce_loss.div_(batch_num_tokens_for_loss) if z_loss is not None: - z_loss.div_(self._batch_num_tokens_for_loss) + z_loss.div_(batch_num_tokens_for_loss) # Get loss to optimize for. loss = ce_loss if z_loss is not None: loss += z_loss - # Update overall CE batch loss. - if self._ce_batch_loss is None: - self._ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) - self._ce_batch_loss += get_local_tensor(ce_loss.detach()) - - # Update overall Z batch loss. - if z_loss is not None: - if self._z_batch_loss is None: - self._z_batch_loss = move_to_device(torch.tensor(0.0), self.device) - self._z_batch_loss += get_local_tensor(z_loss.detach()) - - return loss + return ( + loss, + get_local_tensor(ce_loss.detach()), + None if z_loss is None else get_local_tensor(z_loss.detach()), + ) def eval_loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) @@ -636,35 +429,6 @@ def on_attach(self): f"micro-batch size ({self.rank_microbatch_size:,d}) x DP world size ({dp_ws})" ) - # Maybe initialize pipeline schedule. - if self._pp_config is not None: - assert self._train_pp_schedule is None # make sure we don't initialize this twice - assert self._pp_stages is not None - pp_mesh = get_pp_mesh(self.world_mesh) - assert pp_mesh is not None - - # Determine the number of micro-batches. - rank_batch_size = self.trainer.global_batch_size // dp_ws - num_micro_batches = rank_batch_size // self.rank_microbatch_size - - self._train_pp_schedule = PipelineSchedule( - model_parts=self.model_parts, # type: ignore[arg-type] - stages=self._pp_stages, - pp_mesh=pp_mesh, - schedule_name=self._pp_config.schedule, - loss_fn=self.loss_fn, - n_microbatches=num_micro_batches, - ) - self._eval_pp_schedule = PipelineSchedule( - model_parts=self.model_parts, # type: ignore[arg-type] - stages=self._pp_stages, - pp_mesh=pp_mesh, - schedule_name=self._pp_config.schedule, - n_microbatches=num_micro_batches, - # NOTE: can't pass this here or the schedule will attempt backward pass - # loss_fn=self.eval_loss_fn, - ) - def state_dict(self) -> Dict[str, Any]: return self._get_state_dict(self.state_dict_save_opts) @@ -712,26 +476,24 @@ def state_dict_to_save(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.load_key_mapping is not None: _swap_param_keys(state_dict, self.load_key_mapping, reverse=True, quiet=True) - for model, optim in zip(self.model_parts, self.optimizers): - dist_cp_sd.set_model_state_dict( - model, - state_dict["model"], + dist_cp_sd.set_model_state_dict( + self.model, + state_dict["model"], + options=self.state_dict_load_opts, + ) + gc_cuda() + if "optim" in state_dict: + dist_cp_sd.set_optimizer_state_dict( + self.model, + self.optim, + state_dict["optim"], options=self.state_dict_load_opts, ) gc_cuda() - if "optim" in state_dict: - dist_cp_sd.set_optimizer_state_dict( - model, - optim, - state_dict["optim"], - options=self.state_dict_load_opts, - ) - gc_cuda() def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Set model to train mode if it isn't already. - for model in self.model_parts: - model.train() + self.model.train() # Move tensors to the right device. batch = move_to_device(batch, self.device) @@ -741,39 +503,49 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): batch["labels"] = get_labels(batch, label_ignore_index=self.label_ignore_index) # Calculate how many tokens are going to be used in the loss. - self._batch_num_tokens_for_loss = (batch["labels"] != self.label_ignore_index).sum() + batch_num_tokens_for_loss = (batch["labels"] != self.label_ignore_index).sum() - if not self.pp_enabled: - # Split into micro-batches. - if self.rank_microbatch_size < (seq_len := batch["input_ids"].shape[1]): - raise RuntimeError( - f"Microbatch size ({self.rank_microbatch_size}) is too small relative to sequence length ({seq_len})" + # Update overall CE batch loss. + ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) + + # Update overall Z batch loss. + z_batch_loss: Optional[torch.Tensor] = None + if self.z_loss_multiplier is not None: + z_batch_loss = move_to_device(torch.tensor(0.0), self.device) + + # Split into micro-batches. + if self.rank_microbatch_size < (seq_len := batch["input_ids"].shape[1]): + raise RuntimeError( + f"Microbatch size ({self.rank_microbatch_size}) is too small relative to sequence length ({seq_len})" + ) + micro_batches = split_batch(batch, self.rank_microbatch_size // seq_len) + num_micro_batches = len(micro_batches) + + # Train one micro-batch at a time. + for micro_batch_idx, micro_batch in enumerate(micro_batches): + with self._train_microbatch_context(micro_batch_idx, num_micro_batches): + # Run forward pass. + logits = self.model_forward(micro_batch) + loss, ce_loss, z_loss = self.loss_fn( + logits, micro_batch["labels"], batch_num_tokens_for_loss ) - micro_batches = split_batch(batch, self.rank_microbatch_size // seq_len) - num_micro_batches = len(micro_batches) - - # Train one micro-batch at a time. - for micro_batch_idx, micro_batch in enumerate(micro_batches): - with self._train_microbatch_context(micro_batch_idx, num_micro_batches): - # Run forward pass. - logits, loss = self.model_forward(micro_batch, labels=micro_batch["labels"]) - del logits - assert loss is not None - - # Maybe add MoE losses. - if self.moe_handler is not None: - moe_loss = self.moe_handler.get_combined_loss( - batch=batch, micro_batch=micro_batch - ) - if moe_loss is not None: - loss += moe_loss - - # Run backward pass. - loss.backward() - else: - # Run pipeline schedule. - logits, loss = self.model_forward(batch, labels=batch["labels"]) - del logits, loss # pipeline schedule has already handled backward pass + del logits + + ce_batch_loss += ce_loss + if z_batch_loss is not None: + assert z_loss is not None + z_batch_loss += z_loss + + # Maybe add MoE losses. + if self.moe_handler is not None: + moe_loss = self.moe_handler.get_combined_loss( + batch=batch, micro_batch=micro_batch + ) + if moe_loss is not None: + loss += moe_loss + + # Run backward pass. + loss.backward() del batch # In case this helps with memory utilization. @@ -782,24 +554,16 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): return # Record loss metrics. - # NOTE: losses could be none for pipeline parallelism if rank doesn't have the final stage. - if self._ce_batch_loss is None: - self.record_ce_loss(0.0, ReduceType.sum) - else: - self.record_ce_loss( - self._ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum - ) + self.record_ce_loss(ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum) if self.z_loss_multiplier is not None: - if self._z_batch_loss is None: - self.record_metric("Z loss", 0.0, ReduceType.sum, namespace="train") - else: - self.record_metric( - "Z loss", - self._z_batch_loss / get_world_size(self.dp_process_group), - ReduceType.sum, - namespace="train", - ) + assert z_batch_loss is not None + self.record_metric( + "Z loss", + z_batch_loss / get_world_size(self.dp_process_group), + ReduceType.sum, + namespace="train", + ) if self.moe_handler is not None: if (moe_lb_loss := self.moe_handler.get_lb_loss()) is not None: @@ -807,31 +571,24 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): if (moe_z_loss := self.moe_handler.get_z_loss()) is not None: self.record_metric("router Z loss", moe_z_loss, namespace="train") - for optim in self.optimizers: - if isinstance(optim, SkipStepOptimizer): - assert self._ce_batch_loss is not None - optim.latest_loss = self._ce_batch_loss + if isinstance(self.optim, SkipStepOptimizer): + self.optim.latest_loss = ce_batch_loss # Lastly, clear internal loss buffers. self._clear_loss_buffers() def eval_batch( self, batch: Dict[str, Any], labels: Optional[torch.Tensor] = None - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch = move_to_device(batch, self.device) - for model in self.model_parts: - model.eval() + self.model.eval() with torch.no_grad(): - logits, loss = self.model_forward(batch, labels=labels, training=False) - - if self.pp_enabled: - assert self.eval_pp_schedule is not None - if self.eval_pp_schedule.is_last_stage: - assert logits is not None - if labels is not None: - assert loss is not None + logits = self.model_forward(batch) + loss: Optional[torch.Tensor] = None + if labels is not None: + loss = self.eval_loss_fn(logits, labels) self._clear_loss_buffers() @@ -845,77 +602,70 @@ def optim_step(self): self.trainer.record_metric( "total grad norm", grad_norm, reduce_type=None, namespace="optim" ) - for optim in self.optimizers: - if isinstance(optim, SkipStepOptimizer): - optim.latest_grad_norm = grad_norm + if isinstance(self.optim, SkipStepOptimizer): + self.optim.latest_grad_norm = grad_norm # Sync Float8 AMAXs (argmax of abs(max)) and scales. if self.float8_handler is not None: - for model in self.model_parts: - self.float8_handler.sync_float8_amax_and_scale_history(model) + self.float8_handler.sync_float8_amax_and_scale_history(self.model) # Maybe adjust learning rate. if self.scheduler is not None: - for optim in self.optimizers: - for group_idx, group in enumerate(optim.param_groups): - if (lr_field := self.scheduler.lr_field) not in group and ( - initial_lr_field := self.scheduler.initial_lr_field - ) not in group: - group_fields_list = "\n - ".join( - [f"{k}: {v}" for k, v in group.items() if k != "params"] - ) - raise RuntimeError( - f"learning rate field '{lr_field}' and initial learning rate field " - f"'{initial_lr_field}' not found in optimizer param group {group_idx} " - f"with {len(group['params'])} parameter(s):\n" - f" - {group_fields_list}" - ) - - # Ensure 'initial_lr' is set. - if group.get(self.scheduler.initial_lr_field) is None: - group[self.scheduler.initial_lr_field] = group["lr"] - - # Set new LR. - new_lr = self.scheduler.get_lr( - group[self.scheduler.initial_lr_field], - self.trainer.global_step, - self.trainer.max_steps, + for group_idx, group in enumerate(self.optim.param_groups): + if (lr_field := self.scheduler.lr_field) not in group and ( + initial_lr_field := self.scheduler.initial_lr_field + ) not in group: + group_fields_list = "\n - ".join( + [f"{k}: {v}" for k, v in group.items() if k != "params"] + ) + raise RuntimeError( + f"learning rate field '{lr_field}' and initial learning rate field " + f"'{initial_lr_field}' not found in optimizer param group {group_idx} " + f"with {len(group['params'])} parameter(s):\n" + f" - {group_fields_list}" ) - if isinstance(current_lr := group.get(self.scheduler.lr_field), torch.Tensor): - current_lr.fill_(new_lr) - else: - group[self.scheduler.lr_field] = new_lr + # Ensure 'initial_lr' is set. + if group.get(self.scheduler.initial_lr_field) is None: + group[self.scheduler.initial_lr_field] = group["lr"] - self.trainer.record_metric( - f"LR (group {group_idx})", group[self.scheduler.lr_field], namespace="optim" - ) + # Set new LR. + new_lr = self.scheduler.get_lr( + group[self.scheduler.initial_lr_field], + self.trainer.global_step, + self.trainer.max_steps, + ) + + if isinstance(current_lr := group.get(self.scheduler.lr_field), torch.Tensor): + current_lr.fill_(new_lr) + else: + group[self.scheduler.lr_field] = new_lr + + self.trainer.record_metric( + f"LR (group {group_idx})", group[self.scheduler.lr_field], namespace="optim" + ) # Step optimizer. - for optim in self.optimizers: - optim.step() - if isinstance(optim, SkipStepOptimizer): - self.record_metric("step skipped", optim.step_skipped, namespace="optim") + self.optim.step() + if isinstance(self.optim, SkipStepOptimizer): + self.record_metric("step skipped", self.optim.step_skipped, namespace="optim") # Maybe re-normalize matrices for nGPT-type models. # NOTE: sometimes 'isinstance' checks fail when the model is wrapped in some way. - for model in self.model_parts: - if isinstance(model, NormalizedTransformer) or hasattr(model, "normalize_matrices"): - cast(NormalizedTransformer, model).normalize_matrices() + if isinstance(self.model, NormalizedTransformer) or hasattr( + self.model, "normalize_matrices" + ): + cast(NormalizedTransformer, self.model).normalize_matrices() # Calculate Float8 dynamic AMAX/scale for all parameters. # For FSDP2 this issues a single all-reduce for all parameters at once. if self.float8_handler is not None: - for model in self.model_parts: - self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) + self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model) def zero_grads(self): - for optim in self.optimizers: - optim.zero_grad(set_to_none=True) + self.optim.zero_grad(set_to_none=True) - def model_forward( - self, batch: Dict[str, Any], labels: Optional[torch.Tensor] = None, training: bool = True - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + def model_forward(self, batch: Dict[str, Any]) -> torch.Tensor: """ Run a forward pass on a micro-batch, returning the logits and potentially the loss. """ @@ -929,50 +679,29 @@ def model_forward( if "doc_lens" in batch: mark_dynamic(batch["doc_lens"], (0, 1)) - if not self.pp_enabled: - # shape: (batch_size, seq_len, vocab_size) - logits = self.model_parts[0]( - input_ids=batch["input_ids"], - # attention_mask=micro_batch.get("attention_mask"), - # attention_bias=micro_batch.get("attention_bias"), - doc_lens=batch.get("doc_lens"), - max_doc_lens=batch.get("max_doc_lens"), - ) - loss: Optional[torch.Tensor] = None - if labels is not None: - loss_fn = self.loss_fn if training else self.eval_loss_fn - loss = loss_fn(logits, labels) - return logits, loss - else: - schedule = self.train_pp_schedule if training else self.eval_pp_schedule - assert schedule is not None - # shape: (batch_size, seq_len, vocab_size), (1,) - logits, loss = schedule.step( - input_ids=batch["input_ids"], - # attention_mask=micro_batch.get("attention_mask"), - # attention_bias=micro_batch.get("attention_bias"), - target=labels, - doc_lens=batch.get("doc_lens"), - max_doc_lens=batch.get("max_doc_lens"), - ) - if schedule.is_last_stage: - assert logits is not None - if not training and logits is not None and labels is not None and loss is None: - loss = self.eval_loss_fn(logits, labels) - return logits, loss + # Run model forward, get logits. + # shape: (batch_size, seq_len, vocab_size) + logits = self.model( + input_ids=batch["input_ids"], + # attention_mask=micro_batch.get("attention_mask"), + # attention_bias=micro_batch.get("attention_bias"), + doc_lens=batch.get("doc_lens"), + max_doc_lens=batch.get("max_doc_lens"), + ) + + return logits def num_flops_per_token(self, seq_len: int) -> int: - return self.model_parts[0].num_flops_per_token(seq_len) + return self.model.num_flops_per_token(seq_len) @contextlib.contextmanager def _train_microbatch_context( self, micro_batch_idx: int, num_micro_batches: int ) -> Generator[None, None, None]: with contextlib.ExitStack() as stack: - for model in self.model_parts: - if isinstance(model, DDP) and micro_batch_idx != num_micro_batches - 1: - # For DDP, only sync gradients on the final micro batch. - stack.enter_context(model.no_sync()) + if isinstance(self.model, DDP) and micro_batch_idx != num_micro_batches - 1: + # For DDP, only sync gradients on the final micro batch. + stack.enter_context(self.model.no_sync()) yield @contextlib.contextmanager @@ -983,41 +712,26 @@ def _model_forward_context(self) -> Generator[None, None, None]: yield def _clear_loss_buffers(self): - self._batch_num_tokens_for_loss = None - self._ce_batch_loss = None - self._z_batch_loss = None if self.moe_handler is not None: self.moe_handler.clear_loss_buffers() def _get_state_dict(self, sd_options: dist_cp_sd.StateDictOptions) -> Dict[str, Any]: return { - "model": { - k: v - for sd in map( - partial(dist_cp_sd.get_model_state_dict, options=sd_options), self.model_parts - ) - for k, v in sd.items() - }, - "optim": { - k: v - for sd in map( - partial(dist_cp_sd.get_optimizer_state_dict, options=sd_options), - self.model_parts, - self.optimizers, - ) - for k, v in sd.items() - }, + "model": dist_cp_sd.get_model_state_dict(self.model, options=sd_options), + "optim": dist_cp_sd.get_optimizer_state_dict( + self.model, self.optim, options=sd_options + ), } def _clip_grad_norm( self, max_grad_norm: float, norm_type: float = 2.0, foreach: Optional[bool] = None ) -> torch.Tensor: - if not self.pp_enabled and isinstance(self.model_parts[0], FSDP): - return self.model_parts[0].clip_grad_norm_(max_grad_norm) + if isinstance(self.model, FSDP): + return self.model.clip_grad_norm_(max_grad_norm) # Adapted from https://github.com/pytorch/torchtitan/blob/2a4437014e66bcf88a3f0419b816266e6326d539/torchtitan/utils.py#L348 - parameters = [p for m in self.model_parts for p in m.parameters()] + parameters = [p for p in self.model.parameters()] grads = [p.grad for p in parameters if p.grad is not None] total_norm = nn.utils.get_total_norm( @@ -1035,14 +749,5 @@ def _clip_grad_norm( # If only using PP, total_norm will be a local tensor. total_norm = total_norm.full_tensor() - if self.train_pp_schedule is not None: - pp_mesh = self.train_pp_schedule.pp_mesh - if math.isinf(norm_type): - dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) - else: - total_norm **= norm_type - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type - torch.nn.utils.clip_grads_with_norm_(parameters, max_grad_norm, total_norm, foreach=foreach) return total_norm diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py new file mode 100644 index 00000000..85327a74 --- /dev/null +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -0,0 +1,908 @@ +import contextlib +import copy +import logging +import math +from dataclasses import dataclass, replace +from functools import partial +from typing import Any, Dict, Generator, List, Optional, Tuple, cast + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint.state_dict as dist_cp_sd +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.pipelining import PipelineStage +from torch.distributed.tensor import DTensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer + +from olmo_core.config import Config, DType +from olmo_core.data.utils import get_labels +from olmo_core.distributed.checkpoint import _swap_param_keys +from olmo_core.distributed.parallel import ( + DataParallelType, + PipelineParallelConfig, + PipelineSchedule, + build_device_mesh, + get_dp_mesh, + get_dp_process_group, + get_pp_mesh, + get_tp_mesh, +) +from olmo_core.distributed.utils import get_local_tensor, get_world_size +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.float8 import Float8Config, Float8Handler +from olmo_core.nn.functional.cross_entropy_loss import ( + cross_entropy_loss, + fused_cross_entropy_loss, +) +from olmo_core.nn.moe import MoEHandler +from olmo_core.nn.transformer import NormalizedTransformer, Transformer +from olmo_core.optim import OptimConfig, SkipStepOptimizer +from olmo_core.optim.scheduler import Scheduler +from olmo_core.utils import gc_cuda, get_default_device, mark_dynamic, move_to_device + +from ..common import ReduceType, reshape_inputs_for_loss +from .train_module import EvalBatchSizeUnit, EvalBatchSpec, TrainModule +from .transformer import ( + TransformerActivationCheckpointingConfig, + TransformerDataParallelConfig, + TransformerTensorParallelConfig, +) + +log = logging.getLogger(__name__) + + +@dataclass +class TransformerPipelineParallelConfig(PipelineParallelConfig): + """ + Transformer-specific pipeline parallel config. + """ + + split_points: Optional[List[int]] = None + """ + A list of unique, increasing block indices that define how to split the model into stages. + + For example, ``split_points = [0, 2]`` with a 4-layer model means the model will be split into + 3 stages, with the first containing just the embedding, the second containing blocks 0 and 1, + and the third containing blocks 2 and 3 and the language modeling head. + + If not specified the split points are determined automatically based on the schedule type. + """ + + def get_split_points(self, n_layers: int) -> List[int]: + if self.split_points is not None: + return self.split_points + + # Multi-stage schedules support more than 2 stages per rank, but this is the default if + # no pipeline split is specified. + num_stages_per_rank = 1 if self.schedule.is_single_stage else 2 + total_stages = self.degree * num_stages_per_rank + num_layers = n_layers + if total_stages > num_layers: + raise OLMoConfigurationError("Total stages cannot be greater than the number of layers") + + base_interval = num_layers // total_stages + extra_layers = num_layers % total_stages + + splits: List[int] = [] + current_layer = 0 + for i in range(total_stages - 1): + if i == 0: + current_layer += base_interval + else: + # Middle stages get an extra layer if there are any remaining + if extra_layers > 0: + current_layer += base_interval + 1 + extra_layers -= 1 + else: + current_layer += base_interval + splits.append(current_layer) + log.info(f"Auto generated pipeline split points will be {splits}") + return splits + + def split_model( + self, model: Transformer, *, pp_mesh: DeviceMesh, device: torch.device + ) -> Tuple[List[PipelineStage], List[Transformer]]: + split_points = self.get_split_points(model.n_layers) + pp_rank = pp_mesh.get_local_rank() + + def build_stage( + stage_idx: int, + start_layer: Optional[int], + stop_layer: Optional[int], + is_first: bool = False, + is_last: bool = False, + ) -> Tuple[PipelineStage, Transformer]: + model_chunk = copy.deepcopy(model) + if not is_first: + model_chunk.embeddings = None # type: ignore + + drop_layers = start_layer is not None + for block_idx in range(model.n_layers): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if block_idx == start_layer: + drop_layers = False + if block_idx == stop_layer: + drop_layers = True + if drop_layers: + del model_chunk.blocks[str(block_idx)] + + if not is_last: + model_chunk.lm_head = None # type: ignore + + stage = PipelineStage( + model_chunk, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model_chunk + + num_stages = len(split_points) + 1 + stage_idx = pp_rank + + stages = [] + models = [] + for stage_idx in self.stage_ids_this_rank(pp_rank, num_stages, style="loop"): + start_layer = split_points[stage_idx - 1] if stage_idx > 0 else None + stop_layer = split_points[stage_idx] if stage_idx < num_stages - 1 else None + stage, model_chunk = build_stage( + stage_idx, + start_layer, + stop_layer, + is_first=stage_idx == 0, + is_last=stage_idx == num_stages - 1, + ) + log.info( + f"PP rank {pp_rank} is building stage {stage_idx} with start layer " + f"{start_layer}, stop layer {stop_layer}: {model_chunk}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models + + +@dataclass +class TransformerPipelineTrainModuleConfig(Config): + """ + A configuration class for building :class:`TransformerTrainModule` instances. + + .. seealso:: + See the :class:`TransformerTrainModule` documentation for a description of the fields. + """ + + rank_microbatch_size: int + max_sequence_length: int + pp_config: TransformerPipelineParallelConfig + + # Optimizer settings. + + optim: OptimConfig + max_grad_norm: Optional[float] = None + scheduler: Optional[Scheduler] = None + + # Model settings. + + compile_model: bool = False + float8_config: Optional[Float8Config] = None + dp_config: Optional[TransformerDataParallelConfig] = None + tp_config: Optional[TransformerTensorParallelConfig] = None + ac_config: Optional[TransformerActivationCheckpointingConfig] = None + + # Loss function settings. + + fused_loss: bool = False + compile_loss: bool = False + z_loss_multiplier: Optional[float] = None + + # Checkpoint settings. + + state_dict_save_opts: Optional[Dict[str, Any]] = None + state_dict_load_opts: Optional[Dict[str, Any]] = None + load_key_mapping: Optional[Dict[str, str]] = None + + # Other train settings. + + autocast_precision: Optional[DType] = None + label_ignore_index: int = -100 + + def build( + self, + model: Transformer, + device: Optional[torch.device] = None, + ) -> "TransformerPipelineTrainModule": + """ + Build the corresponding :class:`TransformerPipelineTrainModule`. + + :param model: The :class:`~olmo_core.nn.transformer.Transformer` model to train. + :param device: The device to train on. + """ + kwargs = self.as_dict(exclude_none=True, recurse=False) + if (autocast_precision := kwargs.pop("autocast_precision", None)) is not None: + kwargs["autocast_precision"] = cast(DType, autocast_precision).as_pt() + if (state_dict_save_opts := kwargs.pop("state_dict_save_opts", None)) is not None: + kwargs["state_dict_save_opts"] = dist_cp_sd.StateDictOptions(**state_dict_save_opts) + if (state_dict_load_opts := kwargs.pop("state_dict_load_opts", None)) is not None: + kwargs["state_dict_load_opts"] = dist_cp_sd.StateDictOptions(**state_dict_load_opts) + return TransformerPipelineTrainModule( + model=model, + device=device, + **kwargs, + ) + + +class TransformerPipelineTrainModule(TrainModule): + """ + A pipeline-parallel :class:`TrainModule` for most :class:`~olmo_core.nn.transformer.Transformer` model + implementation provided by this library. + + .. tip:: + Use the :class:`TransformerPipelineTrainModuleConfig` to easily configure and build + :class:`TransformerPipelineTrainModule` instances. + + :param model: The :class:`~olmo_core.nn.transformer.Transformer` model to train. + :param optim: The corresponding optimizer config. + :param rank_microbatch_size: The microbatch size *in tokens* per rank, + i.e. the number of tokens to process at a time from each rank. + + .. note:: This must evenly divide into the global batch size by a factor of the data + parallel world size. If this is less than the global batch divided by the data + parallel world size then gradient accumulation is used. + :param max_sequence_length: The maximum expected sequence length during training and evaluation. + :param compile_model: Whether to compile to the model. + :param float8_config: Float8 configuration for the model. + :param dp_config: Data parallel configuration for the model. + :param tp_config: Tensor parallel configuration for the model. + :param pp_config: Pipeline parallel configuration for the model. + :param ac_config: Activation checkpointing configuration for the model. + :param compile_loss: Compile the loss function. This can provide a small speedup while also + reducing GPU memory usage, especially when using Z-loss. + + .. important:: + This is incompatible with ``fused_loss=True``. + :param fused_loss: Use the fused cross-entropy loss function (:func:`~olmo_core.nn.functional.fused_cross_entropy_loss`) + instead the PyTorch built-in. This can help reduce GPU memory usage when ``compile_loss=False``. + Relative performance will depend on the input sizes. + + .. important:: + This is incompatible with ``compile_loss=True``. + :param z_loss_multiplier: Use Z-loss with this multiplier. + :param autocast_precision: Enable AMP with this data type. + :param max_grad_norm: Clip gradient norms to this value. + :param scheduler: Optional learning rate scheduler for the optimizer. + :param device: The device to train on. + :param state_dict_save_opts: Can be used to override the state dict options used + when saving a checkpoint. + :param state_dict_load_opts: Can be used to override the state dict options used + when loading a checkpoint. + :param load_key_mapping: Can be used to load a checkpoint where certain parameter have different names. + This dictionary should map current keys to keys in the checkpoint to be loaded. + """ + + def __init__( + self, + model: Transformer, + optim: OptimConfig, + rank_microbatch_size: int, + max_sequence_length: int, + pp_config: TransformerPipelineParallelConfig, + compile_model: bool = False, + float8_config: Optional[Float8Config] = None, + dp_config: Optional[TransformerDataParallelConfig] = None, + tp_config: Optional[TransformerTensorParallelConfig] = None, + ac_config: Optional[TransformerActivationCheckpointingConfig] = None, + compile_loss: bool = False, + fused_loss: bool = False, + z_loss_multiplier: Optional[float] = None, + autocast_precision: Optional[torch.dtype] = None, + max_grad_norm: Optional[float] = None, + scheduler: Optional[Scheduler] = None, + device: Optional[torch.device] = None, + state_dict_save_opts: Optional[dist_cp_sd.StateDictOptions] = None, + state_dict_load_opts: Optional[dist_cp_sd.StateDictOptions] = None, + load_key_mapping: Optional[Dict[str, str]] = None, + label_ignore_index: int = -100, + ): + super().__init__() + + # Validate some options. + if fused_loss and compile_loss: + raise OLMoConfigurationError("'fused_loss' is not compatible with 'compile_loss'") + if rank_microbatch_size % max_sequence_length != 0: + raise OLMoConfigurationError( + f"'rank_microbatch_size' ({rank_microbatch_size:,d} tokens) must be divisible by " + f"'max_sequence_length' ({max_sequence_length:,d} tokens)" + ) + + self.device = device or get_default_device() + self.world_mesh = build_device_mesh( + dp=dp_config, tp=tp_config, pp=pp_config, device_type=self.device.type + ) + log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") + + self.base_loss_fn = fused_cross_entropy_loss if fused_loss else cross_entropy_loss + if compile_loss: + if torch.cuda.is_available(): + self.base_loss_fn = torch.compile(self.base_loss_fn) + else: + log.warning("Skipping loss compilation since CUDA is not available") + + self.float8_handler: Optional[Float8Handler] = None + float8_enabled = False + if float8_config is not None: + float8_enabled = float8_config.enabled + float8_config.compile = compile_model + self.float8_handler = float8_config.build() + + self._pp_config = pp_config + # We'll initialize this lazily when the trainer is attached, since we need to know + # the global batch size in order to determine the number of pipeline micro-batches. + self._train_pp_schedule: Optional[PipelineSchedule] = None + self._eval_pp_schedule: Optional[PipelineSchedule] = None + self._pp_stages: Optional[List[PipelineStage]] = None + + self.model_parts: List[Transformer] = [] + pp_mesh = get_pp_mesh(self.world_mesh) + assert pp_mesh is not None + stages, model_parts = pp_config.split_model(model, pp_mesh=pp_mesh, device=self.device) + self._pp_stages = stages + self.model_parts = model_parts + + # Maybe convert linear layers to FP8 linear. + if self.float8_handler is not None and self.float8_handler.enabled: + for model in self.model_parts: + self.float8_handler.convert_to_float8_training( + model, modules_to_ignore={"lm_head.w_out"} + ) + log.info("Swapped linear layers to Float8 linear layers") + + # Maybe apply tensor parallelism. + if tp_config is not None: + tp_mesh = get_tp_mesh(self.world_mesh) + assert tp_mesh is not None + for model in self.model_parts: + model.apply_tp( + tp_mesh, + float8_enabled=float8_enabled, + loss_parallel=False, # TODO (epwalsh): figure out if this will work w/ z-loss + ) + tp_config.maybe_enable_async_tp(tp_mesh) + log.info( + f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" + ) + + # Maybe apply activation checkpointing. + if ac_config is not None: + for model in self.model_parts: + model.apply_activation_checkpointing( + ac_config.mode, + block_interval=ac_config.block_interval, + modules=ac_config.modules, + ) + log.info(f"Applied '{ac_config.mode}' activation checkpointing to the model") + + # Maybe compile. + if compile_model: + if torch.cuda.is_available(): + for model in self.model_parts: + model.apply_compile() + log.info("Applied torch.compile() to the model") + else: + log.warning("Skipping model compilation since CUDA is not available") + + # Maybe shard/replicate according to data parallel config. + self._dp_config = dp_config + if dp_config is not None: + dp_mesh = get_dp_mesh(self.world_mesh) + if dp_config.name in (DataParallelType.fsdp, DataParallelType.hsdp): + for model in self.model_parts: + model.apply_fsdp( + dp_mesh=dp_mesh, + param_dtype=dp_config.param_dtype.as_pt() + if dp_config.param_dtype is not None + else None, + reduce_dtype=dp_config.reduce_dtype.as_pt(), + wrapping_strategy=dp_config.wrapping_strategy, + pp_enabled=True, + ) + log.info("Applied FSDP to the model") + elif dp_config.name == DataParallelType.ddp: + for model in self.model_parts: + model.apply_ddp(dp_mesh=dp_mesh, compile_enabled=compile_model) + log.info("Applied DDP to the model") + else: + raise NotImplementedError(dp_config.name) + + # Materialize and init parameters. + for model in self.model_parts: + log.info("Initializing model weights...") + model.init_weights(max_seq_len=max_sequence_length, device=self.device) + + # Build optimizer(s). + log.info("Building optimizer(s)...") + self.optimizers: List[Optimizer] = [ + optim.build(model, strict=False) for model in self.model_parts + ] + if isinstance(self.optimizers[0], SkipStepOptimizer): + raise NotImplementedError( + "Pipeline parallelism with a SkipStepOptimizer is currently not supported" + ) + + self.rank_microbatch_size = rank_microbatch_size + self.max_sequence_length = max_sequence_length + self.z_loss_multiplier = z_loss_multiplier + self.autocast_precision = autocast_precision + self.max_grad_norm = max_grad_norm + self.scheduler = scheduler + self.state_dict_save_opts = state_dict_save_opts or dist_cp_sd.StateDictOptions( + flatten_optimizer_state_dict=True, cpu_offload=True + ) + self.state_dict_load_opts = state_dict_load_opts or dist_cp_sd.StateDictOptions( + flatten_optimizer_state_dict=True, strict=False + ) + self.load_key_mapping = load_key_mapping + self.label_ignore_index = label_ignore_index + + self.moe_handler: Optional[MoEHandler] = None + for model in self.model_parts: + if MoEHandler.has_moe(model): + # TODO (epwalsh): need to figure out how to handle the internal MoE losses correctly. + raise NotImplementedError( + "Pipeline parallelism with MoE's is currently not supported" + ) + + self._batch_num_tokens_for_loss: Optional[torch.Tensor] = None + self._ce_batch_loss: Optional[torch.Tensor] = None + self._z_batch_loss: Optional[torch.Tensor] = None + + @property + def dp_process_group(self) -> Optional[dist.ProcessGroup]: + return get_dp_process_group(self.world_mesh) + + @property + def eval_batch_spec(self) -> EvalBatchSpec: + # Determine the number of micro-batches. + rank_batch_size = self.trainer.global_batch_size // get_world_size( + self.trainer.dp_process_group + ) + rank_batch_size_instances = rank_batch_size // self.max_sequence_length + return EvalBatchSpec( + rank_batch_size=rank_batch_size_instances, + batch_size_unit=EvalBatchSizeUnit.instances, + max_sequence_length=self.max_sequence_length, + fixed_sequence_length=True, + ) + + @property + def train_pp_schedule(self) -> PipelineSchedule: + self.trainer # make sure trainer has been attached before trying to access this + assert self._train_pp_schedule is not None + return self._train_pp_schedule + + @property + def eval_pp_schedule(self) -> PipelineSchedule: + self.trainer # make sure trainer has been attached before trying to access this + assert self._eval_pp_schedule is not None + return self._eval_pp_schedule + + def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + assert self._batch_num_tokens_for_loss is not None + + logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) + + # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' + # (the total number of tokens used in the loss across the whole batch, not just the micro batch) + # to avoid biasing the loss in the case where micro-batches might not be the same size. + ce_loss, z_loss = self.base_loss_fn( + logits_for_loss, + labels_for_loss, + ignore_index=self.label_ignore_index, + reduction="sum", + compute_z_loss=self.z_loss_multiplier is not None, + z_loss_multiplier=self.z_loss_multiplier or 1e-4, + ) + + ce_loss.div_(self._batch_num_tokens_for_loss) + if z_loss is not None: + z_loss.div_(self._batch_num_tokens_for_loss) + + # Get loss to optimize for. + loss = ce_loss + if z_loss is not None: + loss += z_loss + + # Update overall CE batch loss. + if self._ce_batch_loss is None: + self._ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) + self._ce_batch_loss += get_local_tensor(ce_loss.detach()) + + # Update overall Z batch loss. + if z_loss is not None: + if self._z_batch_loss is None: + self._z_batch_loss = move_to_device(torch.tensor(0.0), self.device) + self._z_batch_loss += get_local_tensor(z_loss.detach()) + + return loss + + def eval_loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) + ce_loss, _ = self.base_loss_fn( + logits_for_loss, + labels_for_loss, + ignore_index=self.label_ignore_index, + reduction="none", + ) + return ce_loss.view(logits.shape[0], -1) + + def on_attach(self): + # Validate batch size. + dp_ws = get_world_size(self.trainer.dp_process_group) + if self.trainer.global_batch_size % (self.rank_microbatch_size * dp_ws) != 0: + raise OLMoConfigurationError( + f"global batch size ({self.trainer.global_batch_size:,d}) must be divisible by " + f"micro-batch size ({self.rank_microbatch_size:,d}) x DP world size ({dp_ws})" + ) + + # Initialize pipeline schedule. + assert self._train_pp_schedule is None # make sure we don't initialize this twice + assert self._pp_stages is not None + pp_mesh = get_pp_mesh(self.world_mesh) + assert pp_mesh is not None + + # Determine the number of micro-batches. + rank_batch_size = self.trainer.global_batch_size // dp_ws + num_micro_batches = rank_batch_size // self.rank_microbatch_size + + self._train_pp_schedule = PipelineSchedule( + model_parts=self.model_parts, # type: ignore[arg-type] + stages=self._pp_stages, + pp_mesh=pp_mesh, + schedule_name=self._pp_config.schedule, + loss_fn=self.loss_fn, + n_microbatches=num_micro_batches, + ) + self._eval_pp_schedule = PipelineSchedule( + model_parts=self.model_parts, # type: ignore[arg-type] + stages=self._pp_stages, + pp_mesh=pp_mesh, + schedule_name=self._pp_config.schedule, + n_microbatches=num_micro_batches, + # NOTE: can't pass this here or the schedule will attempt backward pass + # loss_fn=self.eval_loss_fn, + ) + + def state_dict(self) -> Dict[str, Any]: + return self._get_state_dict(self.state_dict_save_opts) + + def state_dict_to_load(self, metadata: Metadata) -> Dict[str, Any]: + load_opts = self.state_dict_load_opts + + if "optim.param_groups.0.params" in metadata.state_dict_metadata: + # unflattened optimizer state + if load_opts.flatten_optimizer_state_dict: + log.warning( + "Loading checkpoint with an unflattened optimizer state even though " + "'flatten_optimizer_state_dict=True' in train module's 'state_dict_load_opts', " + "automatically switching to 'flatten_optimizer_state_dict=False'." + ) + load_opts = replace(load_opts, flatten_optimizer_state_dict=False) + else: + # flattened optimizer state + if not load_opts.flatten_optimizer_state_dict: + log.warning( + "Loading checkpoint with a flattened optimizer state even though " + "'flatten_optimizer_state_dict=False' in train module's 'state_dict_load_opts', " + "automatically switching to 'flatten_optimizer_state_dict=True'." + ) + load_opts = replace(load_opts, flatten_optimizer_state_dict=True) + + state_dict = self._get_state_dict(load_opts) + if self.load_key_mapping is not None: + _swap_param_keys(state_dict, self.load_key_mapping, metadata=metadata) + + has_optim_state: bool = False + for key in metadata.state_dict_metadata.keys(): + if key.startswith("optim."): + has_optim_state = True + break + + if not has_optim_state: + del state_dict["optim"] + log.warning("No optimizer state found in checkpoint") + + return state_dict + + def state_dict_to_save(self) -> Dict[str, Any]: + return self._get_state_dict(self.state_dict_save_opts) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if self.load_key_mapping is not None: + _swap_param_keys(state_dict, self.load_key_mapping, reverse=True, quiet=True) + for model, optim in zip(self.model_parts, self.optimizers): + dist_cp_sd.set_model_state_dict( + model, + state_dict["model"], + options=self.state_dict_load_opts, + ) + gc_cuda() + if "optim" in state_dict: + dist_cp_sd.set_optimizer_state_dict( + model, + optim, + state_dict["optim"], + options=self.state_dict_load_opts, + ) + gc_cuda() + + def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): + # Set model to train mode if it isn't already. + for model in self.model_parts: + model.train() + + # Move tensors to the right device. + batch = move_to_device(batch, self.device) + + # Generate labels. + if "labels" not in batch: + batch["labels"] = get_labels(batch, label_ignore_index=self.label_ignore_index) + + # Calculate how many tokens are going to be used in the loss. + self._batch_num_tokens_for_loss = (batch["labels"] != self.label_ignore_index).sum() + + # Run pipeline schedule. + logits, loss = self.model_forward(batch, labels=batch["labels"]) + del logits, loss # pipeline schedule has already handled backward pass + + del batch # In case this helps with memory utilization. + + if dry_run: + self._clear_loss_buffers() + return + + # Record loss metrics. + # NOTE: losses could be none for pipeline parallelism if rank doesn't have the final stage. + if self._ce_batch_loss is None: + self.record_ce_loss(0.0, ReduceType.sum) + else: + self.record_ce_loss( + self._ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum + ) + + if self.z_loss_multiplier is not None: + if self._z_batch_loss is None: + self.record_metric("Z loss", 0.0, ReduceType.sum, namespace="train") + else: + self.record_metric( + "Z loss", + self._z_batch_loss / get_world_size(self.dp_process_group), + ReduceType.sum, + namespace="train", + ) + + if self.moe_handler is not None: + if (moe_lb_loss := self.moe_handler.get_lb_loss()) is not None: + self.record_metric("load balancing loss", moe_lb_loss, namespace="train") + if (moe_z_loss := self.moe_handler.get_z_loss()) is not None: + self.record_metric("router Z loss", moe_z_loss, namespace="train") + + for optim in self.optimizers: + if isinstance(optim, SkipStepOptimizer): + assert self._ce_batch_loss is not None + optim.latest_loss = self._ce_batch_loss + + # Lastly, clear internal loss buffers. + self._clear_loss_buffers() + + def eval_batch( + self, batch: Dict[str, Any], labels: Optional[torch.Tensor] = None + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + batch = move_to_device(batch, self.device) + + for model in self.model_parts: + model.eval() + + with torch.no_grad(): + logits, loss = self.model_forward(batch, labels=labels, training=False) + + if self.eval_pp_schedule.is_last_stage: + assert logits is not None + if labels is not None: + assert loss is not None + + self._clear_loss_buffers() + + return logits, loss + + def optim_step(self): + # Maybe clip gradients. + if self.max_grad_norm is not None: + grad_norm = self._clip_grad_norm(self.max_grad_norm) + # NOTE: grad norm is already reduced over ranks, so we set `reduce_type` to `None`. + self.trainer.record_metric( + "total grad norm", grad_norm, reduce_type=None, namespace="optim" + ) + for optim in self.optimizers: + if isinstance(optim, SkipStepOptimizer): + optim.latest_grad_norm = grad_norm + + # Sync Float8 AMAXs (argmax of abs(max)) and scales. + if self.float8_handler is not None: + for model in self.model_parts: + self.float8_handler.sync_float8_amax_and_scale_history(model) + + # Maybe adjust learning rate. + if self.scheduler is not None: + for optim in self.optimizers: + for group_idx, group in enumerate(optim.param_groups): + if (lr_field := self.scheduler.lr_field) not in group and ( + initial_lr_field := self.scheduler.initial_lr_field + ) not in group: + group_fields_list = "\n - ".join( + [f"{k}: {v}" for k, v in group.items() if k != "params"] + ) + raise RuntimeError( + f"learning rate field '{lr_field}' and initial learning rate field " + f"'{initial_lr_field}' not found in optimizer param group {group_idx} " + f"with {len(group['params'])} parameter(s):\n" + f" - {group_fields_list}" + ) + + # Ensure 'initial_lr' is set. + if group.get(self.scheduler.initial_lr_field) is None: + group[self.scheduler.initial_lr_field] = group["lr"] + + # Set new LR. + new_lr = self.scheduler.get_lr( + group[self.scheduler.initial_lr_field], + self.trainer.global_step, + self.trainer.max_steps, + ) + + if isinstance(current_lr := group.get(self.scheduler.lr_field), torch.Tensor): + current_lr.fill_(new_lr) + else: + group[self.scheduler.lr_field] = new_lr + + self.trainer.record_metric( + f"LR (group {group_idx})", group[self.scheduler.lr_field], namespace="optim" + ) + + # Step optimizer. + for optim in self.optimizers: + optim.step() + if isinstance(optim, SkipStepOptimizer): + self.record_metric("step skipped", optim.step_skipped, namespace="optim") + + # Maybe re-normalize matrices for nGPT-type models. + # NOTE: sometimes 'isinstance' checks fail when the model is wrapped in some way. + for model in self.model_parts: + if isinstance(model, NormalizedTransformer) or hasattr(model, "normalize_matrices"): + cast(NormalizedTransformer, model).normalize_matrices() + + # Calculate Float8 dynamic AMAX/scale for all parameters. + # For FSDP2 this issues a single all-reduce for all parameters at once. + if self.float8_handler is not None: + for model in self.model_parts: + self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) + + def zero_grads(self): + for optim in self.optimizers: + optim.zero_grad(set_to_none=True) + + def model_forward( + self, batch: Dict[str, Any], labels: Optional[torch.Tensor] = None, training: bool = True + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Run a forward pass on a micro-batch, returning the logits and potentially the loss. + """ + with self._model_forward_context(): + # NOTE: Input sizes might be dynamic, e.g. when training with variable sequence lengths + # or during an eval loop, so we mark them as dynamic for torch.compile up-front to avoid + # recompiling later. + # In theory this could harm performance a bit when input sizes are actually static + # but so far I haven't noticed any dip in throughput with the models I've tested. + mark_dynamic(batch["input_ids"], (0, 1)) + if "doc_lens" in batch: + mark_dynamic(batch["doc_lens"], (0, 1)) + + schedule = self.train_pp_schedule if training else self.eval_pp_schedule + # shape: (batch_size, seq_len, vocab_size), (1,) + logits, loss = schedule.step( + input_ids=batch["input_ids"], + # attention_mask=micro_batch.get("attention_mask"), + # attention_bias=micro_batch.get("attention_bias"), + target=labels, + doc_lens=batch.get("doc_lens"), + max_doc_lens=batch.get("max_doc_lens"), + ) + if schedule.is_last_stage: + assert logits is not None + if not training and logits is not None and labels is not None and loss is None: + loss = self.eval_loss_fn(logits, labels) + return logits, loss + + def num_flops_per_token(self, seq_len: int) -> int: + return self.model_parts[0].num_flops_per_token(seq_len) + + @contextlib.contextmanager + def _train_microbatch_context( + self, micro_batch_idx: int, num_micro_batches: int + ) -> Generator[None, None, None]: + with contextlib.ExitStack() as stack: + for model in self.model_parts: + if isinstance(model, DDP) and micro_batch_idx != num_micro_batches - 1: + # For DDP, only sync gradients on the final micro batch. + stack.enter_context(model.no_sync()) + yield + + @contextlib.contextmanager + def _model_forward_context(self) -> Generator[None, None, None]: + with contextlib.ExitStack() as stack: + if self.autocast_precision is not None: + stack.enter_context(torch.autocast(self.device.type, dtype=self.autocast_precision)) + yield + + def _clear_loss_buffers(self): + self._batch_num_tokens_for_loss = None + self._ce_batch_loss = None + self._z_batch_loss = None + if self.moe_handler is not None: + self.moe_handler.clear_loss_buffers() + + def _get_state_dict(self, sd_options: dist_cp_sd.StateDictOptions) -> Dict[str, Any]: + return { + "model": { + k: v + for sd in map( + partial(dist_cp_sd.get_model_state_dict, options=sd_options), self.model_parts + ) + for k, v in sd.items() + }, + "optim": { + k: v + for sd in map( + partial(dist_cp_sd.get_optimizer_state_dict, options=sd_options), + self.model_parts, + self.optimizers, + ) + for k, v in sd.items() + }, + } + + def _clip_grad_norm( + self, max_grad_norm: float, norm_type: float = 2.0, foreach: Optional[bool] = None + ) -> torch.Tensor: + # Adapted from https://github.com/pytorch/torchtitan/blob/2a4437014e66bcf88a3f0419b816266e6326d539/torchtitan/utils.py#L348 + + parameters = [p for m in self.model_parts for p in m.parameters()] + grads = [p.grad for p in parameters if p.grad is not None] + + total_norm = nn.utils.get_total_norm( + grads, norm_type=norm_type, error_if_nonfinite=False, foreach=foreach + ) + + # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. + # We can simply reduce the DTensor to get the total norm in this tensor's process group + # and then convert it to a local tensor. + # NOTE: It has two purposes: + # 1. to make sure the total norm is computed correctly when PP is used (see below) + # 2. to return a reduced total_norm tensor whose .item() would return the correct value + if isinstance(total_norm, DTensor): + # Will reach here if any non-PP parallelism is used. + # If only using PP, total_norm will be a local tensor. + total_norm = total_norm.full_tensor() + + pp_mesh = self.train_pp_schedule.pp_mesh + if math.isinf(norm_type): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= norm_type + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / norm_type + + torch.nn.utils.clip_grads_with_norm_(parameters, max_grad_norm, total_norm, foreach=foreach) + return total_norm diff --git a/src/test/train/train_module/transformer_test.py b/src/test/train/train_module/transformer_pipeline_test.py similarity index 85% rename from src/test/train/train_module/transformer_test.py rename to src/test/train/train_module/transformer_pipeline_test.py index 00c78edd..6ed60942 100644 --- a/src/test/train/train_module/transformer_test.py +++ b/src/test/train/train_module/transformer_pipeline_test.py @@ -1,5 +1,7 @@ from olmo_core.distributed.parallel import PipelineScheduleType -from olmo_core.train.train_module.transformer import TransformerPipelineParallelConfig +from olmo_core.train.train_module.transformer_pipeline import ( + TransformerPipelineParallelConfig, +) def test_generate_pipeline_split_points():