Skip to content

Commit

Permalink
feat: support aten.index_select converter (#2710)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Apr 12, 2024
1 parent 821ff91 commit cec3835
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 6 deletions.
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2829,3 +2829,28 @@ def aten_ops_roll(
args[1],
args_bounds_check(args, 2, []),
)


@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
def aten_ops_index_select(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.index_select(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
)
30 changes: 24 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def index(
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
Expand Down Expand Up @@ -123,7 +123,7 @@ def index(
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
indices_tensor = get_trt_tensor(
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
Expand Down Expand Up @@ -204,7 +204,7 @@ def index(
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
ctx, cum_adv_index, name + "_index_sum_intermediate"
)
else:
multiplier = get_trt_tensor(
Expand Down Expand Up @@ -263,7 +263,7 @@ def index(
adv_indx_count
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
):
_LOGGER.debug(f"The indices are continuous in this case")
_LOGGER.debug("The indices are continuous in this case")
concat_tensor_reshape.append(
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
)
Expand All @@ -287,7 +287,7 @@ def index(
source_ir,
)
unfold_tensor = regular_index_shuffle_layer.get_output(0)
_LOGGER.debug(f"The tensor is unfolded now")
_LOGGER.debug("The tensor is unfolded now")
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")

# Transpose folded advanced indexed axis to its original location.
Expand Down Expand Up @@ -342,7 +342,7 @@ def index(
reshape_output = unfold_advanced_shuffle_layer.get_output(0)

else:
_LOGGER.debug(f"The indices are not continuous in this case")
_LOGGER.debug("The indices are not continuous in this case")
concat_final_tensor = []
concat_final_tensor.append(cum_adv_index_shape_tensor)
for i in range(0, rank):
Expand Down Expand Up @@ -370,3 +370,21 @@ def index(
reshape_output = reshape_layer.get_output(0)

return reshape_output


def index_select(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: TRTTensor,
) -> TRTTensor:
# The axis parameter specifies the dimension along which to index.
dim = get_positive_dim(dim, len(input.shape))
gather_layer = ctx.net.add_gather(input, index, axis=dim)

set_layer_name(gather_layer, target, f"{name}_gather", source_ir)

return gather_layer.get_output(0)
41 changes: 41 additions & 0 deletions tests/py/dynamo/conversion/test_index_select_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestIndexSelectConverter(DispatchTestCase):
@parameterized.expand(
[
("1d_input", (10,), 0, (1,)),
("2d_input_dim_0", (10, 3), 0, (0, 2)),
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)),
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)),
("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)),
]
)
def test_index_select(self, _, source_shape, dim, indices_val):
class TestIndexSelect(torch.nn.Module):
def forward(self, source_tensor, indices_tensor):
return torch.ops.aten.index_select.default(
source_tensor, dim, indices_tensor
)

input = [
torch.randn(*source_shape, dtype=torch.float32),
torch.tensor([*indices_val], dtype=torch.int32),
]

self.run_test(
TestIndexSelect(),
input,
)


if __name__ == "__main__":
run_tests()

0 comments on commit cec3835

Please sign in to comment.