Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add dynamic shapes section in the resnet tutorial #2904

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 40 additions & 145 deletions docsrc/user_guide/dynamic_shapes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Dynamic shapes with Torch-TensorRT

By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly.
However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model.

Dynamic shapes using torch.export (AOT)
------------------------------------

In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for
this range of input shapes. An example usage of static and dynamic shapes is as follows.

Expand All @@ -30,168 +34,57 @@ Under the hood

There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default).

- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs)
- torch_tensorrt.dynamo.trace (which uses torch.export to trace the graph with the given inputs)

In the tracing phase, we use torch.export along with the constraints. In the case of
dynamic shaped inputs, the range can be provided to the tracing via constraints. Please
refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_
for detailed information on how to set constraints. In short, we create new inputs for
torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take.
Please take a look at ``aten_tracer.py`` file to understand how this works under the hood.
We use ``torch.export.export()`` API for tracing and exporting a PyTorch module into ``torch.export.ExportedProgram``. In the case of
dynamic shaped inputs, the ``(min_shape, opt_shape, max_shape)`` range provided via ``torch_tensorrt.Input`` API is used to construct ``torch.export.Dim`` objects
which is used in the ``dynamic_shapes`` argument for the export API.
Please take a look at ``_tracer.py`` file to understand how this works under the hood.

- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT)
- torch_tensorrt.dynamo.compile (which compiles an torch.export.ExportedProgram object using TensorRT)

In the conversion to TensorRT, we use the user provided dynamic shape inputs.
We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the
intermediate output shapes which can be used in case the graph has a mix of Pytorch
and TensorRT submodules.
In the conversion to TensorRT, the graph already has the dynamic shape information in the node's metadata which will be used during engine building phase.

Custom Constraints
------------------
Custom Dynamic Shape Constraints
---------------------------------

Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``,
Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows

.. code-block:: python

for dim in constraint_dims:
if min_shape[dim] > 1:
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
if max_shape[dim] > 1:
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])

Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs.

.. code-block:: python

constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0))


If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps.

.. code-block:: python

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
# Assume the model has two inputs
model = MyModel()
torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda()
torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda()

dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14],
opt_shape=[4, 14],
max_shape=[8, 14],
dtype=torch.int32),
torch_tensorrt.Input(min_shape=[1, 14],
opt_shape=[4, 14],
max_shape=[8, 14],
dtype=torch.int32)]

# Export the model with additional constraints
constraints = []
# The following constraints are automatically added by Torch-TensorRT in the
# general case when you call torch_tensorrt.compile directly on MyModel()
constraints.append(dynamic_dim(torch_input_1, 0) < 8)
constraints.append(dynamic_dim(torch_input_2, 0) < 8)
# This is an additional constraint as instructed by Torchdynamo
constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0))
with unittest.mock.patch(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(
model, (torch_input_1, torch_input_2), constraints=constraints
).module()

# Use the dynamo.compile API
trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec)

Limitations
-----------

If there are operations in the graph that use the dynamic dimension of the input, Pytorch
introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and
the compilation results in undefined behavior. We plan to add support for these operators and implement
robust support for shape tensors in the next release. Here is an example of the limitation described above
Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing
`torch.export.Dim` objects with the provided dynamic dimensions accordingly. Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
If you have to set any custom constraints to your model (by using `torch.export.Dim`), we recommend exporting your program first before compiling with Torch-TensorRT.
Please refer to this `documentation <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_ to export the Pytorch module with dynamic shapes.
Here's a simple example that exports a matmul layer with some restrictions on dynamic dimensions.

.. code-block:: python

import torch
import torch_tensorrt

class MyModule(torch.nn.Module):
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))

def forward(self, x):
x = self.avgpool(x)
out = torch.flatten(x, 1)
return out

model = MyModel().eval().cuda()
# Compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1),
opt_shape=(4, 512, 1, 1),
max_shape=(8, 512, 1, 1),
dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)


The traced graph of `MyModule()` looks as follows

.. code-block:: python

Post export graph: graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {})
%sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {})
return (view,)


Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support
which would be a part of our next release.

Workaround (BERT static compilation example)
------------------------------------------

In the case where you encounter the issues mentioned in the **Limitations** section,
you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs,
we can pad them accordingly. This is only a workaround until we address the limitations.

.. code-block:: python

import torch
import torch_tensorrt
from transformers.utils.fx import symbolic_trace as transformers_trace

model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()

# Input sequence length is 20.
input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")

model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda()
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
model_outputs = model(input, input2)

# If you have a sequence of length 14, pad 6 zero tokens and run inference
# or recompile for sequence length of 14.
input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
model_outputs = model(input, input2)
def forward(self, query, key):
attn_weight = torch.matmul(query, key.transpose(-1, -2))
return attn_weight

model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs])
# Run inference
trt_gm(inputs)


Dynamic shapes with ir=torch_compile
Dynamic shapes using torch.compile (JIT)
------------------------------------

``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend
configured to Tensorrt. In the case of ``ir=torch_compile``, when the input size changes, Dynamo will trigger a recompilation
of the TensorRT engine automatically giving dynamic shape behavior similar to native PyTorch eager however with the cost of rebuilding
TRT engine. This limitation will be addressed in future versions of Torch-TensorRT.
configured to TensorRT. In the case of ``ir=torch_compile``, users can provide dynamic shape information for the inputs using ``torch._dynamo.mark_dynamic`` API (https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html)
to avoid recompilation of TensorRT engines.

.. code-block:: python

Expand All @@ -200,10 +93,12 @@ TRT engine. This limitation will be addressed in future versions of Torch-Tensor

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs)
# This indicates the dimension 0 is dynamic and the range is [1, 8]
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
trt_gm = torch.compile(model, backend="tensorrt")
# Compilation happens when you call the model
trt_gm(inputs)

# Recompilation happens with modified batch size
# No recompilation of TRT engines with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2)
trt_gm(inputs_bs2)
56 changes: 41 additions & 15 deletions examples/dynamo/torch_compile_resnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,45 @@
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)

# %%
# Cleanup
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()
# Avoid recompilation by specifying dynamic shapes before Torch-TRT compilation
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# %%
# Cuda Driver Error Note
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`,
# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052
# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in::
#
# if __name__ == '__main__':
# compile_engine_and_infer()
# The following code illustrates the workflow using ir=torch_compile (which uses torch.compile under the hood)
inputs_bs8 = torch.randn((8, 3, 224, 224)).half().to("cuda")
# This indicates dimension 0 of inputs_bs8 is dynamic whose range of values is [2, 16]
torch._dynamo.mark_dynamic(inputs_bs8, 0, min=2, max=16)
optimized_model = torch_tensorrt.compile(
model,
ir="torch_compile",
inputs=inputs_bs8,
enabled_precisions=enabled_precisions,
debug=debug,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
)
outputs_bs8 = optimized_model(inputs_bs8)

# No recompilation happens for batch size = 12
inputs_bs12 = torch.randn((12, 3, 224, 224)).half().to("cuda")
outputs_bs12 = optimized_model(inputs_bs12)

# The following code illustrates the workflow using ir=dynamo (which uses torch.export APIs under the hood)
# dynamic shapes for any inputs are specified using torch_tensorrt.Input API
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.half,
)
],
"enabled_precisions": enabled_precisions,
"ir": "dynamo",
}
trt_model = torch_tensorrt.compile(model, **compile_spec)

# No recompilation happens for batch size = 12
inputs_bs12 = torch.randn((12, 3, 224, 224)).half().to("cuda")
outputs_bs12 = trt_model(inputs_bs12)
8 changes: 5 additions & 3 deletions examples/dynamo/vgg16_fp8_ptq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
.. _vgg16_fp8_ptq:

Torch Compile VGG16 with FP8 and PTQ
Deploy Quantized Models using Torch-TensorRT
======================================================

This script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a VGG16 model with FP8 and PTQ.
Here we demonstrate how to deploy a model quantized to FP8 using the Dynamo frontend of Torch-TensorRT
"""

# %%
Expand Down Expand Up @@ -100,7 +100,7 @@ def vgg16(num_classes=1000, init_weights=False):


PARSER = argparse.ArgumentParser(
description="Load pre-trained VGG model and then tune with FP8 and PTQ"
description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
)
PARSER.add_argument(
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
Expand Down Expand Up @@ -226,6 +226,8 @@ def calibrate_loop(model):
min_block_size=1,
debug=False,
)
# You can also use torch compile path to compile the model with Torch-TensorRT:
# trt_model = torch.compile(model, backend="tensorrt")

# Inference compiled Torch-TensorRT model over the testing dataset
total = 0
Expand Down
Loading