Skip to content

Commit

Permalink
TensorRT-LLM import fix and aot_joint_export specify as explicit sett…
Browse files Browse the repository at this point in the history
…ing in dynamo.compile
  • Loading branch information
apbose committed Feb 25, 2025
1 parent 7ab637e commit 3e38e87
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 46 deletions.
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def cross_compile_for_windows(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -169,6 +170,7 @@ def cross_compile_for_windows(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -326,6 +328,7 @@ def cross_compile_for_windows(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": True,
"enable_weight_streaming": enable_weight_streaming,
"use_aot_joint_export": use_aot_joint_export,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -413,6 +416,7 @@ def compile(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -488,6 +492,7 @@ def compile(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -662,6 +667,7 @@ def compile(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
"use_aot_joint_export": use_aot_joint_export,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -950,6 +956,7 @@ def convert_exported_program_to_serialized_trt_engine(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -1013,6 +1020,7 @@ def convert_exported_program_to_serialized_trt_engine(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -1129,6 +1137,7 @@ def convert_exported_program_to_serialized_trt_engine(
"strip_engine_weights": strip_engine_weights,
"immutable_weights": immutable_weights,
"enable_weight_streaming": enable_weight_streaming,
"use_aot_joint_export": use_aot_joint_export,
}

settings = CompilationSettings(**compilation_options)
Expand Down
129 changes: 83 additions & 46 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import functools
import logging
import os
import subprocess
import sys
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
Expand All @@ -12,6 +14,7 @@
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt import _enums
from torch_tensorrt._enums import Platform
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -930,57 +933,91 @@ def load_tensorrt_llm() -> bool:
Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
try:
import tensorrt_llm as trt_llm # noqa: F401

_LOGGER.info("TensorRT-LLM successfully imported")
return True
except (ImportError, AssertionError) as e_import_error:
# Check for environment variable for the plugin library path
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if not plugin_lib_path:
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if not plugin_lib_path:
_LOGGER.warning(
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library",
)
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
"1",
"true",
"yes",
"on",
)
if not use_trtllm_plugin:
_LOGGER.warning(
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library"
)
return False

_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
try:
# Load the shared library
handle = ctypes.CDLL(plugin_lib_path)
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
_LOGGER.error(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
f"Ensure the path is correct and the library is compatible",
exc_info=e_os_error,
else:
py_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform = Platform.current_platform()
if Platform == Platform.LINUX_X86_64:
platform = "linux_x86_64"
elif Platform == Platform.LINUX_AARCH64:
platform = "linux_aarch64"

if py_version not in ("cp310", "cp312"):
_LOGGER.warning(
"No available wheel for python versions other than py3.10 and py3.12"
)
if py_version == "cp310" and platform == "linux_aarch64":
_LOGGER.warning("No available wheel for python3.10 with Linux aarch64")

base_url = "https://pypi.nvidia.com/tensorrt-llm/"
file_name = (
"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
)
return False
download_url = base_url + file_name
cmd = ["wget", download_url]
subprocess.run(cmd)
if os.path.exists(file_name):
_LOGGER.info("filename download is completed")
import zipfile

with zipfile.ZipFile(file_name, "r") as zip_ref:
zip_ref.extractall(
"./tensorrt_llm"
) # Extract to a folder named 'tensorrt_llm'
plugin_lib_path = (
"./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
)
try:
# Load the shared library
handle = ctypes.CDLL(plugin_lib_path)
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
_LOGGER.error(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
f"Ensure the path is correct and the library is compatible",
exc_info=e_os_error,
)
return False

try:
# Configure plugin initialization arguments
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
_LOGGER.warning(
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False
try:
# Configure plugin initialization arguments
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
_LOGGER.warning(
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False

try:
# Initialize the plugin
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
return True
else:
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
return False
except Exception as e_initialization_error:
_LOGGER.warning(
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
try:
# Initialize the plugin
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
return True
else:
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
return False
return False
except Exception as e_initialization_error:
_LOGGER.warning(
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
return False

0 comments on commit 3e38e87

Please sign in to comment.