From ae09e7bac049cbd5fd12567020f38915dabaa10a Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Tue, 13 Jun 2017 16:10:19 +0000 Subject: [PATCH] csr slice bug fix --- python/mxnet/sparse_ndarray.py | 2 ++ tests/python/unittest/test_sparse_ndarray.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index da54c3584ac6..c3d4453c1d99 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -247,6 +247,8 @@ def _slice(self, start, stop): assert(stype == 'csr'), "_slice for " + str(stype) + " not implemented yet" warnings.warn('slicing SparseNDArray is not efficient', RuntimeWarning) shape = list(self.shape) + stop = shape[0] if stop is None else stop + start = 0 if start is None else start shape[0] = stop - start handle = _new_alloc_handle(self.storage_type, tuple(shape), self.context, True, self.dtype, self.aux_types) diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 25ba83463f80..5048f3d15962 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -157,6 +157,8 @@ def check_sparse_nd_csr_slice(shape): start = rnd.randint(0, shape[0] - 1) end = rnd.randint(start + 1, shape[0]) assert same(A[start:end].asnumpy(), A2[start:end]) + assert same(A[start:].asnumpy(), A2[start:]) + assert same(A[:end].asnumpy(), A2[:end]) shape = (rnd.randint(2, 10), rnd.randint(1, 10)) check_sparse_nd_csr_slice(shape)