Skip to content

Commit

Permalink
Replace std::lower_bound with own impl for gpu use too
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 30, 2017
1 parent 528d1eb commit 3d9e201
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
21 changes: 19 additions & 2 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,25 @@ struct DotCsrRspDnsByRowBlocks {
for (size_t j = seg_start; j < seg_end; ++j) {
if (indptr_l[j] == indptr_l[j+1]) continue;
const size_t offset_out = j * num_cols;
const RType* row_idx_ptr = std::lower_bound(row_idx_r, row_idx_r+nnr_r,
col_idx_l[indptr_l[j]]);
// Use binary search to find the lower_bound of val in row_idx array
const RType* first = row_idx_r;
const RType* last = row_idx_r + nnr_r;
const auto val = col_idx_l[indptr_l[j]];
const RType* it;
int count = last - first, step;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (*it < val) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
const RType* row_idx_ptr = first;
// end of binary search
if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue;
for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) {
if (col_idx_l[k] == *row_idx_ptr) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

lhs_shape = rand_shape_2d()
lhs_shape = rand_shape_2d(50, 200)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)
Expand Down

0 comments on commit 3d9e201

Please sign in to comment.