Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support np.dsplit, fix some error msgs and corner cases for hsplit an…
Browse files Browse the repository at this point in the history
…d vsplit, add interoperability tests for h/v/dsplit
  • Loading branch information
haojin2 committed Feb 15, 2020
1 parent b6b1de0 commit 96b9439
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 22 deletions.
86 changes: 73 additions & 13 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort',
'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount']
Expand Down Expand Up @@ -3734,21 +3734,19 @@ def hsplit(ary, indices_or_sections):
>>> np.hsplit(x, [2, 2])
[array([0., 1.]), array([], dtype=float32), array([2., 3.])]
"""
axis = 1
if len(ary.shape) == 1:
axis = 0
axis_size = ary.shape[axis]
if len(ary.shape) < 1:
raise ValueError('hsplit only works on arrays of 1 or more dimensions')
indices = []
sections = 0
if isinstance(indices_or_sections, integer_types):
sections = indices_or_sections
if axis_size % sections:
raise ValueError('array hsplit does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.hsplit(ary, indices, axis, False)
raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints')
ret = _npi.hsplit(ary, indices, 1, False, sections)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name

Expand Down Expand Up @@ -3827,9 +3825,71 @@ def vsplit(ary, indices_or_sections):
[6., 7.]]])]
"""
if len(ary.shape) < 2:
raise ValueError("vsplit only works on arrays of 2 or more dimensions")
return split(ary, indices_or_sections, 0)


# pylint: disable=redefined-outer-name
@set_module('mxnet.ndarray.numpy')
def dsplit(ary, indices_or_sections):
"""
Split array into multiple sub-arrays along the 3rd axis (depth).
Please refer to the `split` documentation. `dsplit` is equivalent
to `split` with ``axis=2``, the array is always split along the third
axis provided the array dimension is greater than or equal to 3.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 2. If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 2 the array is split. For example, ``[2, 3]`` would result in
- ary[:, :, :2]
- ary[:, :, 2:3]
- ary[:, :, 3:]
If an index exceeds the dimension of the array along axis 2, an error will be thrown.
Examples
--------
>>> x = np.arange(16.0).reshape(2, 2, 4)
>>> x
array([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.]],
[[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]])
>>> np.dsplit(x, 2)
[array([[[ 0., 1.],
[ 4., 5.]],
[[ 8., 9.],
[12., 13.]]]), array([[[ 2., 3.],
[ 6., 7.]],
[[10., 11.],
[14., 15.]]])]
>>> np.dsplit(x, np.array([3, 6]))
[array([[[ 0., 1., 2.],
[ 4., 5., 6.]],
[[ 8., 9., 10.],
[12., 13., 14.]]]),
array([[[ 3.],
[ 7.]],
[[11.],
[15.]]]),
array([], shape=(2, 2, 0), dtype=float64)]
"""
if len(ary.shape) < 3:
raise ValueError('dsplit only works on arrays of 3 or more dimensions')
return split(ary, indices_or_sections, 2)
# pylint: enable=redefined-outer-name


@set_module('mxnet.ndarray.numpy')
def concatenate(seq, axis=0, out=None):
"""
Expand Down
89 changes: 83 additions & 6 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert',
'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman',
'flip', 'flipud', 'fliplr', 'around', 'round', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount']

Expand Down Expand Up @@ -5439,8 +5440,8 @@ def vsplit(ary, indices_or_sections):
Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
This function differs from the original `numpy.vsplit
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.vsplit.html>`_ in
the following aspects:
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
Expand Down Expand Up @@ -5474,7 +5475,83 @@ def vsplit(ary, indices_or_sections):
[6., 7.]]])]
"""
return split(ary, indices_or_sections, 0)
return _mx_nd_np.vsplit(ary, indices_or_sections)


@set_module('mxnet.numpy')
def dsplit(ary, indices_or_sections):
r"""
Split array into multiple sub-arrays along the 3rd axis (depth).
Please refer to the `split` documentation. `dsplit` is equivalent
to `split` with ``axis=2``, the array is always split along the third
axis provided the array dimension is greater than or equal to 3.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 2. If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 2 the array is split. For example, ``[2, 3]`` would result in
- ary[:, :, :2]
- ary[:, :, 2:3]
- ary[:, :, 3:]
If an index exceeds the dimension of the array along axis 2, an error will be thrown.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
See Also
--------
split : Split an array into multiple sub-arrays of equal size.
Notes
-------
This function differs from the original `numpy.dsplit
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.dsplit.html>`_ in
the following aspects:
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list.
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 2,
an error will be thrown.
Examples
--------
>>> x = np.arange(16.0).reshape(2, 2, 4)
>>> x
array([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.]],
[[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]])
>>> np.dsplit(x, 2)
[array([[[ 0., 1.],
[ 4., 5.]],
[[ 8., 9.],
[12., 13.]]]), array([[[ 2., 3.],
[ 6., 7.]],
[[10., 11.],
[14., 15.]]])]
>>> np.dsplit(x, np.array([3, 6]))
[array([[[ 0., 1., 2.],
[ 4., 5., 6.]],
[[ 8., 9., 10.],
[12., 13., 14.]]]),
array([[[ 3.],
[ 7.]],
[[11.],
[15.]]]),
array([], shape=(2, 2, 0), dtype=float64)]
"""
return _mx_nd_np.dsplit(ary, indices_or_sections)


@set_module('mxnet.numpy')
Expand Down
3 changes: 3 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'roll',
'split',
'array_split',
'hsplit',
'vsplit',
'dsplit',
'squeeze',
'stack',
'std',
Expand Down
46 changes: 43 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'matmul',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', 'insert',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount']
Expand Down Expand Up @@ -3704,6 +3704,46 @@ def vsplit(ary, indices_or_sections):
return split(ary, indices_or_sections, 0)


# pylint: disable=redefined-outer-name
@set_module('mxnet.symbol.numpy')
def dsplit(ary, indices_or_sections):
"""
Split array into multiple sub-arrays along the 3rd axis (depth).
Please refer to the `split` documentation. `dsplit` is equivalent
to `split` with ``axis=2``, the array is always split along the third
axis provided the array dimension is greater than or equal to 3.
Parameters
----------
ary : _Symbol
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 2. If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 2 the array is split. For example, ``[2, 3]`` would result in
- ary[:, :, :2]
- ary[:, :, 2:3]
- ary[:, :, 3:]
If an index exceeds the dimension of the array along axis 2, an error will be thrown.
"""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.dsplit(ary, indices, 2, False, sections)
return ret
# pylint: enable=redefined-outer-name


@set_module('mxnet.symbol.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
32 changes: 32 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,8 @@ inline bool HSplitOpShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_attrs->size(), 1U);
mxnet::TShape dshape = in_attrs->at(split_enum::kData);
CHECK_GE(dshape.ndim(), 1U)
<< "ValueError: hsplit only works on arrays of 1 or more dimensions";
if (!mxnet::ndim_is_known(dshape)) return false;
int real_axis;
if (dshape.ndim() > 1) {
Expand Down Expand Up @@ -1403,6 +1405,36 @@ NNVM_REGISTER_OP(_npi_hsplit_backward)
})
.set_attr<FCompute>("FCompute<cpu>", HSplitOpBackward<cpu>);

inline bool DSplitOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
using namespace mshadow;
CHECK_EQ(in_attrs->size(), 1U);
mxnet::TShape dshape = in_attrs->at(split_enum::kData);
if (!mxnet::ndim_is_known(dshape)) return false;
CHECK(dshape.ndim() >= 3) << "ValueError: dsplit only works on arrays of 3 or more dimensions";
return SplitOpShapeImpl(attrs, in_attrs, out_attrs, 2);
}

NNVM_REGISTER_OP(_npi_dsplit)
.set_attr_parser(ParamParser<SplitParam>)
.set_num_inputs(1)
.set_num_outputs(SplitNumOutputs)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", DSplitOpShape)
.set_attr<nnvm::FInferType>("FInferType", SplitOpType)
.set_attr<FCompute>("FCompute<cpu>", SplitOpForward<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_split_v2_backward"})
.add_argument("data", "NDArray-or-Symbol", "The input")
.add_arguments(SplitParam::__FIELDS__());

NNVM_REGISTER_OP(_np_diag)
.set_attr_parser(ParamParser<NumpyDiagParam>)
.set_num_inputs(1)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ NNVM_REGISTER_OP(_npi_hsplit)
NNVM_REGISTER_OP(_npi_hsplit_backward)
.set_attr<FCompute>("FCompute<gpu>", HSplitOpBackward<gpu>);

NNVM_REGISTER_OP(_npi_dsplit)
.set_attr<FCompute>("FCompute<gpu>", SplitOpForward<gpu>);

NNVM_REGISTER_OP(_npx_reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

Expand Down
Loading

0 comments on commit 96b9439

Please sign in to comment.