-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support aten.roll dynamo converter #2569
Conversation
5bc8314
to
5be128d
Compare
5be128d
to
d58f865
Compare
from .harness import DispatchTestCase | ||
|
||
|
||
class TestRollConverter(DispatchTestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a case where both shifts
and dims
are single integers, which is a supported case in the docstring. These may be casted to lists in the operator before the converter ever gets them, but it is still a valid input I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review! I found the schema is:
- func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
Does this mean shifts
and dims
should be a 1d list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
roll(tensor, 2, 3)
--> roll(tensor, [2], [3])
pool(3)
--> pool([3, 3])
To share additional documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive Thanks for the details!
Unfortunately, when testing shifts=2, dims=0
, I got error:
File "<eval_with_key>.0 from /home/zewenl/TensorRT/tests/py/dynamo/conversion/test_roll_aten.py:35 in forward", line 5, in forward
roll_default = torch.ops.aten.roll.default(x, 2, 0); x = None
File "/home/zewenl/.local/lib/python3.10/site-packages/torch/_ops.py", line 571, in __call__
return self_._op(*args, **(kwargs or {}))
RuntimeError: aten::roll() Expected a value of type 'List[int]' for argument 'shifts' but instead found type 'int'.
Position: 1
Value: 2
Declaration: aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
Python error details: TypeError: 'int' object is not iterable
Then, I also tested shifts=(2,), dims=0
, it works.
It seems that pytorch requires shifts
to be a list.
According to the schema - func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
, I guess SymInt[1]
and int[1]
may have different behaviors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be so, yes, though that is strange - thanks for the update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm working on adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
converter, output_size
expects List[int]
as well:
RuntimeError: aten::adaptive_avg_pool2d() Expected a value of type 'List[int]' for argument 'output_size' but instead found type 'int'.
Position: 1
Value: 3
Declaration: aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
d58f865
to
6e19c81
Compare
6e19c81
to
a10ac64
Compare
Description
Support
aten.roll
dynamo converter.Fixes #2567
Type of change
Checklist: