diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index 029b660a8099..2b9d8dfa58fe 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -59,6 +59,22 @@ def prepare_workloads(): OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("linalg.svd", pool['3x3']) OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1) + OpArgMngr.add_workload("subtract", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("multiply", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("mod", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("remainder", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("divide", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("true_divide", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("power", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("lcm", pool['2x2'].astype('int32'), pool['2x2'].astype('int32')) + OpArgMngr.add_workload("diff", pool['2x2'], n=1, axis=-1) + OpArgMngr.add_workload("nonzero", pool['2x2']) + OpArgMngr.add_workload("tril", pool['2x2'], k=0) + OpArgMngr.add_workload("expand_dims", pool['2x2'], axis=0) + OpArgMngr.add_workload("broadcast_to", pool['2x2'], (2, 2, 2)) + OpArgMngr.add_workload("full_like", pool['2x2'], 2) + OpArgMngr.add_workload("zeros_like", pool['2x2']) + OpArgMngr.add_workload("ones_like", pool['2x2']) OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1) OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1']) OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1]) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 54d90bbb362e..e235c54b8711 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -212,9 +212,7 @@ def zeros_like(a, dtype=None, order='C', ctx=None, out=None): """ if order != 'C': raise NotImplementedError - if ctx is None: - ctx = current_context() - return _npi.full_like(a, fill_value=0, dtype=dtype, ctx=ctx, out=out) + return full_like(a, 0, dtype=dtype, order=order, ctx=ctx, out=out) @set_module('mxnet.ndarray.numpy') @@ -270,11 +268,7 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None): >>> np.ones_like(y) array([1., 1., 1.], dtype=float64) """ - if order != 'C': - raise NotImplementedError - if ctx is None: - ctx = current_context() - return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=ctx, out=out) + return full_like(a, 1, dtype=dtype, order=order, ctx=ctx, out=out) @set_module('mxnet.ndarray.numpy') @@ -433,11 +427,15 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin """ if order != 'C': raise NotImplementedError - if ctx is None: - ctx = current_context() if isinstance(fill_value, bool): fill_value = int(fill_value) - return _npi.full_like(a, fill_value=fill_value, dtype=dtype, ctx=ctx, out=out) + if ctx is None: + ctx = str(current_context()) + else: + ctx = str(ctx) + if dtype is not None and not isinstance(dtype, str): + dtype = _np.dtype(dtype).name + return _api_internal.full_like(a, fill_value, dtype, ctx, out) @set_module('mxnet.ndarray.numpy') @@ -1025,8 +1023,9 @@ def subtract(x1, x2, out=None, **kwargs): * If only one of the inputs is floating number type, the result is that type. * If both inputs are of integer types (including boolean), not supported yet. """ - return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar, - _npi.rsubtract_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.subtract(x1, x2, out=out) + return _api_internal.subtract(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1060,7 +1059,9 @@ def multiply(x1, x2, out=None, **kwargs): * If only one of the inputs is floating number type, the result is that type. * If both inputs are of integer types (including boolean), not supported yet. """ - return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.multiply(x1, x2, out=out) + return _api_internal.multiply(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1095,8 +1096,9 @@ def divide(x1, x2, out=None, **kwargs): * If only one of the inputs is floating number type, the result is that type. * If both inputs are of integer types (including boolean), the output is of float32 type. """ - return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, - _npi.rtrue_divide_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.divide(x1, x2, out=out) + return _api_internal.true_divide(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1133,8 +1135,9 @@ def true_divide(x1, x2, out=None): * If only one of the inputs is floating number type, the result is that type. * If both inputs are of integer types (including boolean), the output is of float32 type. """ - return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, - _npi.rtrue_divide_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.true_divide(x1, x2, out=out) + return _api_internal.true_divide(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1161,7 +1164,9 @@ def mod(x1, x2, out=None, **kwargs): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. """ - return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.mod(x1, x2, out=out) + return _api_internal.mod(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1349,7 +1354,9 @@ def remainder(x1, x2, out=None): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. """ - return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + _np.mod(x1, x2, out=out) + return _api_internal.mod(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1377,7 +1384,9 @@ def power(x1, x2, out=None, **kwargs): The bases in x1 raised to the exponents in x2. This is a scalar if both x1 and x2 are scalars. """ - return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.power(x1, x2, out=out) + return _api_internal.power(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -1976,7 +1985,9 @@ def lcm(x1, x2, out=None, **kwargs): >>> np.lcm(np.arange(6, dtype=int), 20) array([ 0, 20, 20, 60, 20, 20], dtype=int64) """ - return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.lcm(x1, x2, out=out) + return _api_internal.lcm(x1, x2, out) @set_module('mxnet.ndarray.numpy') @@ -6658,7 +6669,7 @@ def nonzero(a): >>> (a > 3).nonzero() (array([1, 1, 1, 2, 2, 2], dtype=int64), array([0, 1, 2, 0, 1, 2], dtype=int64)) """ - out = _npi.nonzero(a).transpose() + out = _api_internal.nonzero(a).transpose() return tuple([out[i] for i in range(len(out))]) diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_value.cc b/src/api/operator/numpy/np_broadcast_reduce_op_value.cc index 2322860aa609..224451c70570 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_value.cc @@ -46,7 +46,8 @@ MXNET_REGISTER_API("_npi.broadcast_to") int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/api/operator/numpy/np_diff_op.cc b/src/api/operator/numpy/np_diff_op.cc index dec73b8496a7..7be5b804eade 100644 --- a/src/api/operator/numpy/np_diff_op.cc +++ b/src/api/operator/numpy/np_diff_op.cc @@ -43,7 +43,8 @@ MXNET_REGISTER_API("_npi.diff") int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc index e724a7c58bd3..7a9eb0139439 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc @@ -36,4 +36,56 @@ MXNET_REGISTER_API("_npi.add") UFuncHelper(args, ret, op, op_scalar, nullptr); }); +MXNET_REGISTER_API("_npi.subtract") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_subtract"); + const nnvm::Op* op_scalar = Op::Get("_npi_subtract_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rsubtract_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); +}); + +MXNET_REGISTER_API("_npi.multiply") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_multiply"); + const nnvm::Op* op_scalar = Op::Get("_npi_multiply_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); +}); + +MXNET_REGISTER_API("_npi.true_divide") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_true_divide"); + const nnvm::Op* op_scalar = Op::Get("_npi_true_divide_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rtrue_divide_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); +}); + +MXNET_REGISTER_API("_npi.mod") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_mod"); + const nnvm::Op* op_scalar = Op::Get("_npi_mod_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rmod_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); +}); + +MXNET_REGISTER_API("_npi.power") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_power"); + const nnvm::Op* op_scalar = Op::Get("_npi_power_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rpower_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); +}); + +MXNET_REGISTER_API("_npi.lcm") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_lcm"); + const nnvm::Op* op_scalar = Op::Get("_npi_lcm_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); +}); + } // namespace mxnet diff --git a/src/api/operator/numpy/np_init_op.cc b/src/api/operator/numpy/np_init_op.cc index c65f90c841f4..4f7c6e497616 100644 --- a/src/api/operator/numpy/np_init_op.cc +++ b/src/api/operator/numpy/np_init_op.cc @@ -21,6 +21,7 @@ * \file np_init_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_init_op.cc */ +#include #include #include #include "../utils.h" @@ -55,4 +56,36 @@ MXNET_REGISTER_API("_npi.zeros") *ret = ndoutputs[0]; }); +MXNET_REGISTER_API("_npi.full_like") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_full_like"); + nnvm::NodeAttrs attrs; + op::FullLikeOpParam param; + param.fill_value = args[1].operator double(); + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + attrs.parsed = std::move(param); + attrs.op = op; + if (args[3].type_code() != kNull) { + attrs.dict["ctx"] = args[3].operator std::string(); + } + SetAttrDict(&attrs); + NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(4); + } else { + *ret = ndoutputs[0]; + } + *ret = ndoutputs[0]; +}); + } // namespace mxnet diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index cc268c202c9b..b4bb583c0511 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -42,7 +42,8 @@ MXNET_REGISTER_API("_npi.expand_dims") int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/api/operator/numpy/np_nonzero_op.cc b/src/api/operator/numpy/np_nonzero_op.cc new file mode 100644 index 000000000000..85510633c054 --- /dev/null +++ b/src/api/operator/numpy/np_nonzero_op.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_nonzero_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_nonzero_op.cc + */ +#include +#include +#include "../utils.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.nonzero") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_nonzero"); + nnvm::NodeAttrs attrs; + + attrs.op = op; + + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_tril_op.cc b/src/api/operator/numpy/np_tril_op.cc index 105ff58bf559..1acb1b8e4b10 100644 --- a/src/api/operator/numpy/np_tril_op.cc +++ b/src/api/operator/numpy/np_tril_op.cc @@ -42,7 +42,8 @@ MXNET_REGISTER_API("_npi.tril") int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index fb739c690607..ed0569c799ce 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -105,6 +105,17 @@ struct FullLikeOpParam : public dmlc::Parameter { MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Target data type."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream fill_value_s, dtype_s; + fill_value_s << fill_value; + dtype_s << dtype; + (*dict)["fill_value"] = fill_value_s.str(); + if (dtype.has_value()) { + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value()); + } else { + (*dict)["dtype"] = dtype_s.str(); + } + } }; /*! \brief Infer type of FullLikeOpCompute*/