Skip to content

Commit

Permalink
edge_id op csr forward on CPU (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperZealot authored and zheng-da committed Oct 26, 2018
1 parent ffeaf31 commit 5dbf068
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
cond
index_copy
getnnz
edge_id
```

## API Reference
Expand Down
163 changes: 163 additions & 0 deletions src/operator/contrib/edge_id-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file edge_id-inl.h
* \brief Operator implementing edge_id function.
*/
#ifndef MXNET_OPERATOR_CONTRIB_EDGE_ID_INL_H_
#define MXNET_OPERATOR_CONTRIB_EDGE_ID_INL_H_

#include <mxnet/operator_util.h>
#include <vector>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../tensor/init_op.h"

namespace mxnet {
namespace op {

inline bool EdgeIDShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ(in_attrs->at(1).ndim(), 1U);
CHECK_EQ(in_attrs->at(2).ndim(), 1U);
CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]);

SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
SHAPE_ASSIGN_CHECK(*in_attrs, 2, out_attrs->at(0));
return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U;
}

inline bool EdgeIDType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return out_attrs->at(0) != -1;
}

inline bool EdgeIDStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U) << "Only works for 2d arrays";
CHECK_EQ(out_attrs->size(), 1U);
int& in_stype = in_attrs->at(0);
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && in_stype == kCSRStorage) {
// csr -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
LOG(ERROR) << "Cannot dispatch edge_id storage type, only works for csr matrices";
}
return dispatched;
}

struct edge_id_csr_forward {
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
const IType* in_indices, const IType* in_indptr,
const CType* u, const CType* v) {
const int64_t target_row_id = static_cast<int64_t>(u[i]);
const IType target_col_id = static_cast<IType>(v[i]);
auto ptr = std::find(in_indices + in_indptr[target_row_id], in_indices + in_indptr[target_row_id + 1], target_col_id);
if (ptr == in_indices + in_indptr[target_row_id + 1]) {
// does not exist in the range
out_data[i] = DType(-1);
} else {
out_data[i] = *(in_data + (ptr - in_indices));
}
}
};

template<typename xpu>
void EdgeIDForwardCsrImpl(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const OpReqType req,
const NDArray& output) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
if (req == kNullOp) return;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(req, kWriteTo) << "EdgeID with CSR only supports kWriteTo";
Stream<xpu> *s = ctx.get_stream<xpu>();
const NDArray& u = inputs[1];
const nnvm::dim_t out_elems = u.shape().Size();
if (!inputs[0].storage_initialized()) {
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, xpu>::Launch(
s, out_elems, output.data().dptr<DType>(), DType(-1));
});
return;
}
const NDArray& data = inputs[0];
const TBlob& in_data = data.data();
const TBlob& in_indices = data.aux_data(kIdx);
const TBlob& in_indptr = data.aux_data(kIndPtr);
const NDArray& v = inputs[2];

CHECK_EQ(data.aux_type(kIdx), data.aux_type(kIndPtr))
<< "The dtypes of indices and indptr don't match";
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(data.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(u.dtype(), CType, {
Kernel<edge_id_csr_forward, xpu>::Launch(
s, out_elems, output.data().dptr<DType>(), in_data.dptr<DType>(),
in_indices.dptr<IType>(), in_indptr.dptr<IType>(),
u.data().dptr<CType>(), v.data().dptr<CType>());
});
});
});
}

template<typename xpu>
void EdgeIDForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const auto in_stype = inputs[0].storage_type();
const auto out_stype = outputs[0].storage_type();
if (in_stype == kCSRStorage && out_stype == kDefaultStorage) {
EdgeIDForwardCsrImpl<xpu>(ctx, inputs, req[0], outputs[0]);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_EDGE_ID_INL_H_
65 changes: 65 additions & 0 deletions src/operator/contrib/edge_id.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file edge_id.cc
* \brief CPU Implementation of edge_id op.
*/
#include "./edge_id-inl.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_contrib_edge_id)
.describe(R"code(This operator implements the edge_id function for csr arrays,
where output[i] = input[u[i], v[i]] if input[u[i], v[i]] is a non-zero element of input,
otherwise output[i] will be -1. Both u and v should be 1D vectors.
Example::
x = [[ 1, 0, 0 ],
[ 0, 2, 0 ],
[ 0, 0, 3 ]]
u = [ 0, 0, 1, 1, 2, 2 ]
v = [ 0, 1, 1, 2, 0, 2 ]
edge_id(x, u, v) = [ 1, -1, 2, -1, -1, 3 ]
The storage type of ``edge_id`` output depends on storage types of inputs
- quadratic(csr, default, default) = default
- default and rsp inputs are not supported
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "u", "v"};
})
.set_attr<nnvm::FInferShape>("FInferShape", EdgeIDShape)
.set_attr<nnvm::FInferType>("FInferType", EdgeIDType)
.set_attr<FInferStorageType>("FInferStorageType", EdgeIDStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", EdgeIDForwardEx<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_argument("u", "NDArray-or-Symbol", "u ndarray")
.add_argument("v", "NDArray-or-Symbol", "v ndarray");

} // namespace op
} // namespace mxnet
21 changes: 21 additions & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,27 @@ def test_multibox_target_op():
assert_array_equal(loc_mask.asnumpy(), expected_loc_mask)
assert_array_equal(cls_target.asnumpy(), expected_cls_target)

def test_edge_id():
shape = rand_shape_2d()
data = rand_ndarray(shape, stype='csr', density=0.4)
ground_truth = np.zeros(shape, dtype=np.float32)
ground_truth -= 1.0
indptr_np = data.indptr.asnumpy()
data_np = data.data.asnumpy()
indices_np = data.indices.asnumpy()
for i in range(shape[0]):
for j in range(indptr_np[i], indptr_np[i+1]):
idx = indices_np[j]
ground_truth[i, idx] = data_np[j]

np_u = np.random.randint(0, shape[0], size=(5, ))
np_v = np.random.randint(0, shape[1], size=(5, ))
mx_u = mx.nd.array(np_u)
mx_v = mx.nd.array(np_v)
assert_almost_equal(mx.nd.contrib.edge_id(data, mx_u, mx_v).asnumpy(),
ground_truth[np_u, np_v], rtol=1e-5, atol=1e-6)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 5dbf068

Please sign in to comment.