Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
interp
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Mar 9, 2020
1 parent 752461d commit 992d0c0
Show file tree
Hide file tree
Showing 10 changed files with 797 additions and 3 deletions.
82 changes: 81 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']

Expand Down Expand Up @@ -6803,6 +6803,86 @@ def may_share_memory(a, b, max_work=None):
return _npi.share_memory(a, b).item()


@set_module('mxnet.ndarray.numpy')
def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function
with given values at discrete data-points.
Parameters
----------
x : ndarray
The x-coordinates of the interpolated values.
xp : 1-D array of floats
The x-coordinates of the data points, must be increasing if argument
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
fp : 1-D array of floats
The y-coordinates of the data points, same length as `xp`.
left : optional float corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
right : optional float corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters `left` and `right`
are ignored if `period` is specified.
.. versionadded:: 1.10.0
Returns
-------
y : float (corresponding to fp) or ndarray
The interpolated values, same shape as `x`.
Raises
------
ValueError
If `xp` and `fp` have different length
If `xp` or `fp` are not 1-D sequences
If `period == 0`
Notes
-----
Does not check that the x-coordinate sequence `xp` is increasing.
If `xp` is not increasing, the results are nonsense.
A simple check for increasing is::
np.all(np.diff(xp) > 0)
Examples
--------
>>> xp = [1, 2, 3]
>>> fp = [3, 2, 0]
>>> np.interp(2.5, xp, fp)
1.0
>>> np.interp([0, 1, 1.5, 2.72, 3.14], xp, fp)
array([ 3. , 3. , 2.5 , 0.56, 0. ])
>>> UNDEF = -99.0
>>> np.interp(3.14, xp, fp, right=UNDEF)
-99.0
Plot an interpolant to the sine function:
>>> x = np.linspace(0, 2*np.pi, 10)
>>> y = np.sin(x)
>>> xvals = np.linspace(0, 2*np.pi, 50)
>>> yinterp = np.interp(xvals, x, y)
>>> import matplotlib.pyplot as plt
>>> plt.plot(x, y, 'o')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.plot(xvals, yinterp, '-x')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.show()
Interpolation with periodic x-coordinates:
>>> x = [-180, -170, -185, 185, -10, -5, 0, 365]
>>> xp = [190, -190, 350, -350]
>>> fp = [5, 10, 3, 4]
>>> np.interp(x, xp, fp, period=360)
array([7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75])
"""
if not isinstance(x, numeric_types):
x = x.astype(float)
return _api_internal.interp(xp.astype(float), fp.astype(float), x, left,
right, period)


@set_module('mxnet.ndarray.numpy')
def diff(a, n=1, axis=-1, prepend=None, append=None): # pylint: disable=redefined-outer-name
r"""
Expand Down
78 changes: 77 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero', 'interp',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'pad', 'cumsum']
Expand Down Expand Up @@ -8953,6 +8953,82 @@ def resize(a, new_shape):
"""
return _mx_nd_np.resize(a, new_shape)

@set_module('mxnet.numpy')
def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function
with given values at discrete data-points.
Parameters
----------
x : ndarray
The x-coordinates of the interpolated values.
xp : 1-D array of floats
The x-coordinates of the data points, must be increasing if argument
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
fp : 1-D array of floats
The y-coordinates of the data points, same length as `xp`.
left : optional float corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
right : optional float corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters `left` and `right`
are ignored if `period` is specified.
.. versionadded:: 1.10.0
Returns
-------
y : float (corresponding to fp) or ndarray
The interpolated values, same shape as `x`.
Raises
------
ValueError
If `xp` and `fp` have different length
If `xp` or `fp` are not 1-D sequences
If `period == 0`
Notes
-----
Does not check that the x-coordinate sequence `xp` is increasing.
If `xp` is not increasing, the results are nonsense.
A simple check for increasing is::
np.all(np.diff(xp) > 0)
Examples
--------
>>> xp = [1, 2, 3]
>>> fp = [3, 2, 0]
>>> np.interp(2.5, xp, fp)
1.0
>>> np.interp([0, 1, 1.5, 2.72, 3.14], xp, fp)
array([ 3. , 3. , 2.5 , 0.56, 0. ])
>>> UNDEF = -99.0
>>> np.interp(3.14, xp, fp, right=UNDEF)
-99.0
Plot an interpolant to the sine function:
>>> x = np.linspace(0, 2*np.pi, 10)
>>> y = np.sin(x)
>>> xvals = np.linspace(0, 2*np.pi, 50)
>>> yinterp = np.interp(xvals, x, y)
>>> import matplotlib.pyplot as plt
>>> plt.plot(x, y, 'o')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.plot(xvals, yinterp, '-x')
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.show()
Interpolation with periodic x-coordinates:
>>> x = [-180, -170, -185, 185, -10, -5, 0, 365]
>>> xp = [190, -190, 350, -350]
>>> fp = [5, 10, 3, 4]
>>> np.interp(x, xp, fp, period=360)
array([7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75])
"""
return _mx_nd_np.interp(x, xp, fp, left=left, right=right, period=period)


# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'fliplr',
'inner',
'insert',
'interp',
'max',
'amax',
'mean',
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'interp',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
Expand Down Expand Up @@ -6151,6 +6151,59 @@ def ediff1d(ary, to_end=None, to_begin=None):
to_begin_scalar=to_begin, to_end_scalar=to_end)


@set_module('mxnet.symbol.numpy')
def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function
with given values at discrete data-points.
Parameters
----------
x : _Symbol
The x-coordinates of the interpolated values.
xp : _Symbol
The x-coordinates of the data points, must be increasing if argument
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
fp : _Symbol
The y-coordinates of the data points, same length as `xp`.
left : optional float corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
right : optional float corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters `left` and `right`
are ignored if `period` is specified.
.. versionadded:: 1.10.0
Returns
-------
y : _Symbol
The interpolated values, same shape as `x`.
Raises
------
ValueError
If `xp` and `fp` have different length
If `xp` or `fp` are not 1-D sequences
If `period == 0`
Notes
-----
Does not check that the x-coordinate sequence `xp` is increasing.
If `xp` is not increasing, the results are nonsense.
A simple check for increasing is::
np.all(np.diff(xp) > 0)
"""
if isinstance(x, numeric_types):
return _npi.interp(xp.astype(float), fp.astype(float), left=left,
right=right, period=period, x_scalar=x)
return _npi.interp(xp.astype(float), fp.astype(float),x.astype(float),
left=left, right=right, period=period, x_scalar=None)


@set_module('mxnet.symbol.numpy')
def resize(a, new_shape):
"""
Expand Down
74 changes: 74 additions & 0 deletions src/api/operator/numpy/np_interp_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_interp_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_interp_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/numpy/np_interp_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.interp")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_interp");
nnvm::NodeAttrs attrs;
op::NumpyInterpParam param;
if (args[3].type_code() == kNull) {
param.left = dmlc::nullopt;
} else {
param.left = args[3].operator double();
}
if (args[4].type_code() == kNull) {
param.right = dmlc::nullopt;
} else {
param.right = args[4].operator double();
}
if (args[5].type_code() == kNull) {
param.period = dmlc::nullopt;
} else {
param.period = args[5].operator double();
}
if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) {
param.x_scalar = args[2].operator double();
attrs.op = op;
attrs.parsed = std::move(param);
SetAttrDict<op::NumpyInterpParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
} else {
param.x_scalar = dmlc::nullopt;
attrs.op = op;
attrs.parsed = std::move(param);
SetAttrDict<op::NumpyInterpParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*(),
args[2].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 3, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
Loading

0 comments on commit 992d0c0

Please sign in to comment.