From 89d4a98f10b914887684ac191ee0827d0de27f00 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 3 Jun 2024 10:28:42 +0000 Subject: [PATCH] chore: Test update for int type dynamic shape input --- py/torch_tensorrt/_Input.py | 22 +++++++++++++++++++--- tests/py/dynamo/conversion/harness.py | 14 ++++++++++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 32f19ce1f0..b898bb4ff0 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -364,9 +364,25 @@ def example_tensor( ) if isinstance(self.shape, dict): - return torch.rand(self.shape[optimization_profile_field]).to( - dtype=self.dtype.to(torch.dtype, use_default=True) - ) + if ( + self.dtype == dtype.u8 + or self.dtype == dtype.i8 + or self.dtype == dtype.i32 + or self.dtype == dtype.i64 + ): + type = self.dtype.to(torch.dtype, use_default=True) + min_value = torch.iinfo(type).min + max_value = torch.iinfo(type).max + return torch.randint( + min_value, + max_value, + self.shape[optimization_profile_field], + dtype=type, + ) + else: + return torch.rand(self.shape[optimization_profile_field]).to( + dtype=self.dtype.to(torch.dtype, use_default=True) + ) else: raise RuntimeError( f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})" diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 8f55ce3fb6..aa51ceacef 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -338,7 +338,7 @@ def run_test_with_dynamic_shape( input_specs, rtol=1e-03, atol=1e-03, - output_dtypes=None, + check_dtype=True, use_dynamo_tracer=False, enable_passes=False, ): @@ -355,13 +355,23 @@ def run_test_with_dynamic_shape( # We replicate this behavior here compilation_settings = CompilationSettings(truncate_double=True) + output_dtypes = None + if check_dtype: + output_dtypes = infer_module_output_dtypes( + mod, + input_specs, + compilation_settings.device, + truncate_double=compilation_settings.truncate_double, + ) + interp = TRTInterpreter( mod, input_specs, output_dtypes=output_dtypes, compilation_settings=compilation_settings, ) + # Since the lowering is based on optimal shape. We need to test with # different shape(for ex. max shape) for testing dynamic shape inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] - super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol) + super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol, check_dtype)