Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add: numpy op rollaxis
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed Mar 18, 2020
1 parent dfb1b88 commit 5fcd1f7
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 3 deletions.
36 changes: 35 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
'where', 'bincount', 'pad', 'cumsum', 'rollaxis']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -7745,3 +7745,37 @@ def cumsum(a, axis=None, dtype=None, out=None):
[ 4, 9, 15]])
"""
return _api_internal.cumsum(a, axis, dtype, out)

@set_module('mxnet.ndarray.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.
Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.
-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _npi.rollaxis(a, axis, start)
37 changes: 36 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'pad', 'cumsum']
'pad', 'cumsum', 'rollaxis']

__all__ += fallback.__all__

Expand Down Expand Up @@ -9891,4 +9891,39 @@ def cumsum(a, axis=None, dtype=None, out=None):
[ 4, 9, 15]])
"""
return _mx_nd_np.cumsum(a, axis=axis, dtype=dtype, out=out)

@set_module('mxnet.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.
Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.
-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _mx_nd_np.rollaxis(a, axis, start)

# pylint: enable=redefined-outer-name
35 changes: 34 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
'where', 'bincount', 'pad', 'cumsum', 'rollaxis']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -6866,5 +6866,38 @@ def cumsum(a, axis=None, dtype=None, out=None):
"""
return _npi.cumsum(a, axis=axis, dtype=dtype, out=out)

@set_module('mxnet.symbol.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.
Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.
-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _npi.rollaxis(a, axis, start)

_set_np_symbol_class(_Symbol)
65 changes: 65 additions & 0 deletions src/operator/numpy/np_rollaixs_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

#include "./np_rollaxis_op-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyRollaxisParam);

bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
// check 1 input, 1 output
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

// check transpose dimentions no more than 6
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";

// check axis and start range
CHECK_GE(param.axis, -shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_LT(param.axis, shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_GE(param.start, -shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim();
CHECK_LE(param.start, shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim();

// generate output shape
mxnet::TShape ret(shp.ndim(), -1);
mxnet::TShape axes;

axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim());
for (int i = 0; i < shp.ndim(); ++i) {
CHECK(axes[i] < static_cast<int64_t>(shp.ndim()));
ret[i] = shp[axes[i]];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
return shape_is_known(ret);
}

NNVM_REGISTER_OP(_npi_rollaxis)
.describe(R"code(Roll the specified axis backwards,
until it lies in a given position.)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyRollaxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollaxisParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>);

} // namespace op
} // namespace mxnet
117 changes: 117 additions & 0 deletions src/operator/numpy/np_rollaxis_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#ifndef MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_

#include "../operator_common.h"
#include <mxnet/operator_util.h>
#include "../tensor/matrix_op-inl.h"
#include "../nn/concat-inl.h"
#include "../../common/utils.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../tensor/broadcast_reduce_op.h"

namespace mxnet {
namespace op {

struct NumpyRollaxisParam : public dmlc::Parameter<NumpyRollaxisParam> {
int axis;
int start;
DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) {
DMLC_DECLARE_FIELD(axis)
.describe("The axis to roll backwards. The positions of the other axes do not change relative to one another.");
DMLC_DECLARE_FIELD(start)
.set_default(0)
.describe("The axis is rolled until it lies before this position. The default, 0, results in a “complete” roll.");
}
};

inline mxnet::TShape NumpyRollaxisShapeImpl(int axis,
int start,
const int& ndim) {
mxnet::TShape axes(ndim, -1);
if (axis < 0) {
axis += ndim;
}
if (start < 0){
start += ndim;
}
if (axis < start){
axes[start - 1] = axis;
} else {
axes[start] = axis;
}
int new_axis = 0;
for(int i = 0; i < axes.ndim(); i++){
if (axes[i] < 0){
if (new_axis == axis){
new_axis++;
}
axes[i] = new_axis++;
}
}
return axes;
}


template<typename xpu>
void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace";
mxnet::TShape axes;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

template<typename xpu>
void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
int axis_origin = param.axis;
int start_origin = param.start;
int ndim = inputs[0].ndim();

int axis;
int start;

if (axis_origin < 0) {
axis_origin += ndim;
}

if (start_origin < 0) {
start_origin += ndim;
}

if (axis_origin < start_origin){
axis = start_origin - 1;
start = axis_origin;
} else {
axis = start_origin;
start = axis_origin + 1;
}
mxnet::TShape axes;
axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_
13 changes: 13 additions & 0 deletions src/operator/numpy/np_rollaxis_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "./np_rollaxis_op-inl.h"

namespace mxnet{
namespace op{

NNVM_REGISTER_OP(_npi_rollaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisBackward<gpu>);

}
}
38 changes: 38 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8374,6 +8374,44 @@ def hybrid_forward(self, F, x, *args, **kwargs):
ret = np.empty_like(prototype, dtype, order, subok)
assert ret.asnumpy().shape == expected_ret.shape

@with_seed()
@use_np
def test_np_rollaxis():
class TestRollaxis(HybridBlock):
def __init__(self, axis=0, start=0):
super(TestRollaxis, self).__init__()
self._axis = axis
self._start = start

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.rollaxis(a, axis=self._axis, start=self._start)

dtypes = ['int32', 'int64', 'float16', 'float32', 'float64']
for hybridize in [False, True]:
for dtype in dtypes:
for ndim in [0, 1, 2, 3, 4, 5, 6]:
shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True)
np_data = _np.random.uniform(low=-100, high=100, size=shape).astype(dtype)
mx_data = np.array(np_data, dtype=dtype)
for axis in range(-ndim, ndim):
for start in range(-ndim, ndim + 1):
# test gluon
test_rollaxis = TestRollaxis(axis, start)
if hybridize:
test_rollaxis.hybridize()
np_out = _np.rollaxis(np_data, axis=axis, start=start)
mx_data.attach_grad()
with mx.autograd.record():
mx_out = test_rollaxis(mx_data)
assert mx_out.shape == np_out.shape
mx_out.backward()
assert same(mx_data.grad.shape, mx_data.shape)
assert same(mx_data.grad.asnumpy(), _np.ones(shape))
# test imperative
np_out = _np.rollaxis(np_data, axis=axis, start=start)
mx_out = np.rollaxis(mx_data, axis=axis, start=start)
assert np_out.dtype == mx_out.dtype
assert same(mx_out.asnumpy(), np_out)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 5fcd1f7

Please sign in to comment.