Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: structured inputs for CudaGraphsTorchTensorRTModule #3407

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 60 additions & 19 deletions py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@
from __future__ import annotations

import logging
from typing import List, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple

import torch
import torch_tensorrt
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torch_tensorrt.dynamo import partitioning

logger = logging.getLogger(__name__)


def _unflatten_inputs(
flattened_inputs: Sequence[torch_tensorrt.Input],
compiled_module: torch.fx.GraphModule,
) -> Tuple[Any, Any]:
"""
Process inputs using tree_unflatten and tree_map to reconstructe inputs

Args:
flattened_inputs: Flattened input tensors to process
compiled_module: The compiled GraphModule containing input specifications

Returns:
Tuple of (args, kwargs) containing reconstructed input tensors
"""

def convert_input_to_cuda_tensor(input: Any) -> torch.Tensor:
if isinstance(input, torch_tensorrt.Input):
return input.torch_tensor.cuda()
else:
raise RuntimeError("Input is not a torch_tensorrt.Input")

# Reconstruct the (args, kwargs) structure that was flattened during export
pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec)
# Apply the tensor creation to the reconstructed structure
processed_inputs = tree_map(convert_input_to_cuda_tensor, pytree_inputs)

# Since inputs were originally flattened from (args, kwargs),
# processed_inputs is now that same tuple structure
return processed_inputs[0], processed_inputs[1]


class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules

Expand Down Expand Up @@ -42,14 +74,15 @@ def warm_up(self) -> None:
Warm up is necessary to ensure that memory allocations and initializations
are not recorded in cuda graphs
"""

with torch_tensorrt.logging.errors():
with unset_fake_temporarily():
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.compiled_module(*inputs_tensor)
self.compiled_module(*args, **kwargs)
torch.cuda.current_stream().wait_stream(s)

def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
Expand All @@ -73,15 +106,18 @@ def __del__(self) -> None:
if self.cudagraph:
self.cudagraph.reset()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
def forward(
self, *args: Any, **kwargs: Any
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
inputs, _ = tree_flatten((args, kwargs))
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
if cudagraphs_enabled:
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.inputs)
self._input_buffers = [None] * len(inputs)

self.is_weight_streaming_set = False
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
Expand All @@ -94,10 +130,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
for i in inputs
]
assert len(contiguous_inputs) == len(
self.inputs
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
inputs
), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}."

for i, _ in enumerate(self.inputs):
for i, _ in enumerate(inputs):
if not contiguous_inputs[i].is_cuda:
logger.warning(
f"Detected input[{i}] is not on a cuda device. "
Expand All @@ -112,15 +148,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
)

assert (
contiguous_inputs[i].dtype == self.inputs[i].dtype
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
contiguous_inputs[i].dtype == inputs[i].dtype
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()
else:
self._input_buffers[i].copy_(contiguous_inputs[i])

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()
else:
self._input_buffers[i].copy_(contiguous_inputs[i])
# Reconstruct the original args and kwargs structure from static input buffers
# using the input specification stored during module compilation
args, kwargs = tree_unflatten(
self._input_buffers, self.compiled_module._in_spec
)

self._caller_stream = torch.cuda.current_stream()
if (
Expand All @@ -135,9 +178,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
self._output_buffers = self.compiled_module(
*self._input_buffers
)
self._output_buffers = self.compiled_module(*args, **kwargs)

self.cudagraph.replay() # type: ignore
self._caller_stream.wait_stream(self._engine_stream)
Expand All @@ -154,4 +195,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
return self.compiled_module(*inputs)
return self.compiled_module(*args, **kwargs)
82 changes: 50 additions & 32 deletions tests/py/dynamo/runtime/test_004_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_tensorrt as torchtrt
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.utils import prepare_inputs

INPUT_SIZE = (64, 100)

Expand Down Expand Up @@ -302,45 +303,62 @@ def __init__(self):
self.layer2 = torch.nn.Linear(128, 64)
self.relu = torch.nn.ReLU()

def forward(self, x):
def forward(self, x, b=None, c=None, d=None, e=[]):
out = self.layer1(x)
out = out + b
if c is not None:
out = out * c
out = self.relu((out + 2.0) * 0.05)
if d is not None:
out = out - d["value"] + d["value2"]
out = self.layer2(out)
for n in e:
out += n
return out

inputs = torchtrt.Input(
min_shape=(1, 100),
opt_shape=(64, 100),
max_shape=(128, 100),
dtype=torch.float,
name="x",
)
model = SampleModel().eval().cuda()
input_list = []
input_list.append(torch.randn((8, 100)).cuda())
input_list.append(torch.randn((12, 100)).cuda())
input_list.append(torch.randn((12, 100)).cuda())
input_list.append(torch.randn((8, 100)).cuda())
input_list.append(torch.randn((8, 100)).cuda())

dynamic_shapes = (
{
0: torch.export.Dim("batch_size", min=1, max=128),
},
)
exp_program = torch.export.export(
model, (input_list[0],), dynamic_shapes=dynamic_shapes
)

for batch_size in [8, 12, 12, 8, 8]:
args = [torch.rand((batch_size, 100)).to("cuda")]
kwargs = {
"b": torch.rand((1, 128)).to("cuda"),
"d": {
"value": torch.rand(1).to("cuda"),
"value2": torch.tensor(1.2).to("cuda"),
},
"e": [torch.rand(1).to("cuda"), torch.rand(1).to("cuda")],
}
input_list.append((args, kwargs))

kwarg_torchtrt_input = prepare_inputs(input_list[0][1])

compile_spec = {
"inputs": [
torchtrt.Input(
min_shape=(1, 100),
opt_shape=(64, 100),
max_shape=(128, 100),
dtype=torch.float32,
name="x",
),
],
"kwarg_inputs": kwarg_torchtrt_input,
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"pass_through_build_failures": True,
"min_block_size": 1,
"ir": "dynamo",
"cache_built_engines": False,
"reuse_cached_engines": False,
"use_explicit_typing": True,
"enable_weight_streaming": True,
"torch_executed_ops": {"torch.ops.aten.mul.Tensor"},
"use_python_runtime": use_python_runtime,
}
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
optimized_model = torchtrt.dynamo.compile(
exp_program,
inputs,
min_block_size=1,
pass_through_build_failures=True,
use_explicit_typing=True,
enable_weight_streaming=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
use_python_runtime=use_python_runtime,
**compile_spec,
)

# List of tuples representing different configurations for three features:
Expand All @@ -361,12 +379,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list):
for i in range(len(input_list)):
if enable_weight_streaming and i == 4:
weight_streaming_ctx.device_budget = int(streamable_budget * 0.6)
out_list.append(optimized_model(input_list[i]))
out_list.append(optimized_model(*input_list[i][0], **input_list[i][1]))
return out_list

ref_out_list = []
for i in range(len(input_list)):
ref_out_list.append(model(input_list[i]))
ref_out_list.append(model(*input_list[i][0], **input_list[i][1]))

pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs(
optimized_model
Expand Down
Loading