Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ffi wrappers for polyval, ediff1d, nan_to_num (#17832)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 authored Mar 17, 2020
1 parent 796fa50 commit a7ecb35
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 23 deletions.
3 changes: 3 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def generate_workloads():
def prepare_workloads():
pool = generate_workloads()
OpArgMngr.add_workload("zeros", (2, 2))
OpArgMngr.add_workload("polyval", dnp.arange(10), pool['2x2'])
OpArgMngr.add_workload("ediff1d", pool['2x2'], pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("nan_to_num", pool['2x2'])
OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1)))
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
Expand Down
29 changes: 6 additions & 23 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6983,24 +6983,7 @@ def ediff1d(ary, to_end=None, to_begin=None):
>>> np.ediff1d(x, to_begin=y)
array([ 1., 2., 4., 1., 6., 24., 1., 2., 3., -7.])
"""
from ...numpy import ndarray as np_ndarray
input_type = (isinstance(to_begin, np_ndarray), isinstance(to_end, np_ndarray))
# case 1: when both `to_begin` and `to_end` are arrays
if input_type == (True, True):
return _npi.ediff1d(ary, to_begin, to_end, to_begin_arr_given=True, to_end_arr_given=True,
to_begin_scalar=None, to_end_scalar=None)
# case 2: only `to_end` is array but `to_begin` is scalar/None
elif input_type == (False, True):
return _npi.ediff1d(ary, to_end, to_begin_arr_given=False, to_end_arr_given=True,
to_begin_scalar=to_begin, to_end_scalar=None)
# case 3: only `to_begin` is array but `to_end` is scalar/None
elif input_type == (True, False):
return _npi.ediff1d(ary, to_begin, to_begin_arr_given=True, to_end_arr_given=False,
to_begin_scalar=None, to_end_scalar=to_end)
# case 4: both `to_begin` and `to_end` are scalar/None
else:
return _npi.ediff1d(ary, to_begin_arr_given=False, to_end_arr_given=False,
to_begin_scalar=to_begin, to_end_scalar=to_end)
return _api_internal.ediff1d(ary, to_end, to_begin)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -7148,8 +7131,8 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
if x.dtype in ['int8', 'uint8', 'int32', 'int64']:
return x
if not copy:
return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=x)
return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=None)
return _api_internal.nan_to_num(x, copy, nan, posinf, neginf, x)
return _api_internal.nan_to_num(x, copy, nan, posinf, neginf, None)
else:
raise TypeError('type {} not supported'.format(str(type(x))))

Expand Down Expand Up @@ -7529,10 +7512,10 @@ def polyval(p, x):
array([76., 49.])
"""
from ...numpy import ndarray
if isinstance(p, ndarray) and isinstance(x, ndarray):
return _npi.polyval(p, x)
elif not isinstance(p, ndarray) and not isinstance(x, ndarray):
if isinstance(p, numeric_types) and isinstance(x, numeric_types):
return _np.polyval(p, x)
elif isinstance(p, ndarray) and isinstance(x, ndarray):
return _api_internal.polyval(p, x)
else:
raise TypeError('type not supported')

Expand Down
75 changes: 75 additions & 0 deletions src/api/operator/numpy/np_ediff1d_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_ediff1d_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_ediff1d_op.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/numpy/np_ediff1d_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.ediff1d")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_ediff1d");
nnvm::NodeAttrs attrs;
op::EDiff1DParam param;
int num_inputs = 1;
NDArray* inputs[3];
inputs[0] = args[0].operator mxnet::NDArray*();
// the order of `to_end` and `to_begin` array in the backend is different from the front-end
if (args[2].type_code() == kDLFloat || args[2].type_code() == kDLInt) {
param.to_begin_scalar = args[2].operator double();
param.to_begin_arr_given = false;
} else if (args[2].type_code() == kNull) {
param.to_begin_scalar = dmlc::nullopt;
param.to_begin_arr_given = false;
} else {
param.to_begin_scalar = dmlc::nullopt;
param.to_begin_arr_given = true;
inputs[num_inputs] = args[2].operator mxnet::NDArray*();
num_inputs++;
}

if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) {
param.to_end_scalar = args[1].operator double();
param.to_end_arr_given = false;
} else if (args[1].type_code() == kNull) {
param.to_end_scalar = dmlc::nullopt;
param.to_end_arr_given = false;
} else {
param.to_end_scalar = dmlc::nullopt;
param.to_end_arr_given = true;
inputs[num_inputs] = args[1].operator mxnet::NDArray*();
num_inputs++;
}

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::EDiff1DParam>(&attrs);

int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
72 changes: 72 additions & 0 deletions src/api/operator/numpy/np_nan_to_num_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_nan_to_num_op.cc
* \brief Implementation of the API of nan_to_num function in
* src/operator/tensor/np_elemwise_unary_op_basic.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/tensor/elemwise_unary_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.nan_to_num")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_nan_to_num");
nnvm::NodeAttrs attrs;

op::NumpyNanToNumParam param;
int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};

param.copy = args[1].operator bool();
param.nan = args[2].operator double();

if (args[3].type_code() == kNull) {
param.posinf = dmlc::nullopt;
} else {
param.posinf = args[3].operator double();
}

if (args[4].type_code() == kNull) {
param.neginf = dmlc::nullopt;
} else {
param.neginf = args[4].operator double();
}

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyNanToNumParam>(&attrs);

NDArray* out = args[5].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
// set the number of outputs provided by the `out` arugment
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = ndoutputs[0];
}
});

} // namespace mxnet
44 changes: 44 additions & 0 deletions src/api/operator/numpy/np_polynomial_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_polynomial_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_polynomial_op.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/numpy/np_polynomial_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.polyval")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_polyval");
nnvm::NodeAttrs attrs;
attrs.op = op;

NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
int num_inputs = 2;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
13 changes: 13 additions & 0 deletions src/operator/numpy/np_ediff1d_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <mxnet/base.h>
#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
Expand All @@ -53,6 +54,18 @@ struct EDiff1DParam : public dmlc::Parameter<EDiff1DParam> {
.set_default(dmlc::optional<double>())
.describe("If the `to_end`is a scalar, the value of this parameter.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream to_end_arr_given_s, to_begin_arr_given_s,
to_end_scalar_s, to_begin_scalar_s;
to_end_arr_given_s << to_end_arr_given;
to_begin_arr_given_s << to_begin_arr_given;
to_end_scalar_s << to_end_scalar;
to_begin_scalar_s << to_begin_scalar;
(*dict)["to_end_arr_given"] = to_end_arr_given_s.str();
(*dict)["to_begin_arr_given"] = to_begin_arr_given_s.str();
(*dict)["to_end_scalar"] = to_end_scalar_s.str();
(*dict)["to_begin_scalar"] = to_begin_scalar_s.str();
}
};

template<typename DType>
Expand Down
12 changes: 12 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include <climits>
Expand Down Expand Up @@ -704,6 +705,17 @@ struct NumpyNanToNumParam : public dmlc::Parameter<NumpyNanToNumParam> {
"If no value is passed then negative infinity values"
"will be replaced with a very small (or negative) number.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream copy_s, nan_s, posinf_s, neginf_s;
copy_s << copy;
nan_s << nan;
posinf_s << posinf;
neginf_s << neginf;
(*dict)["copy"] = copy_s.str();
(*dict)["nan"] = nan_s.str();
(*dict)["posinf"] = posinf_s.str();
(*dict)["neginf"] = neginf_s.str();
}
};

template<int req>
Expand Down

0 comments on commit a7ecb35

Please sign in to comment.