diff --git a/recipes/configs/generation.yaml b/recipes/configs/generation.yaml index c2081a1ed7..0171310aef 100644 --- a/recipes/configs/generation.yaml +++ b/recipes/configs/generation.yaml @@ -1,4 +1,9 @@ -# Config for running the InferenceRecipe in generate.py to generate output from an LLM +# Config for running the InferenceRecipe in generate.py to generate output +# from Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --ignore-patterns "*.safetensors" --hf-token # # To launch, run the following command from root torchtune directory: # tune run generate --config generation diff --git a/recipes/configs/llama3/70B_generation_distributed.yaml b/recipes/configs/llama3/70B_generation_distributed.yaml new file mode 100644 index 0000000000..78c77ba263 --- /dev/null +++ b/recipes/configs/llama3/70B_generation_distributed.yaml @@ -0,0 +1,50 @@ +# Config for running the InferenceRecipe in dev/generate_v2.py to generate output +# using a Llama3 70B Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token +# +# To launch, run the following command from root torchtune directory: +# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed + +output_dir: ./ + +# Model arguments +model: + _component_: torchtune.models.llama3.llama3_70b + +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + +# Transform arguments +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model + prompt_template: null + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO + +# Generation arguments +prompt: + system: null + user: + text: Tell a joke. +max_new_tokens: 200 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 diff --git a/recipes/configs/llama3_1/70B_generation_distributed.yaml b/recipes/configs/llama3_1/70B_generation_distributed.yaml new file mode 100644 index 0000000000..d71a94f8de --- /dev/null +++ b/recipes/configs/llama3_1/70B_generation_distributed.yaml @@ -0,0 +1,50 @@ +# Config for running the InferenceRecipe in dev/generate_v2.py to generate output +# using a Llama3.1 70B Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token +# +# To launch, run the following command from root torchtune directory: +# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_1/70B_generation_distributed + +output_dir: ./ + +# Model arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_70b + +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + +# Transform arguments +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model + prompt_template: null + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO + +# Generation arguments +prompt: + system: null + user: + text: Tell a joke. +max_new_tokens: 200 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 diff --git a/recipes/configs/llama3_2_vision/11B_generation_v2.yaml b/recipes/configs/llama3_2_vision/11B_generation_v2.yaml index c78e0e52b6..11fd14f8d1 100644 --- a/recipes/configs/llama3_2_vision/11B_generation_v2.yaml +++ b/recipes/configs/llama3_2_vision/11B_generation_v2.yaml @@ -7,7 +7,7 @@ # To launch, run the following command from root torchtune directory: # tune run dev/generate_v2 --config llama3_2_vision/generation_v2 -output_dir: ./ # Not needed +output_dir: ./ # Model arguments model: diff --git a/recipes/configs/llama3_3/70B_generation_distributed.yaml b/recipes/configs/llama3_3/70B_generation_distributed.yaml new file mode 100644 index 0000000000..d39acf45ad --- /dev/null +++ b/recipes/configs/llama3_3/70B_generation_distributed.yaml @@ -0,0 +1,50 @@ +# Config for running the InferenceRecipe in dev/generate_v2.py to generate output +# using a Llama3.1 70B Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token +# +# To launch, run the following command from root torchtune directory: +# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_3/70B_generation_distributed + +output_dir: ./ + +# Model arguments +model: + _component_: torchtune.models.llama3_3.llama3_3_70b + +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + +# Transform arguments +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model + prompt_template: null + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO + +# Generation arguments +prompt: + system: null + user: + text: Tell a joke. +max_new_tokens: 200 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 diff --git a/recipes/dev/generate_v2.py b/recipes/dev/generate_v2.py index b94b616e47..66329e8367 100644 --- a/recipes/dev/generate_v2.py +++ b/recipes/dev/generate_v2.py @@ -39,18 +39,22 @@ def __call__(self, prompt: Dict[str, Any]) -> List[Message]: # Iterate through roles and add content for role, content in prompt.items(): - if isinstance(content, str): + if content is None: + continue + elif isinstance(content, str): new_content = [{"type": "text", "content": content}] - else: - assert ( - "image" in content.keys() - ), "Multiple entries per role expect an image key" + elif "image" in content.keys(): image_loc = content["image"] image = load_image(image_loc) new_content = [ {"type": "image", "content": image}, {"type": "text", "content": content["text"]}, ] + else: + assert ( + "text" in content.keys() + ), "Multiple entries per role expect at least a text key" + new_content = [{"type": "text", "content": content["text"]}] messages.append(Message(role=role, content=new_content)) # Finally, add an empty assistant message to kick-start generation @@ -109,12 +113,12 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" ) self._logger.info( - f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" + f"Bandwidth achieved: {model_size * tokens_per_second / (1024**3):.02f} GiB/s" ) if self._device.type != "cpu": torch_device = utils.get_torch_device_namespace() self._logger.info( - f"Max memory allocated: {torch_device.max_memory_allocated() / 1e9:.02f} GB" + f"Max memory allocated: {torch_device.max_memory_allocated() / (1024**3):.02f} GiB" ) @torch.inference_mode() diff --git a/recipes/dev/generate_v2_distributed.py b/recipes/dev/generate_v2_distributed.py new file mode 100644 index 0000000000..48a147bd15 --- /dev/null +++ b/recipes/dev/generate_v2_distributed.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +from omegaconf import DictConfig, OmegaConf +from torch.distributed.tensor.parallel import parallelize_module + +from torchtune import config, training, utils +from torchtune.data import load_image, Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample + +from torchtune.modules.transforms import Transform + + +class SingleTurnYAMLToMessages(Transform): + """ + Converts a single turn conversation in YAML format to a list of messages. + + Expects the YAML to look like: + system: You are a helpful AI assistant. + user: What is the capital of France? + + or if it includes an image: + system: You are a helpful AI assistant. + user: + image: url or path_to_image + text: Describe the image in detail. + """ + + def __call__(self, prompt: Dict[str, Any]) -> List[Message]: + messages = [] + + # Iterate through roles and add content + for role, content in prompt.items(): + if content is None: + continue + elif isinstance(content, str): + new_content = [{"type": "text", "content": content}] + elif "image" in content.keys(): + image_loc = content["image"] + image = load_image(image_loc) + new_content = [ + {"type": "image", "content": image}, + {"type": "text", "content": content["text"]}, + ] + else: + assert ( + "text" in content.keys() + ), "Multiple entries per role expect at least a text key" + new_content = [{"type": "text", "content": content["text"]}] + messages.append(Message(role=role, content=new_content)) + + # Finally, add an empty assistant message to kick-start generation + messages.append(Message(role="assistant", content="")) + return messages + + +class InferenceRecipe: + """ + Recipe for generating tokens from a dense Transformer-based LLM. + This works for text-only generation and image-text generation. + + Supports distributed inference using Tensor Paralellism(TP) for + large models that don't fit on a single GPU. For more information + on TP, see: https://pytorch.org/docs/stable/distributed.tensor.parallel.html. + + This *does not* currently support the following features: + - torch.compile + - quantization through torchao + - batch generation + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) + self._logger = utils.get_logger(cfg.log_level) + # Set up distributed env + dist.init_process_group(backend="nccl") + _, rank = utils.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + training.set_seed(seed=cfg.seed) + + def setup(self, cfg: DictConfig) -> None: + """Setup the model and transforms.""" + # Load checkpointer and state_dict + _checkpointer = config.instantiate(cfg.checkpointer) + _ckpt_dict = _checkpointer.load_checkpoint() + + # Instantiate model on meta device + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg.model) + + # Set up tensor parallel device mesh + tp_degree = dist.get_world_size() # Using all GPUs for TP + tp_mesh_shape = (tp_degree,) + tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) + + # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell + training.prepare_mha_for_tp(model, tp_device_mesh) + parallelize_module( + model, + tp_device_mesh, + parallelize_plan=config.instantiate(cfg.parallelize_plan), + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model=model, + full_sd=_ckpt_dict[training.MODEL_KEY], + device=self._device, + strict=True, + cpu_offload=False, + ) + + self.model = model + if self._is_rank_zero: + self._logger.info( + f"Model was initialized with precision {self._dtype} and TP degree {tp_degree}." + ) + + # Instantiate transforms + self.model_transform = config.instantiate(cfg.tokenizer) + self.to_messages = SingleTurnYAMLToMessages() + + def log_metrics(self, total_time: int, tokens_per_second: float) -> None: + """Logs the following metrics: total time for inference, tokens/sec, + bandwidth achieved, and max memory allocated. + + Feel free to modify this function to log additional metrics. + """ + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(self.model.parameters(), self.model.buffers()) + ] + ) + self._logger.info( + f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" + ) + self._logger.info( + f"Bandwidth achieved: {model_size * tokens_per_second / (1024**3):.02f} GiB/s" + ) + if self._device.type != "cpu": + torch_device = utils.get_torch_device_namespace() + self._logger.info( + f"Max memory allocated: {torch_device.max_memory_allocated() / (1024**3):.02f} GiB" + ) + + @torch.inference_mode() + def generate(self, cfg: DictConfig): + """The main entry point for generating tokens from a prompt.""" + # 1. Convert input to messages + messages = self.to_messages(OmegaConf.to_container(cfg.prompt)) + is_multimodal_input = any([m.contains_media for m in messages]) + + # 2. Apply model transform + model_inputs = self.model_transform({"messages": messages}, inference=True) + seq_len = len(model_inputs["tokens"]) + total_response_length = seq_len + cfg.max_new_tokens + + # 3. Setup KV cache + with self._device: + self.model.setup_caches( + batch_size=1, + dtype=self._dtype, + encoder_max_seq_len=( + self.model_transform.image_seq_len if is_multimodal_input else None + ), + decoder_max_seq_len=total_response_length, + ) + + # 4. Pre-allocate causal mask and input_pos + causal_mask = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + device=self._device, + ) + ) + input_pos = torch.arange(total_response_length) + + # 5. Collate to batch size of 1 and tensor-ify + batch = {} + if is_multimodal_input: + batch = padded_collate_tiled_images_and_mask( + [model_inputs], + pad_direction="left", + pad_max_images=1, + pad_max_tiles=self.model_transform.max_num_tiles, + ) + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + prompt = batch.pop("tokens").to(self._device) + else: + prompt = torch.tensor( + model_inputs["tokens"], device=self._device + ).unsqueeze(0) + batch["mask"] = causal_mask[None, :seq_len] + batch["input_pos"] = input_pos[None, :seq_len] + utils.batch_to_device(batch, self._device) + + # 6. Prefill step + generated_tokens = [] + t0 = time.perf_counter() + logits = self.model(prompt, **batch)[:, -1] + token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) + generated_tokens.append(token.item()) + + if is_multimodal_input: + # Don't need image info b/c we only support 1 image and it's been + # processed by the model now + batch.pop("encoder_input") + batch["encoder_mask"] = batch["encoder_mask"][:, -1:] + + # 7. Continue generating + for i in range(cfg.max_new_tokens): + + # Update position and mask for incremental decoding + batch["input_pos"] = input_pos[None, seq_len] + batch["mask"] = causal_mask[None, seq_len, None, :] + + if token.item() in self.model_transform.stop_tokens: + break + + logits = self.model(token, **batch)[:, -1] + token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) + generated_tokens.append(token.item()) + seq_len += 1 + + t = time.perf_counter() - t0 + + # 8. Translate tokens back to text + decoded = self.model_transform.decode(generated_tokens) + if self._is_rank_zero: + self._logger.info(f"\n\n{decoded}\n") + + # 9. Log metrics + tokens_per_second = len(generated_tokens) / t + if self._is_rank_zero: + self.log_metrics(total_time=t, tokens_per_second=tokens_per_second) + + +@config.parse +def main(cfg: DictConfig) -> None: + config.log_config(recipe_name="InferenceRecipe", cfg=cfg) + recipe = InferenceRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.generate(cfg=cfg) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 2e5e16da9a..3fe2dd340d 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -10,20 +10,24 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from packaging import version from tests.test_utils import gpu_test from torch.distributed import launcher - from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointWrapper, ) +from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_fsdp import FSDPTest, MLP from torchao.dtypes.nf4tensor import NF4Tensor from torchtune import modules, training from torchtune.models.llama2._component_builders import lora_llama2 -from torchtune.modules import TransformerSelfAttentionLayer +from torchtune.models.llama3_1._component_builders import llama3_mlp +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.modules import RMSNorm, TransformerSelfAttentionLayer +from torchtune.modules.attention import MultiHeadAttention from torchtune.modules.peft import ( DoRALinear, get_adapter_params, @@ -379,3 +383,57 @@ def _broadcast_full_state_dict(self, full_sd): result.append(None) torch.distributed.broadcast_object_list(result, src=0) return result[0] + + +class TestTensorParalell(MultiProcessTestCase): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_prepare_mha_for_tp(self) -> None: + """Test tensor parallelism preparation for multi-head attention.""" + # Create a device mesh for tensor parallelism + mesh = dist.init_device_mesh("cuda", mesh_shape=(2,)) + + # Parameters for TransformerSelfAttentionLayer + embed_dim = 64 + hidden_dim = 64 + num_heads = 4 + num_kv_heads = 4 + max_seq_len = 128 + rope_base = 500000 + head_dim = embed_dim // num_heads + rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=0.0, + ) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + decoder_layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=1e-5), + mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5), + ) + + orig_num_heads = self_attn.num_heads + orig_num_kv_heads = self_attn.num_kv_heads + orig_embed_dim = self_attn.embed_dim + + # Apply tensor parallelism preparation + decoder_layer = training.prepare_mha_for_tp(decoder_layer, mesh) + + # Verify that parameters were scaled correctly + assert decoder_layer.attn.num_heads == orig_num_heads // 2 + assert decoder_layer.attn.num_kv_heads == orig_num_kv_heads // 2 + assert decoder_layer.attn.embed_dim == orig_embed_dim // 2 diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 4b504315a9..1c41519712 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -433,6 +433,25 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="dev/generate_v2_distributed", + file_path="dev/generate_v2_distributed.py", + configs=[ + Config( + name="llama3/70B_generation_distributed", + file_path="llama3/70B_generation_distributed.yaml", + ), + Config( + name="llama3_1/70B_generation_distributed", + file_path="llama3_1/70B_generation_distributed.yaml", + ), + Config( + name="llama3_3/70B_generation_distributed", + file_path="llama3_3/70B_generation_distributed.yaml", + ), + ], + supports_distributed=True, + ), Recipe( name="dev/early_exit_finetune_distributed", file_path="dev/early_exit_finetune_distributed.py", diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 90de8c286f..5cf4e6b616 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -15,6 +15,7 @@ qlora_llama3_70b, qlora_llama3_8b, ) +from ._parallelism import base_llama_tp_plan from ._tokenizer import Llama3Tokenizer __all__ = [ @@ -28,4 +29,5 @@ "lora_llama3_70b", "qlora_llama3_8b", "qlora_llama3_70b", + "base_llama_tp_plan", ] diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py new file mode 100644 index 0000000000..6046f0e83c --- /dev/null +++ b/torchtune/models/llama3/_parallelism.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +from torch.distributed._tensor import Replicate +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.parallel.style import ParallelStyle + + +# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models +BASE_LLAMA_TP_PLAN = { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + "layers.*.attn.q_proj": ColwiseParallel(), + "layers.*.attn.k_proj": ColwiseParallel(), + "layers.*.attn.v_proj": ColwiseParallel(), + "layers.*.attn.output_proj": RowwiseParallel(), + "layers.*.mlp.w1": ColwiseParallel(), + "layers.*.mlp.w2": RowwiseParallel(), + "layers.*.mlp.w3": ColwiseParallel(), +} + + +def base_llama_tp_plan() -> Dict[str, ParallelStyle]: + """ + Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models + + Returns: + Dict[str, Any]: The tensor parallel plan for Llama3 model. + """ + return BASE_LLAMA_TP_PLAN diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index f9da10095a..4f035f92c5 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -20,7 +20,6 @@ from torchtune.models.llama3_2_vision._transform import Llama3VisionTransform from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.peft import LORA_ATTN_MODULES -from torchtune.modules.tokenizers import parse_hf_tokenizer_json def llama3_2_vision_transform( diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 9dd31246c3..d461d84dc4 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -18,6 +18,7 @@ is_distributed, load_from_full_model_state_dict, load_from_full_optimizer_state_dict, + prepare_mha_for_tp, set_torch_num_threads, shard_model, validate_no_params_on_meta_device, @@ -74,6 +75,7 @@ __all__ = [ "get_act_offloading_ctx_manager", + "prepare_mha_for_tp", "apply_selective_activation_checkpointing", "get_dtype", "set_default_dtype", diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 4001db768b..ff959c5f23 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -25,11 +25,14 @@ set_optimizer_state_dict, StateDictOptions, ) +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune.modules import TransformerDecoder +from torchtune.modules.attention import MultiHeadAttention +from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.peft import get_adapter_state_dict from torchtune.utils import get_device, get_logger from torchtune.utils._logging import deprecated @@ -201,7 +204,7 @@ def load_from_full_model_state_dict( for param in model.parameters() ) meta_sharded_sd = model.state_dict() - # NF4Tensor is not supported in `set_model_state_dict` right now, running with the privious logic right + # NF4Tensor is not supported in `set_model_state_dict` right now, running with the previous logic right # now, would support in the future and remove the following code if _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE and not has_nf4: for param_name in full_sd.keys(): @@ -546,3 +549,64 @@ def shard_model( # Finally shard the entire model to account for any stragglers fully_shard(model, **fsdp_kwargs) + + +def prepare_mha_for_tp( + model: nn.Module, + tp_mesh: DeviceMesh, +) -> nn.Module: + """ + Utility to scale MultiHeadAttention parameters(num_heads, num_kv_heads, embed_dim) across + tensor parallel devices. Each device will handle a portion of the attention computations. + + Args: + model (nn.Module): Model whose attention parameters will be scaled by TP size. + tp_mesh (DeviceMesh): Tensor parallel device mesh. + + Returns: + nn.Module: The model with scaled MultiHeadAttention parameters. + + Raises: + ValueError: If attention heads, kv heads, or embed dimension is not divisible by TP size. + + Examples: + >>> from torchtune.modules import TransformerDecoder + >>> from torch.distributed.device_mesh import DeviceMesh + >>> model = TransformerDecoder( + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + ) + >>> tp_mesh = DeviceMesh("cuda", torch.arange(2)) # 2 GPUs + >>> model = prepare_mha_for_tp(model, tp_mesh) + >>> # Now each GPU has: + >>> # num_heads = 16 (32/2) + >>> # num_kv_heads = 16 (32/2) + >>> # embed_dim = 2048 (4096/2) + """ + # Consider the case of Deep Fusion models + if isinstance(model, DeepFusionModel): + model = model.decoder + tp_size = tp_mesh.size() + for m in list(model.modules()): + if isinstance(m, MultiHeadAttention): + # Adjust attention module to use the local number of heads + if m.num_heads % tp_size != 0: + raise ValueError( + f"Number of attention heads ({m.num_heads}) must be divisible by " + f"tensor parallel size ({tp_size})." + ) + if m.num_kv_heads % tp_size != 0: + raise ValueError( + f"Number of KV heads ({m.num_kv_heads}) must be divisible by " + f"tensor parallel size ({tp_size})." + ) + if m.embed_dim % tp_size != 0: + raise ValueError( + f"Embedding dimension ({m.embed_dim}) must be divisible by " + f"tensor parallel size ({tp_size})." + ) + m.num_heads = m.num_heads // tp_size + m.num_kv_heads = m.num_kv_heads // tp_size + m.embed_dim = m.embed_dim // tp_size + return model