From a73399873bd07357d46678b476727d68e661b0b2 Mon Sep 17 00:00:00 2001 From: hanke580 <38852697+hanke580@users.noreply.github.com> Date: Tue, 24 Mar 2020 03:41:58 +0800 Subject: [PATCH] [Numpy] Kron operator (#17323) * [Numpy]Add kron * Implement the forward of Kron op * Implement the Backward of a * Implement the Backward of b * Fix 3rd party * Fix cpp sanity * Finish grad check * address comments: fix test_np_op and reduce req to req[0] * * Fix ndim = 0 * * Fix uninitialize bugs * * Impl FFI --- benchmark/python/ffi/benchmark_ffi.py | 1 + python/mxnet/ndarray/numpy/_op.py | 47 ++- python/mxnet/numpy/multiarray.py | 48 ++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 48 ++- src/api/operator/numpy/np_kron.cc | 44 +++ src/operator/numpy/np_kron-inl.h | 322 ++++++++++++++++++ src/operator/numpy/np_kron.cc | 94 +++++ src/operator/numpy/np_kron.cu | 37 ++ .../unittest/test_numpy_interoperability.py | 8 + tests/python/unittest/test_numpy_op.py | 81 +++++ 11 files changed, 728 insertions(+), 3 deletions(-) create mode 100644 src/api/operator/numpy/np_kron.cc create mode 100644 src/operator/numpy/np_kron-inl.h create mode 100644 src/operator/numpy/np_kron.cc create mode 100644 src/operator/numpy/np_kron.cu diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index 98addb02ffda..4a4c4107d481 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -55,6 +55,7 @@ def prepare_workloads(): OpArgMngr.add_workload("ediff1d", pool['2x2'], pool['2x2'], pool['2x2']) OpArgMngr.add_workload("nan_to_num", pool['2x2']) OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1))) + OpArgMngr.add_workload("kron", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2']) OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("linalg.svd", pool['3x3']) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 84aea04b10c3..4bcc4a55cc17 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -43,7 +43,7 @@ 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', - 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', @@ -6146,6 +6146,51 @@ def outer(a, b): return tensordot(a.flatten(), b.flatten(), 0) +@set_module('mxnet.ndarray.numpy') +def kron(a, b): + r""" + Kronecker product of two arrays. + Computes the Kronecker product, a composite array made of blocks of the + second array scaled by the first. + + Parameters + ---------- + a, b : ndarray + + Returns + ------- + out : ndarray + + See Also + -------- + outer : The outer product + + Notes + ----- + The function assumes that the number of dimensions of `a` and `b` + are the same, if necessary prepending the smallest with ones. + If `a.shape = (r0,r1,..,rN)` and `b.shape = (s0,s1,...,sN)`, + the Kronecker product has shape `(r0*s0, r1*s1, ..., rN*SN)`. + The elements are products of elements from `a` and `b`, organized + explicitly by:: + kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN] + where:: + kt = it * st + jt, t = 0,...,N + In the common 2-D case (N=1), the block structure can be visualized:: + [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], + [ ... ... ], + [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + + Examples + -------- + >>> np.kron([1,10,100], [5,6,7]) + array([ 5, 6, 7, 50, 60, 70, 500, 600, 700]) + >>> np.kron([5,6,7], [1,10,100]) + array([ 5, 50, 500, 6, 60, 600, 7, 70, 700]) + """ + return _api_internal.kron(a, b) + + @set_module('mxnet.ndarray.numpy') def vdot(a, b): r""" diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index ee9df300756d..61f0705824a3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -68,7 +68,8 @@ 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', + 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron', + 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', @@ -8035,6 +8036,51 @@ def outer(a, b): return tensordot(a.flatten(), b.flatten(), 0) +@set_module('mxnet.numpy') +def kron(a, b): + r""" + Kronecker product of two arrays. + Computes the Kronecker product, a composite array made of blocks of the + second array scaled by the first. + + Parameters + ---------- + a, b : ndarray + + Returns + ------- + out : ndarray + + See Also + -------- + outer : The outer product + + Notes + ----- + The function assumes that the number of dimensions of `a` and `b` + are the same, if necessary prepending the smallest with ones. + If `a.shape = (r0,r1,..,rN)` and `b.shape = (s0,s1,...,sN)`, + the Kronecker product has shape `(r0*s0, r1*s1, ..., rN*SN)`. + The elements are products of elements from `a` and `b`, organized + explicitly by:: + kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN] + where:: + kt = it * st + jt, t = 0,...,N + In the common 2-D case (N=1), the block structure can be visualized:: + [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], + [ ... ... ], + [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + + Examples + -------- + >>> np.kron([1,10,100], [5,6,7]) + array([ 5, 6, 7, 50, 60, 70, 500, 600, 700]) + >>> np.kron([5,6,7], [1,10,100]) + array([ 5, 50, 500, 6, 60, 600, 7, 70, 700]) + """ + return _mx_nd_np.kron(a, b) + + @set_module('mxnet.numpy') def vdot(a, b): r""" diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index a4b251b55607..110f2273a852 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -168,6 +168,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'tril', 'meshgrid', 'outer', + 'kron', 'einsum', 'polyval', 'shares_memory', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index d29768b49d1b..897f856ae84f 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -49,7 +49,7 @@ 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', - 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', @@ -5626,6 +5626,52 @@ def outer(a, b): return tensordot(a.flatten(), b.flatten(), 0) +@set_module('mxnet.symbol.numpy') +def kron(a, b): + r""" + kron(a, b) + Kronecker product of two arrays. + Computes the Kronecker product, a composite array made of blocks of the + second array scaled by the first. + + Parameters + ---------- + a, b : ndarray + + Returns + ------- + out : ndarray + + See Also + -------- + outer : The outer product + + Notes + ----- + The function assumes that the number of dimensions of `a` and `b` + are the same, if necessary prepending the smallest with ones. + If `a.shape = (r0,r1,..,rN)` and `b.shape = (s0,s1,...,sN)`, + the Kronecker product has shape `(r0*s0, r1*s1, ..., rN*SN)`. + The elements are products of elements from `a` and `b`, organized + explicitly by:: + kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN] + where:: + kt = it * st + jt, t = 0,...,N + In the common 2-D case (N=1), the block structure can be visualized:: + [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], + [ ... ... ], + [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + + Examples + -------- + >>> np.kron([1,10,100], [5,6,7]) + array([ 5, 6, 7, 50, 60, 70, 500, 600, 700]) + >>> np.kron([5,6,7], [1,10,100]) + array([ 5, 50, 500, 6, 60, 600, 7, 70, 700]) + """ + return _npi.kron(a, b) + + @set_module('mxnet.symbol.numpy') def vdot(a, b): r""" diff --git a/src/api/operator/numpy/np_kron.cc b/src/api/operator/numpy/np_kron.cc new file mode 100644 index 000000000000..753798208b4f --- /dev/null +++ b/src/api/operator/numpy/np_kron.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_kron.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_kron.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/numpy/np_kron-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.kron") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_kron"); + attrs.op = op; + NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; + int num_inputs = 2; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h new file mode 100644 index 000000000000..0d72921691a9 --- /dev/null +++ b/src/operator/numpy/np_kron-inl.h @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_kron-inl.h + * \brief Function definition of matrix numpy-compatible kron operator + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_KRON_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_KRON_INL_H_ + +#include +#include "np_tensordot_op-inl.h" +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + +template +struct kron { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, + const DType* a, const DType* b, + mshadow::Shape ashape, + mshadow::Shape bshape, + mshadow::Shape oshape) { + using namespace mxnet_op; + + auto k = unravel(i, oshape); + Shape ia; + Shape jb; + for (int q = 0; q < ndim; q++) { + ia[q] = static_cast(k[q] / bshape[q]); + jb[q] = k[q] % bshape[q]; + } + auto idx_a = ravel(ia, ashape); + auto idx_b = ravel(jb, bshape); + + KERNEL_ASSIGN(out[i], req, a[idx_a] * b[idx_b]); + } +}; + +template +struct kron_back_a { + template + MSHADOW_XINLINE static void Map(index_t i, DType* agrad, + const DType* b, const DType* ograd, + mshadow::Shape ashape, + mshadow::Shape bshape, + mshadow::Shape oshape) { + using namespace mxnet_op; + + auto ia = unravel(i, ashape); + Shape k; + DType temp_agrad = 0; + + for (int idx_b = 0; idx_b < bshape.Size(); idx_b++) { + auto jb = unravel(idx_b, bshape); + for (int q = 0; q < ndim; q++) { + k[q] = ia[q]*bshape[q] + jb[q]; + } + auto idx_o = ravel(k, oshape); + temp_agrad += b[idx_b]*ograd[idx_o]; + } + KERNEL_ASSIGN(agrad[i], req, temp_agrad); + } +}; + +template +struct kron_back_b { + template + MSHADOW_XINLINE static void Map(index_t i, const DType* a, + DType* bgrad, const DType* ograd, + mshadow::Shape ashape, + mshadow::Shape bshape, + mshadow::Shape oshape) { + using namespace mxnet_op; + + auto jb = unravel(i, bshape); + Shape k; + DType temp_bgrad = 0; + + for (int idx_a = 0; idx_a < ashape.Size(); idx_a++) { + auto ia = unravel(idx_a, ashape); + for (int q = 0; q < ndim; q++) { + k[q] = ia[q] * bshape[q] + jb[q]; + } + auto idx_o = ravel(k, oshape); + temp_bgrad += a[idx_a]*ograd[idx_o]; + } + KERNEL_ASSIGN(bgrad[i], req, temp_bgrad); + } +}; + +template +void KronOpForwardImpl(const OpContext& ctx, + OpReqType req, + const TBlob& a, + const TBlob& b, + const TBlob& out + ) { + using namespace mshadow; + + if (req == kNullOp) { + return; + } + + if (out.shape_.Size() == 0U) { + return; // zero-size output, no need to launch kernel + } + + const mxnet::TShape& ashape = a.shape_; + const mxnet::TShape& bshape = b.shape_; + const mxnet::TShape& oshape = out.shape_; + + + // TensordotIntAxesImpl(0, ctx, a, b, out, req[0]); + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { + if (ashape.Size() == 0U || bshape.Size() == 0U) { + // 0-size input + if (req != kAddTo) { + Tensor out_data = out.get_with_shape( + Shape1(out.shape_.Size()), s); + out_data = static_cast(0); + } + } else if (ashape.ndim() == 0 && bshape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor out_data = out.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(out_data, req, a_data * b_data); + } else if (ashape.ndim() == 0 || bshape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const DType* tensor = (ashape.ndim() == 0) ? b.dptr() : a.dptr(); + const DType* scalar = (ashape.ndim() == 0) ? a.dptr() : b.dptr(); + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), tensor, scalar); + }); + } else { + MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { + Shape ashape_ = oshape.get(); + Shape bshape_ = oshape.get(); + Shape oshape_ = oshape.get(); + int temp = ashape.ndim()-bshape.ndim(); + int s_dim = (temp > 0)?bshape.ndim():ashape.ndim(); + for (int i = 0; i < s_dim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + if (temp > 0) { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = 1; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } else { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = 1; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), a.dptr(), b.dptr(), + ashape_, bshape_, oshape_); + }); + }); + } + }); +} + +template +void KronOpBackwardImpl(const OpContext& ctx, + const std::vector& req, + const TBlob& a, + const TBlob& b, + const TBlob& ograd, + const TBlob& agrad, + const TBlob& bgrad) { + const mxnet::TShape& ashape = a.shape_; + const mxnet::TShape& bshape = b.shape_; + const mxnet::TShape& oshape = ograd.shape_; + + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + if (ashape.ndim() == 0 && bshape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor ograd_data = ograd.get_with_shape(Shape1(1), s); + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor agrad_data = agrad.get_with_shape(Shape1(1), s); + Tensor bgrad_data = bgrad.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(agrad_data, req[0], b_data * ograd_data); + ASSIGN_DISPATCH(bgrad_data, req[1], a_data * ograd_data); + } else if (ashape.ndim() == 0 || bshape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const TBlob& tensor = (ashape.ndim() == 0) ? b : a; + const TBlob& tensor_grad = (ashape.ndim() == 0) ? bgrad : agrad; + const TBlob& scalar = (ashape.ndim() == 0) ? a : b; + const TBlob& scalar_grad = (ashape.ndim() == 0) ? agrad : bgrad; + Tensor scalar_ = scalar.get_with_shape(Shape1(1), s); + Tensor scalar_grad_ = scalar_grad.get_with_shape(Shape1(1), s); + Tensor tensor_ = tensor.FlatTo1D(s); + Tensor tensor_grad_ = tensor_grad.FlatTo1D(s); + Tensor ograd_ = ograd.FlatTo1D(s); + const OpReqType& tensor_req = (ashape.ndim() == 0) ? req[1] : req[0]; + const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1]; + ASSIGN_DISPATCH(tensor_grad_, tensor_req, + broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(ograd.shape_.Size()), s); + ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_); + + ReduceAxesComputeImpl( + ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { + MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { + Shape ashape_ = oshape.get(); + Shape bshape_ = oshape.get(); + Shape oshape_ = oshape.get(); + int temp = ashape.ndim()-bshape.ndim(); + int s_dim = (temp > 0)?bshape.ndim():ashape.ndim(); + for (int i = 0; i < s_dim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + if (temp > 0) { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = 1; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } else { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = 1; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } + MSHADOW_TYPE_SWITCH(agrad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, agrad.Size(), agrad.dptr(), b.dptr(), ograd.dptr(), + ashape_, bshape_, oshape_); + }); + }); + MSHADOW_TYPE_SWITCH(bgrad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, bgrad.Size(), a.dptr(), bgrad.dptr(), ograd.dptr(), + ashape_, bshape_, oshape_); + }); + }); + }); + } + }); +} + +template +inline void KronOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + const TBlob& out = outputs[0]; + + KronOpForwardImpl(ctx, req[0], a, b, out); +} + + +template +inline void KronOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + + const TBlob& ograd = inputs[0]; + const TBlob& a = inputs[1]; + const TBlob& b = inputs[2]; + const TBlob& grad_a = outputs[0]; + const TBlob& grad_b = outputs[1]; + + KronOpBackwardImpl(ctx, req, a, b, ograd, grad_a, grad_b); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_KRON_INL_H_ diff --git a/src/operator/numpy/np_kron.cc b/src/operator/numpy/np_kron.cc new file mode 100644 index 000000000000..321e51bc259e --- /dev/null +++ b/src/operator/numpy/np_kron.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_kron.cc + * \brief CPU Implementation of numpy-compatible Kronecker product + */ + +#include "./np_kron-inl.h" + +namespace mxnet { +namespace op { + +inline bool KronOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& a_shape = in_attrs->at(0); + const mxnet::TShape& b_shape = in_attrs->at(1); + + if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) { + return false; + } + + mxnet::TShape out_shape(std::max(a_shape.ndim(), b_shape.ndim()), -1); + if (a_shape.ndim() > b_shape.ndim()) { + for (int i = 0; i < a_shape.ndim() - b_shape.ndim(); i++) { + out_shape[i] = a_shape[i]; + } + for (int i = a_shape.ndim() - b_shape.ndim(); i < a_shape.ndim(); i++) { + out_shape[i] = a_shape[i] * b_shape[i - a_shape.ndim() + b_shape.ndim()]; + } + } else { + for (int i = 0; i < b_shape.ndim() - a_shape.ndim(); i++) { + out_shape[i] = b_shape[i]; + } + for (int i = b_shape.ndim() - a_shape.ndim(); i < b_shape.ndim(); i++) { + out_shape[i] = b_shape[i] * a_shape[i - b_shape.ndim() + a_shape.ndim()]; + } + } + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +NNVM_REGISTER_OP(_npi_kron) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) +.set_attr("FInferShape", KronOpShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", KronOpForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_kron"}) +.add_argument("a", "NDArray-or-Symbol", "First input") +.add_argument("b", "NDArray-or-Symbol", "Second input"); + +NNVM_REGISTER_OP(_backward_npi_kron) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", KronOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_kron.cu b/src/operator/numpy/np_kron.cu new file mode 100644 index 000000000000..fc2fb1f765b9 --- /dev/null +++ b/src/operator/numpy/np_kron.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_kron.cu + * \brief GPU Implementation of numpy-compatible Kronecker product + */ + +#include "./np_kron-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_kron) +.set_attr("FCompute", KronOpForward); + +NNVM_REGISTER_OP(_backward_npi_kron) +.set_attr("FCompute", KronOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index f58002c8634b..a4492a3beab1 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1229,6 +1229,13 @@ def _add_workload_outer(): OpArgMngr.add_workload('outer', np.ones((5)), np.ones((2))) +def _add_workload_kron(): + OpArgMngr.add_workload('kron', np.ones((5)), np.ones((2))) + OpArgMngr.add_workload('kron', np.arange(16).reshape((4,4)), np.ones((4,4))) + OpArgMngr.add_workload('kron', np.ones((2,4)), np.zeros((2,4))) + OpArgMngr.add_workload('kron', np.ones(()), np.ones(())) + + def _add_workload_meshgrid(): OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3])) OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7])) @@ -2833,6 +2840,7 @@ def _prepare_workloads(): _add_workload_trace() _add_workload_tril() _add_workload_outer() + _add_workload_kron() _add_workload_meshgrid() _add_workload_einsum() _add_workload_abs() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ef5de655cbca..d3964af0bab7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -555,6 +555,87 @@ def ShapeReduce(mat, shape, is_b=False): assert_raises(MXNetError, lambda: np.matmul(a, b)) +@with_seed() +@use_np +def test_np_kron(): + def np_kron_backward(ograd, a, b): + ndim = ograd.ndim + # Make ndim equal + if ndim > a.ndim: + a = a.reshape((1,)*(ndim - a.ndim) + a.shape) + else: + b = b.reshape((1,)*(ndim - b.ndim) + b.shape) + assert(a.ndim == b.ndim) + + # Compute agrad + agrad = _np.zeros(a.shape) + for i in range(a.size): + ia = _np.asarray(_np.unravel_index(i, a.shape)) + for j in range(b.size): + jb = _np.asarray(_np.unravel_index(j, b.shape)) + k = ia * _np.asarray(b.shape) + jb + agrad[tuple(ia)] += ograd[tuple(k)] * b[tuple(jb)] + # Compute bgrad + bgrad = _np.zeros(b.shape) + for j in range(b.size): + jb = _np.asarray(_np.unravel_index(j, b.shape)) + for i in range(a.size): + ia = _np.asarray(_np.unravel_index(i, a.shape)) + k = ia * _np.asarray(b.shape) + jb + bgrad[tuple(jb)] += ograd[tuple(k)] * a[tuple(ia)] + return [agrad, bgrad] + + class TestKron(HybridBlock): + def __init__(self): + super(TestKron, self).__init__() + + def hybrid_forward(self, F, a, b): + return F.np.kron(a, b) + + # test input + tensor_shapes = [ + ((3,), (3,)), + ((2, 3), (3,)), + ((2, 3, 4), (2,)), + ((3, 2), ()) + ] + + for hybridize in [True, False]: + for a_shape, b_shape in tensor_shapes: + for dtype in [_np.float32, _np.float64]: + test_kron = TestKron() + if hybridize: + test_kron.hybridize() + a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() + b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray() + a.attach_grad() + b.attach_grad() + + np_out = _np.kron(a.asnumpy(), b.asnumpy()) + with mx.autograd.record(): + mx_out = test_kron(a, b) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + mx_out.backward() + + # Test imperative once again + mx_out = np.kron(a, b) + np_out = _np.kron(a.asnumpy(), b.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + + # test numeric gradient + a_sym = mx.sym.Variable("a").as_np_ndarray() + b_sym = mx.sym.Variable("b").as_np_ndarray() + mx_sym = mx.sym.np.kron(a_sym, b_sym).as_nd_ndarray() + check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()], + rtol=1e-2, atol=1e-2, dtype=dtype) + + # test gradient via backward implemented by numpy + np_backward = np_kron_backward(_np.ones(np_out.shape, dtype = dtype), a.asnumpy(), b.asnumpy()) + assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol=1e-2, atol=1e-2) + assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol=1e-2, atol=1e-2) + + @with_seed() @use_np def test_np_sum():