Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] OP_interp #17793

Merged
merged 4 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']

Expand Down Expand Up @@ -7081,6 +7081,86 @@ def may_share_memory(a, b, max_work=None):
return _api_internal.share_memory(a, b).item()


@set_module('mxnet.ndarray.numpy')
def interp(x, xp, fp, left=None, right=None, period=None): # pylint: disable=too-many-arguments
"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function
with given values at discrete data-points.

Parameters
----------
x : ndarray
The x-coordinates of the interpolated values.
xp : 1-D array of floats
The x-coordinates of the data points, must be increasing if argument
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
fp : 1-D array of floats
The y-coordinates of the data points, same length as `xp`.
left : optional float corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
right : optional float corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters `left` and `right`
are ignored if `period` is specified.
.. versionadded:: 1.10.0

Returns
-------
y : float (corresponding to fp) or ndarray
The interpolated values, same shape as `x`.
Raises
------
ValueError
If `xp` and `fp` have different length
If `xp` or `fp` are not 1-D sequences
If `period == 0`

Notes
-----
Does not check that the x-coordinate sequence `xp` is increasing.
If `xp` is not increasing, the results are nonsense.
A simple check for increasing is::
np.all(np.diff(xp) > 0)

Examples
--------
>>> xp = [1, 2, 3]
>>> fp = [3, 2, 0]
>>> np.interp(2.5, xp, fp)
1.0
>>> np.interp([0, 1, 1.5, 2.72, 3.14], xp, fp)
array([ 3. , 3. , 2.5 , 0.56, 0. ])
>>> UNDEF = -99.0
>>> np.interp(3.14, xp, fp, right=UNDEF)
-99.0
Plot an interpolant to the sine function:
>>> x = np.linspace(0, 2*np.pi, 10)
>>> y = np.sin(x)
>>> xvals = np.linspace(0, 2*np.pi, 50)
>>> yinterp = np.interp(xvals, x, y)
>>> import matplotlib.pyplot as plt
>>> plt.plot(x, y, 'o')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.plot(xvals, yinterp, '-x')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.show()
Interpolation with periodic x-coordinates:
>>> x = [-180, -170, -185, 185, -10, -5, 0, 365]
>>> xp = [190, -190, 350, -350]
>>> fp = [5, 10, 3, 4]
>>> np.interp(x, xp, fp, period=360)
array([7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75])
"""
if not isinstance(x, numeric_types):
x = x.astype(float)
return _api_internal.interp(xp.astype(float), fp.astype(float), x, left,
right, period)


@set_module('mxnet.ndarray.numpy')
def diff(a, n=1, axis=-1, prepend=None, append=None): # pylint: disable=redefined-outer-name
r"""
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/numpy/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
'histogramdd',
'i0',
'in1d',
'interp',
'intersect1d',
'isclose',
'isin',
Expand Down Expand Up @@ -138,7 +137,6 @@
histogramdd = onp.histogramdd
i0 = onp.i0
in1d = onp.in1d
interp = onp.interp
intersect1d = onp.intersect1d
isclose = onp.isclose
isin = onp.isin
Expand Down
79 changes: 78 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron',
'equal', 'not_equal',
'equal', 'not_equal', 'interp',
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
Expand Down Expand Up @@ -9275,6 +9275,83 @@ def resize(a, new_shape):
return _mx_nd_np.resize(a, new_shape)


@set_module('mxnet.numpy')
def interp(x, xp, fp, left=None, right=None, period=None): # pylint: disable=too-many-arguments
"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function
with given values at discrete data-points.

Parameters
----------
x : ndarray
The x-coordinates of the interpolated values.
xp : 1-D array of floats
The x-coordinates of the data points, must be increasing if argument
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
fp : 1-D array of floats
The y-coordinates of the data points, same length as `xp`.
left : optional float corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
right : optional float corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters `left` and `right`
are ignored if `period` is specified.
.. versionadded:: 1.10.0

Returns
-------
y : float (corresponding to fp) or ndarray
The interpolated values, same shape as `x`.
Raises
------
ValueError
If `xp` and `fp` have different length
If `xp` or `fp` are not 1-D sequences
If `period == 0`

Notes
-----
Does not check that the x-coordinate sequence `xp` is increasing.
If `xp` is not increasing, the results are nonsense.
A simple check for increasing is::
np.all(np.diff(xp) > 0)

Examples
--------
>>> xp = [1, 2, 3]
>>> fp = [3, 2, 0]
>>> np.interp(2.5, xp, fp)
1.0
>>> np.interp([0, 1, 1.5, 2.72, 3.14], xp, fp)
array([ 3. , 3. , 2.5 , 0.56, 0. ])
>>> UNDEF = -99.0
>>> np.interp(3.14, xp, fp, right=UNDEF)
-99.0
Plot an interpolant to the sine function:
>>> x = np.linspace(0, 2*np.pi, 10)
>>> y = np.sin(x)
>>> xvals = np.linspace(0, 2*np.pi, 50)
>>> yinterp = np.interp(xvals, x, y)
>>> import matplotlib.pyplot as plt
>>> plt.plot(x, y, 'o')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.plot(xvals, yinterp, '-x')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.show()
Interpolation with periodic x-coordinates:
>>> x = [-180, -170, -185, 185, -10, -5, 0, 365]
>>> xp = [190, -190, 350, -350]
>>> fp = [5, 10, 3, 4]
>>> np.interp(x, xp, fp, period=360)
array([7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75])
"""
return _mx_nd_np.interp(x, xp, fp, left=left, right=right, period=period)


# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'fliplr',
'inner',
'insert',
'interp',
'max',
'amax',
'mean',
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'flatnonzero',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'interp',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
Expand Down Expand Up @@ -6349,6 +6349,59 @@ def ediff1d(ary, to_end=None, to_begin=None):
to_begin_scalar=to_begin, to_end_scalar=to_end)


@set_module('mxnet.symbol.numpy')
def interp(x, xp, fp, left=None, right=None, period=None): # pylint: disable=too-many-arguments
"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function
with given values at discrete data-points.

Parameters
----------
x : _Symbol
The x-coordinates of the interpolated values.
xp : _Symbol
The x-coordinates of the data points, must be increasing if argument
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
fp : _Symbol
The y-coordinates of the data points, same length as `xp`.
left : optional float corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
right : optional float corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters `left` and `right`
are ignored if `period` is specified.
.. versionadded:: 1.10.0

Returns
-------
y : _Symbol
The interpolated values, same shape as `x`.

Raises
------
ValueError
If `xp` and `fp` have different length
If `xp` or `fp` are not 1-D sequences
If `period == 0`

Notes
-----
Does not check that the x-coordinate sequence `xp` is increasing.
If `xp` is not increasing, the results are nonsense.
A simple check for increasing is::
np.all(np.diff(xp) > 0)
"""
if isinstance(x, numeric_types):
return _npi.interp(xp.astype(float), fp.astype(float), left=left,
right=right, period=period, x_scalar=x, x_is_scalar=True)
return _npi.interp(xp.astype(float), fp.astype(float), x.astype(float), left=left,
right=right, period=period, x_scalar=0.0, x_is_scalar=False)


@set_module('mxnet.symbol.numpy')
def resize(a, new_shape):
"""
Expand Down
78 changes: 78 additions & 0 deletions src/api/operator/numpy/np_interp_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_interp_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_interp_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/numpy/np_interp_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.interp")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_interp");
nnvm::NodeAttrs attrs;
op::NumpyInterpParam param;
if (args[3].type_code() == kNull) {
param.left = dmlc::nullopt;
} else {
param.left = args[3].operator double();
}
if (args[4].type_code() == kNull) {
param.right = dmlc::nullopt;
} else {
param.right = args[4].operator double();
}
if (args[5].type_code() == kNull) {
param.period = dmlc::nullopt;
} else {
param.period = args[5].operator double();
}
if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) {
param.x_scalar = args[2].operator double();
param.x_is_scalar = true;
attrs.op = op;
attrs.parsed = std::move(param);
SetAttrDict<op::NumpyInterpParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
int num_inputs = 2;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
} else {
param.x_scalar = 0.0;
param.x_is_scalar = false;
attrs.op = op;
attrs.parsed = std::move(param);
SetAttrDict<op::NumpyInterpParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*(),
args[2].operator mxnet::NDArray*()};
int num_inputs = 3;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
Loading