Skip to content

Commit

Permalink
feat: support aten.atan2 converter (#2689)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Apr 12, 2024
1 parent 7d30714 commit 821ff91
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 1 deletion.
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,30 @@ def aten_ops_atanh(
)


@dynamo_tensorrt_converter(torch.ops.aten.atan2.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
}
)
def aten_ops_atan2(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.atan2(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.ceil.default)
def aten_ops_ceil(
ctx: ConversionContext,
Expand Down
179 changes: 178 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
Expand All @@ -9,13 +10,15 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_int_int_div_trt_tensor,
cast_int_or_float_to_bool,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.unary import sign
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import broadcast
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

Expand Down Expand Up @@ -213,6 +216,180 @@ def remainder(
return fmod2_value


def atan2(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
other: TRTTensor,
) -> TRTTensor:
"""
Perform atan2 operation on Tensor, calculating the arctangent of the quotient of input tensors.
atan2(x,y) = atan(x/y) if y > 0,
= atan(x/y) + π if x ≥ 0 and y < 0,
= atan(x/y) - π if x < 0 and y < 0,
= π/2 if x > 0 and y = 0,
= -π/2 if x < 0 and y = 0,
= 0 if x = 0 and y = 0
Args:
ctx: ConversionContext.
target: node target
source_ir (SourceIR): Source IR calling the function.
name: namespace for the op
input: Tensor or constant representing the dividend.
other: Tensor or constant representing the divisor.
Returns:
A TensorRT tensor representing the result of the atan2 operation.
"""
pi_value = 3.141592653589793
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi")

if isinstance(input, TRTTensor):
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input")
if isinstance(other, TRTTensor):
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")

input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other")

# Calculate x_zero, y_zero (whether inputs are zero)
x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0)
y_zero = eq(ctx, target, source_ir, f"{name}_y_zero", other, 0)

# Get sign of inputs
x_positive = gt(ctx, target, source_ir, f"{name}_x_positive", input, 0)
x_zero_positive = ge(ctx, target, source_ir, f"{name}_x_zero_positive", input, 0)
x_negative = lt(ctx, target, source_ir, f"{name}_x_negative", input, 0)
y_positive = gt(ctx, target, source_ir, f"{name}_y_positive", other, 0)
y_negative = lt(ctx, target, source_ir, f"{name}_y_negative", other, 0)

# Calculate atan(x/y)
input_div_other = div(
ctx, target, source_ir, f"{name}_input_div_other", input, other
)
atan_val = atan(ctx, target, source_ir, f"{name}_atan", input_div_other)

# atan(x/y)+π if x≥0 and y<0,
atan_add_pi = add(
ctx, target, source_ir, f"{name}_atan_add_pi", atan_val, pi_tensor
)

# atan(x/y)-π if x<0 and y<0,
atan_sub_pi = sub(
ctx, target, source_ir, f"{name}_atan_sub_pi", atan_val, pi_tensor
)

# atan(x/y)+π if x≥0 and y<0,
atan_corrected = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_atan_corrected",
atan_add_pi,
atan_val,
logical_and(
ctx,
target,
source_ir,
f"{name}_x_zero_positive_and_y_negative",
x_zero_positive,
y_negative,
),
)

# atan(x/y)-π if x<0 and y<0,
atan_corrected_2 = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_atan_corrected_2",
atan_sub_pi,
atan_corrected,
logical_and(
ctx,
target,
source_ir,
f"{name}_x_negative_and_y_negative",
x_negative,
y_negative,
),
)

# atan(x/y) if y>0
atan_output = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_atan_output",
atan_val,
atan_corrected_2,
y_positive,
)

# on x or y-axis
pi_over_2_tensor = get_trt_tensor(
ctx,
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
f"{name}_pi_over_2_tensor",
dtype=trt.float32,
)
minus_pi_over_2_tensor = get_trt_tensor(
ctx,
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
f"{name}_minus_pi_over_2_tensor",
dtype=trt.float32,
)
zero_tensor = get_trt_tensor(
ctx,
np.zeros(input.shape, dtype=np.float32),
f"{name}_zero_tensor",
dtype=trt.float32,
)

# π/2 if x>0 and y=0,
pi_over_2_output = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_pi_over_2_output",
pi_over_2_tensor,
atan_output,
logical_and(
ctx, target, source_ir, f"{name}_x_zero_and_y_positive", x_positive, y_zero
),
)

# -π/2 if x<0 and y=0,
minus_pi_over_2_output = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_minus_pi_over_2_output",
minus_pi_over_2_tensor,
pi_over_2_output,
logical_and(
ctx, target, source_ir, f"{name}_x_zero_and_y_negative", x_negative, y_zero
),
)

# 0 if x=0 and y=0,
zero_output = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_zero_output",
zero_tensor,
minus_pi_over_2_output,
logical_and(
ctx, target, source_ir, f"{name}_x_zero_and_y_zero", y_zero, x_zero
),
)

return zero_output


def clamp(
ctx: ConversionContext,
target: Target,
Expand Down
132 changes: 132 additions & 0 deletions tests/py/dynamo/conversion/test_atan2_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestAtan2Converter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_atan2_lhs_const(self, input_shape, dtype):
class atan2(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.atan2.default(lhs_val, rhs_val)

inputs = [
torch.randn(input_shape, dtype=dtype),
torch.rand(1),
]

self.run_test(
atan2(),
inputs,
)

@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_atan2_rhs_const(self, input_shape, dtype):
class atan2(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.atan2.default(lhs_val, rhs_val)

inputs = [
torch.rand(1),
torch.randn(input_shape, dtype=dtype),
]

self.run_test(
atan2(),
inputs,
)

@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_atan2_float(self, input_shape, dtype):
class atan2(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.atan2.default(lhs_val, rhs_val)

inputs = [
torch.randn(input_shape, dtype=dtype),
torch.randn(input_shape, dtype=dtype),
]

self.run_test(
atan2(),
inputs,
)

@parameterized.expand(
[
((50,), torch.int, -5, 5),
((1, 20), torch.int32, -5, 5),
((2, 3, 4), torch.int, -5, 5),
]
)
def test_atan2_int(self, input_shape, dtype, low, high):
class atan2(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.atan2.default(lhs_val, rhs_val)

inputs = [
torch.randint(low, high, input_shape, dtype=dtype),
torch.randint(low, high, input_shape, dtype=dtype),
]
self.run_test(
atan2(),
inputs,
)

@parameterized.expand(
[
(torch.float, 0.0, 0.0),
(torch.float, 0.0, torch.rand(1)),
(torch.float, torch.rand(1), 0.0),
(torch.int, 0, 0),
(torch.int, 0, torch.randint(-5, 5, (1,))),
(torch.int, torch.randint(1, 10, (1,)), 0),
]
)
def test_atan2_zero(self, dtype, x_val, y_val):
class Atan2(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.atan2.default(lhs_val, rhs_val)

if isinstance(x_val, torch.Tensor):
x_val = x_val.item()
if isinstance(y_val, torch.Tensor):
y_val = y_val.item()

inputs = [
torch.tensor([x_val], dtype=dtype),
torch.tensor([y_val], dtype=dtype),
]

self.run_test(
Atan2(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 821ff91

Please sign in to comment.