From 938b35b0d39751518e29c0b4ee911f70d02cd3ad Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 4 Mar 2020 22:06:46 -0800 Subject: [PATCH] [Numpy] FFI for cumsum and add (#17747) * FFI cumsum * Dispatch ufunc * Add PythonArg * Remove unused data type * Seperate op_utils and utils --- include/mxnet/runtime/c_runtime_api.h | 13 +-- include/mxnet/runtime/packed_func.h | 36 +----- include/mxnet/runtime/py_arg.h | 42 +++++++ python/mxnet/_ffi/_ctypes/function.py | 7 +- python/mxnet/_ffi/_ctypes/types.py | 16 ++- python/mxnet/_ffi/_cython/base.pxi | 13 +-- python/mxnet/_ffi/_cython/function.pxi | 18 ++- python/mxnet/_numpy_op_doc.py | 51 -------- python/mxnet/ndarray/numpy/_op.py | 58 ++++++++- python/mxnet/numpy/multiarray.py | 57 ++++++++- python/mxnet/symbol/numpy/_symbol.py | 39 ++++++- src/api/operator/numpy/np_cumsum.cc | 67 +++++++++++ .../numpy/np_elemwise_broadcast_op.cc | 39 +++++++ src/api/operator/numpy/np_init_op.cc | 5 +- src/api/operator/numpy/np_tensordot_op.cc | 11 +- src/api/operator/op_utils.cc | 55 +++++++++ src/api/operator/op_utils.h | 35 ++++++ src/api/operator/ufunc_helper.cc | 110 ++++++++++++++++++ src/api/operator/ufunc_helper.h | 36 ++++++ src/api/operator/utils.cc | 22 ++++ src/api/operator/utils.h | 21 +--- src/operator/numpy/np_cumsum-inl.h | 13 +++ src/operator/numpy/np_cumsum.cc | 7 +- src/operator/numpy/np_cumsum.cu | 4 +- 24 files changed, 629 insertions(+), 146 deletions(-) create mode 100644 include/mxnet/runtime/py_arg.h create mode 100644 src/api/operator/numpy/np_cumsum.cc create mode 100644 src/api/operator/numpy/np_elemwise_broadcast_op.cc create mode 100644 src/api/operator/op_utils.cc create mode 100644 src/api/operator/op_utils.h create mode 100644 src/api/operator/ufunc_helper.cc create mode 100644 src/api/operator/ufunc_helper.h diff --git a/include/mxnet/runtime/c_runtime_api.h b/include/mxnet/runtime/c_runtime_api.h index 208a64326ac4..bbc8862d5439 100644 --- a/include/mxnet/runtime/c_runtime_api.h +++ b/include/mxnet/runtime/c_runtime_api.h @@ -47,14 +47,11 @@ typedef enum { kNull = 4U, kMXNetType = 5U, kMXNetContext = 6U, - kArrayHandle = 7U, - kObjectHandle = 8U, - kModuleHandle = 9U, - kFuncHandle = 10U, - kStr = 11U, - kBytes = 12U, - kNDArrayContainer = 13U, - kNDArrayHandle = 14U, + kObjectHandle = 7U, + kStr = 8U, + kBytes = 9U, + kPyArg = 10U, + kNDArrayHandle = 11U, // Extension codes for other frameworks to integrate MXNet PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/mxnet/runtime/packed_func.h b/include/mxnet/runtime/packed_func.h index 16351a7604dc..ac7b462ce471 100644 --- a/include/mxnet/runtime/packed_func.h +++ b/include/mxnet/runtime/packed_func.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -416,7 +417,6 @@ class MXNetPODValue_ { } operator void*() const { if (type_code_ == kNull) return nullptr; - if (type_code_ == kArrayHandle) return value_.v_handle; MXNET_CHECK_TYPE_CODE(type_code_, kHandle); return value_.v_handle; } @@ -520,11 +520,6 @@ class MXNetArgValue : public MXNetPODValue_ { MXNET_CHECK_TYPE_CODE(type_code_, kNDArrayHandle); return reinterpret_cast<::mxnet::NDArray*>(value_.v_handle); } - operator PackedFunc() const { - if (type_code_ == kNull) return PackedFunc(); - MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle); - return *ptr(); - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -597,11 +592,6 @@ class MXNetRetValue : public MXNetPODValue_ { operator MXNetDataType() const { return MXNetDataType(operator DLDataType()); } - operator PackedFunc() const { - if (type_code_ == kNull) return PackedFunc(); - MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle); - return *ptr(); - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -668,10 +658,6 @@ class MXNetRetValue : public MXNetPODValue_ { SwitchToObject(kObjectHandle, std::move(other)); return *this; } - MXNetRetValue& operator=(PackedFunc f) { - this->SwitchToClass(kFuncHandle, f); - return *this; - } template MXNetRetValue& operator=(const TypedPackedFunc& f) { return operator=(f.packed()); @@ -689,6 +675,11 @@ class MXNetRetValue : public MXNetPODValue_ { value_.v_handle = reinterpret_cast(value); return *this; } + MXNetRetValue& operator=(const PythonArg& value) { + this->SwitchToPOD(kPyArg); + value_.v_int64 = value.offset(); + return *this; + } template::code != 0>::type> @@ -717,7 +708,6 @@ class MXNetRetValue : public MXNetPODValue_ { /*! \return The value field, if the data is POD */ const MXNetValue& value() const { CHECK(type_code_ != kObjectHandle && - type_code_ != kFuncHandle && type_code_ != kStr) << "MXNetRetValue.value can only be used for POD data"; return value_; } @@ -741,10 +731,6 @@ class MXNetRetValue : public MXNetPODValue_ { SwitchToClass(kBytes, other); break; } - case kFuncHandle: { - SwitchToClass(kFuncHandle, other); - break; - } case kObjectHandle: { *this = other.operator ObjectRef(); break; @@ -792,7 +778,6 @@ class MXNetRetValue : public MXNetPODValue_ { if (type_code_ == kNull) return; switch (type_code_) { case kStr: delete ptr(); break; - case kFuncHandle: delete ptr(); break; case kObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; @@ -857,7 +842,6 @@ inline const char* TypeCode2Str(int type_code) { case kBytes: return "bytes"; case kHandle: return "handle"; case kNull: return "NULL"; - case kFuncHandle: return "FunctionHandle"; case kObjectHandle: return "ObjectCell"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; @@ -1012,10 +996,6 @@ class MXNetArgsSetter { values_[i].v_handle = value; type_codes_[i] = kHandle; } - void operator()(size_t i, DLTensor* value) const { - values_[i].v_handle = value; - type_codes_[i] = kArrayHandle; - } void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kStr; @@ -1038,10 +1018,6 @@ class MXNetArgsSetter { values_[i].v_handle = const_cast(&value); type_codes_[i] = kBytes; } - void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kFuncHandle; - } template void operator()(size_t i, const TypedPackedFunc& value) const { // NOLINT(*) operator()(i, value.packed()); diff --git a/include/mxnet/runtime/py_arg.h b/include/mxnet/runtime/py_arg.h new file mode 100644 index 000000000000..81d1b30a573e --- /dev/null +++ b/include/mxnet/runtime/py_arg.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file py_arg.h + * \brief Python runtime arguments specifier. + */ +#ifndef MXNET_RUNTIME_PY_ARG_H_ +#define MXNET_RUNTIME_PY_ARG_H_ + +namespace mxnet { +namespace runtime { + +class PythonArg { + public: + explicit PythonArg(int offset): offset_(offset) {} + int offset() const { + return offset_; + } + private: + int offset_; +}; + +} // namespace runtime + +} // namespace mxnet +#endif // MXNET_RUNTIME_PY_ARG_H_ diff --git a/python/mxnet/_ffi/_ctypes/function.py b/python/mxnet/_ffi/_ctypes/function.py index 5b126913b998..0a005dd7b749 100644 --- a/python/mxnet/_ffi/_ctypes/function.py +++ b/python/mxnet/_ffi/_ctypes/function.py @@ -22,6 +22,7 @@ """ import ctypes from numbers import Number, Integral +import numpy as onp from ...base import get_last_ffi_error, _LIB from ..base import c_str @@ -66,6 +67,9 @@ def _make_mxnet_args(args, temp_args): elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg type_codes[i] = TypeCode.HANDLE + elif isinstance(arg, type): + values[i].v_str = c_str(onp.dtype(arg).name) + type_codes[i] = TypeCode.STR else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -110,7 +114,8 @@ def __call__(self, *args): raise get_last_ffi_error() _ = temp_args _ = args - return RETURN_SWITCH[ret_tcode.value](ret_val) + return (RETURN_SWITCH[ret_tcode.value](ret_val) if ret_tcode.value != TypeCode.PYARG + else RETURN_SWITCH[ret_tcode.value](ret_val, args)) _CLASS_OBJECT = None diff --git a/python/mxnet/_ffi/_ctypes/types.py b/python/mxnet/_ffi/_ctypes/types.py index 265408e5ba93..d1b253af2da3 100644 --- a/python/mxnet/_ffi/_ctypes/types.py +++ b/python/mxnet/_ffi/_ctypes/types.py @@ -32,14 +32,11 @@ class TypeCode(object): NULL = 4 MXNET_TYPE = 5 MXNET_CONTEXT = 6 - ARRAY_HANDLE = 7 - OBJECT_HANDLE = 8 - MODULE_HANDLE = 9 - FUNC_HANDLE = 10 - STR = 11 - BYTES = 12 - NDARRAY_CONTAINER = 13 - NDARRAYHANDLE = 14 + OBJECT_HANDLE = 7 + STR = 8 + BYTES = 9 + PYARG = 10 + NDARRAYHANDLE = 11 EXT_BEGIN = 15 @@ -54,5 +51,6 @@ class MXNetValue(ctypes.Union): TypeCode.INT: lambda x: x.v_int64, TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.NULL: lambda x: None, - TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle)) + TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle)), + TypeCode.PYARG: lambda x, args: args[x.v_int64], } diff --git a/python/mxnet/_ffi/_cython/base.pxi b/python/mxnet/_ffi/_cython/base.pxi index 1c393e8e241c..bc2273bacd0d 100644 --- a/python/mxnet/_ffi/_cython/base.pxi +++ b/python/mxnet/_ffi/_cython/base.pxi @@ -32,14 +32,11 @@ cdef enum MXNetTypeCode: kNull = 4 kMXNetType = 5 kMXNetContext = 6 - kArrayHandle = 7 - kObjectHandle = 8 - kModuleHandle = 9 - kFuncHandle = 10 - kStr = 11 - kBytes = 12 - kNDArrayContainer = 13 - kNDArrayHandle = 14 + kObjectHandle = 7 + kStr = 8 + kBytes = 9 + kPyArg = 10 + kNDArrayHandle = 11 kExtBegin = 15 cdef extern from "mxnet/runtime/c_runtime_api.h": diff --git a/python/mxnet/_ffi/_cython/function.pxi b/python/mxnet/_ffi/_cython/function.pxi index 2683868cba03..d4c629a618d5 100644 --- a/python/mxnet/_ffi/_cython/function.pxi +++ b/python/mxnet/_ffi/_cython/function.pxi @@ -18,6 +18,7 @@ """Acknowledgement: This file originates from incubator-tvm""" import ctypes +import numpy as onp import traceback from ...ndarray._internal import NDArrayBase from numbers import Number, Integral @@ -58,14 +59,23 @@ cdef inline int make_arg(object arg, elif isinstance(arg, ctypes.c_void_p): value[0].v_handle = c_handle(arg) tcode[0] = kHandle + elif isinstance(arg, type): + tstr = c_str(onp.dtype(arg).name) + value[0].v_str = tstr + tcode[0] = kStr + temp_args.append(tstr) else: raise TypeError("Don't know how to handle type %s" % type(arg)) return 0 -cdef inline object make_ret(MXNetValue value, int tcode): +cdef inline object make_ret(MXNetValue value, int tcode, tuple args): """convert result to return value.""" - if tcode == kNull: + if tcode == kNDArrayHandle: + return c_make_array(value.v_handle) + elif tcode == kPyArg: + return args[value.v_int64] + elif tcode == kNull: return None elif tcode == kInt: return value.v_int64 @@ -75,8 +85,6 @@ cdef inline object make_ret(MXNetValue value, int tcode): return py_str(value.v_str) elif tcode == kHandle: return ctypes_handle(value.v_handle) - elif tcode == kNDArrayHandle: - return c_make_array(value.v_handle) raise ValueError("Unhandled type code %d" % tcode) @@ -160,4 +168,4 @@ cdef class FunctionBase: cdef MXNetValue ret_val cdef int ret_tcode FuncCall(self.chandle, args, &ret_val, &ret_tcode) - return make_ret(ret_val, ret_tcode) + return make_ret(ret_val, ret_tcode, args) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 271bb1827b97..279501d385f8 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -134,57 +134,6 @@ def _np_sometrue(a, axis=None, keepdims=False, out=None): pass -def _np_cumsum(a, axis=None, dtype=None, out=None): - """ - Return the cumulative sum of the elements along a given axis. - - Parameters - ---------- - a : array_like - Input array. - axis : int, optional - Axis along which the cumulative sum is computed. The default - (None) is to compute the cumsum over the flattened array. - dtype : dtype, optional - Type of the returned array and of the accumulator in which the - elements are summed. If `dtype` is not specified, it defaults - to the dtype of `a`, unless `a` has an integer dtype with a - precision less than that of the default platform integer. In - that case, the default platform integer is used. - out : ndarray, optional - Alternative output array in which to place the result. It must - have the same shape and buffer length as the expected output - but the type will be cast if necessary. See `doc.ufuncs` - (Section "Output arguments") for more details. - - Returns - ------- - cumsum_along_axis : ndarray. - A new array holding the result is returned unless `out` is - specified, in which case a reference to `out` is returned. The - result has the same size as `a`, and the same shape as `a` if - `axis` is not None or `a` is a 1-d array. - - Examples - -------- - >>> a = np.array([[1,2,3], [4,5,6]]) - >>> a - array([[1, 2, 3], - [4, 5, 6]]) - >>> np.cumsum(a) - array([ 1, 3, 6, 10, 15, 21]) - >>> np.cumsum(a, dtype=float) # specifies type of output value(s) - array([ 1., 3., 6., 10., 15., 21.]) - >>> np.cumsum(a,axis=0) # sum over rows for each of the 3 columns - array([[1, 2, 3], - [5, 7, 9]]) - >>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows - array([[ 1, 3, 6], - [ 4, 9, 15]]) - """ - pass - - def _npx_nonzero(a): """ Return the indices of the elements that are non-zero. diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 82b57fb8cc1f..2d73699afff0 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -46,7 +46,7 @@ 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', - 'where', 'bincount', 'pad'] + 'where', 'bincount', 'pad', 'cumsum'] @set_module('mxnet.ndarray.numpy') @@ -989,7 +989,9 @@ def add(x1, x2, out=None, **kwargs): * If only one of the inputs is floating number type, the result is that type. * If both inputs are of integer types (including boolean), not supported yet. """ - return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + _np.add(x1, x2, out=out) + return _api_internal.add(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -7637,3 +7639,55 @@ def pad(x, pad_width, mode='constant', **kwargs): # pylint: disable=too-many-arg raise ValueError("unsupported stat_length '{}'".format(values)) return _npi.pad(x, pad_width, mode='minimum') return _npi.pad(x, pad_width, mode='constant', constant_value=0) + + +@set_module('mxnet.ndarray.numpy') +def cumsum(a, axis=None, dtype=None, out=None): + """ + Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + a : array_like + Input array. + axis : int, optional + Axis along which the cumulative sum is computed. The default + (None) is to compute the cumsum over the flattened array. + dtype : dtype, optional + Type of the returned array and of the accumulator in which the + elements are summed. If `dtype` is not specified, it defaults + to the dtype of `a`, unless `a` has an integer dtype with a + precision less than that of the default platform integer. In + that case, the default platform integer is used. + out : ndarray, optional + Alternative output array in which to place the result. It must + have the same shape and buffer length as the expected output + but the type will be cast if necessary. See `doc.ufuncs` + (Section "Output arguments") for more details. + + Returns + ------- + cumsum_along_axis : ndarray. + A new array holding the result is returned unless `out` is + specified, in which case a reference to `out` is returned. The + result has the same size as `a`, and the same shape as `a` if + `axis` is not None or `a` is a 1-d array. + + Examples + -------- + >>> a = np.array([[1,2,3], [4,5,6]]) + >>> a + array([[1, 2, 3], + [4, 5, 6]]) + >>> np.cumsum(a) + array([ 1, 3, 6, 10, 15, 21]) + >>> np.cumsum(a, dtype=float) # specifies type of output value(s) + array([ 1., 3., 6., 10., 15., 21.]) + >>> np.cumsum(a,axis=0) # sum over rows for each of the 3 columns + array([[1, 2, 3], + [5, 7, 9]]) + >>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows + array([[ 1, 3, 6], + [ 4, 9, 15]]) + """ + return _api_internal.cumsum(a, axis, dtype, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b6a68dbcbdd8..eefef48f55ce 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -66,7 +66,8 @@ 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul', - 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', 'pad'] + 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', + 'pad', 'cumsum'] __all__ += fallback.__all__ @@ -1808,7 +1809,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): def cumsum(self, axis=None, dtype=None, out=None): """Return the cumulative sum of the elements along the given axis.""" - return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out) + return _mx_nd_np.cumsum(self, axis=axis, dtype=dtype, out=out) def tolist(self): return self.asnumpy().tolist() @@ -9698,3 +9699,55 @@ def pad(x, pad_width=None, mode="constant", **kwargs): # pylint: disable=too-man [10, 10, 10, 10, 10, 10, 10]]) """ return _mx_nd_np.pad(x, pad_width, mode, **kwargs) + + +@set_module('mxnet.numpy') +def cumsum(a, axis=None, dtype=None, out=None): + """ + Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + a : array_like + Input array. + axis : int, optional + Axis along which the cumulative sum is computed. The default + (None) is to compute the cumsum over the flattened array. + dtype : dtype, optional + Type of the returned array and of the accumulator in which the + elements are summed. If `dtype` is not specified, it defaults + to the dtype of `a`, unless `a` has an integer dtype with a + precision less than that of the default platform integer. In + that case, the default platform integer is used. + out : ndarray, optional + Alternative output array in which to place the result. It must + have the same shape and buffer length as the expected output + but the type will be cast if necessary. See `doc.ufuncs` + (Section "Output arguments") for more details. + + Returns + ------- + cumsum_along_axis : ndarray. + A new array holding the result is returned unless `out` is + specified, in which case a reference to `out` is returned. The + result has the same size as `a`, and the same shape as `a` if + `axis` is not None or `a` is a 1-d array. + + Examples + -------- + >>> a = np.array([[1,2,3], [4,5,6]]) + >>> a + array([[1, 2, 3], + [4, 5, 6]]) + >>> np.cumsum(a) + array([ 1, 3, 6, 10, 15, 21]) + >>> np.cumsum(a, dtype=float) # specifies type of output value(s) + array([ 1., 3., 6., 10., 15., 21.]) + >>> np.cumsum(a,axis=0) # sum over rows for each of the 3 columns + array([[1, 2, 3], + [5, 7, 9]]) + >>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows + array([[ 1, 3, 6], + [ 4, 9, 15]]) + """ + return _mx_nd_np.cumsum(a, axis=axis, dtype=dtype, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 8756b6a78ac9..bf3e50ba1388 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -51,7 +51,7 @@ 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', - 'where', 'bincount', 'pad'] + 'where', 'bincount', 'pad', 'cumsum'] @set_module('mxnet.symbol.numpy') @@ -683,7 +683,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylin def cumsum(self, axis=None, dtype=None, out=None): """Return the cumulative sum of the elements along the given axis.""" - return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out) + return _npi.cumsum(self, axis=axis, dtype=dtype, out=out) def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Return the maximum along a given axis.""" @@ -6732,4 +6732,39 @@ def pad(x, pad_width, mode='constant', **kwargs): # pylint: disable=too-many-arg return _npi.pad(x, pad_width, mode='constant', constant_value=0) +@set_module('mxnet.symbol.numpy') +def cumsum(a, axis=None, dtype=None, out=None): + """ + Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + a : _Symbol + Input array. + axis : int, optional + Axis along which the cumulative sum is computed. The default + (None) is to compute the cumsum over the flattened array. + dtype : dtype, optional + Type of the returned array and of the accumulator in which the + elements are summed. If `dtype` is not specified, it defaults + to the dtype of `a`, unless `a` has an integer dtype with a + precision less than that of the default platform integer. In + that case, the default platform integer is used. + out : _Symbol, optional + Alternative output array in which to place the result. It must + have the same shape and buffer length as the expected output + but the type will be cast if necessary. See `doc.ufuncs` + (Section "Output arguments") for more details. + + Returns + ------- + cumsum_along_axis : _Symbol. + A new array holding the result is returned unless `out` is + specified, in which case a reference to `out` is returned. The + result has the same size as `a`, and the same shape as `a` if + `axis` is not None or `a` is a 1-d array. + """ + return _npi.cumsum(a, axis=axis, dtype=dtype, out=out) + + _set_np_symbol_class(_Symbol) diff --git a/src/api/operator/numpy/np_cumsum.cc b/src/api/operator/numpy/np_cumsum.cc new file mode 100644 index 000000000000..0ef3b3fdf7bf --- /dev/null +++ b/src/api/operator/numpy/np_cumsum.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_cumsum.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_cumsum.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/numpy/np_cumsum-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.cumsum") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_cumsum"); + op::CumsumParam param; + // axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + // dtype + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + attrs.parsed = std::move(param); + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + // outputs + NDArray* outputs[] = {args[3].operator NDArray*()}; + NDArray** out = outputs[0] == nullptr ? nullptr : outputs; + int num_outputs = outputs[0] != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc new file mode 100644 index 000000000000..e724a7c58bd3 --- /dev/null +++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_elemwise_broadcast_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_op.cc + */ +#include +#include +#include "../utils.h" +#include "../ufunc_helper.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.add") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_add"); + const nnvm::Op* op_scalar = Op::Get("_npi_add_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_init_op.cc b/src/api/operator/numpy/np_init_op.cc index 746985c6e9f3..c65f90c841f4 100644 --- a/src/api/operator/numpy/np_init_op.cc +++ b/src/api/operator/numpy/np_init_op.cc @@ -21,6 +21,8 @@ * \file np_init_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_init_op.cc */ +#include +#include #include "../utils.h" #include "../../../operator/tensor/init_op.h" @@ -44,11 +46,12 @@ MXNET_REGISTER_API("_npi.zeros") } attrs.parsed = std::move(param); attrs.op = op; + SetAttrDict(&attrs); if (args[2].type_code() != kNull) { attrs.dict["ctx"] = args[2].operator std::string(); } int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/api/operator/numpy/np_tensordot_op.cc b/src/api/operator/numpy/np_tensordot_op.cc index ade2a0314d01..b163757f85b1 100644 --- a/src/api/operator/numpy/np_tensordot_op.cc +++ b/src/api/operator/numpy/np_tensordot_op.cc @@ -21,6 +21,7 @@ * \file np_tensordot_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_tensordot_op.cc */ +#include #include "../utils.h" #include "../../../operator/numpy/np_tensordot_op-inl.h" @@ -32,13 +33,14 @@ inline static void _npi_tensordot_int_axes(runtime::MXNetArgs args, const nnvm::Op* op = Op::Get("_npi_tensordot_int_axes"); op::TensordotIntAxesParam param; nnvm::NodeAttrs attrs; - attrs.op = op; param.axes = args[2].operator int(); + attrs.op = op; // we directly copy TensordotIntAxesParam, which is trivially-copyable attrs.parsed = param; + SetAttrDict(&attrs); int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } @@ -48,7 +50,6 @@ inline static void _npi_tensordot(runtime::MXNetArgs args, const nnvm::Op* op = Op::Get("_npi_tensordot"); op::TensordotParam param; nnvm::NodeAttrs attrs; - attrs.op = op; ADT adt = Downcast(args[2].operator ObjectRef()); if (const IntegerObj* lop = adt[0].as()) { param.a_axes_summed = Tuple(1, lop->value); @@ -57,10 +58,12 @@ inline static void _npi_tensordot(runtime::MXNetArgs args, param.a_axes_summed = Tuple(adt[0]); param.b_axes_summed = Tuple(adt[1]); } + attrs.op = op; attrs.parsed = std::move(param); + SetAttrDict(&attrs); int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } diff --git a/src/api/operator/op_utils.cc b/src/api/operator/op_utils.cc new file mode 100644 index 000000000000..220a880336db --- /dev/null +++ b/src/api/operator/op_utils.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_utils.cc + * \brief Utility functions for modification in src/operator + */ + +#include "op_utils.h" +#include + +namespace mxnet { + +std::string String2MXNetTypeWithBool(int dtype) { + switch (dtype) { + case mshadow::kFloat32: + return "float32"; + case mshadow::kFloat64: + return "float64"; + case mshadow::kFloat16: + return "float16"; + case mshadow::kUint8: + return "uint8"; + case mshadow::kInt8: + return "int8"; + case mshadow::kInt32: + return "int32"; + case mshadow::kInt64: + return "int64"; + case mshadow::kBool: + return "bool"; + default: + LOG(FATAL) << "Unknown type enum " << dtype; + } + LOG(FATAL) << "should not reach here "; + return ""; +} + +} // namespace mxnet diff --git a/src/api/operator/op_utils.h b/src/api/operator/op_utils.h new file mode 100644 index 000000000000..4c577983c405 --- /dev/null +++ b/src/api/operator/op_utils.h @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_utils.h + * \brief Utility functions for modification in src/operator + */ +#ifndef MXNET_API_OPERATOR_OP_UTILS_H_ +#define MXNET_API_OPERATOR_OP_UTILS_H_ + +#include + +namespace mxnet { + +std::string String2MXNetTypeWithBool(int dtype); + +} // namespace mxnet + +#endif // MXNET_API_OPERATOR_OP_UTILS_H_ diff --git a/src/api/operator/ufunc_helper.cc b/src/api/operator/ufunc_helper.cc new file mode 100644 index 000000000000..67bc68031417 --- /dev/null +++ b/src/api/operator/ufunc_helper.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ufunc_helper.cc + * \brief ufunc helper + */ +#include "ufunc_helper.h" +#include "utils.h" + +namespace mxnet { + +template<> +void SetAttrDict(nnvm::NodeAttrs* attrs) { + if (Imperative::Get()->is_recording()) { + attrs->dict["scalar"] = std::to_string(::dmlc::get(attrs->parsed)); + } +} + +void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + attrs.op = op; + NDArray* inputs[] = {lhs, rhs}; + int num_inputs = 2; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + +void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + attrs.op = op; + attrs.parsed = rhs; + SetAttrDict(&attrs); + NDArray** inputs = &lhs; + int num_inputs = 1; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + +void UFuncHelper(double lhs, NDArray* rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + attrs.op = op; + attrs.parsed = lhs; + SetAttrDict(&attrs); + NDArray** inputs = &rhs; + int num_inputs = 1; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + +void UFuncHelper(runtime::MXNetArgs args, + runtime::MXNetRetValue* ret, + const nnvm::Op* fn_array, + const nnvm::Op* lfn_scalar, + const nnvm::Op* rfn_scalar) { + using namespace runtime; + NDArray* out = args[2].operator NDArray*(); + if (args[0].type_code() == kNDArrayHandle) { + if (args[1].type_code() == kNDArrayHandle) { + UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array); + } else { + UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar); + } + } else { + UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret, + rfn_scalar ? rfn_scalar : lfn_scalar); + } +} + +} // namespace mxnet diff --git a/src/api/operator/ufunc_helper.h b/src/api/operator/ufunc_helper.h new file mode 100644 index 000000000000..793d0b22ed9f --- /dev/null +++ b/src/api/operator/ufunc_helper.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ufunc_helper.h + * \brief ufunc helper + */ +#ifndef MXNET_API_OPERATOR_UFUNC_HELPER_H_ +#define MXNET_API_OPERATOR_UFUNC_HELPER_H_ +#include +namespace mxnet { + +void UFuncHelper(runtime::MXNetArgs args, + runtime::MXNetRetValue* ret, + const nnvm::Op* fn_array, + const nnvm::Op* lfn_scalar, + const nnvm::Op* rfn_scalar); +} // namespace mxnet + +#endif // MXNET_API_OPERATOR_UFUNC_HELPER_H_ diff --git a/src/api/operator/utils.cc b/src/api/operator/utils.cc index d8cd4c922603..3d8401270a40 100644 --- a/src/api/operator/utils.cc +++ b/src/api/operator/utils.cc @@ -66,4 +66,26 @@ void SetInOut(std::vector* ndinputs, } } +std::vector Invoke(const nnvm::Op* op, + nnvm::NodeAttrs* attrs, + int num_inputs, + NDArray** inputs, + int* num_outputs, + NDArray** outputs) { + int infered_num_outputs; + int num_visible_outputs; + imperative::SetNumOutputs(op, *attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); + + std::vector ndinputs, ndoutputs; + SetInOut(&ndinputs, &ndoutputs, num_inputs, inputs, + num_outputs, infered_num_outputs, num_visible_outputs, outputs); + + auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs); + if (Imperative::Get()->is_recording()) { + Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state); + } + for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i]; + return ndoutputs; +} + } // namespace mxnet diff --git a/src/api/operator/utils.h b/src/api/operator/utils.h index 7a31e4537780..49ee6bf2c9af 100644 --- a/src/api/operator/utils.h +++ b/src/api/operator/utils.h @@ -24,13 +24,10 @@ #ifndef MXNET_API_OPERATOR_UTILS_H_ #define MXNET_API_OPERATOR_UTILS_H_ -#include -#include -#include -#include #include #include #include +#include #include "../../imperative/imperative_utils.h" namespace mxnet { @@ -44,28 +41,18 @@ void SetInOut(std::vector* ndinputs, int num_visible_outputs, NDArray** out_array); -template std::vector Invoke(const nnvm::Op* op, nnvm::NodeAttrs* attrs, int num_inputs, NDArray** inputs, int* num_outputs, - NDArray** outputs) { - int infered_num_outputs; - int num_visible_outputs; - imperative::SetNumOutputs(op, *attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); - - std::vector ndinputs, ndoutputs; - SetInOut(&ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outputs); + NDArray** outputs); - auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs); +template +void SetAttrDict(nnvm::NodeAttrs* attrs) { if (Imperative::Get()->is_recording()) { ::dmlc::get(attrs->parsed).SetAttrDict(&(attrs->dict)); - Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state); } - for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i]; - return ndoutputs; } } // namespace mxnet diff --git a/src/operator/numpy/np_cumsum-inl.h b/src/operator/numpy/np_cumsum-inl.h index 65e658115dc4..b6e0eab5a8f5 100644 --- a/src/operator/numpy/np_cumsum-inl.h +++ b/src/operator/numpy/np_cumsum-inl.h @@ -28,9 +28,11 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../operator_common.h" #include "../elemwise_op_common.h" +#include "../../api/operator/op_utils.h" namespace mxnet { namespace op { @@ -56,6 +58,17 @@ struct CumsumParam : public dmlc::Parameter { " unless a has an integer dtype with a precision less than that of the" " default platform integer. In that case, the default platform integer is used."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, dtype_s; + axis_s << axis; + dtype_s << dtype; + (*dict)["axis"] = axis_s.str(); + if (dtype.has_value()) { + (*dict)["dtype"] = String2MXNetTypeWithBool(dtype.value()); + } else { + (*dict)["dtype"] = dtype_s.str(); + } + } }; struct cumsum_forward { diff --git a/src/operator/numpy/np_cumsum.cc b/src/operator/numpy/np_cumsum.cc index 2d5dbb99f90a..ea0f9b6b11bc 100644 --- a/src/operator/numpy/np_cumsum.cc +++ b/src/operator/numpy/np_cumsum.cc @@ -65,8 +65,7 @@ inline bool CumsumType(const nnvm::NodeAttrs& attrs, DMLC_REGISTER_PARAMETER(CumsumParam); -NNVM_REGISTER_OP(_np_cumsum) -.add_alias("cumsum") +NNVM_REGISTER_OP(_npi_cumsum) .describe(R"code(Return the cumulative sum of the elements along a given axis.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) @@ -78,7 +77,7 @@ NNVM_REGISTER_OP(_np_cumsum) .set_attr("FInferShape", CumsumShape) .set_attr("FInferType", CumsumType) .set_attr("FCompute", CumsumForward) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_cumsum"}) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_cumsum"}) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; @@ -86,7 +85,7 @@ NNVM_REGISTER_OP(_np_cumsum) .add_argument("a", "NDArray-or-Symbol", "Input ndarray") .add_arguments(CumsumParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_np_cumsum) +NNVM_REGISTER_OP(_backward_npi_cumsum) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_cumsum.cu b/src/operator/numpy/np_cumsum.cu index cc574ebf72c5..438bab2b5efe 100644 --- a/src/operator/numpy/np_cumsum.cu +++ b/src/operator/numpy/np_cumsum.cu @@ -27,10 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_np_cumsum) +NNVM_REGISTER_OP(_npi_cumsum) .set_attr("FCompute", CumsumForward); -NNVM_REGISTER_OP(_backward_np_cumsum) +NNVM_REGISTER_OP(_backward_npi_cumsum) .set_attr("FCompute", CumsumBackward); } // namespace op