From 2158a4f7495fcf421fb54f4bec8207cc5cff22e8 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 3 May 2021 09:09:53 +0800 Subject: [PATCH 1/2] fix numpy op fallback when ndarray in kwargs --- python/mxnet/numpy/multiarray.py | 37 +++++++++++++++----------- tests/python/unittest/test_numpy_op.py | 14 ++++++++++ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index dd6504ef8fb7..6882091d76ac 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -198,24 +198,32 @@ def _as_mx_np_array(object, ctx=None, zero_copy=False): raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object)))) -def _as_onp_array(object): +def _as_onp_array(object, cur_ctx=None): """Convert object to mxnet.numpy.ndarray.""" - cur_ctx = None + def _update_ctx(cur_ctx, tmp_ctx): + if cur_ctx is None: + cur_ctx = tmp_ctx + elif tmp_ctx is not None and cur_ctx != tmp_ctx: + raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args + ' input ndarrays are allocated on different devices: {} and {}' + .format(str(cur_ctx, tmp_ctx))) + return cur_ctx + if isinstance(object, ndarray): return object.asnumpy(), object.ctx elif isinstance(object, (list, tuple)): tmp = [] for arr in object: - arr, tmp_ctx = _as_onp_array(arr) - # if isinstance(arr, (list, tuple)): - # raise TypeError('type {} not supported'.format(str(type(arr)))) + arr, tmp_ctx = _as_onp_array(arr, cur_ctx) tmp.append(arr) - if cur_ctx is None: - cur_ctx = tmp_ctx - elif tmp_ctx is not None and cur_ctx != tmp_ctx: - raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args - ' input ndarrays are allocated on different devices: {} and {}' - .format(str(cur_ctx, tmp_ctx))) + cur_ctx = _update_ctx(cur_ctx, tmp_ctx) + return object.__class__(tmp), cur_ctx + elif isinstance(object, dict): + tmp = dict() + for key, value in object.items(): + value, tmp_ctx = _as_onp_array(value, cur_ctx) + tmp[key] = value + cur_ctx = _update_ctx(cur_ctx, tmp_ctx) return object.__class__(tmp), cur_ctx else: return object, cur_ctx @@ -377,13 +385,12 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- raise ValueError("Falling back to NumPy operator {} with autograd active is not supported." "Please consider moving the operator to the outside of the autograd scope.")\ .format(func) - new_args, cur_ctx = _as_onp_array(args) + cur_ctx = None + new_args, cur_ctx = _as_onp_array(args, cur_ctx) + new_kwargs, cur_ctx = _as_onp_array(kwargs, cur_ctx) if cur_ctx is None: raise ValueError('Unknown context for the input ndarrays. It is probably a bug. Please' ' create an issue on GitHub.') - new_kwargs = {} - for k, v in kwargs.items(): - new_kwargs[k] = v.asnumpy() if isinstance(v, ndarray) else v if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD: import logging logging.warning("np.%s is a fallback operator, " diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1e253f20923d..257ea4711b4f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -10349,3 +10349,17 @@ def test_broadcast_like_different_types(): z = mx.npx.broadcast_like(x, y, 1, 1) assert_almost_equal(z.asnumpy(), np.array([[0,0],[0,0]])) assert x.dtype == z.dtype + + +@use_np +def test_np_apply_along_axis_fallback(): + data = np.array([[1, 2, 3.], [4., 5., 6]]) + axis = 1 + func1d = lambda x: x.mean() + np_y = _np.apply_along_axis(func1d, 1, data.asnumpy()) + y1 = np.apply_along_axis(func1d, 1, data) + y2 = np.apply_along_axis(func1d, 1, arr=data) + assert_almost_equal(y1.asnumpy(), np_y) + assert y1.asnumpy().dtype == np_y.dtype + assert_almost_equal(y2.asnumpy(), np_y) + assert y2.asnumpy().dtype == np_y.dtype From 09717aef582a315ad8cdf2ce7839e04b0eb56341 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 3 May 2021 15:46:16 +0800 Subject: [PATCH 2/2] update note --- python/mxnet/numpy/multiarray.py | 2 +- tests/python/unittest/test_numpy_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 6882091d76ac..b64c170cd7df 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -199,7 +199,7 @@ def _as_mx_np_array(object, ctx=None, zero_copy=False): def _as_onp_array(object, cur_ctx=None): - """Convert object to mxnet.numpy.ndarray.""" + """Convert object to numpy.ndarray.""" def _update_ctx(cur_ctx, tmp_ctx): if cur_ctx is None: cur_ctx = tmp_ctx diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 257ea4711b4f..ba8e3278330b 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -10353,7 +10353,7 @@ def test_broadcast_like_different_types(): @use_np def test_np_apply_along_axis_fallback(): - data = np.array([[1, 2, 3.], [4., 5., 6]]) + data = np.random.randint(-100, 100, (2, 3)) axis = 1 func1d = lambda x: x.mean() np_y = _np.apply_along_axis(func1d, 1, data.asnumpy())