This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
338 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
|
||
#include "./np_rollaxis_op-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
DMLC_REGISTER_PARAMETER(NumpyRollaxisParam); | ||
|
||
bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs, | ||
mxnet::ShapeVector *in_attrs, | ||
mxnet::ShapeVector *out_attrs) { | ||
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed); | ||
// check 1 input, 1 output | ||
CHECK_EQ(in_attrs->size(), 1U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
|
||
// check transpose dimentions no more than 6 | ||
mxnet::TShape& shp = (*in_attrs)[0]; | ||
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; | ||
|
||
// check axis and start range | ||
CHECK_GE(param.axis, -shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1; | ||
CHECK_LT(param.axis, shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1; | ||
CHECK_GE(param.start, -shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim(); | ||
CHECK_LE(param.start, shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim(); | ||
|
||
// generate output shape | ||
mxnet::TShape ret(shp.ndim(), -1); | ||
mxnet::TShape axes; | ||
|
||
axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim()); | ||
for (int i = 0; i < shp.ndim(); ++i) { | ||
CHECK(axes[i] < static_cast<int64_t>(shp.ndim())); | ||
ret[i] = shp[axes[i]]; | ||
} | ||
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); | ||
return shape_is_known(ret); | ||
} | ||
|
||
NNVM_REGISTER_OP(_npi_rollaxis) | ||
.describe(R"code(Roll the specified axis backwards, | ||
until it lies in a given position.)code" ADD_FILELINE) | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_attr_parser(ParamParser<NumpyRollaxisParam>) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"data"}; | ||
}) | ||
.set_attr<mxnet::FInferShape>("FInferShape", NumpyRollaxisShape) | ||
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) | ||
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"}) | ||
.add_argument("data", "NDArray-or-Symbol", "Input ndarray") | ||
.add_arguments(NumpyRollaxisParam::__FIELDS__()); | ||
|
||
NNVM_REGISTER_OP(_npi_rollaxis_backward) | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_attr_parser(ParamParser<NumpyRollaxisParam>) | ||
.set_attr<nnvm::TIsBackward>("TIsBackward", true) | ||
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#ifndef MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_ | ||
#define MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_ | ||
|
||
#include "../operator_common.h" | ||
#include <mxnet/operator_util.h> | ||
#include "../tensor/matrix_op-inl.h" | ||
#include "../nn/concat-inl.h" | ||
#include "../../common/utils.h" | ||
#include "../mxnet_op.h" | ||
#include "../operator_common.h" | ||
#include "../elemwise_op_common.h" | ||
#include "../tensor/broadcast_reduce_op.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
struct NumpyRollaxisParam : public dmlc::Parameter<NumpyRollaxisParam> { | ||
int axis; | ||
int start; | ||
DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) { | ||
DMLC_DECLARE_FIELD(axis) | ||
.describe("The axis to roll backwards. The positions of the other axes do not change relative to one another."); | ||
DMLC_DECLARE_FIELD(start) | ||
.set_default(0) | ||
.describe("The axis is rolled until it lies before this position. The default, 0, results in a “complete” roll."); | ||
} | ||
}; | ||
|
||
inline mxnet::TShape NumpyRollaxisShapeImpl(int axis, | ||
int start, | ||
const int& ndim) { | ||
mxnet::TShape axes(ndim, -1); | ||
if (axis < 0) { | ||
axis += ndim; | ||
} | ||
if (start < 0){ | ||
start += ndim; | ||
} | ||
if (axis < start){ | ||
axes[start - 1] = axis; | ||
} else { | ||
axes[start] = axis; | ||
} | ||
int new_axis = 0; | ||
for(int i = 0; i < axes.ndim(); i++){ | ||
if (axes[i] < 0){ | ||
if (new_axis == axis){ | ||
new_axis++; | ||
} | ||
axes[i] = new_axis++; | ||
} | ||
} | ||
return axes; | ||
} | ||
|
||
|
||
template<typename xpu> | ||
void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
CHECK_EQ(inputs.size(), 1U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace"; | ||
mxnet::TShape axes; | ||
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed); | ||
axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim()); | ||
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { | ||
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes); | ||
}) | ||
} | ||
|
||
template<typename xpu> | ||
void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed); | ||
int axis_origin = param.axis; | ||
int start_origin = param.start; | ||
int ndim = inputs[0].ndim(); | ||
|
||
int axis; | ||
int start; | ||
|
||
if (axis_origin < 0) { | ||
axis_origin += ndim; | ||
} | ||
|
||
if (start_origin < 0) { | ||
start_origin += ndim; | ||
} | ||
|
||
if (axis_origin < start_origin){ | ||
axis = start_origin - 1; | ||
start = axis_origin; | ||
} else { | ||
axis = start_origin; | ||
start = axis_origin + 1; | ||
} | ||
mxnet::TShape axes; | ||
axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim()); | ||
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { | ||
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes); | ||
}) | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#include "./np_rollaxis_op-inl.h" | ||
|
||
namespace mxnet{ | ||
namespace op{ | ||
|
||
NNVM_REGISTER_OP(_npi_rollaxis) | ||
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisCompute<gpu>); | ||
|
||
NNVM_REGISTER_OP(_npi_rollaxis_backward) | ||
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisBackward<gpu>); | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters