From ebaacff7a01626a100337017401959995e8232b0 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 24 Feb 2025 13:51:43 +0000 Subject: [PATCH] fix: structured inputs for CudaGraphsTorchTensorRTModule --- .../runtime/_CudaGraphsTorchTensorRTModule.py | 79 ++++++++++++++----- 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index e860c5762f..adf0fa79df 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -1,16 +1,48 @@ from __future__ import annotations import logging -from typing import List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import torch import torch_tensorrt from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch_tensorrt.dynamo import partitioning logger = logging.getLogger(__name__) +def _unflatten_inputs( + flattened_inputs: Sequence[torch_tensorrt.Input], + compiled_module: torch.fx.GraphModule, +) -> Tuple[Any, Any]: + """ + Process inputs using tree_unflatten and tree_map to reconstructe inputs + + Args: + flattened_inputs: Flattened input tensors to process + compiled_module: The compiled GraphModule containing input specifications + + Returns: + Tuple of (args, kwargs) containing reconstructed input tensors + """ + + def create_example_tensor(input: Any) -> torch.Tensor: + if isinstance(input, torch_tensorrt.Input): + return input.torch_tensor.cuda() + else: + raise RuntimeError("Input is not a torch_tensorrt.Input") + + # Reconstruct the (args, kwargs) structure that was flattened during export + pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec) + # Apply the tensor creation to the reconstructed structure + processed_inputs = tree_map(create_example_tensor, pytree_inputs) + + # Since inputs were originally flattened from (args, kwargs), + # processed_inputs is now that same tuple structure + return processed_inputs[0], processed_inputs[1] + + class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """This Wrapper runtime module is to record/replay whole cuda graph in sub modules @@ -42,14 +74,15 @@ def warm_up(self) -> None: Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs """ + with torch_tensorrt.logging.errors(): with unset_fake_temporarily(): - inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs] + args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - self.compiled_module(*inputs_tensor) + self.compiled_module(*args, **kwargs) torch.cuda.current_stream().wait_stream(s) def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: @@ -73,7 +106,12 @@ def __del__(self) -> None: if self.cudagraph: self.cudagraph.reset() - def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + def forward( + self, *args: Any, **kwargs: Any + ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + # pytree_inputs = tree_unflatten(self.inputs, self.compiled_module._in_spec) + + inputs, spec = tree_flatten((args, kwargs)) cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) @@ -94,10 +132,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . for i in inputs ] assert len(contiguous_inputs) == len( - self.inputs - ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." + inputs + ), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}." - for i, _ in enumerate(self.inputs): + for i, _ in enumerate(inputs): if not contiguous_inputs[i].is_cuda: logger.warning( f"Detected input[{i}] is not on a cuda device. " @@ -112,15 +150,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ) assert ( - contiguous_inputs[i].dtype == self.inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + contiguous_inputs[i].dtype == inputs[i].dtype + ), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + self._input_buffers[i] = contiguous_inputs[i].clone() + else: + self._input_buffers[i].copy_(contiguous_inputs[i]) if need_cudagraphs_record: - # If cudagraphs is enabled, this memory is reserved for future cudagraph runs - # Clone is required to avoid re-using user-provided GPU memory - self._input_buffers[i] = contiguous_inputs[i].clone() - else: - self._input_buffers[i].copy_(contiguous_inputs[i]) + # Reconstruct the original args and kwargs structure from static input buffers + # using the input specification stored during module compilation + args, kwargs = tree_unflatten( + self._input_buffers, self.compiled_module._in_spec + ) self._caller_stream = torch.cuda.current_stream() if ( @@ -135,9 +180,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): - self._output_buffers = self.compiled_module( - *self._input_buffers - ) + self._output_buffers = self.compiled_module(*args, **kwargs) self.cudagraph.replay() # type: ignore self._caller_stream.wait_stream(self._engine_stream) @@ -154,4 +197,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.cudagraph: self.cudagraph.reset() self.cudagraph = None - return self.compiled_module(*inputs) + return self.compiled_module(*args, **kwargs)