Skip to content

Commit

Permalink
fix: structured inputs for CudaGraphsTorchTensorRTModule
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Feb 24, 2025
1 parent 0a46392 commit ebaacff
Showing 1 changed file with 61 additions and 18 deletions.
79 changes: 61 additions & 18 deletions py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@
from __future__ import annotations

import logging
from typing import List, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple

import torch
import torch_tensorrt
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torch_tensorrt.dynamo import partitioning

logger = logging.getLogger(__name__)


def _unflatten_inputs(
flattened_inputs: Sequence[torch_tensorrt.Input],
compiled_module: torch.fx.GraphModule,
) -> Tuple[Any, Any]:
"""
Process inputs using tree_unflatten and tree_map to reconstructe inputs
Args:
flattened_inputs: Flattened input tensors to process
compiled_module: The compiled GraphModule containing input specifications
Returns:
Tuple of (args, kwargs) containing reconstructed input tensors
"""

def create_example_tensor(input: Any) -> torch.Tensor:
if isinstance(input, torch_tensorrt.Input):
return input.torch_tensor.cuda()
else:
raise RuntimeError("Input is not a torch_tensorrt.Input")

# Reconstruct the (args, kwargs) structure that was flattened during export
pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec)
# Apply the tensor creation to the reconstructed structure
processed_inputs = tree_map(create_example_tensor, pytree_inputs)

# Since inputs were originally flattened from (args, kwargs),
# processed_inputs is now that same tuple structure
return processed_inputs[0], processed_inputs[1]


class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
Expand Down Expand Up @@ -42,14 +74,15 @@ def warm_up(self) -> None:
Warm up is necessary to ensure that memory allocations and initializations
are not recorded in cuda graphs
"""

with torch_tensorrt.logging.errors():
with unset_fake_temporarily():
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.compiled_module(*inputs_tensor)
self.compiled_module(*args, **kwargs)
torch.cuda.current_stream().wait_stream(s)

def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
Expand All @@ -73,7 +106,12 @@ def __del__(self) -> None:
if self.cudagraph:
self.cudagraph.reset()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
def forward(
self, *args: Any, **kwargs: Any
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# pytree_inputs = tree_unflatten(self.inputs, self.compiled_module._in_spec)

inputs, spec = tree_flatten((args, kwargs))
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
if cudagraphs_enabled:
shape_changed = self.validate_input_shapes(inputs)
Expand All @@ -94,10 +132,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
for i in inputs
]
assert len(contiguous_inputs) == len(
self.inputs
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
inputs
), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}."

for i, _ in enumerate(self.inputs):
for i, _ in enumerate(inputs):
if not contiguous_inputs[i].is_cuda:
logger.warning(
f"Detected input[{i}] is not on a cuda device. "
Expand All @@ -112,15 +150,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
)

assert (
contiguous_inputs[i].dtype == self.inputs[i].dtype
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
contiguous_inputs[i].dtype == inputs[i].dtype
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()
else:
self._input_buffers[i].copy_(contiguous_inputs[i])

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()
else:
self._input_buffers[i].copy_(contiguous_inputs[i])
# Reconstruct the original args and kwargs structure from static input buffers
# using the input specification stored during module compilation
args, kwargs = tree_unflatten(
self._input_buffers, self.compiled_module._in_spec
)

self._caller_stream = torch.cuda.current_stream()
if (
Expand All @@ -135,9 +180,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
self._output_buffers = self.compiled_module(
*self._input_buffers
)
self._output_buffers = self.compiled_module(*args, **kwargs)

self.cudagraph.replay() # type: ignore
self._caller_stream.wait_stream(self._engine_stream)
Expand All @@ -154,4 +197,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
return self.compiled_module(*inputs)
return self.compiled_module(*args, **kwargs)

0 comments on commit ebaacff

Please sign in to comment.