From 15df89f890d9934eca73caedadcded9e642dc2cb Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 24 Feb 2025 13:51:43 +0000 Subject: [PATCH 1/2] fix: structured inputs for CudaGraphsTorchTensorRTModule --- .../runtime/_CudaGraphsTorchTensorRTModule.py | 79 ++++++++++++++----- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index e860c5762f..bb6e634853 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -1,16 +1,48 @@ from __future__ import annotations import logging -from typing import List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import torch import torch_tensorrt from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch_tensorrt.dynamo import partitioning logger = logging.getLogger(__name__) +def _unflatten_inputs( + flattened_inputs: Sequence[torch_tensorrt.Input], + compiled_module: torch.fx.GraphModule, +) -> Tuple[Any, Any]: + """ + Process inputs using tree_unflatten and tree_map to reconstructe inputs + + Args: + flattened_inputs: Flattened input tensors to process + compiled_module: The compiled GraphModule containing input specifications + + Returns: + Tuple of (args, kwargs) containing reconstructed input tensors + """ + + def convert_input_to_cuda_tensor(input: Any) -> torch.Tensor: + if isinstance(input, torch_tensorrt.Input): + return input.torch_tensor.cuda() + else: + raise RuntimeError("Input is not a torch_tensorrt.Input") + + # Reconstruct the (args, kwargs) structure that was flattened during export + pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec) + # Apply the tensor creation to the reconstructed structure + processed_inputs = tree_map(convert_input_to_cuda_tensor, pytree_inputs) + + # Since inputs were originally flattened from (args, kwargs), + # processed_inputs is now that same tuple structure + return processed_inputs[0], processed_inputs[1] + + class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """This Wrapper runtime module is to record/replay whole cuda graph in sub modules @@ -42,14 +74,15 @@ def warm_up(self) -> None: Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs """ + with torch_tensorrt.logging.errors(): with unset_fake_temporarily(): - inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs] + args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - self.compiled_module(*inputs_tensor) + self.compiled_module(*args, **kwargs) torch.cuda.current_stream().wait_stream(s) def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: @@ -73,7 +106,10 @@ def __del__(self) -> None: if self.cudagraph: self.cudagraph.reset() - def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + def forward( + self, *args: Any, **kwargs: Any + ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + inputs, _ = tree_flatten((args, kwargs)) cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) @@ -81,7 +117,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if need_cudagraphs_record: if self.cudagraph: self.cudagraph.reset() - self._input_buffers = [None] * len(self.inputs) + self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False # Ensure inputs are available in all scopes and cast symbolic integers to Tensors @@ -94,10 +130,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . for i in inputs ] assert len(contiguous_inputs) == len( - self.inputs - ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." + inputs + ), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}." - for i, _ in enumerate(self.inputs): + for i, _ in enumerate(inputs): if not contiguous_inputs[i].is_cuda: logger.warning( f"Detected input[{i}] is not on a cuda device. " @@ -112,15 +148,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ) assert ( - contiguous_inputs[i].dtype == self.inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + contiguous_inputs[i].dtype == inputs[i].dtype + ), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + self._input_buffers[i] = contiguous_inputs[i].clone() + else: + self._input_buffers[i].copy_(contiguous_inputs[i]) if need_cudagraphs_record: - # If cudagraphs is enabled, this memory is reserved for future cudagraph runs - # Clone is required to avoid re-using user-provided GPU memory - self._input_buffers[i] = contiguous_inputs[i].clone() - else: - self._input_buffers[i].copy_(contiguous_inputs[i]) + # Reconstruct the original args and kwargs structure from static input buffers + # using the input specification stored during module compilation + args, kwargs = tree_unflatten( + self._input_buffers, self.compiled_module._in_spec + ) self._caller_stream = torch.cuda.current_stream() if ( @@ -135,9 +178,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): - self._output_buffers = self.compiled_module( - *self._input_buffers - ) + self._output_buffers = self.compiled_module(*args, **kwargs) self.cudagraph.replay() # type: ignore self._caller_stream.wait_stream(self._engine_stream) @@ -154,4 +195,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.cudagraph: self.cudagraph.reset() self.cudagraph = None - return self.compiled_module(*inputs) + return self.compiled_module(*args, **kwargs) From 21ab1b85b06bd241babb737c90873baa5bc63c67 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Tue, 25 Feb 2025 05:39:03 +0000 Subject: [PATCH 2/2] chore: Update test case --- .../runtime/test_004_weight_streaming.py | 82 +++++++++++-------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 78522388d1..67d69df381 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo.utils import prepare_inputs INPUT_SIZE = (64, 100) @@ -302,45 +303,62 @@ def __init__(self): self.layer2 = torch.nn.Linear(128, 64) self.relu = torch.nn.ReLU() - def forward(self, x): + def forward(self, x, b=None, c=None, d=None, e=[]): out = self.layer1(x) + out = out + b + if c is not None: + out = out * c out = self.relu((out + 2.0) * 0.05) + if d is not None: + out = out - d["value"] + d["value2"] out = self.layer2(out) + for n in e: + out += n return out - inputs = torchtrt.Input( - min_shape=(1, 100), - opt_shape=(64, 100), - max_shape=(128, 100), - dtype=torch.float, - name="x", - ) model = SampleModel().eval().cuda() input_list = [] - input_list.append(torch.randn((8, 100)).cuda()) - input_list.append(torch.randn((12, 100)).cuda()) - input_list.append(torch.randn((12, 100)).cuda()) - input_list.append(torch.randn((8, 100)).cuda()) - input_list.append(torch.randn((8, 100)).cuda()) - - dynamic_shapes = ( - { - 0: torch.export.Dim("batch_size", min=1, max=128), - }, - ) - exp_program = torch.export.export( - model, (input_list[0],), dynamic_shapes=dynamic_shapes - ) - + for batch_size in [8, 12, 12, 8, 8]: + args = [torch.rand((batch_size, 100)).to("cuda")] + kwargs = { + "b": torch.rand((1, 128)).to("cuda"), + "d": { + "value": torch.rand(1).to("cuda"), + "value2": torch.tensor(1.2).to("cuda"), + }, + "e": [torch.rand(1).to("cuda"), torch.rand(1).to("cuda")], + } + input_list.append((args, kwargs)) + + kwarg_torchtrt_input = prepare_inputs(input_list[0][1]) + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 100), + opt_shape=(64, 100), + max_shape=(128, 100), + dtype=torch.float32, + name="x", + ), + ], + "kwarg_inputs": kwarg_torchtrt_input, + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "min_block_size": 1, + "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + "enable_weight_streaming": True, + "torch_executed_ops": {"torch.ops.aten.mul.Tensor"}, + "use_python_runtime": use_python_runtime, + } + exp_program = torchtrt.dynamo.trace(model, **compile_spec) optimized_model = torchtrt.dynamo.compile( exp_program, - inputs, - min_block_size=1, - pass_through_build_failures=True, - use_explicit_typing=True, - enable_weight_streaming=True, - torch_executed_ops={"torch.ops.aten.mul.Tensor"}, - use_python_runtime=use_python_runtime, + **compile_spec, ) # List of tuples representing different configurations for three features: @@ -361,12 +379,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list): for i in range(len(input_list)): if enable_weight_streaming and i == 4: weight_streaming_ctx.device_budget = int(streamable_budget * 0.6) - out_list.append(optimized_model(input_list[i])) + out_list.append(optimized_model(*input_list[i][0], **input_list[i][1])) return out_list ref_out_list = [] for i in range(len(input_list)): - ref_out_list.append(model(input_list[i])) + ref_out_list.append(model(*input_list[i][0], **input_list[i][1])) pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs( optimized_model