Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor parallel documentation #3359

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Tutorials
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`pre_allocated_output_example`
* :ref:`tensor_parallel_llama3`

.. toctree::
:caption: Tutorials
Expand All @@ -87,6 +88,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
tutorials/_rendered_examples/distributed_inference/tensor_parallel_llama3

Dynamo Frontend
----------------
Expand Down
50 changes: 0 additions & 50 deletions examples/distributed_inference/README.md

This file was deleted.

83 changes: 83 additions & 0 deletions examples/distributed_inference/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
.. _tensor_parallel_llama3:

Torch-TensorRT Parallelism for Distributed Inference
====================================================

Examples in this folder demonstrate distributed inference on multiple devices with the Torch-TensorRT backend.

Data Parallel Distributed Inference based on `Accelerate <https://huggingface.co/docs/accelerate/usage_guides/distributed_inference>`_
-----------------------------------------------------------------------------------------------------------------------------------------

Using Accelerate, users can achieve data parallel distributed inference with the Torch-TensorRT backend.
In this case, the entire model will be loaded onto each GPU, and different chunks of batch input are processed on each device.

See the examples:

- `data_parallel_gpt2.py <https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_gpt2.py>`_
- `data_parallel_stable_diffusion.py <https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_stable_diffusion.py>`_

for more details.

Tensor Parallel Distributed Inference
--------------------------------------

Here, we use `torch.distributed` as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

.. code-block:: bash

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

Tensor Parallel Distributed Inference on a Simple Model using NCCL Ops Plugin
------------------------------------------------------------------------------

We use `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ to shard the model with Tensor parallelism.
The distributed operations (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation.
The `converters for these operators <https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py#L25-L55>`_ are already available in Torch-TensorRT.
The functional implementation of ops is imported from the `tensorrt_llm` package (specifically, `libnvinfer_plugin_tensorrt_llm.so` is required).

We have two options:

Option 1: Install TensorRT-LLM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets only recommend option 2 at this point with the fetching tool you are making

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Follow the instructions to `install TensorRT-LLM <https://nvidia.github.io/TensorRT-LLM/installation/linux.html>`_.
Please note that before installing TensorRT-LLM, you need to

1. apt install libmpich-dev
2. apt install libopenmpi-dev

If the default installation fails due to issues like library version mismatches or Python compatibility, consider using Option 2.
After a successful installation, test by running:

.. code-block:: python

import torch_tensorrt

to ensure it works without errors.
The import might fail if `tensorrt_llm` overrides `torch_tensorrt` dependencies.
Option 2 is preferable if you do not wish to install `tensorrt_llm` and its dependencies.

Option 2: Link the TensorRT-LLM Directly
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Alternatively, you can load `libnvinfer_plugin_tensorrt_llm.so` manually:

1. Download the `tensorrt_llm-0.16.0 <https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-0.16.0-cp310-cp310-linux_x86_64.whl#sha256=f86c6b89647802f49b26b4f6e40824701da14c0f053dbda3e1e7a8709d6939c7>`_ wheel file from NVIDIA's Python index.
2. Extract the wheel file to a directory and locate `libnvinfer_plugin_tensorrt_llm.so` under the `tensorrt_llm/libs` directory.
3. Set the environment variable `TRTLLM_PLUGINS_PATH` to the extracted path at the `initialize_distributed_env() <https://github.com/pytorch/TensorRT/blob/54e36dbafe567c75f36b3edb22d6f49d4278c12a/examples/distributed_inference/tensor_parallel_initialize_dist.py#L45>`_ call.

After configuring TensorRT-LLM or the TensorRT-LLM plugin library path, run the following command to illustrate tensor parallelism of a simple model and compilation with Torch-TensorRT:

.. code-block:: bash

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py

We also provide a tensor parallelism compilation example on a more advanced model like `Llama-3`. Run the following command:

.. code-block:: bash

mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py

Tutorials
-----------------------------------------
* :ref:`tensor_parallel_llama3`: Illustration of distributed inference on multiple devices with the Torch-TensorRT backend.
4 changes: 4 additions & 0 deletions examples/distributed_inference/llama3_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This file contains the Llama3 model example used for tensor parallel distribution
"""

# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This script contains utility functions for Tensor Parallelism
using Torch-TensorRT. It sets up the necessary communication protocols,
environments and partitions the model across multiple GPUs.
"""

import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down
92 changes: 88 additions & 4 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
"""
.. _tensor_parallel_llama3:

Torch distributed example for llama3-7B model
======================================================

As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import logging
import os
import time

import torch
import torch_tensorrt

# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model
# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
Expand All @@ -14,11 +29,26 @@
checkpoint_wrapper,
)

# %%
# Initialize the distributed environment
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The following steps are performed:
#
# - Initialize the communicators and the distributed environment
# - Set the path for the `TRT-LLM`` plugin `.so` file, which is required for the NCCL operations in Torch-TRT backend.
# - Initialize the logger:
#
# - Example: In a 2-GPU setup, the log files will be:
# - `./tensor_parallel_llama3_0.log`
# - `./tensor_parallel_llama3_1.log`
#
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

# %%
# Model initialization with torch distributed parallel plan
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand All @@ -36,7 +66,59 @@
)

with torch.no_grad():
# The plan is
# plan = {
# "attention": PrepareModuleInput(
# input_layouts=(Shard(1), None),
# desired_input_layouts=(Replicate(), None),
# ),
# "attention.wq": ColwiseParallel(),
# "attention.wk": ColwiseParallel(),
# "attention.wv": ColwiseParallel(),
# "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
# "attention_norm": SequenceParallel(),
# "feed_forward": PrepareModuleInput(
# input_layouts=(Shard(1),),
# desired_input_layouts=(Replicate(),),
# ),
# "feed_forward.w1": ColwiseParallel(),
# "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
# "feed_forward.w3": ColwiseParallel(),
# "ffn_norm": SequenceParallel(),
# }

model = ParallelTransformer(model_args, device_mesh)

# %%
# Model inference with Torch-TensorRT backend
# -------------------------------------------
# When we compile the distributed model using the **Torch-TensorRT** backend, PyTorch's distributed libraries:
#
# - Create the **sharded model** across multiple GPUs.
# - Use **communicator operations** to ensure proper communication.
#
# The following components manage different aspects of parallelism:
#
# - **`ColwiseParallel`** and **`RowwiseParallel`**:
# - Shard the attention layers in **column-wise** or **row-wise** fashion.
#
# - **`SequenceParallel`**:
# - Performs **sharded computations** of the normalization layer.
#
# - **`PrepareModuleInput`**:
# - Configures the model input with proper **communication operations**.
#
# **NCCL Operations in TensorRT-LLM:**
#
# - The **TensorRT-LLM NCCL plugins** handle distributed backend NCCL operations, preventing **graph breaks**.
# - Depending on the **DTensor sharding layout**, proper **communication operations** are required to transform the DTensor layout.
#
# **Common NCCL Operations Used:**
#
# - `allreduce`
# - `allgather`
# - `reduce_scatter`
#
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -62,9 +144,11 @@
output = model(inp)
end = time.time()
if i == 0:
# Logging the Compilation time
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
# Logging the inference time
logger.info(f"Inference time is {end-start}")
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This file contains the Tensor parallel simple model example used for tensor parallel distribution
"""

import time

import tensorrt as trt
Expand All @@ -15,7 +19,6 @@
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
Expand Down
Loading