Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

getnnz operator for csr matrix #12898

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
foreach
while_loop
cond
index_copy
getnnz
```

## API Reference
Expand Down
188 changes: 188 additions & 0 deletions src/operator/contrib/nnz.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file nnz.cc
* \brief CPU Implementation of nnz operator
*/
#include <mxnet/operator_util.h>
#include <vector>
#include <limits>
#include <algorithm>
#include "../elemwise_op_common.h"
#include "../tensor/init_op.h"
#include "../mshadow_op.h"
#include "../mxnet_op.h"

namespace mxnet {
namespace op {

struct NNZParam : public dmlc::Parameter<NNZParam> {
dmlc::optional<int> axis;
DMLC_DECLARE_PARAMETER(NNZParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<int>())
.describe("Select between the number of values across the whole matrix, "
"in each column, or in each row.");
}
};

static bool NNZType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
// infer int64 for count
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return true;
}

inline bool NNZShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
// csr_matrix is 2-D
CHECK_EQ(in_attrs->at(0).ndim(), 2);
const NNZParam& param = nnvm::get<NNZParam>(attrs.parsed);
// whole matrix
if (!param.axis.has_value()) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1));
} else if (param.axis.value() == 0) {
// columns
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(in_attrs->at(0)[1]));
} else if (param.axis.value() == 1) {
// rows
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(in_attrs->at(0)[0]));
} else {
LOG(FATAL) << "Unexpected value for axis(" << param.axis.value()
<< "). Candidates are None, 0, and 1";
}
return true;
}

template<typename xpu>
void NNZComputeCsrImpl(const NNZParam& param,
const OpContext& ctx,
const NDArray& input,
const OpReqType req,
const TBlob& output);

struct CsrNNZRowKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out ptr to output
* \param indptr ptr to source csr indptr
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int tid, DType* out, const IType* indptr) {
out[tid] = static_cast<DType>(indptr[tid + 1] - indptr[tid]);
}
};

template<>
void NNZComputeCsrImpl<cpu>(const NNZParam& param,
const OpContext& ctx,
const NDArray& input,
const OpReqType req,
const TBlob& output) {
using namespace csr;
CHECK_EQ(req, kWriteTo);
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
if (!input.storage_initialized()) {
Fill<false>(s, output, kWriteTo, 0);
return;
}
MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIndPtr), IType, {
DType* out_ptr = output.dptr<DType>();
const IType* indptr = input.aux_data(kIndPtr).dptr<IType>();
const nnvm::dim_t num_rows = input.shape()[0];
if (!param.axis.has_value()) {
// whole matrix
out_ptr[0] = indptr[num_rows];
} else if (param.axis.value() == 0) {
// column
LOG(FATAL) << "getnnz with axis = 1 is not supported yet";
} else if (param.axis.value() == 1) {
// row
mxnet_op::Kernel<CsrNNZRowKernel, cpu>::Launch(s, num_rows, out_ptr, indptr);
}
});
});
}

template<typename xpu>
void NNZComputeEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const auto in_stype = inputs[0].storage_type();
const auto out_stype = outputs[0].storage_type();
const NNZParam& param = nnvm::get<NNZParam>(attrs.parsed);
if (in_stype == kCSRStorage && out_stype == kDefaultStorage) {
NNZComputeCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0].data());
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

bool NNZStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
bool dispatched = false;
const auto in_stype = in_attrs->at(0);
auto& out_stype = out_attrs->at(0);
// only support csr for now
if (!dispatched && in_stype == kCSRStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
return dispatched;
}

DMLC_REGISTER_PARAMETER(NNZParam);

NNVM_REGISTER_OP(_contrib_getnnz)
.describe(R"code(Number of stored values for a sparse tensor, including explicit zeros.

This operator only supports CSR matrix on CPU.

)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NNZParam>)
.set_attr<nnvm::FInferShape>("FInferShape", NNZShape)
.set_attr<nnvm::FInferType>("FInferType", NNZType)
.set_attr<FInferStorageType>("FInferStorageType", NNZStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", NNZComputeEx<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Input");

} // namespace op
} // namespace mxnet
16 changes: 16 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,22 @@ def check_sparse_take(density, mode):
for m in modes:
check_sparse_take(d, m)

@with_seed()
def test_sparse_getnnz():
def check_sparse_getnnz(density, axis):
shape = rand_shape_2d()
data = rand_ndarray(shape, 'csr', density=density)
data_sp = data.asscipy()
result = mx.nd.contrib.getnnz(data, axis=axis)
expected_result = data_sp.getnnz(axis=axis)
assert_almost_equal(result.asnumpy(), expected_result)

densities = [0, 0.5, 1]
axis = [1, None]
for d in densities:
for a in axis:
check_sparse_getnnz(d, a)

if __name__ == '__main__':
import nose
nose.runmodule()