Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add ffi wrapper for np.average
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Mar 27, 2020
1 parent e5440e7 commit 2607ace
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
1 change: 1 addition & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def prepare_workloads():
OpArgMngr.add_workload("sum", pool['2x2'], axis=0, keepdims=True, out=pool['1x2'])
OpArgMngr.add_workload("std", pool['2x2'], axis=0, ddof=0, keepdims=True, out=pool['1x2'])
OpArgMngr.add_workload("var", pool['2x2'], axis=0, ddof=1, keepdims=True, out=pool['1x2'])
OpArgMngr.add_workload("average", pool['2x2'], weights=pool['2'], axis=1, returned=True)
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
Expand Down
5 changes: 1 addition & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4754,10 +4754,7 @@ def average(a, axis=None, weights=None, returned=False, out=None):
>>> np.average(data, axis=1, weights=weights)
array([0.75, 2.75, 4.75])
"""
if weights is None:
return _npi.average(a, axis=axis, weights=None, returned=returned, weighted=False, out=out)
else:
return _npi.average(a, axis=axis, weights=weights, returned=returned, out=out)
return _api_internal.average(a, weights, axis, returned, False if weights is None else True, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
67 changes: 67 additions & 0 deletions src/api/operator/numpy/np_moments_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,71 @@ MXNET_REGISTER_API("_npi.var")
}
});

MXNET_REGISTER_API("_npi.average")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_average");
op::NumpyWeightedAverageParam param;
nnvm::NodeAttrs attrs;
attrs.op = op;

// parse axis
if (args[2].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
if (args[2].type_code() == kDLInt) {
param.axis = Tuple<int>(1, args[2].operator int64_t());
} else {
param.axis = Tuple<int>(args[2].operator ObjectRef());
}
}

// parse returned
CHECK_NE(args[3].type_code(), kNull)
<< "returned cannot be None";
param.returned = args[3].operator bool();

// parse weighted
CHECK_NE(args[4].type_code(), kNull)
<< "weighted cannot be None";
param.weighted = args[4].operator bool();

attrs.parsed = std::move(param);

SetAttrDict<op::NumpyWeightedAverageParam>(&attrs);

int num_inputs = param.weighted ? 2 : 1;
NDArray* outputs[] = {args[5].operator NDArray*()};
NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs;
int num_outputs = (outputs[0] != nullptr);

if (param.weighted) {
NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out);
if (out) {
*ret = PythonArg(5);
} else {
if (param.returned) {
*ret = ADT(0, {NDArrayHandle(ndoutputs[0]),
NDArrayHandle(ndoutputs[1])});
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
}
} else {
NDArray* inputs[] = {args[0].operator NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out);
if (out) {
*ret = PythonArg(5);
} else {
if (param.returned) {
*ret = ADT(0, {NDArrayHandle(ndoutputs[0]),
NDArrayHandle(ndoutputs[1])});
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
}
}
});

}; // namespace mxnet
10 changes: 10 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,16 @@ struct NumpyWeightedAverageParam : public dmlc::Parameter<NumpyWeightedAveragePa
.set_default(true)
.describe("Auxiliary flag to deal with none weights.");
}

void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, returned_s, weighted_s;
axis_s << axis;
returned_s << returned;
weighted_s << weighted;
(*dict)["axis"] = axis_s.str();
(*dict)["returned"] = returned_s.str();
(*dict)["weighted"] = weighted_s.str();
}
};

inline bool NumpyWeightedAverageShape(const nnvm::NodeAttrs& attrs,
Expand Down

0 comments on commit 2607ace

Please sign in to comment.