-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[v1.x] Add onnx export support for where and greater_scalar operators. #19745
Conversation
Hey @josephevans , Thanks for submitting the PR
CI supported jobs: [windows-cpu, edge, unix-cpu, unix-gpu, sanity, centos-cpu, clang, centos-gpu, windows-gpu, miscellaneous, website] Note: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -378,3 +378,27 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params): | |||
x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8)) | |||
M = def_model('contrib.BilinearResize2D', **params) | |||
op_export_test('contrib_BilinearResize2D', M, [x], tmp_path) | |||
|
|||
|
|||
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also try float16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update to work with float16 and also test it :) Thanks for the help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Description
Add onnx export support and unit tests for the "where" and "greater_scalar" operators.