diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 86f15d06a1..95ad6d5f36 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2829,3 +2829,28 @@ def aten_ops_roll( args[1], args_bounds_check(args, 2, []), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.index_select.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 2: (TRTTensor,), + } +) +def aten_ops_index_select( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.index_select( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index db586be65f..470abb8f48 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -90,7 +90,7 @@ def index( # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( - f"Determining whether aten.index constant-index optimization can be invoked" + "Determining whether aten.index constant-index optimization can be invoked" ) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None @@ -123,7 +123,7 @@ def index( return identity_layer.get_output(0) elif len(tensor_indices) == 1: indices_tensor = get_trt_tensor( - ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor" + ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor" ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") @@ -204,7 +204,7 @@ def index( cum_adv_index = cum_adv_index + adv_index multiplier = multiplier * input_shape[adv_indx_indices[i]] cum_adv_index = get_trt_tensor( - ctx, cum_adv_index, name + f"_index_sum_intermediate" + ctx, cum_adv_index, name + "_index_sum_intermediate" ) else: multiplier = get_trt_tensor( @@ -263,7 +263,7 @@ def index( adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 ): - _LOGGER.debug(f"The indices are continuous in this case") + _LOGGER.debug("The indices are continuous in this case") concat_tensor_reshape.append( get_trt_tensor(ctx, -1, name + "_dynamic_concat") ) @@ -287,7 +287,7 @@ def index( source_ir, ) unfold_tensor = regular_index_shuffle_layer.get_output(0) - _LOGGER.debug(f"The tensor is unfolded now") + _LOGGER.debug("The tensor is unfolded now") _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") # Transpose folded advanced indexed axis to its original location. @@ -342,7 +342,7 @@ def index( reshape_output = unfold_advanced_shuffle_layer.get_output(0) else: - _LOGGER.debug(f"The indices are not continuous in this case") + _LOGGER.debug("The indices are not continuous in this case") concat_final_tensor = [] concat_final_tensor.append(cum_adv_index_shape_tensor) for i in range(0, rank): @@ -370,3 +370,21 @@ def index( reshape_output = reshape_layer.get_output(0) return reshape_output + + +def index_select( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, +) -> TRTTensor: + # The axis parameter specifies the dimension along which to index. + dim = get_positive_dim(dim, len(input.shape)) + gather_layer = ctx.net.add_gather(input, index, axis=dim) + + set_layer_name(gather_layer, target, f"{name}_gather", source_ir) + + return gather_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py new file mode 100644 index 0000000000..83eaedb944 --- /dev/null +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIndexSelectConverter(DispatchTestCase): + @parameterized.expand( + [ + ("1d_input", (10,), 0, (1,)), + ("2d_input_dim_0", (10, 3), 0, (0, 2)), + ("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), + ("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)), + ("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), + ("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), + ("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)), + ("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)), + ] + ) + def test_index_select(self, _, source_shape, dim, indices_val): + class TestIndexSelect(torch.nn.Module): + def forward(self, source_tensor, indices_tensor): + return torch.ops.aten.index_select.default( + source_tensor, dim, indices_tensor + ) + + input = [ + torch.randn(*source_shape, dtype=torch.float32), + torch.tensor([*indices_val], dtype=torch.int32), + ] + + self.run_test( + TestIndexSelect(), + input, + ) + + +if __name__ == "__main__": + run_tests()