diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index a129cc1d792f..5060af1a348c 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -5975,15 +5975,19 @@ def kron(a, b): Kronecker product of two arrays. Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first. + Parameters ---------- a, b : ndarray + Returns ------- out : ndarray + See Also -------- outer : The outer product + Notes ----- The function assumes that the number of dimensions of `a` and `b` @@ -5999,6 +6003,7 @@ def kron(a, b): [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + Examples -------- >>> np.kron([1,10,100], [5,6,7]) @@ -6006,7 +6011,7 @@ def kron(a, b): >>> np.kron([5,6,7], [1,10,100]) array([ 5, 50, 500, 6, 60, 600, 7, 70, 700]) """ - return _npi.kron(a, b) + return _api_internal.kron(a, b) @set_module('mxnet.ndarray.numpy') diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 50bd5a9e2bf0..fb15436544a3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -7833,15 +7833,19 @@ def kron(a, b): Kronecker product of two arrays. Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first. + Parameters ---------- a, b : ndarray + Returns ------- out : ndarray + See Also -------- outer : The outer product + Notes ----- The function assumes that the number of dimensions of `a` and `b` @@ -7857,6 +7861,7 @@ def kron(a, b): [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + Examples -------- >>> np.kron([1,10,100], [5,6,7]) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index e3505cc9c629..690270afcbdd 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -5516,15 +5516,19 @@ def kron(a, b): Kronecker product of two arrays. Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first. + Parameters ---------- a, b : ndarray + Returns ------- out : ndarray + See Also -------- outer : The outer product + Notes ----- The function assumes that the number of dimensions of `a` and `b` @@ -5540,6 +5544,7 @@ def kron(a, b): [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + Examples -------- >>> np.kron([1,10,100], [5,6,7]) diff --git a/src/api/operator/numpy/np_kron.cc b/src/api/operator/numpy/np_kron.cc new file mode 100644 index 000000000000..688f185e906f --- /dev/null +++ b/src/api/operator/numpy/np_kron.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_kron.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_kron.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/numpy/np_kron-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.kron") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + attrs.op = op; + const nnvm::Op* op = Op::Get("_npi_kron"); + NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; + int num_inputs = 2; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet