Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Add onnx export support for one_hot and random_uniform_like an…
Browse files Browse the repository at this point in the history
…d unit tests for one_hot. (#19952)

* Add onnx export support and unit tests for one_hot.

* Add onnx export function for random_uniform_like.

* Fix lint.

* Update tests and use correct dtype for on and off values.

Co-authored-by: Joe Evans <[email protected]>
Co-authored-by: Zhaoqi Zhu <[email protected]>
  • Loading branch information
3 people authored Mar 2, 2021
1 parent 9b8a5e6 commit 8493c33
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
41 changes: 41 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,6 +4035,47 @@ def convert_argsort(node, **kwargs):
return nodes


@mx_op.register('one_hot')
def convert_one_hot(node, **kwargs):
"""Map MXNet's one_hot operator attributes to onnx's OneHot operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

depth = int(attrs.get('depth'))
on_value = float(attrs.get('on_value', 1.))
off_value = float(attrs.get('off_value', 0.))
dtype = attrs.get('dtype', 'float32')

create_tensor([off_value, on_value], name+'_values', kwargs['initializer'], dtype=np.dtype(dtype))
create_tensor([depth], name+'_depth', kwargs['initializer'])
nodes = [
make_node('OneHot', [input_nodes[0], name+'_depth', name+'_values'], [name], name=name)
]

return nodes


@mx_op.register('_random_uniform_like')
def convert_random_uniform_like(node, **kwargs):
"""Map MXNet's random_uniform_like operator attributes to onnx's RandomUniformLike operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

low = float(attrs.get('low', 0.))
high = float(attrs.get('high', 1.))
dtype = attrs.get('dtype', 'float32')

nodes = [
make_node('RandomUniformLike', [input_nodes[0]], [name], name=name,
dtype=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
low=low, high=high)
]

return nodes


@mx_op.register('SequenceReverse')
def convert_sequence_reverse(node, **kwargs):
"""Map MXNet's SequenceReverse op
Expand Down
10 changes: 10 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,16 @@ def test_onnx_export_take_raise(tmp_path, dtype, axis):
op_export_test('take', M, [x, y], tmp_path)


# onnxruntime currently does not support int32
@pytest.mark.parametrize("dtype", ["float16", "float32", "int64"])
@pytest.mark.parametrize("depth", [1, 3, 5, 10])
@pytest.mark.parametrize("shape", [(1,1), (1,5), (5,5), (3,4,5)])
def test_onnx_export_one_hot(tmp_path, dtype, depth, shape):
M = def_model('one_hot', depth=depth, dtype=dtype)
x = mx.random.randint(0, 10, shape).astype('int64')
op_export_test('one_hot', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('params', [((6, 5, 4), [1, 2, 4, 5, 6]),
((7, 3, 5), [1, 7, 4]),
Expand Down

0 comments on commit 8493c33

Please sign in to comment.