Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add: numpy rollaxis
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed Apr 1, 2020
1 parent 56e7985 commit 75df9d1
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 6 deletions.
34 changes: 33 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']
'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'diag', 'diagonal']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -7938,6 +7938,38 @@ def cumsum(a, axis=None, dtype=None, out=None):
return _api_internal.cumsum(a, axis, dtype, out)


@set_module('mxnet.ndarray.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
a
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.
Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.
-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _npi.rollaxis(a, axis, start)


@set_module('mxnet.ndarray.numpy')
def diag(v, k=0):
"""
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/numpy/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
'rate',
'real',
'result_type',
'rollaxis',
'roots',
'searchsorted',
'select',
Expand Down Expand Up @@ -182,7 +181,6 @@
rate = onp.rate
real = onp.real
result_type = onp.result_type
rollaxis = onp.rollaxis
roots = onp.roots
searchsorted = onp.searchsorted
select = onp.select
Expand Down
39 changes: 37 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'pad', 'cumsum', 'diag', 'diagonal']
'pad', 'cumsum', 'rollaxis', 'diag', 'diagonal']

__all__ += fallback.__all__

Expand Down Expand Up @@ -10101,7 +10101,42 @@ def cumsum(a, axis=None, dtype=None, out=None):
[ 4, 9, 15]])
"""
return _mx_nd_np.cumsum(a, axis=axis, dtype=dtype, out=out)
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.
Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.
-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _mx_nd_np.rollaxis(a, axis, start)


@set_module('mxnet.numpy')
Expand Down
37 changes: 36 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']
'where', 'bincount', 'pad', 'rollaxis', 'cumsum', 'diag', 'diagonal']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -6968,6 +6968,41 @@ def cumsum(a, axis=None, dtype=None, out=None):
return _npi.cumsum(a, axis=axis, dtype=dtype, out=out)


@set_module('mxnet.symbol.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.
Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.
-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _npi.rollaxis(a, axis, start)


@set_module('mxnet.symbol.numpy')
def diag(v, k=0):
"""
Expand Down
98 changes: 98 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,47 @@ void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
}

struct NumpyRollaxisParam : public dmlc::Parameter<NumpyRollaxisParam> {
int axis;
int start;
DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) {
DMLC_DECLARE_FIELD(axis)
.describe("The axis to roll backwards. "
"The positions of the other axes do not change relative to one another.");
DMLC_DECLARE_FIELD(start)
.set_default(0)
.describe("The axis is rolled until it lies before this position. "
"The default, 0, results in a “complete” roll.");
}
};

inline mxnet::TShape NumpyRollaxisShapeImpl(int axis,
int start,
const int& ndim) {
mxnet::TShape axes(ndim, -1);
if (axis < 0) {
axis += ndim;
}
if (start < 0) {
start += ndim;
}
if (axis < start) {
axes[start - 1] = axis;
} else {
axes[start] = axis;
}
int new_axis = 0;
for (int i = 0; i < axes.ndim(); i++) {
if (axes[i] < 0) {
if (new_axis == axis) {
new_axis++;
}
axes[i] = new_axis++;
}
}
return axes;
}

struct NumpyMoveaxisParam : public dmlc::Parameter<NumpyMoveaxisParam> {
mxnet::TShape source;
mxnet::TShape destination;
Expand Down Expand Up @@ -601,6 +642,63 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
})
}

template<typename xpu>
void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace";
mxnet::TShape axes;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

template<typename xpu>
void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
int axis_origin = param.axis;
int start_origin = param.start;
int ndim = inputs[0].ndim();

int axis;
int start;

if (axis_origin < 0) {
axis_origin += ndim;
}

if (start_origin < 0) {
start_origin += ndim;
}

if (axis_origin < start_origin) {
axis = start_origin - 1;
start = axis_origin;
} else {
axis = start_origin;
start = axis_origin + 1;
}
mxnet::TShape axes;
axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

struct NumpyRot90Param : public dmlc::Parameter<NumpyRot90Param> {
int k;
dmlc::optional<mxnet::TShape> axes;
Expand Down
64 changes: 64 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyTransposeParam);
DMLC_REGISTER_PARAMETER(NumpyRollParam);
DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRollaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRot90Param);
DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
DMLC_REGISTER_PARAMETER(NumpyXReshapeParam);
Expand Down Expand Up @@ -1190,6 +1191,69 @@ NNVM_REGISTER_OP(_npi_roll)
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollParam::__FIELDS__());

bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
// check 1 input, 1 output
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

// check transpose dimentions no more than 6
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";

// check axis and start range
CHECK_GE(param.axis, -shp.ndim())
<< "axis must be within the range of "
<< -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_LT(param.axis, shp.ndim())
<< "axis must be within the range of "
<< -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_GE(param.start, -shp.ndim())
<< "start must be within the range of "
<< -shp.ndim() << " and " << shp.ndim();
CHECK_LE(param.start, shp.ndim())
<< "start must be within the range of "
<< -shp.ndim() << " and " << shp.ndim();

// generate output shape
mxnet::TShape ret(shp.ndim(), -1);
mxnet::TShape axes;

axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim());
for (int i = 0; i < shp.ndim(); ++i) {
CHECK(axes[i] < static_cast<int64_t>(shp.ndim()));
ret[i] = shp[axes[i]];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
return shape_is_known(ret);
}

NNVM_REGISTER_OP(_npi_rollaxis)
.describe(R"code(Roll the specified axis backwards,
until it lies in a given position.)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyRollaxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollaxisParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>);

template<>
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ NNVM_REGISTER_OP(_backward_npi_flip)
NNVM_REGISTER_OP(_np_moveaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyMoveaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisBackward<gpu>);

NNVM_REGISTER_OP(_npi_rot90)
.set_attr<FCompute>("FCompute<gpu>", NumpyRot90Compute<gpu>);

Expand Down
Loading

0 comments on commit 75df9d1

Please sign in to comment.