diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index fb5db527fb..59fa5030b2 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -499,25 +499,6 @@ def aten_ops_softplus( ) -@dynamo_tensorrt_converter(torch.ops.aten.clip.default) -def aten_ops_clip( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.activation.clip( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - alpha=args_bounds_check(args, 1), - beta=args_bounds_check(args, 2), - ) - - @dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) def aten_ops_hard_sigmoid( ctx: ConversionContext, @@ -695,6 +676,9 @@ def aten_ops_where( @dynamo_tensorrt_converter(torch.ops.aten.clamp.default) +@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.clip.default) +@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor) def aten_ops_clamp( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py index ac77f790cb..f578351ef2 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -235,36 +235,6 @@ def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float] ) -def clip( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_val: TRTTensor, - alpha: float, - beta: float, -) -> TRTTensor: - operation_type = trt.ActivationType.CLIP - - def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: - def clip_fn(x: float) -> float: - return max(alpha, min(beta, x)) - - return clip_fn(dyn_range[0]), clip_fn(dyn_range[1]) - - return convert_activation( - ctx, - target, - source_ir, - name, - operation_type, - input_val, - alpha=alpha, - beta=beta, - dyn_range_fn=clip_dyn_range_fn, - ) - - def hard_sigmoid( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 06e07eedb1..a69fca944b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,6 +1,5 @@ from typing import Optional, Union -import numpy as np import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl @@ -17,7 +16,6 @@ ) from torch_tensorrt.dynamo.conversion.impl.unary import sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary -from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter @@ -186,63 +184,21 @@ def clamp( source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - min_val: Optional[float] = None, - max_val: Optional[float] = None, + min_val: Optional[Union[int, float, TRTTensor]] = None, + max_val: Optional[Union[int, float, TRTTensor]] = None, ) -> TRTTensor: - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Clamp received input {input_val} that is not part " - "of the TensorRT region!" - ) - - def _add_layer( - ctx: ConversionContext, - input: TRTTensor, - val: float, - op: trt.ElementWiseOperation, - name: str, - ) -> ( - trt.ILayer - ): # TODO: Simplify and merge implementations, should just be max and min stacked - if not len(input.shape): - # clamping scalar - acc_ops_clamp_trt = get_trt_tensor( - ctx, - squeeze_left( - np.array( - [val], - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), - ) - ), - f"{name}_clamp_{val}", - ) - else: - acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions - acc_ops_clamp_tensor = np.full( - acc_ops_clamp_shape, - val, - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), - ) - acc_ops_clamp_trt = ctx.net.add_constant( - acc_ops_clamp_shape, acc_ops_clamp_tensor - ).get_output(0) - layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op) - return layer - + clamped_val = input_val if min_val is not None: - clamp_min_layer = _add_layer( - ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name + clamped_val = impl.elementwise.max( + ctx, target, source_ir, f"{name}_max", clamped_val, min_val ) - set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") - input_val = clamp_min_layer.get_output(0) + if max_val is not None: - clamp_max_layer = _add_layer( - ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name + clamped_val = impl.elementwise.min( + ctx, target, source_ir, f"{name}_min", clamped_val, max_val ) - set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") - input_val = clamp_max_layer.get_output(0) - return input_val + return clamped_val def add( diff --git a/tests/py/dynamo/conversion/test_clamp_aten.py b/tests/py/dynamo/conversion/test_clamp_aten.py index fcee7bfa3c..0bad9ee350 100644 --- a/tests/py/dynamo/conversion/test_clamp_aten.py +++ b/tests/py/dynamo/conversion/test_clamp_aten.py @@ -49,7 +49,7 @@ def forward(self, x): class TestScalarModule(torch.nn.Module): def forward(self, x): - y = torch.ops.aten.mean.default(x) + y = torch.ops.aten.mean.dim(x, None, True) return torch.ops.aten.clamp.default(y, min, max) input_specs = [ @@ -63,6 +63,30 @@ def forward(self, x): self.run_test_with_dynamic_shape(TestModule(), input_specs) self.run_test_with_dynamic_shape(TestScalarModule(), input_specs) + @parameterized.expand( + [ + param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)), + param("min", min=0.5 * torch.randn(3, 4)), + param("max", max=0.5 * torch.randn(3, 4)), + param( + "minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4) + ), + param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)), + ] + ) + def test_clamp_tensor( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.clamp.Tensor(x, min, max) + + inputs = [torch.randn(3, 4)] + self.run_test(TestModule(), inputs) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_clip_aten.py b/tests/py/dynamo/conversion/test_clip_aten.py index a3819fb4dd..447e2c9e17 100644 --- a/tests/py/dynamo/conversion/test_clip_aten.py +++ b/tests/py/dynamo/conversion/test_clip_aten.py @@ -19,11 +19,38 @@ class TestClipConverter(DispatchTestCase): def test_clip(self, test_name, min=None, max=None): class TestModule(torch.nn.Module): def forward(self, x): - return torch.ops.aten.clamp.default(x, min, max) + return torch.ops.aten.clip.default(x, min, max) inputs = [torch.randn(3, 4)] self.run_test(TestModule(), inputs) + @parameterized.expand( + [ + param( + "defaultInt32", + min=torch.tensor(-1, dtype=torch.int32), + max=torch.tensor(0, dtype=torch.int32), + ), + param( + "defaultFloat32", + min=torch.tensor(0.5, dtype=torch.float32), + max=torch.tensor(1.0, dtype=torch.float32), + ), + param( + "minBiggerThanMax", + min=torch.tensor(1.0, dtype=torch.float32), + max=torch.tensor(0, dtype=torch.int32), + ), + ] + ) + def test_clip(self, test_name, min=None, max=None): + class TestModule(torch.nn.Module): + def forward(self, x, min, max): + return torch.ops.aten.clip.Tensor(x, min, max) + + inputs = [torch.randn(3, 4), min, max] + self.run_test(TestModule(), inputs) + @parameterized.expand( [ param("default", min=-1, max=0), @@ -37,12 +64,12 @@ def test_clip_with_dynamic_shape_four_dimensions( ): class TestModule(torch.nn.Module): def forward(self, x): - return torch.ops.aten.clamp.default(x, min, max) + return torch.ops.aten.clip.default(x, min, max) class TestScalarModule(torch.nn.Module): def forward(self, x): - y = torch.ops.aten.mean.default(x) - return torch.ops.aten.clamp.default(y, min, max) + y = torch.ops.aten.mean.dim(x, None, True) + return torch.ops.aten.clip.default(y, min, max) input_specs = [ Input(