Skip to content

Commit

Permalink
Aten scatter converter (#2664)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored and laikhtewari committed May 24, 2024
1 parent 10698e2 commit 35b5d03
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 10 deletions.
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,26 @@ def aten_ops_clamp(
)


@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
def aten_ops_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


@dynamo_tensorrt_converter(torch.ops.aten.select.int)
def aten_ops_select(
ctx: ConversionContext,
Expand Down
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,41 @@ def index_select(
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)

return gather_layer.get_output(0)


def scatter(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
src: Union[TRTTensor, int, float],
) -> TRTTensor:
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index_shape)
if index.dtype == trt.int64:
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
dim = get_positive_dim(dim, len(input_shape))
src_tensor = src
# scatter.value
if isinstance(src, int) or isinstance(src, float):
src_tensor = get_trt_tensor(
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
)
src_tensor = cast_trt_tensor(
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
)
# scatter.src
elif not (isinstance(src, TRTTensor)):
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")

scatter_layer = ctx.net.add_scatter(
input, index, src_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
49 changes: 39 additions & 10 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# type: ignore

import logging
import time
import unittest
Expand Down Expand Up @@ -50,16 +48,20 @@ def setUp(self):
def run_test(
self,
mod,
inputs,
fx_inputs,
trt_interpreter_inputs,
interpreter,
rtol,
atol,
check_dtype=True,
):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
cuda_fx_inputs = []
cuda_trt_inputs = []
for i in trt_interpreter_inputs:
cuda_trt_inputs.append(i.cuda())
for i in fx_inputs:
cuda_fx_inputs.append(i.cuda())

mod.eval()
start = time.perf_counter()
Expand All @@ -73,13 +75,13 @@ def run_test(
)

mod = mod.cuda()
ref_outputs = mod(*cuda_inputs)
ref_outputs = mod(*cuda_fx_inputs)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
outputs = trt_mod(*cuda_inputs)
outputs = trt_mod(*cuda_trt_inputs)
end_event.record()
torch.cuda.synchronize()
_LOGGER.info(
Expand Down Expand Up @@ -220,6 +222,7 @@ def run_test(
check_dtype=True,
use_dynamo_tracer=False,
enable_passes=False,
int32_reqd=False,
):
mod.eval()
mod = self.generate_graph(
Expand All @@ -237,6 +240,30 @@ def run_test(
debug=True,
)

num_inputs = len(inputs)
trt_inputs = inputs
dtype_to_change = []
if int32_reqd:
dtype_to_change = [torch.int64, torch.float64]
else:
dtype_to_change = [
torch.float64,
]
for num_input in range(num_inputs):
input = inputs[num_input]
if input.dtype in dtype_to_change:
dtype_32bit = (
torch.float32 if (input.dtype == torch.float64) else torch.int32
)
trt_inputs = (
list(trt_inputs[:num_input])
+ [
input.to(dtype_32bit),
]
+ list(trt_inputs[num_input + 1 :])
)

trt_input_specs = [Input.from_tensor(i) for i in trt_inputs]
input_specs = [Input.from_tensor(i) for i in inputs]

output_dtypes = None
Expand All @@ -254,13 +281,15 @@ def run_test(

interp = TRTInterpreter(
mod,
input_specs,
trt_input_specs,
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)

super().run_test(
mod,
inputs,
trt_inputs,
interp,
rtol,
atol,
Expand Down Expand Up @@ -335,4 +364,4 @@ def run_test_with_dynamic_shape(
# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
super().run_test(mod, inputs_max, interp, rtol, atol)
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)
177 changes: 177 additions & 0 deletions tests/py/dynamo/conversion/test_scatter_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestScatterValueConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_value",
0,
torch.tensor([[0, 1, 2, 0]]),
1,
),
(
"scatter_zero_dim_indexTwo_constant_value",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
(
"scatter_one_dim_indexOne_constant_value",
1,
torch.tensor([[0, 1, 2, 0]]),
1,
),
(
"scatter_one_dim_indexTwo_costant_value",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
]
)
def test_scatter_index_constant(self, _, dim, index, value):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
self.run_test(TestModule(), inputs, int32_reqd=True)

@parameterized.expand(
[
("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1),
(
"scatter_zero_dim_indexTwo_value",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1),
(
"scatter_one_dim_indexTwo_value",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
]
)
def test_scatter_index_input(self, _, dim, index, value):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs, int32_reqd=True)


class TestScatterSrcConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
]
)
def test_scatter_index_constant(self, _, dim, index, src):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter.src(input, dim, index, src)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
scatter = TestModule()
self.run_test(TestModule(), inputs, int32_reqd=True)

@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_constant_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
),
]
)
def test_scatter_index_input(self, _, dim, index, src):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.scatter.src(input, dim, index, src)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs, int32_reqd=True)


if __name__ == "__main__":
run_tests()

0 comments on commit 35b5d03

Please sign in to comment.