Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] fix numpy op fallback bug when ndarray in kwargs #20233

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,32 @@ def _as_mx_np_array(object, ctx=None, zero_copy=False):
raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))


def _as_onp_array(object):
def _as_onp_array(object, cur_ctx=None):
wkcn marked this conversation as resolved.
Show resolved Hide resolved
"""Convert object to mxnet.numpy.ndarray."""
cur_ctx = None
def _update_ctx(cur_ctx, tmp_ctx):
if cur_ctx is None:
cur_ctx = tmp_ctx
elif tmp_ctx is not None and cur_ctx != tmp_ctx:
raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args
' input ndarrays are allocated on different devices: {} and {}'
.format(str(cur_ctx, tmp_ctx)))
return cur_ctx

if isinstance(object, ndarray):
return object.asnumpy(), object.ctx
elif isinstance(object, (list, tuple)):
tmp = []
for arr in object:
arr, tmp_ctx = _as_onp_array(arr)
# if isinstance(arr, (list, tuple)):
# raise TypeError('type {} not supported'.format(str(type(arr))))
arr, tmp_ctx = _as_onp_array(arr, cur_ctx)
tmp.append(arr)
if cur_ctx is None:
cur_ctx = tmp_ctx
elif tmp_ctx is not None and cur_ctx != tmp_ctx:
raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args
' input ndarrays are allocated on different devices: {} and {}'
.format(str(cur_ctx, tmp_ctx)))
cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
return object.__class__(tmp), cur_ctx
elif isinstance(object, dict):
tmp = dict()
for key, value in object.items():
value, tmp_ctx = _as_onp_array(value, cur_ctx)
tmp[key] = value
cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
return object.__class__(tmp), cur_ctx
else:
return object, cur_ctx
Expand Down Expand Up @@ -377,13 +385,12 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
"Please consider moving the operator to the outside of the autograd scope.")\
.format(func)
new_args, cur_ctx = _as_onp_array(args)
cur_ctx = None
new_args, cur_ctx = _as_onp_array(args, cur_ctx)
new_kwargs, cur_ctx = _as_onp_array(kwargs, cur_ctx)
if cur_ctx is None:
raise ValueError('Unknown context for the input ndarrays. It is probably a bug. Please'
' create an issue on GitHub.')
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.asnumpy() if isinstance(v, ndarray) else v
if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10349,3 +10349,17 @@ def test_broadcast_like_different_types():
z = mx.npx.broadcast_like(x, y, 1, 1)
assert_almost_equal(z.asnumpy(), np.array([[0,0],[0,0]]))
assert x.dtype == z.dtype


@use_np
def test_np_apply_along_axis_fallback():
data = np.array([[1, 2, 3.], [4., 5., 6]])
wkcn marked this conversation as resolved.
Show resolved Hide resolved
axis = 1
func1d = lambda x: x.mean()
np_y = _np.apply_along_axis(func1d, 1, data.asnumpy())
y1 = np.apply_along_axis(func1d, 1, data)
y2 = np.apply_along_axis(func1d, 1, arr=data)
assert_almost_equal(y1.asnumpy(), np_y)
assert y1.asnumpy().dtype == np_y.dtype
assert_almost_equal(y2.asnumpy(), np_y)
assert y2.asnumpy().dtype == np_y.dtype