diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h index bad30a82cc1a..e8985152eb4e 100644 --- a/src/operator/numpy/np_kron-inl.h +++ b/src/operator/numpy/np_kron-inl.h @@ -116,43 +116,75 @@ void KronOpForwardImpl(const OpContext& ctx, ) { using namespace mshadow; + if (req == kNullOp) { + return; + } + + if (out.shape_.Size() == 0U) { + return; // zero-size output, no need to launch kernel + } + const mxnet::TShape& ashape = a.shape_; const mxnet::TShape& bshape = b.shape_; const mxnet::TShape& oshape = out.shape_; - MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { - Shape ashape_; - Shape bshape_; - Shape oshape_; - int temp = ashape.ndim()-bshape.ndim(); - int s_dim = (temp > 0)?bshape.ndim():ashape.ndim(); - for (int i = 0; i < s_dim; i++) { - ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; - bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; - oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; - } - if (temp > 0) { - for (int i = s_dim; i < ndim; i++) { - ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; - bshape_[ndim - i - 1] = 1; - oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + + + // TensordotIntAxesImpl(0, ctx, a, b, out, req[0]); + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { + if (ashape.Size() == 0U || bshape.Size() == 0U) { + // 0-size input + if (req != kAddTo) { + Tensor out_data = out.get_with_shape( + Shape1(out.shape_.Size()), s); + out_data = static_cast(0); } + } else if (ashape.ndim() == 0 && bshape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor out_data = out.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(out_data, req, a_data * b_data); + } else if (ashape.ndim() == 0 || bshape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const DType* tensor = (ashape.ndim() == 0) ? b.dptr() : a.dptr(); + const DType* scalar = (ashape.ndim() == 0) ? a.dptr() : b.dptr(); + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), tensor, scalar); + }); } else { - for (int i = s_dim; i < ndim; i++) { - ashape_[ndim - i - 1] = 1; - bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; - oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; - } - } - - // TensordotIntAxesImpl(0, ctx, a, b, out, req[0]); - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req, req_type, { - mxnet_op::Kernel, xpu>::Launch( - s, out.Size(), out.dptr(), a.dptr(), b.dptr(), - ashape_, bshape_, oshape_); + MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { + Shape ashape_; + Shape bshape_; + Shape oshape_; + int temp = ashape.ndim()-bshape.ndim(); + int s_dim = (temp > 0)?bshape.ndim():ashape.ndim(); + for (int i = 0; i < s_dim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + if (temp > 0) { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = 1; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } else { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = 1; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), a.dptr(), b.dptr(), + ashape_, bshape_, oshape_); + }); }); - }); + } }); } @@ -167,47 +199,80 @@ void KronOpBackwardImpl(const OpContext& ctx, const mxnet::TShape& ashape = a.shape_; const mxnet::TShape& bshape = b.shape_; const mxnet::TShape& oshape = ograd.shape_; - MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { - Shape ashape_; - Shape bshape_; - Shape oshape_; - int temp = ashape.ndim()-bshape.ndim(); - int s_dim = (temp > 0)?bshape.ndim():ashape.ndim(); - for (int i = 0; i < s_dim; i++) { - ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; - bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; - oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; - } - if (temp > 0) { - for (int i = s_dim; i < ndim; i++) { - ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; - bshape_[ndim - i - 1] = 1; - oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; - } - } else { - for (int i = s_dim; i < ndim; i++) { - ashape_[ndim - i - 1] = 1; - bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; - oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; - } - } - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(agrad.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - mxnet_op::Kernel, xpu>::Launch( - s, agrad.Size(), agrad.dptr(), b.dptr(), ograd.dptr(), - ashape_, bshape_, oshape_); - }); - }); + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + if (ashape.ndim() == 0 && bshape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor ograd_data = ograd.get_with_shape(Shape1(1), s); + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor agrad_data = agrad.get_with_shape(Shape1(1), s); + Tensor bgrad_data = bgrad.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(agrad_data, req[0], b_data * ograd_data); + ASSIGN_DISPATCH(bgrad_data, req[1], a_data * ograd_data); + } else if (ashape.ndim() == 0 || bshape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const TBlob& tensor = (ashape.ndim() == 0) ? b : a; + const TBlob& tensor_grad = (ashape.ndim() == 0) ? bgrad : agrad; + const TBlob& scalar = (ashape.ndim() == 0) ? a : b; + const TBlob& scalar_grad = (ashape.ndim() == 0) ? agrad : bgrad; + Tensor scalar_ = scalar.get_with_shape(Shape1(1), s); + Tensor scalar_grad_ = scalar_grad.get_with_shape(Shape1(1), s); + Tensor tensor_ = tensor.FlatTo1D(s); + Tensor tensor_grad_ = tensor_grad.FlatTo1D(s); + Tensor ograd_ = ograd.FlatTo1D(s); + const OpReqType& tensor_req = (ashape.ndim() == 0) ? req[1] : req[0]; + const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1]; + ASSIGN_DISPATCH(tensor_grad_, tensor_req, + broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(ograd.shape_.Size()), s); + ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_); - MSHADOW_TYPE_SWITCH(bgrad.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { - mxnet_op::Kernel, xpu>::Launch( - s, bgrad.Size(), a.dptr(), bgrad.dptr(), ograd.dptr(), - ashape_, bshape_, oshape_); + ReduceAxesComputeImpl( + ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { + MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { + Shape ashape_; + Shape bshape_; + Shape oshape_; + int temp = ashape.ndim()-bshape.ndim(); + int s_dim = (temp > 0)?bshape.ndim():ashape.ndim(); + for (int i = 0; i < s_dim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + if (temp > 0) { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1]; + bshape_[ndim - i - 1] = 1; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } else { + for (int i = s_dim; i < ndim; i++) { + ashape_[ndim - i - 1] = 1; + bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1]; + oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1]; + } + } + MSHADOW_TYPE_SWITCH(agrad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, agrad.Size(), agrad.dptr(), b.dptr(), ograd.dptr(), + ashape_, bshape_, oshape_); + }); + }); + MSHADOW_TYPE_SWITCH(bgrad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, bgrad.Size(), a.dptr(), bgrad.dptr(), ograd.dptr(), + ashape_, bshape_, oshape_); + }); + }); }); - }); + } }); }