From 05551d0454eb6e89a444df9229edf84e70553ef2 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 16 Apr 2020 12:32:44 -0700 Subject: [PATCH] [Numpy] Add ffi for np.sum, np.std, np.var, np.average and np.histogram (#17866) * add ffi for sum, var and std * add ffi wrapper for np.average * add ffi wrapper for np.histogram --- benchmark/python/ffi/benchmark_ffi.py | 5 + include/mxnet/runtime/ffi_helper.h | 18 ++ include/mxnet/runtime/object.h | 1 + python/mxnet/_ffi/_cython/convert.pxi | 6 + python/mxnet/_ffi/node_generic.py | 2 + python/mxnet/_numpy_op_doc.py | 92 -------- python/mxnet/ndarray/numpy/_op.py | 114 +++++++++- python/mxnet/numpy/multiarray.py | 104 ++++++++- python/mxnet/symbol/numpy/_symbol.py | 51 ++++- python/mxnet/symbol/numpy/linalg.py | 8 +- src/api/_api_internal/_api_internal.cc | 10 + src/api/operator/numpy/np_bincount_op.cc | 4 +- .../numpy/np_broadcast_reduce_op_value.cc | 67 +++++- src/api/operator/numpy/np_cumsum.cc | 4 +- src/api/operator/numpy/np_histogram_op.cc | 81 +++++++ src/api/operator/numpy/np_moments_op.cc | 209 ++++++++++++++++++ src/api/operator/numpy/np_tensordot_op.cc | 4 +- src/api/operator/utils.h | 10 + src/operator/numpy/np_broadcast_reduce_op.h | 32 ++- .../numpy/np_broadcast_reduce_op_value.cc | 22 +- .../numpy/np_broadcast_reduce_op_value.cu | 4 +- src/operator/tensor/histogram-inl.h | 42 ++-- 22 files changed, 739 insertions(+), 151 deletions(-) create mode 100644 src/api/operator/numpy/np_histogram_op.cc create mode 100644 src/api/operator/numpy/np_moments_op.cc diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index bc55164d9663..1e911598bbb1 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -60,6 +60,11 @@ def prepare_workloads(): OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1))) OpArgMngr.add_workload("kron", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2']) + OpArgMngr.add_workload("sum", pool['2x2'], axis=0, keepdims=True, out=pool['1x2']) + OpArgMngr.add_workload("std", pool['2x2'], axis=0, ddof=0, keepdims=True, out=pool['1x2']) + OpArgMngr.add_workload("var", pool['2x2'], axis=0, ddof=1, keepdims=True, out=pool['1x2']) + OpArgMngr.add_workload("average", pool['2x2'], weights=pool['2'], axis=1, returned=True) + OpArgMngr.add_workload("histogram", pool['2x2'], bins=10, range=(0.0, 10.0)) OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("linalg.eig", pool['3x3']) OpArgMngr.add_workload("linalg.eigh", pool['3x3']) diff --git a/include/mxnet/runtime/ffi_helper.h b/include/mxnet/runtime/ffi_helper.h index 49134ca122a7..cfc79a6c4f47 100644 --- a/include/mxnet/runtime/ffi_helper.h +++ b/include/mxnet/runtime/ffi_helper.h @@ -99,6 +99,24 @@ class Integer: public ObjectRef { MXNET_DEFINE_OBJECT_REF_METHODS(Integer, ObjectRef, IntegerObj) }; +class FloatObj: public Object { + public: + double value; + static constexpr const uint32_t _type_index = TypeIndex::kFloat; + static constexpr const char* _type_key = "MXNet.Float"; + MXNET_DECLARE_FINAL_OBJECT_INFO(FloatObj, Object) +}; + +class Float: public ObjectRef { + public: + explicit Float(double value, + ObjectPtr&& data = make_object()) { + data->value = value; + data_ = std::move(data); + } + MXNET_DEFINE_OBJECT_REF_METHODS(Float, ObjectRef, FloatObj) +}; + // Helper functions for fast FFI implementations /*! * \brief A builder class that helps to incrementally build ADT. diff --git a/include/mxnet/runtime/object.h b/include/mxnet/runtime/object.h index a031a56d88ed..48c9badb3ba7 100644 --- a/include/mxnet/runtime/object.h +++ b/include/mxnet/runtime/object.h @@ -58,6 +58,7 @@ enum TypeIndex { kEllipsis = 5, kSlice = 6, kInteger = 7, + kFloat = 8, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd diff --git a/python/mxnet/_ffi/_cython/convert.pxi b/python/mxnet/_ffi/_cython/convert.pxi index 2cbdc48b49a8..d7b1ea5659dc 100644 --- a/python/mxnet/_ffi/_cython/convert.pxi +++ b/python/mxnet/_ffi/_cython/convert.pxi @@ -43,6 +43,10 @@ cdef extern from "mxnet/runtime/ffi_helper.h" namespace "mxnet::runtime": Integer() Integer(int64_t) + cdef cppclass Float(ObjectRef): + Float() + Float(double) + cdef inline ADT convert_tuple(tuple src_tuple) except *: cdef uint32_t size = len(src_tuple) @@ -71,5 +75,7 @@ cdef inline ObjectRef convert_object(object src_obj) except *: return convert_list(src_obj) elif isinstance(src_obj, Integral): return Integer(src_obj) + elif isinstance(src_obj, float): + return Float(src_obj) else: raise TypeError("Don't know how to convert type %s" % type(src_obj)) diff --git a/python/mxnet/_ffi/node_generic.py b/python/mxnet/_ffi/node_generic.py index c7f332390ce7..07b4825654d1 100644 --- a/python/mxnet/_ffi/node_generic.py +++ b/python/mxnet/_ffi/node_generic.py @@ -52,6 +52,8 @@ def convert_to_node(value): """ if isinstance(value, Integral): return _api_internal._Integer(value) + elif isinstance(value, float): + return _api_internal._Float(value) elif isinstance(value, (list, tuple)): value = [convert_to_node(x) for x in value] return _api_internal._ADT(*value) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 8341d43608ce..857b87a7586f 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -231,98 +231,6 @@ def _np_dot(a, b, out=None): pass -def _np_sum(a, axis=None, dtype=None, keepdims=False, initial=None, out=None): - r""" - Sum of array elements over a given axis. - - Parameters - ---------- - a : ndarray - Input data. - axis : None or int, optional - Axis or axes along which a sum is performed. The default, - axis=None, will sum all of the elements of the input array. If - axis is negative it counts from the last to the first axis. - dtype : dtype, optional - The type of the returned array and of the accumulator in which the - elements are summed. The default type is float32. - keepdims : bool, optional - If this is set to True, the axes which are reduced are left - in the result as dimensions with size one. With this option, - the result will broadcast correctly against the input array. - - If the default value is passed, then `keepdims` will not be - passed through to the `sum` method of sub-classes of - `ndarray`, however any non-default value will be. If the - sub-classes `sum` method does not implement `keepdims` any - exceptions will be raised. - initial: Currently only supports None as input, optional - Starting value for the sum. - Currently not implemented. Please use ``None`` as input or skip this argument. - out : ndarray or None, optional - Alternative output array in which to place the result. It must have - the same shape and dtype as the expected output. - - Returns - ------- - sum_along_axis : ndarray - An ndarray with the same shape as `a`, with the specified - axis removed. If an output array is specified, a reference to - `out` is returned. - - Notes - ----- - - Input type does not support Python native iterables. - - "out" param: cannot perform auto type change. out ndarray's dtype must be the same as the expected output. - - "initial" param is not supported yet. Please use None as input. - - Arithmetic is modular when using integer types, and no error is raised on overflow. - - The sum of an empty array is the neutral element 0: - - >>> a = np.empty(1) - >>> np.sum(a) - array(0.) - - This function differs from the original `numpy.sum - `_ in - the following aspects: - - - Input type does not support Python native iterables(list, tuple, ...). - - "out" param: cannot perform auto type cast. out ndarray's dtype must be the same as the expected output. - - "initial" param is not supported yet. Please use ``None`` as input or skip it. - - Examples - -------- - >>> a = np.array([0.5, 1.5]) - >>> np.sum(a) - array(2.) - >>> a = np.array([0.5, 0.7, 0.2, 1.5]) - >>> np.sum(a, dtype=np.int32) - array(2, dtype=int32) - >>> a = np.array([[0, 1], [0, 5]]) - >>> np.sum(a) - array(6.) - >>> np.sum(a, axis=0) - array([0., 6.]) - >>> np.sum(a, axis=1) - array([1., 5.]) - - With output ndarray: - - >>> a = np.array([[0, 1], [0, 5]]) - >>> b = np.ones((2,), dtype=np.float32) - >>> np.sum(a, axis = 0, out=b) - array([0., 6.]) - >>> b - array([0., 6.]) - - If the accumulator is too small, overflow occurs: - - >>> np.ones(128, dtype=np.int8).sum(dtype=np.int8) - array(-128, dtype=int8) - """ - pass - - def _np_copy(a, out=None): """ Return an array copy of the given object. diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index c1ce9092904a..3074cf2b3502 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -48,7 +48,7 @@ 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'atleast_1d', 'atleast_2d', 'atleast_3d', - 'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'diag', 'diagonal'] + 'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'sum', 'diag', 'diagonal'] @set_module('mxnet.ndarray.numpy') @@ -1739,13 +1739,13 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): if isinstance(bins, numeric_types): if range is None: raise NotImplementedError("automatic range is not supported yet...") - return _npi.histogram(a, bin_cnt=bins, range=range) + return tuple(_api_internal.histogram(a, None, bins, range)) if isinstance(bins, (list, tuple)): raise NotImplementedError("array_like bins is not supported yet...") if isinstance(bins, str): raise NotImplementedError("string bins is not supported yet...") if isinstance(bins, NDArray): - return _npi.histogram(a, bins=bins) + return tuple(_api_internal.histogram(a, bins, None, None)) raise ValueError("np.histogram fails with", locals()) @@ -4859,10 +4859,7 @@ def average(a, axis=None, weights=None, returned=False, out=None): >>> np.average(data, axis=1, weights=weights) array([0.75, 2.75, 4.75]) """ - if weights is None: - return _npi.average(a, axis=axis, weights=None, returned=returned, weighted=False, out=out) - else: - return _npi.average(a, axis=axis, weights=weights, returned=returned, out=out) + return _api_internal.average(a, weights, axis, returned, weights is not None, out) @set_module('mxnet.ndarray.numpy') @@ -4987,7 +4984,7 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: >>> np.std(a, dtype=np.float64) array(0.45, dtype=float64) """ - return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return _api_internal.std(a, axis, dtype, ddof, keepdims, out) @set_module('mxnet.ndarray.numpy') @@ -5057,7 +5054,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: >>> ((1-0.55)**2 + (0.1-0.55)**2)/2 0.2025 """ - return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return _api_internal.var(a, axis, dtype, ddof, keepdims, out) # pylint: disable=redefined-outer-name @@ -6257,7 +6254,7 @@ def outer(a, b): [-2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2.]]) """ - return tensordot(a.flatten(), b.flatten(), 0) + return tensordot(a.reshape_view((-1, )), b.reshape_view((-1, )), 0) @set_module('mxnet.ndarray.numpy') @@ -8427,3 +8424,100 @@ def diagonal(a, offset=0, axis1=0, axis2=1): [1, 7]]) """ return _api_internal.diagonal(a, offset, axis1, axis2) + + +# pylint:disable=redefined-outer-name, too-many-arguments +@set_module('mxnet.ndarray.numpy') +def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None): + r""" + Sum of array elements over a given axis. + + Parameters + ---------- + a : ndarray + Input data. + axis : None or int, optional + Axis or axes along which a sum is performed. The default, + axis=None, will sum all of the elements of the input array. If + axis is negative it counts from the last to the first axis. + dtype : dtype, optional + The type of the returned array and of the accumulator in which the + elements are summed. The default type is float32. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `sum` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-classes `sum` method does not implement `keepdims` any + exceptions will be raised. + initial: Currently only supports None as input, optional + Starting value for the sum. + Currently not implemented. Please use ``None`` as input or skip this argument. + out : ndarray or None, optional + Alternative output array in which to place the result. It must have + the same shape and dtype as the expected output. + + Returns + ------- + sum_along_axis : ndarray + An ndarray with the same shape as `a`, with the specified + axis removed. If an output array is specified, a reference to + `out` is returned. + + Notes + ----- + - Input type does not support Python native iterables. + - "out" param: cannot perform auto type change. out ndarray's dtype must be the same as the expected output. + - "initial" param is not supported yet. Please use None as input. + - Arithmetic is modular when using integer types, and no error is raised on overflow. + - The sum of an empty array is the neutral element 0: + + >>> a = np.empty(1) + >>> np.sum(a) + array(0.) + + This function differs from the original `numpy.sum + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - "out" param: cannot perform auto type cast. out ndarray's dtype must be the same as the expected output. + - "initial" param is not supported yet. Please use ``None`` as input or skip it. + + Examples + -------- + >>> a = np.array([0.5, 1.5]) + >>> np.sum(a) + array(2.) + >>> a = np.array([0.5, 0.7, 0.2, 1.5]) + >>> np.sum(a, dtype=np.int32) + array(2, dtype=int32) + >>> a = np.array([[0, 1], [0, 5]]) + >>> np.sum(a) + array(6.) + >>> np.sum(a, axis=0) + array([0., 6.]) + >>> np.sum(a, axis=1) + array([1., 5.]) + + With output ndarray: + + >>> a = np.array([[0, 1], [0, 5]]) + >>> b = np.ones((2,), dtype=np.float32) + >>> np.sum(a, axis=0, out=b) + array([0., 6.]) + >>> b + array([0., 6.]) + + If the accumulator is too small, overflow occurs: + + >>> np.ones(128, dtype=np.int8).sum(dtype=np.int8) + array(-128, dtype=int8) + """ + if where is not None and where is not True: + raise ValueError("only where=None or where=True cases are supported for now") + return _api_internal.sum(a, axis, dtype, keepdims, initial, out) +# pylint:enable=redefined-outer-name, too-many-arguments diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 57807f78e388..d8f7f4a29ce3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -73,7 +73,7 @@ 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', 'atleast_1d', 'atleast_2d', 'atleast_3d', - 'pad', 'cumsum', 'rollaxis', 'diag', 'diagonal'] + 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal'] __all__ += fallback.__all__ @@ -6843,7 +6843,7 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: >>> np.std(a, dtype=np.float64) array(0.45, dtype=float64) """ - return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return _mx_nd_np.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) # pylint: enable=redefined-outer-name @@ -6964,7 +6964,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: >>> ((1-0.55)**2 + (0.1-0.55)**2)/2 0.2025 """ - return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return _mx_nd_np.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) # pylint: disable=redefined-outer-name @@ -7127,6 +7127,7 @@ def ravel(x, order='C'): return _mx_nd_np.ravel(x, order) +@set_module('mxnet.numpy') def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-name """ Converts a flat index or array of flat indices into a tuple of coordinate arrays. @@ -7157,6 +7158,7 @@ def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer- return _mx_nd_np.unravel_index(indices, shape, order=order) +@set_module('mxnet.numpy') def flatnonzero(a): r""" Return indices that are non-zero in the flattened version of a. @@ -7196,6 +7198,7 @@ def flatnonzero(a): return _mx_nd_np.flatnonzero(a) +@set_module('mxnet.numpy') def diag_indices_from(arr): """ This returns a tuple of indices that can be used to access the main diagonal of an array @@ -10548,3 +10551,98 @@ def diagonal(a, offset=0, axis1=0, axis2=1): [1, 7]]) """ return _mx_nd_np.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) + + +# pylint: disable=redefined-outer-name, too-many-arguments +@set_module('mxnet.numpy') +def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None): + r""" + Sum of array elements over a given axis. + + Parameters + ---------- + a : ndarray + Input data. + axis : None or int, optional + Axis or axes along which a sum is performed. The default, + axis=None, will sum all of the elements of the input array. If + axis is negative it counts from the last to the first axis. + dtype : dtype, optional + The type of the returned array and of the accumulator in which the + elements are summed. The default type is float32. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `sum` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-classes `sum` method does not implement `keepdims` any + exceptions will be raised. + initial: Currently only supports None as input, optional + Starting value for the sum. + Currently not implemented. Please use ``None`` as input or skip this argument. + out : ndarray or None, optional + Alternative output array in which to place the result. It must have + the same shape and dtype as the expected output. + + Returns + ------- + sum_along_axis : ndarray + An ndarray with the same shape as `a`, with the specified + axis removed. If an output array is specified, a reference to + `out` is returned. + + Notes + ----- + - Input type does not support Python native iterables. + - "out" param: cannot perform auto type change. out ndarray's dtype must be the same as the expected output. + - "initial" param is not supported yet. Please use None as input. + - Arithmetic is modular when using integer types, and no error is raised on overflow. + - The sum of an empty array is the neutral element 0: + + >>> a = np.empty(1) + >>> np.sum(a) + array(0.) + + This function differs from the original `numpy.sum + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - "out" param: cannot perform auto type cast. out ndarray's dtype must be the same as the expected output. + - "initial" param is not supported yet. Please use ``None`` as input or skip it. + + Examples + -------- + >>> a = np.array([0.5, 1.5]) + >>> np.sum(a) + array(2.) + >>> a = np.array([0.5, 0.7, 0.2, 1.5]) + >>> np.sum(a, dtype=np.int32) + array(2, dtype=int32) + >>> a = np.array([[0, 1], [0, 5]]) + >>> np.sum(a) + array(6.) + >>> np.sum(a, axis=0) + array([0., 6.]) + >>> np.sum(a, axis=1) + array([1., 5.]) + + With output ndarray: + + >>> a = np.array([[0, 1], [0, 5]]) + >>> b = np.ones((2,), dtype=np.float32) + >>> np.sum(a, axis = 0, out=b) + array([0., 6.]) + >>> b + array([0., 6.]) + + If the accumulator is too small, overflow occurs: + + >>> np.ones(128, dtype=np.int8).sum(dtype=np.int8) + array(-128, dtype=int8) + """ + return _mx_nd_np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) +# pylint: enable=redefined-outer-name, too-many-arguments diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index aae0ed21efc1..bf1d31329c4e 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -54,7 +54,7 @@ 'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'atleast_1d', 'atleast_2d', 'atleast_3d', - 'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'diag', 'diagonal'] + 'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'sum', 'diag', 'diagonal'] @set_module('mxnet.symbol.numpy') @@ -650,7 +650,7 @@ def diag(self, k=0, **kwargs): def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Return the sum of the array elements over the given axis.""" - return _mx_np_op.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) + return _npi.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) def nansum(self, *args, **kwargs): """Convenience fluent method for :py:func:`nansum`. @@ -7295,4 +7295,51 @@ def diagonal(a, offset=0, axis1=0, axis2=1): return _npi.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) +# pylint:disable=redefined-outer-name, too-many-arguments +@set_module('mxnet.symbol.numpy') +def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None): + r""" + Sum of array elements over a given axis. + + Parameters + ---------- + a : _Symbol + Input data. + axis : None or int, optional + Axis or axes along which a sum is performed. The default, + axis=None, will sum all of the elements of the input array. If + axis is negative it counts from the last to the first axis. + dtype : dtype, optional + The type of the returned array and of the accumulator in which the + elements are summed. The default type is float32. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `sum` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-classes `sum` method does not implement `keepdims` any + exceptions will be raised. + initial: Currently only supports None as input, optional + Starting value for the sum. + Currently not implemented. Please use ``None`` as input or skip this argument. + out : ndarray or None, optional + Alternative output array in which to place the result. It must have + the same shape and dtype as the expected output. + + Returns + ------- + sum_along_axis : _Symbol + An ndarray with the same shape as `a`, with the specified + axis removed. If an output array is specified, a reference to + `out` is returned. + """ + if where is not None and where is not True: + raise ValueError("only where=None or where=True cases are supported for now") + return _npi.sum(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, out=out) +# pylint:enable=redefined-outer-name, too-many-arguments + + _set_np_symbol_class(_Symbol) diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 3cea6ddae157..1fbac50b630a 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -324,18 +324,18 @@ def norm(x, ord=None, axis=None, keepdims=False): if row_axis > col_axis: row_axis -= 1 if ord == 'inf': - return _mx_sym_np.sum(_symbol.abs(x), axis=col_axis, keepdims=keepdims).max(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long + return _npi.sum(_symbol.abs(x), axis=col_axis, keepdims=keepdims).max(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long else: - return _mx_sym_np.sum(_symbol.abs(x), axis=col_axis, keepdims=keepdims).min(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long + return _npi.sum(_symbol.abs(x), axis=col_axis, keepdims=keepdims).min(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long if ord in [1, -1]: row_axis, col_axis = axis if not keepdims: if row_axis < col_axis: col_axis -= 1 if ord == 1: - return _mx_sym_np.sum(_symbol.abs(x), axis=row_axis, keepdims=keepdims).max(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long + return _npi.sum(_symbol.abs(x), axis=row_axis, keepdims=keepdims).max(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long elif ord == -1: - return _mx_sym_np.sum(_symbol.abs(x), axis=row_axis, keepdims=keepdims).min(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long + return _npi.sum(_symbol.abs(x), axis=row_axis, keepdims=keepdims).min(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long if ord in [2, -2]: return _npi.norm(x, ord=ord, axis=axis, keepdims=keepdims, flag=0) if ord is None: diff --git a/src/api/_api_internal/_api_internal.cc b/src/api/_api_internal/_api_internal.cc index 586dce82f383..7e1ce045f353 100644 --- a/src/api/_api_internal/_api_internal.cc +++ b/src/api/_api_internal/_api_internal.cc @@ -43,6 +43,16 @@ MXNET_REGISTER_GLOBAL("_Integer") } }); +MXNET_REGISTER_GLOBAL("_Float") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + if (args[0].type_code() == kDLFloat) { + *ret = Float(args[0].operator double()); + } else { + LOG(FATAL) << "only accept float"; + } +}); + MXNET_REGISTER_GLOBAL("_ADT") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; diff --git a/src/api/operator/numpy/np_bincount_op.cc b/src/api/operator/numpy/np_bincount_op.cc index afa3278c24e4..7be884aefb1a 100644 --- a/src/api/operator/numpy/np_bincount_op.cc +++ b/src/api/operator/numpy/np_bincount_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_value.cc b/src/api/operator/numpy/np_broadcast_reduce_op_value.cc index c2d87a285cde..4cd2e485d987 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_value.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,13 +18,14 @@ */ /*! - * \file broadcast_reduce_op_value.cc + * \file np_broadcast_reduce_op_value.cc * \brief Implementation of the API of functions in * src/operator/tensor/np_broadcast_reduce_op_value.cc */ #include #include #include "../utils.h" +#include "../../../operator/tensor/broadcast_reduce_op.h" #include "../../../operator/numpy/np_broadcast_reduce_op.h" namespace mxnet { @@ -51,6 +52,65 @@ MXNET_REGISTER_API("_npi.broadcast_to") *ret = ndoutputs[0]; }); +MXNET_REGISTER_API("_npi.sum") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_sum"); + op::NumpyReduceAxesParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + + // parse axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + if (args[1].type_code() == kDLInt) { + param.axis = Tuple(1, args[1].operator int64_t()); + } else { + param.axis = Tuple(args[1].operator ObjectRef()); + } + } + + // parse dtype + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + + // parse keepdims + if (args[3].type_code() == kNull) { + param.keepdims = false; + } else { + param.keepdims = args[3].operator bool(); + } + + // parse initial + if (args[4].type_code() == kNull) { + param.initial = dmlc::nullopt; + } else { + param.initial = args[4].operator double(); + } + + attrs.parsed = std::move(param); + + SetAttrDict(&attrs); + + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + + NDArray* outputs[] = {args[5].operator NDArray*()}; + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +}); + MXNET_REGISTER_API("_npi.mean") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; @@ -67,6 +127,7 @@ MXNET_REGISTER_API("_npi.mean") } else { param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); } + if (args[3].type_code() == kNull) { param.keepdims = false; } else { diff --git a/src/api/operator/numpy/np_cumsum.cc b/src/api/operator/numpy/np_cumsum.cc index 0ef3b3fdf7bf..d0b200c66fd4 100644 --- a/src/api/operator/numpy/np_cumsum.cc +++ b/src/api/operator/numpy/np_cumsum.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/api/operator/numpy/np_histogram_op.cc b/src/api/operator/numpy/np_histogram_op.cc new file mode 100644 index 000000000000..b517cce80803 --- /dev/null +++ b/src/api/operator/numpy/np_histogram_op.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_histogram_op.cc + * \brief Implementation of the API of functions in src/operator/tensor/histogram.cc + */ + +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/histogram-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.histogram") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_histogram"); + op::HistogramParam param; + // parse bin_cnt + if (args[2].type_code() == kNull) { + param.bin_cnt = dmlc::nullopt; + } else { + param.bin_cnt = args[2].operator int(); + } + + // parse range + if (args[3].type_code() == kNull) { + param.range = dmlc::nullopt; + } else { + param.range = Obj2Tuple(args[3].operator ObjectRef()); + } + + attrs.parsed = std::move(param); + attrs.op = op; + SetAttrDict(&attrs); + + std::vector inputs_vec; + int num_inputs = 0; + + if (args[2].type_code() != kNull) { + CHECK_EQ(args[1].type_code(), kNull) + << "bins should be None when bin_cnt is provided"; + inputs_vec.push_back((args[0].operator NDArray*())); + num_inputs = 1; + } else { + CHECK_NE(args[1].type_code(), kNull) + << "bins should not be None when bin_cnt is not provided"; + // inputs + inputs_vec.push_back((args[0].operator NDArray*())); + inputs_vec.push_back((args[1].operator NDArray*())); + num_inputs = 2; + } + + // outputs + NDArray** out = nullptr; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs_vec.data(), &num_outputs, out); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), + NDArrayHandle(ndoutputs[1])}); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_moments_op.cc b/src/api/operator/numpy/np_moments_op.cc new file mode 100644 index 000000000000..e4e9238bb6c1 --- /dev/null +++ b/src/api/operator/numpy/np_moments_op.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_moments_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_moments_op.cc + */ + +#include +#include +#include "../utils.h" +#include "../../../operator/numpy/np_broadcast_reduce_op.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.std") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_std"); + op::NumpyMomentsParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + + // parse axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + if (args[1].type_code() == kDLInt) { + param.axis = Tuple(1, args[1].operator int64_t()); + } else { + param.axis = Tuple(args[1].operator ObjectRef()); + } + } + + // parse dtype + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + + // parse ddof + param.ddof = args[3].operator int(); + + // parse keepdims + if (args[4].type_code() == kNull) { + param.keepdims = false; + } else { + param.keepdims = args[4].operator bool(); + } + + attrs.parsed = std::move(param); + + SetAttrDict(&attrs); + + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + + NDArray* outputs[] = {args[5].operator NDArray*()}; + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +}); + +MXNET_REGISTER_API("_npi.var") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_var"); + op::NumpyMomentsParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + + // parse axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + if (args[1].type_code() == kDLInt) { + param.axis = Tuple(1, args[1].operator int64_t()); + } else { + param.axis = Tuple(args[1].operator ObjectRef()); + } + } + + // parse dtype + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + + // parse ddof + param.ddof = args[3].operator int(); + + // parse keepdims + if (args[4].type_code() == kNull) { + param.keepdims = false; + } else { + param.keepdims = args[4].operator bool(); + } + + attrs.parsed = std::move(param); + + SetAttrDict(&attrs); + + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + + NDArray* outputs[] = {args[5].operator NDArray*()}; + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +}); + +MXNET_REGISTER_API("_npi.average") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_average"); + op::NumpyWeightedAverageParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + + // parse axis + if (args[2].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + if (args[2].type_code() == kDLInt) { + param.axis = Tuple(1, args[2].operator int64_t()); + } else { + param.axis = Tuple(args[2].operator ObjectRef()); + } + } + + // parse returned + CHECK_NE(args[3].type_code(), kNull) + << "returned cannot be None"; + param.returned = args[3].operator bool(); + + // parse weighted + CHECK_NE(args[4].type_code(), kNull) + << "weighted cannot be None"; + param.weighted = args[4].operator bool(); + + attrs.parsed = std::move(param); + + SetAttrDict(&attrs); + + int num_inputs = param.weighted ? 2 : 1; + NDArray* outputs[] = {args[5].operator NDArray*()}; + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + + if (param.weighted) { + NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + if (out) { + *ret = PythonArg(5); + } else { + if (param.returned) { + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), + NDArrayHandle(ndoutputs[1])}); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + } + } else { + NDArray* inputs[] = {args[0].operator NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + if (out) { + *ret = PythonArg(5); + } else { + if (param.returned) { + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), + NDArrayHandle(ndoutputs[1])}); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + } + } +}); + +}; // namespace mxnet diff --git a/src/api/operator/numpy/np_tensordot_op.cc b/src/api/operator/numpy/np_tensordot_op.cc index eef58b5b3389..55c131468b12 100644 --- a/src/api/operator/numpy/np_tensordot_op.cc +++ b/src/api/operator/numpy/np_tensordot_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/api/operator/utils.h b/src/api/operator/utils.h index 53e62ee7635b..8943e8058a19 100644 --- a/src/api/operator/utils.h +++ b/src/api/operator/utils.h @@ -56,6 +56,16 @@ void SetAttrDict(nnvm::NodeAttrs* attrs) { } } +template +Tuple Obj2Tuple(const runtime::ObjectRef& src) { + runtime::ADT adt = Downcast(src); + Tuple ret(adt.size(), 0); + for (size_t i = 0; i < adt.size(); ++i) { + ret[i] = Downcast(adt[i])->value; + } + return ret; +} + } // namespace mxnet #endif // MXNET_API_OPERATOR_UTILS_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 33cee78ebf80..dbef0cd96bec 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -67,6 +67,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional()) .describe("Starting value for the sum."); } + void SetAttrDict(std::unordered_map* dict) { std::ostringstream axis_s, dtype_s, keepdims_s, initial_s; axis_s << axis; @@ -447,6 +448,7 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, } BroadcastComputeImpl(attrs, ctx, inputs, req, outputs, small); + if (normalize) { Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { @@ -498,11 +500,27 @@ struct NumpyMomentsParam : public dmlc::Parameter { "precision than the default platform integer. In that case, if a is signed then " "the platform integer is used while if a is unsigned then an unsigned integer of " "the same precision as the platform integer is used."); - DMLC_DECLARE_FIELD(ddof).set_default(0) - .describe("Starting value for the sum."); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axes are left " "in the result as dimension with size one."); + DMLC_DECLARE_FIELD(ddof).set_default(0) + .describe("Starting value for the sum."); + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, dtype_s, keepdims_s, ddof_s; + axis_s << axis; + keepdims_s << keepdims; + ddof_s << ddof; + (*dict)["axis"] = axis_s.str(); + dtype_s << dtype; + if (dtype.has_value()) { + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value()); + } else { + (*dict)["dtype"] = dtype_s.str(); + } + (*dict)["keepdims"] = keepdims_s.str(); + (*dict)["ddof"] = ddof_s.str(); } }; @@ -558,6 +576,16 @@ struct NumpyWeightedAverageParam : public dmlc::Parameter* dict) { + std::ostringstream axis_s, returned_s, weighted_s; + axis_s << axis; + returned_s << returned; + weighted_s << weighted; + (*dict)["axis"] = axis_s.str(); + (*dict)["returned"] = returned_s.str(); + (*dict)["weighted"] = weighted_s.str(); + } }; inline bool NumpyWeightedAverageShape(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 026e60e8bb25..33418667dfb7 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -51,12 +51,12 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs, if (param.dtype.has_value()) { if (in_attrs->at(0) == mshadow::kBool) { - CHECK(param.dtype.value() == mshadow::kInt32 - || param.dtype.value() == mshadow::kInt64 - || param.dtype.value() == mshadow::kFloat32 - || param.dtype.value() == mshadow::kFloat64) << "Only support the following output " - "dtypes when input dtype is bool: " - "int32, int64, float32, float64."; + CHECK(param.dtype.value() == mshadow::kInt32 || + param.dtype.value() == mshadow::kInt64 || + param.dtype.value() == mshadow::kFloat32 || + param.dtype.value() == mshadow::kFloat64) + << "Only support the following output dtypes when input dtype is bool: " + "int32, int64, float32, float64."; } TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); } else if (in_attrs->at(0) == mshadow::kBool) { @@ -126,7 +126,7 @@ void TVMOpReduce(const OpContext& ctx, #endif // MXNET_USE_TVM_OP } -NNVM_REGISTER_OP(_np_sum) +NNVM_REGISTER_OP(_npi_sum) .describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -145,9 +145,9 @@ NNVM_REGISTER_OP(_np_sum) return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("THasDeterministicOutput", true) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_sum"}); +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_sum"}); -NNVM_REGISTER_OP(_backward_np_sum) +NNVM_REGISTER_OP(_backward_npi_sum) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) @@ -155,8 +155,8 @@ NNVM_REGISTER_OP(_backward_np_sum) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); inline bool NumpyReduceAxesNoDTypeType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index 684348fcaa37..c5111c2954cd 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -26,10 +26,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_np_sum) +NNVM_REGISTER_OP(_npi_sum) .set_attr("FCompute", NumpyReduceAxesCompute); -NNVM_REGISTER_OP(_backward_np_sum) +NNVM_REGISTER_OP(_backward_npi_sum) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); NNVM_REGISTER_OP(_np_max) diff --git a/src/operator/tensor/histogram-inl.h b/src/operator/tensor/histogram-inl.h index 7194445d7b52..29b27c6d659d 100644 --- a/src/operator/tensor/histogram-inl.h +++ b/src/operator/tensor/histogram-inl.h @@ -34,6 +34,8 @@ #include #include #include +#include +#include #include #include "./util/tensor_util-inl.h" #include "../elemwise_op_common.h" @@ -45,22 +47,30 @@ namespace mxnet { namespace op { struct HistogramParam : public dmlc::Parameter { - dmlc::optional bin_cnt; - dmlc::optional> range; - DMLC_DECLARE_PARAMETER(HistogramParam) { - DMLC_DECLARE_FIELD(bin_cnt) - .set_default(dmlc::optional()) - .describe("Number of bins for uniform case"); - DMLC_DECLARE_FIELD(range) - .set_default(dmlc::optional>()) - .describe("The lower and upper range of the bins. if not provided, " - "range is simply (a.min(), a.max()). values outside the " - "range are ignored. the first element of the range must be " - "less than or equal to the second. range affects the automatic " - "bin computation as well. while bin width is computed to be " - "optimal based on the actual data within range, the bin count " - "will fill the entire range including portions containing no data."); - } + dmlc::optional bin_cnt; + dmlc::optional> range; + DMLC_DECLARE_PARAMETER(HistogramParam) { + DMLC_DECLARE_FIELD(bin_cnt) + .set_default(dmlc::optional()) + .describe("Number of bins for uniform case"); + DMLC_DECLARE_FIELD(range) + .set_default(dmlc::optional>()) + .describe("The lower and upper range of the bins. if not provided, " + "range is simply (a.min(), a.max()). values outside the " + "range are ignored. the first element of the range must be " + "less than or equal to the second. range affects the automatic " + "bin computation as well. while bin width is computed to be " + "optimal based on the actual data within range, the bin count " + "will fill the entire range including portions containing no data."); + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream bin_cnt_s, range_s; + bin_cnt_s << bin_cnt; + range_s << range; + (*dict)["bin_cnt"] = bin_cnt_s.str(); + (*dict)["range"] = range_s.str(); + } }; struct FillBinBoundsKernel {