Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.index_select converter #2710

Merged
merged 4 commits into from
Apr 12, 2024
Merged

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Mar 25, 2024

Description

New feature to support aten.index_select converter. I also add test case for different dimensions.

Fixes # (#2708)

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 requested review from zewenli98 and apbose March 25, 2024 11:19
@chohk88 chohk88 self-assigned this Mar 25, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 25, 2024
@github-actions github-actions bot requested a review from gs-olive March 25, 2024 11:19
@chohk88 chohk88 linked an issue Mar 25, 2024 that may be closed by this pull request
index: TRTTensor,
) -> TRTTensor:
# The axis parameter specifies the dimension along which to index.
gather_layer = ctx.net.add_gather(input, index, axis=dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim likely needs to be corrected using get_positive_dim to ensure the value is positive for add_gather

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified it. Thanks!

("2d_input_dim_0", (10, 3), 0, (0, 2)),
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case for a negative dim input

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a test case for a negative dim input and verified a test case. Thank you!

kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.index.index_select(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the index_select function could be put into select.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved index_select inside select.py. Thank you!

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -12,6 +12,7 @@
elementwise,
embedding,
grid,
index,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can likely be removed - it seems to be causing a circular import error in CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It seems I overlooked removing an unnecessary import.

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@narendasan narendasan merged commit cec3835 into main Apr 12, 2024
16 of 21 checks passed
@narendasan narendasan deleted the aten_index_select_converter branch April 12, 2024 00:46
HolyWu added a commit to HolyWu/TensorRT that referenced this pull request Apr 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.index_select
5 participants