Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Add op fmax, fmin, fmod (#17567)
Browse files Browse the repository at this point in the history
* [Numpy] Add op fmax, fmin

* Fix sanity

* Fix bug of gpu part, add scalar compute

* Finish cpu,gpu test of fmax, fmin

* * Fix 3rd Party

* * Prune redundent alias

* * Add op fmod (rfmod still need check)

* * Fix rfmod function

* * Impl FFI

* * Fix windows oversize by adding files

Co-authored-by: Han <[email protected]>
  • Loading branch information
hanke580 and Han authored Mar 23, 2020
1 parent f01dc80 commit 2f358fd
Show file tree
Hide file tree
Showing 12 changed files with 559 additions and 6 deletions.
3 changes: 3 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def prepare_workloads():
OpArgMngr.add_workload("ones_like", pool['2x2'])
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("fmax", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("fmin", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("fmod", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0)
OpArgMngr.add_workload("rot90", pool["2x2"], 2)
Expand Down
76 changes: 74 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@


__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'invert', 'delete',
'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod',
'power', 'bitwise_not',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'insert', 'fabs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'matmul',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'all', 'any', 'sort',
'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_', 'flatnonzero',
'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', 'around', 'round', 'round_', 'flatnonzero',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -1169,6 +1170,35 @@ def mod(x1, x2, out=None, **kwargs):
return _api_internal.mod(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def fmod(x1, x2, out=None, **kwargs):
"""
Return element-wise remainder of division.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.fmod(x1, x2, out=out)
return _api_internal.fmod(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
def delete(arr, obj, axis=None):
"""
Expand Down Expand Up @@ -4366,6 +4396,27 @@ def maximum(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def fmax(x1, x2, out=None, **kwargs):
"""
Returns element-wise maximum of the input arrays with broadcasting. (Ignores NaNs)
Parameters
----------
x1, x2 : scalar or mxnet.numpy.ndarray
The arrays holding the elements to be compared. They must have the same shape,
or shapes that can be broadcast to a single shape.
Returns
-------
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.fmax(x1, x2, out=out)
return _api_internal.fmax(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def minimum(x1, x2, out=None, **kwargs):
Expand All @@ -4385,6 +4436,27 @@ def minimum(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def fmin(x1, x2, out=None, **kwargs):
"""
Returns element-wise minimum of the input arrays with broadcasting. (Ignores NaNs)
Parameters
----------
x1, x2 : scalar or mxnet.numpy.ndarray
The arrays holding the elements to be compared. They must have the same shape,
or shapes that can be broadcast to a single shape.
Returns
-------
out : mxnet.numpy.ndarray or scalar
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.fmin(x1, x2, out=out)
return _api_internal.fmin(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
def swapaxes(a, axis1, axis2):
"""Interchange two axes of an array.
Expand Down
96 changes: 94 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@

__all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape',
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'all', 'any', 'broadcast_to',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'power', 'bitwise_not',
'delete',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert',
'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'array_split', 'split', 'hsplit', 'vsplit', 'dsplit', 'flatnonzero',
'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert',
'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert',
'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman',
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -3147,6 +3149,38 @@ def mod(x1, x2, out=None, **kwargs):
return _mx_nd_np.mod(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def fmod(x1, x2, out=None, **kwargs):
"""
Return element-wise remainder of division.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.fmod(np.arange(7), 5)
array([0., 1., 2., 3., 4., 0., 1.])
"""
return _mx_nd_np.fmod(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def matmul(a, b, out=None, **kwargs):
Expand Down Expand Up @@ -6185,6 +6219,35 @@ def maximum(x1, x2, out=None, **kwargs):
return _mx_nd_np.maximum(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def fmax(x1, x2, out=None, **kwargs):
"""
Returns element-wise maximum of the input arrays with broadcasting. (Ignores NaNs)
Parameters
----------
x1, x2 : scalar or mxnet.numpy.ndarray
The arrays holding the elements to be compared. They must have the same shape,
or shapes that can be broadcast to a single shape.
Returns
-------
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.fmax(np.array([2, 3, 4]), np.array([1, 5, 2]))
array([2., 5., 4.])
>>> np.fmax(np.eye(2), np.array([0.5, 2])) # broadcasting
array([[1. , 2. ],
[0.5, 2. ]])
"""
return _mx_nd_np.fmax(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def minimum(x1, x2, out=None, **kwargs):
Expand Down Expand Up @@ -6214,6 +6277,35 @@ def minimum(x1, x2, out=None, **kwargs):
return _mx_nd_np.minimum(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def fmin(x1, x2, out=None, **kwargs):
"""
Returns element-wise minimum of the input arrays with broadcasting. (Ignores NaNs)
Parameters
----------
x1, x2 : scalar or mxnet.numpy.ndarray
The arrays holding the elements to be compared. They must have the same shape,
or shapes that can be broadcast to a single shape.
Returns
-------
out : mxnet.numpy.ndarray or scalar
The fmin of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.fmin(np.array([2, 3, 4]), np.array([1, 5, 2]))
array([1., 3., 2.])
>>> np.fmin(np.eye(2), np.array([0.5, 2])) # broadcasting
array([[0.5, 0. ],
[0. , 1. ]])
"""
return _mx_nd_np.fmin(x1, x2, out=out)


@set_module('mxnet.numpy')
def swapaxes(a, axis1, axis2):
"""Interchange two axes of an array.
Expand Down
3 changes: 3 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _register_array_function():
'negative',
'power',
'mod',
'fmod',
'matmul',
'absolute',
'rint',
Expand Down Expand Up @@ -277,7 +278,9 @@ def _register_array_function():
'arccosh',
'arctanh',
'maximum',
'fmax',
'minimum',
'fmin',
'ceil',
'trunc',
'floor',
Expand Down
24 changes: 22 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@
from builtins import slice as py_slice

__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'bitwise_not', 'invert',
'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod',
'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'matmul',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', 'insert',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'any', 'all', 'around', 'round', 'round_', 'flatnonzero',
'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', 'any', 'all', 'around', 'round', 'round_',
'flatnonzero',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -1620,6 +1622,12 @@ def mod(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def fmod(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def remainder(x1, x2, out=None, **kwargs):
Expand Down Expand Up @@ -4127,12 +4135,24 @@ def maximum(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def fmax(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.fmax, _np.fmax, _npi.fmax_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def minimum(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def fmin(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.fmin, _np.fmin, _npi.fmin_scalar, None, out)


@set_module('mxnet.symbol.numpy')
def all(a, axis=None, out=None, keepdims=False):
"""
Expand Down
56 changes: 56 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_elemwise_broadcast_op_extended_sec.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../ufunc_helper.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.fmax")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_fmax");
const nnvm::Op* op_scalar = Op::Get("_npi_fmax_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.fmin")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_fmin");
const nnvm::Op* op_scalar = Op::Get("_npi_fmin_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.fmod")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_fmod");
const nnvm::Op* op_scalar = Op::Get("_npi_fmod_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rfmod_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet
Loading

0 comments on commit 2f358fd

Please sign in to comment.