diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index b898bb4ff0..e2a48b64a7 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -364,18 +364,11 @@ def example_tensor( ) if isinstance(self.shape, dict): - if ( - self.dtype == dtype.u8 - or self.dtype == dtype.i8 - or self.dtype == dtype.i32 - or self.dtype == dtype.i64 - ): + if self.dtype in [dtype.u8, dtype.i8, dtype.i32, dtype.i64]: type = self.dtype.to(torch.dtype, use_default=True) - min_value = torch.iinfo(type).min - max_value = torch.iinfo(type).max return torch.randint( - min_value, - max_value, + torch.iinfo(type).min, + torch.iinfo(type).max, self.shape[optimization_profile_field], dtype=type, )