Skip to content

Commit

Permalink
chore: Test update for int type dynamic shape input
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Jun 3, 2024
1 parent 610057c commit 89d4a98
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
22 changes: 19 additions & 3 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,25 @@ def example_tensor(
)

if isinstance(self.shape, dict):
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
)
if (
self.dtype == dtype.u8
or self.dtype == dtype.i8
or self.dtype == dtype.i32
or self.dtype == dtype.i64
):
type = self.dtype.to(torch.dtype, use_default=True)
min_value = torch.iinfo(type).min
max_value = torch.iinfo(type).max
return torch.randint(
min_value,
max_value,
self.shape[optimization_profile_field],
dtype=type,
)
else:
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
)
else:
raise RuntimeError(
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
Expand Down
14 changes: 12 additions & 2 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def run_test_with_dynamic_shape(
input_specs,
rtol=1e-03,
atol=1e-03,
output_dtypes=None,
check_dtype=True,
use_dynamo_tracer=False,
enable_passes=False,
):
Expand All @@ -355,13 +355,23 @@ def run_test_with_dynamic_shape(
# We replicate this behavior here
compilation_settings = CompilationSettings(truncate_double=True)

output_dtypes = None
if check_dtype:
output_dtypes = infer_module_output_dtypes(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

interp = TRTInterpreter(
mod,
input_specs,
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)

# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol, check_dtype)

0 comments on commit 89d4a98

Please sign in to comment.