Skip to content

Commit

Permalink
feat: support aten.clamp.Tensor and update aten.clamp.default dynamo …
Browse files Browse the repository at this point in the history
…converters (#2522)
  • Loading branch information
zewenli98 authored Dec 28, 2023
1 parent 128dd65 commit 088900d
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 107 deletions.
22 changes: 3 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,25 +499,6 @@ def aten_ops_softplus(
)


@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
def aten_ops_clip(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.activation.clip(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
alpha=args_bounds_check(args, 1),
beta=args_bounds_check(args, 2),
)


@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
def aten_ops_hard_sigmoid(
ctx: ConversionContext,
Expand Down Expand Up @@ -695,6 +676,9 @@ def aten_ops_where(


@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor)
def aten_ops_clamp(
ctx: ConversionContext,
target: Target,
Expand Down
30 changes: 0 additions & 30 deletions py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,36 +235,6 @@ def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]
)


def clip(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
alpha: float,
beta: float,
) -> TRTTensor:
operation_type = trt.ActivationType.CLIP

def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]:
def clip_fn(x: float) -> float:
return max(alpha, min(beta, x))

return clip_fn(dyn_range[0]), clip_fn(dyn_range[1])

return convert_activation(
ctx,
target,
source_ir,
name,
operation_type,
input_val,
alpha=alpha,
beta=beta,
dyn_range_fn=clip_dyn_range_fn,
)


def hard_sigmoid(
ctx: ConversionContext,
target: Target,
Expand Down
62 changes: 9 additions & 53 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
Expand All @@ -17,7 +16,6 @@
)
from torch_tensorrt.dynamo.conversion.impl.unary import sign
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

Expand Down Expand Up @@ -186,63 +184,21 @@ def clamp(
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
min_val: Optional[Union[int, float, TRTTensor]] = None,
max_val: Optional[Union[int, float, TRTTensor]] = None,
) -> TRTTensor:
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Clamp received input {input_val} that is not part "
"of the TensorRT region!"
)

def _add_layer(
ctx: ConversionContext,
input: TRTTensor,
val: float,
op: trt.ElementWiseOperation,
name: str,
) -> (
trt.ILayer
): # TODO: Simplify and merge implementations, should just be max and min stacked
if not len(input.shape):
# clamping scalar
acc_ops_clamp_trt = get_trt_tensor(
ctx,
squeeze_left(
np.array(
[val],
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
)
),
f"{name}_clamp_{val}",
)
else:
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = np.full(
acc_ops_clamp_shape,
val,
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
)
acc_ops_clamp_trt = ctx.net.add_constant(
acc_ops_clamp_shape, acc_ops_clamp_tensor
).get_output(0)
layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op)
return layer

clamped_val = input_val
if min_val is not None:
clamp_min_layer = _add_layer(
ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name
clamped_val = impl.elementwise.max(
ctx, target, source_ir, f"{name}_max", clamped_val, min_val
)
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
input_val = clamp_min_layer.get_output(0)

if max_val is not None:
clamp_max_layer = _add_layer(
ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name
clamped_val = impl.elementwise.min(
ctx, target, source_ir, f"{name}_min", clamped_val, max_val
)
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
input_val = clamp_max_layer.get_output(0)

return input_val
return clamped_val


def add(
Expand Down
26 changes: 25 additions & 1 deletion tests/py/dynamo/conversion/test_clamp_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(self, x):

class TestScalarModule(torch.nn.Module):
def forward(self, x):
y = torch.ops.aten.mean.default(x)
y = torch.ops.aten.mean.dim(x, None, True)
return torch.ops.aten.clamp.default(y, min, max)

input_specs = [
Expand All @@ -63,6 +63,30 @@ def forward(self, x):
self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)

@parameterized.expand(
[
param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)),
param("min", min=0.5 * torch.randn(3, 4)),
param("max", max=0.5 * torch.randn(3, 4)),
param(
"minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)
),
param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)),
]
)
def test_clamp_tensor(
self,
test_name,
min=None,
max=None,
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.clamp.Tensor(x, min, max)

inputs = [torch.randn(3, 4)]
self.run_test(TestModule(), inputs)


if __name__ == "__main__":
run_tests()
35 changes: 31 additions & 4 deletions tests/py/dynamo/conversion/test_clip_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,38 @@ class TestClipConverter(DispatchTestCase):
def test_clip(self, test_name, min=None, max=None):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.clamp.default(x, min, max)
return torch.ops.aten.clip.default(x, min, max)

inputs = [torch.randn(3, 4)]
self.run_test(TestModule(), inputs)

@parameterized.expand(
[
param(
"defaultInt32",
min=torch.tensor(-1, dtype=torch.int32),
max=torch.tensor(0, dtype=torch.int32),
),
param(
"defaultFloat32",
min=torch.tensor(0.5, dtype=torch.float32),
max=torch.tensor(1.0, dtype=torch.float32),
),
param(
"minBiggerThanMax",
min=torch.tensor(1.0, dtype=torch.float32),
max=torch.tensor(0, dtype=torch.int32),
),
]
)
def test_clip(self, test_name, min=None, max=None):
class TestModule(torch.nn.Module):
def forward(self, x, min, max):
return torch.ops.aten.clip.Tensor(x, min, max)

inputs = [torch.randn(3, 4), min, max]
self.run_test(TestModule(), inputs)

@parameterized.expand(
[
param("default", min=-1, max=0),
Expand All @@ -37,12 +64,12 @@ def test_clip_with_dynamic_shape_four_dimensions(
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.clamp.default(x, min, max)
return torch.ops.aten.clip.default(x, min, max)

class TestScalarModule(torch.nn.Module):
def forward(self, x):
y = torch.ops.aten.mean.default(x)
return torch.ops.aten.clamp.default(y, min, max)
y = torch.ops.aten.mean.dim(x, None, True)
return torch.ops.aten.clip.default(y, min, max)

input_specs = [
Input(
Expand Down

0 comments on commit 088900d

Please sign in to comment.