Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* Fix ndim = 0
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Feb 24, 2020
1 parent f877cd8 commit 48a6db7
Showing 1 changed file with 134 additions and 69 deletions.
203 changes: 134 additions & 69 deletions src/operator/numpy/np_kron-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,43 +116,75 @@ void KronOpForwardImpl(const OpContext& ctx,
) {
using namespace mshadow;

if (req == kNullOp) {
return;
}

if (out.shape_.Size() == 0U) {
return; // zero-size output, no need to launch kernel
}

const mxnet::TShape& ashape = a.shape_;
const mxnet::TShape& bshape = b.shape_;
const mxnet::TShape& oshape = out.shape_;
MXNET_NDIM_SWITCH(oshape.ndim(), ndim, {
Shape<ndim> ashape_;
Shape<ndim> bshape_;
Shape<ndim> oshape_;
int temp = ashape.ndim()-bshape.ndim();
int s_dim = (temp > 0)?bshape.ndim():ashape.ndim();
for (int i = 0; i < s_dim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
if (temp > 0) {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = 1;
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];


// TensordotIntAxesImpl<xpu>(0, ctx, a, b, out, req[0]);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
if (ashape.Size() == 0U || bshape.Size() == 0U) {
// 0-size input
if (req != kAddTo) {
Tensor<xpu, 1, DType> out_data = out.get_with_shape<xpu, 1, DType>(
Shape1(out.shape_.Size()), s);
out_data = static_cast<DType>(0);
}
} else if (ashape.ndim() == 0 && bshape.ndim() == 0) {
// Both 0-D scalars, equivalent to multiply
Tensor<xpu, 1, DType> a_data = a.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> b_data = b.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> out_data = out.get_with_shape<xpu, 1, DType>(Shape1(1), s);
ASSIGN_DISPATCH(out_data, req, a_data * b_data);
} else if (ashape.ndim() == 0 || bshape.ndim() == 0) {
// Either of them is a scalar, just scale by one of them
const DType* tensor = (ashape.ndim() == 0) ? b.dptr<DType>() : a.dptr<DType>();
const DType* scalar = (ashape.ndim() == 0) ? a.dptr<DType>() : b.dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<scalar_mul_kernel<Req>, xpu>::Launch(
s, out.Size(), out.dptr<DType>(), tensor, scalar);
});
} else {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = 1;
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
}

// TensordotIntAxesImpl<xpu>(0, ctx, a, b, out, req[0]);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
mxnet_op::Kernel<kron<ndim, req_type>, xpu>::Launch(
s, out.Size(), out.dptr<DType>(), a.dptr<DType>(), b.dptr<DType>(),
ashape_, bshape_, oshape_);
MXNET_NDIM_SWITCH(oshape.ndim(), ndim, {
Shape<ndim> ashape_;
Shape<ndim> bshape_;
Shape<ndim> oshape_;
int temp = ashape.ndim()-bshape.ndim();
int s_dim = (temp > 0)?bshape.ndim():ashape.ndim();
for (int i = 0; i < s_dim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
if (temp > 0) {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = 1;
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
} else {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = 1;
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
}
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
mxnet_op::Kernel<kron<ndim, req_type>, xpu>::Launch(
s, out.Size(), out.dptr<DType>(), a.dptr<DType>(), b.dptr<DType>(),
ashape_, bshape_, oshape_);
});
});
});
}
});
}

Expand All @@ -167,47 +199,80 @@ void KronOpBackwardImpl(const OpContext& ctx,
const mxnet::TShape& ashape = a.shape_;
const mxnet::TShape& bshape = b.shape_;
const mxnet::TShape& oshape = ograd.shape_;
MXNET_NDIM_SWITCH(oshape.ndim(), ndim, {
Shape<ndim> ashape_;
Shape<ndim> bshape_;
Shape<ndim> oshape_;
int temp = ashape.ndim()-bshape.ndim();
int s_dim = (temp > 0)?bshape.ndim():ashape.ndim();
for (int i = 0; i < s_dim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
if (temp > 0) {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = 1;
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
} else {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = 1;
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
}

Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(agrad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
mxnet_op::Kernel<kron_back_a<ndim, req_type>, xpu>::Launch(
s, agrad.Size(), agrad.dptr<DType>(), b.dptr<DType>(), ograd.dptr<DType>(),
ashape_, bshape_, oshape_);
});
});
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
if (ashape.ndim() == 0 && bshape.ndim() == 0) {
// Both 0-D scalars, equivalent to multiply
Tensor<xpu, 1, DType> ograd_data = ograd.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> a_data = a.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> b_data = b.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> agrad_data = agrad.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> bgrad_data = bgrad.get_with_shape<xpu, 1, DType>(Shape1(1), s);
ASSIGN_DISPATCH(agrad_data, req[0], b_data * ograd_data);
ASSIGN_DISPATCH(bgrad_data, req[1], a_data * ograd_data);
} else if (ashape.ndim() == 0 || bshape.ndim() == 0) {
// Either of them is a scalar, just scale by one of them
const TBlob& tensor = (ashape.ndim() == 0) ? b : a;
const TBlob& tensor_grad = (ashape.ndim() == 0) ? bgrad : agrad;
const TBlob& scalar = (ashape.ndim() == 0) ? a : b;
const TBlob& scalar_grad = (ashape.ndim() == 0) ? agrad : bgrad;
Tensor<xpu, 1, DType> scalar_ = scalar.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> scalar_grad_ = scalar_grad.get_with_shape<xpu, 1, DType>(Shape1(1), s);
Tensor<xpu, 1, DType> tensor_ = tensor.FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> tensor_grad_ = tensor_grad.FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> ograd_ = ograd.FlatTo1D<xpu, DType>(s);
const OpReqType& tensor_req = (ashape.ndim() == 0) ? req[1] : req[0];
const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1];
ASSIGN_DISPATCH(tensor_grad_, tensor_req,
broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_);
Tensor<xpu, 1, DType> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(ograd.shape_.Size()), s);
ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_);

MSHADOW_TYPE_SWITCH(bgrad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, {
mxnet_op::Kernel<kron_back_b<ndim, req_type>, xpu>::Launch(
s, bgrad.Size(), a.dptr<DType>(), bgrad.dptr<DType>(), ograd.dptr<DType>(),
ashape_, bshape_, oshape_);
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
} else {
MXNET_NDIM_SWITCH(oshape.ndim(), ndim, {
Shape<ndim> ashape_;
Shape<ndim> bshape_;
Shape<ndim> oshape_;
int temp = ashape.ndim()-bshape.ndim();
int s_dim = (temp > 0)?bshape.ndim():ashape.ndim();
for (int i = 0; i < s_dim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
if (temp > 0) {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = ashape[ashape.ndim() - i - 1];
bshape_[ndim - i - 1] = 1;
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
} else {
for (int i = s_dim; i < ndim; i++) {
ashape_[ndim - i - 1] = 1;
bshape_[ndim - i - 1] = bshape[bshape.ndim() - i - 1];
oshape_[ndim - i - 1] = oshape[oshape.ndim() - i - 1];
}
}
MSHADOW_TYPE_SWITCH(agrad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
mxnet_op::Kernel<kron_back_a<ndim, req_type>, xpu>::Launch(
s, agrad.Size(), agrad.dptr<DType>(), b.dptr<DType>(), ograd.dptr<DType>(),
ashape_, bshape_, oshape_);
});
});
MSHADOW_TYPE_SWITCH(bgrad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, {
mxnet_op::Kernel<kron_back_b<ndim, req_type>, xpu>::Launch(
s, bgrad.Size(), a.dptr<DType>(), bgrad.dptr<DType>(), ograd.dptr<DType>(),
ashape_, bshape_, oshape_);
});
});
});
});
}
});
}

Expand Down

0 comments on commit 48a6db7

Please sign in to comment.