Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add ffi wrapper for np.histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Mar 27, 2020
1 parent 809b504 commit f687d2d
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 22 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/_ffi/node_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def convert_to_node(value):
"""
if isinstance(value, Integral):
return _api_internal._Integer(value)
elif isinstance(value, float):
return _api_internal._Float(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._ADT(*value)
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,13 +1734,13 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
if isinstance(bins, numeric_types):
if range is None:
raise NotImplementedError("automatic range is not supported yet...")
return _npi.histogram(a, bin_cnt=bins, range=range)
return _api_internal.histogram(a, None, bins, range)
if isinstance(bins, (list, tuple)):
raise NotImplementedError("array_like bins is not supported yet...")
if isinstance(bins, str):
raise NotImplementedError("string bins is not supported yet...")
if isinstance(bins, NDArray):
return _npi.histogram(a, bins=bins)
return _npi.histogram(a, bins, None, None)
raise ValueError("np.histogram fails with", locals())


Expand Down
10 changes: 10 additions & 0 deletions src/api/_api_internal/_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ MXNET_REGISTER_GLOBAL("_Integer")
}
});

MXNET_REGISTER_GLOBAL("_Float")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
if (args[0].type_code() == kDLFloat) {
*ret = Integer(args[0].operator double());
} else {
LOG(FATAL) << "only accept float";
}
});

MXNET_REGISTER_GLOBAL("_ADT")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
Expand Down
4 changes: 2 additions & 2 deletions src/api/operator/numpy/np_bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
4 changes: 2 additions & 2 deletions src/api/operator/numpy/np_cumsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
81 changes: 81 additions & 0 deletions src/api/operator/numpy/np_histogram_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_histogram_op.cc
* \brief Implementation of the API of functions in src/operator/tensor/histogram.cc
*/

#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/histogram-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.histogram")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
nnvm::NodeAttrs attrs;
const nnvm::Op* op = Op::Get("_npi_histogram");
op::HistogramParam param;
// parse bin_cnt
if (args[2].type_code() == kNull) {
param.bin_cnt = dmlc::nullopt;
} else {
param.bin_cnt = args[2].operator int();
}

// parse range
if (args[3].type_code() == kNull) {
param.range = dmlc::nullopt;
} else {
param.range = Tuple<double>(args[3].operator ObjectRef());
}

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::HistogramParam>(&attrs);

std::vector<NDArray*> inputs_vec;
int num_inputs = 0;

if (param.bin_cnt.has_value()) {
CHECK_EQ(args[1].type_code(), kNull)
<< "bins should be None when bin_cnt is provided";
inputs_vec.push_back((args[0].operator NDArray*()));
num_inputs = 1;
} else {
CHECK_NE(args[1].type_code(), kNull)
<< "bins should not be None when bin_cnt is not provided";
// inputs
inputs_vec.push_back((args[0].operator NDArray*()));
inputs_vec.push_back((args[1].operator NDArray*()));
num_inputs = 2;
}

// outputs
NDArray** out = nullptr;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs_vec.data(), &num_outputs, out);
*ret = ADT(0, {NDArrayHandle(ndoutputs[0]),
NDArrayHandle(ndoutputs[1])});
});

} // namespace mxnet
40 changes: 24 additions & 16 deletions src/operator/tensor/histogram-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,30 @@ namespace mxnet {
namespace op {

struct HistogramParam : public dmlc::Parameter<HistogramParam> {
dmlc::optional<int> bin_cnt;
dmlc::optional<mxnet::Tuple<double>> range;
DMLC_DECLARE_PARAMETER(HistogramParam) {
DMLC_DECLARE_FIELD(bin_cnt)
.set_default(dmlc::optional<int>())
.describe("Number of bins for uniform case");
DMLC_DECLARE_FIELD(range)
.set_default(dmlc::optional<mxnet::Tuple<double>>())
.describe("The lower and upper range of the bins. if not provided, "
"range is simply (a.min(), a.max()). values outside the "
"range are ignored. the first element of the range must be "
"less than or equal to the second. range affects the automatic "
"bin computation as well. while bin width is computed to be "
"optimal based on the actual data within range, the bin count "
"will fill the entire range including portions containing no data.");
}
dmlc::optional<int> bin_cnt;
dmlc::optional<mxnet::Tuple<double>> range;
DMLC_DECLARE_PARAMETER(HistogramParam) {
DMLC_DECLARE_FIELD(bin_cnt)
.set_default(dmlc::optional<int>())
.describe("Number of bins for uniform case");
DMLC_DECLARE_FIELD(range)
.set_default(dmlc::optional<mxnet::Tuple<double>>())
.describe("The lower and upper range of the bins. if not provided, "
"range is simply (a.min(), a.max()). values outside the "
"range are ignored. the first element of the range must be "
"less than or equal to the second. range affects the automatic "
"bin computation as well. while bin width is computed to be "
"optimal based on the actual data within range, the bin count "
"will fill the entire range including portions containing no data.");
}

void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream bin_cnt_s, range_s;
bin_cnt_s << bin_cnt;
range_s << range;
(*dict)["bin_cnt"] = bin_cnt_s.str();
(*dict)["range"] = range_s.str();
}
};

struct FillBinBoundsKernel {
Expand Down

0 comments on commit f687d2d

Please sign in to comment.