Skip to content

Commit

Permalink
README.rst for examples/distributed_inference/tensor_parallel_llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Feb 14, 2025
1 parent 67115d5 commit 61f57c0
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 49 deletions.
2 changes: 1 addition & 1 deletion docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Tutorials
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`pre_allocated_output_example`
* :ref:`tensor_parallel_llama`
* :ref:`tensor_parallel_llama3`

.. toctree::
:caption: Tutorials
Expand Down
47 changes: 0 additions & 47 deletions examples/distributed_inference/README.md

This file was deleted.

77 changes: 77 additions & 0 deletions examples/distributed_inference/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
Torch-TensorRT Parallelism for Distributed Inference
====================================================

Examples in this folder demonstrate distributed inference on multiple devices with the Torch-TensorRT backend.

Data Parallel Distributed Inference based on `Accelerate <https://huggingface.co/docs/accelerate/usage_guides/distributed_inference>`_
---------------------------------------------------------------------------------------------------------------

Using Accelerate, users can achieve data parallel distributed inference with the Torch-TensorRT backend.
In this case, the entire model will be loaded onto each GPU, and different chunks of batch input are processed on each device.

See the examples:

- `data_parallel_gpt2.py <https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_gpt2.py>`_
- `data_parallel_stable_diffusion.py <https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_stable_diffusion.py>`_

for more details.

Tensor Parallel Distributed Inference
--------------------------------------

Here, we use `torch.distributed` as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

.. code-block:: bash
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
Tensor Parallel Distributed Inference on a Simple Model using NCCL Ops Plugin
------------------------------------------------------------------------------

We use `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ to shard the model with Tensor parallelism.
The distributed operations (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation.
The `converters for these operators <https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py#L25-L55>`_ are already available in Torch-TensorRT.
The functional implementation of ops is imported from the `tensorrt_llm` package (specifically, `libnvinfer_plugin_tensorrt_llm.so` is required).

We have two options:

Option 1: Install TensorRT-LLM
-------------------------------

Follow the instructions to `install TensorRT-LLM <https://nvidia.github.io/TensorRT-LLM/installation/linux.html>`_.

If the default installation fails due to issues like library version mismatches or Python compatibility, consider using Option 2.
After a successful installation, test by running:

.. code-block:: python
import torch_tensorrt
to ensure it works without errors.
The import might fail if `tensorrt_llm` overrides `torch_tensorrt` dependencies.
Option 2 is preferable if you do not wish to install `tensorrt_llm` and its dependencies.

Option 2: Link the TensorRT-LLM Directly
-----------------------------------------

Alternatively, you can load `libnvinfer_plugin_tensorrt_llm.so` manually:

1. Download the `tensorrt_llm-0.16.0 <https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-0.16.0-cp310-cp310-linux_x86_64.whl#sha256=f86c6b89647802f49b26b4f6e40824701da14c0f053dbda3e1e7a8709d6939c7>`_ wheel file from NVIDIA's Python index.
2. Extract the wheel file to a directory and locate `libnvinfer_plugin_tensorrt_llm.so` under the `tensorrt_llm/libs` directory.
3. Set the environment variable `TRTLLM_PLUGINS_PATH` to the extracted path at the `initialize_distributed_env() <https://github.com/pytorch/TensorRT/blob/54e36dbafe567c75f36b3edb22d6f49d4278c12a/examples/distributed_inference/tensor_parallel_initialize_dist.py#L45>`_ call.

After configuring TensorRT-LLM or the TensorRT-LLM plugin library path, run the following command to illustrate tensor parallelism of a simple model and compilation with Torch-TensorRT:

.. code-block:: bash
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
We also provide a tensor parallelism compilation example on a more advanced model like `Llama-3`. Run the following command:

.. code-block:: bash
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
Tutorials
-----------------------------------------
* :ref:`tensor_parallel_llama3`: Illustration of distributed inference on multiple devices with the Torch-TensorRT backend.
3 changes: 2 additions & 1 deletion examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
.. _tensor_parallel_llama:
.. _tensor_parallel_llama3:
Torch distributed example for llama3-7B model
======================================================
Expand All @@ -16,6 +16,7 @@
import time

import torch
import torch_tensorrt

# %%
# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model
Expand Down

0 comments on commit 61f57c0

Please sign in to comment.