Skip to content

Commit

Permalink
feat: Improve logging throughout the Dynamo path
Browse files Browse the repository at this point in the history
- Add clear logging at the beginning and end of each phase of
compilation
- Reword logging in certain locations for clarity
  • Loading branch information
gs-olive committed Oct 20, 2023
1 parent 4e5b0f6 commit 8df34b8
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 7 deletions.
18 changes: 14 additions & 4 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_tensorrt
from torch.export import ExportedProgram
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
Expand Down Expand Up @@ -42,6 +41,8 @@
to_torch_tensorrt_device,
)

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -177,9 +178,11 @@ def compile_module(
# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

logger.info("Beginning TensorRT operator Partitioning Phase")
# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
logger.info("Partitioning the graph via the fast partitioner")
partitioned_module = partitioning.fast_partition(
gm,
verbose=settings.debug,
Expand All @@ -189,21 +192,27 @@ def compile_module(
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
"Retrying with global partition.",
exc_info=True,
)

fast_partitioner_failed = True
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
logger.info("Partitioning the graph via the global partitioner")
partitioned_module = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

logger.info(
"Successfully completed graph partitioning phase. "
"Beginning the conversion phase."
)

# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
Expand All @@ -222,14 +231,15 @@ def compile_module(
to_torch_device(settings.device),
)

assert submodule_inputs is not None

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
"Converting submodule: %s\n Input shapes: %s\n %s",
str(name),
[input.shape for input in submodule_inputs],
str(submodule.graph),
)

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
n.kwargs = kwargs

# run the node
_LOGGER.debug(
f"Running node {self._cur_node_name}, a {self._cur_node.op} node "
f"with target {self._cur_node.target} in the TensorRT Interpreter"
)
trt_node: torch.fx.Node = super().run_node(n)

# remove "_itensor_to_tensor_meta"
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def get_node_name(node: torch.fx.Node) -> str:
# like the node.meta['source_fn'] attr
pass

_LOGGER.debug(f"Node meta name {node_name}")
return node_name


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import logging
from typing import Optional, Sequence, Set

import torch
from torch.fx.node import _get_qualified_name
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.utils import get_torch_inputs

logger = logging.getLogger(__name__)


def _extract_downstream_get_nodes(
module_node: torch.fx.Node, output_indices: Set[int]
Expand Down Expand Up @@ -62,6 +65,10 @@ def _repair_64bit_input(
torch.float64,
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"

logger.info(
f"Downcasting a 64-bit input at position {position} of submodule {submodule_name}"
)

# Determine target data type in 32 and 64 bit forms
dtype_64bit = dtype
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
Expand Down
11 changes: 9 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from typing import Any, Callable, Dict, Optional, Sequence, Union

import torch
import torch_tensorrt
from torch_tensorrt._Device import Device
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._defaults import PRECISION

import torch_tensorrt
from packaging import version

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,9 +136,16 @@ def prepare_inputs(

return torchtrt_inputs_dict

elif isinstance(inputs, (torch.SymBool, torch.SymFloat, torch.SymInt)):
raise ValueError(
f"Detected Torch symbolic input type {type(inputs)} during input parsing. "
"Symbolic inputs are not currently allowed; please specify dynamic=False "
"if using torch.compile with the Torch-TensorRT backend."
)

else:
raise ValueError(
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
f"Invalid input type {type(inputs)} encountered during Dynamo input parsing. "
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
)

Expand Down

0 comments on commit 8df34b8

Please sign in to comment.