-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support aten index_put converter for accumulate=False #2880
Conversation
) -> TRTTensor: | ||
# Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors) | ||
reshaped_indices = [] | ||
for i, each_input in enumerate(indices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since indices
is possible to be ITensor
per the schema, you may not be able to iterate an ITensor.
In the test case, you can try to change the line 173 to inputs=[source_tensor, indices_tensor, value_tensor],
.
It's kind of similar to the offsets
in the annoying embedding_bag
. You can think about how to use native TRT Layers to do this, like ILoop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, what blocks you when accumulate=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your review. When indices
is a torch.tensor
, an error occurs in PyTorch as shown in the example below. This situation is somewhat different from embedding_bag
. It is a case where the input is a tuple of tensors
, which we discussed earlier.
If you look at the example, the index_put_
function throws an error when indices
is of torch.tensor
type and only works correctly when indices
is a tuple
or list
.
Therefore, indices
can be iterated over for loop and I did not use a for loop for each_input
since it is an ITensor
. If I am mistaken, your comments would be very helpful.
One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as Union[TRTTensor, Tuple[TRTTensor, ...]]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When accumulate=True
, if there are duplicate pairs of index in indices
, the corresponding values
should be summed and then removed from the elements. Therefore, I aimed to obtain indices
without duplicated pairs and corresponding modified values
, and then use these to input into the scatter layer. However, I encountered difficulties in implementing the for loop to check for duplicate pairs of index in indices
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanations! Yes you are right, the indices
should be list or tuple, and thus it could be iterated over. Then your current implementation LGTM.
One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as Union[TRTTensor, Tuple[TRTTensor, ...]]?
I think it could be Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]
since a single TRTTensor cannot be iterated and per the schema, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you say accumulate=True
is causing issue, I believe the duplicate indices causes issues. I faced the same in scatter_reduce
and I believe advanced indexing would be the way to deal with it (lengthy code that would be I believe :( ). Do you have any other ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have write a validator to handle the accumulate=True
case. And I have created a separate issue for implementing the converter for accumulate=True
. It would be great to share ideas and work together on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks good to me. Can you add a test case like this -
tensor = torch.zeros([4, 4, 4, 4], dtype = torch.int32)
indices = (torch.tensor([0, 1, 2, 3]), torch.tensor([2, 3, 1, 0]))
values = torch.tensor([10, 20, 30, 40], dtype = torch.int32)
out = torch.index_put_(tensor, indices, values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets write a validator for this case and resolve in a new PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have write validator for broadcasting.
77c1d8f
to
663cc02
Compare
5f6f2b2
to
c72222a
Compare
Description
I have implemented the
aten::index_put
operation using theadd_scatter
layer withtrt.ScatterMode.ND
. However, I was unable to implement theaccumulate=True
case, which is currently handled by the validator.Fixes # (issue)
Type of change
Checklist: