diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d3cf6709ef..9592bc1fd5 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Collection, Optional, Union +from typing import Collection, Optional, Set, Union from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -71,7 +71,7 @@ class CompilationSettings: hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) """ - enabled_precisions: dtype = field(default_factory=lambda: ENABLED_PRECISIONS) + enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE