Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Numpy pick large tensor fix (#19025)
Browse files Browse the repository at this point in the history
* fix indexing

* tweak test

* add more check

* tweak test to set last element

Co-authored-by: Zhu <[email protected]>
  • Loading branch information
Zha0q1 and Zhu authored Sep 1, 2020
1 parent af467d2 commit 9268f89
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ template<int ndim, bool clip = true>
struct pick {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
const IType *idx, index_t M, int stride,
const IType *idx, index_t M, index_t stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace mxnet_op;
Expand All @@ -1652,7 +1652,7 @@ template<int ndim, bool clip = true>
struct pick_grad {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd,
const IType *idx, index_t M, int stride,
const IType *idx, index_t M, index_t stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace mxnet_op;
Expand Down Expand Up @@ -1717,7 +1717,7 @@ void PickOpForward(const nnvm::NodeAttrs& attrs,

const mxnet::TShape& ishape = inputs[0].shape_;
index_t axis = CheckAxis(param.axis.value(), ishape.ndim());
int leading = 1, trailing = 1, M = ishape[axis];
index_t leading = 1, trailing = 1, M = ishape[axis];
for (index_t i = 0; i < axis; ++i) leading *= ishape[i];
for (index_t i = axis+1; i < ishape.ndim(); ++i) trailing *= ishape[i];

Expand Down Expand Up @@ -1764,7 +1764,7 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,

const mxnet::TShape& ishape = outputs[0].shape_;
const index_t axis = CheckAxis(param.axis.value(), ishape.ndim());
int leading = 1, trailing = 1, M = ishape[axis];
index_t leading = 1, trailing = 1, M = ishape[axis];
for (index_t i = 0; i < axis; ++i) leading *= ishape[i];
for (index_t i = axis+1; i < ishape.ndim(); ++i) trailing *= ishape[i];

Expand Down
4 changes: 3 additions & 1 deletion tests/nightly/test_np_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,16 +709,18 @@ def test_one_hot():
assert A.grad[0] == 0

@use_np
@pytest.mark.skip(reason='backward value broken on large tensor')
def test_pick():
INT_OVERFLOW = 2**31
A = np.zeros((INT_OVERFLOW, 2))
B = np.zeros((INT_OVERFLOW))
A[-1, 0] = 3
A.attach_grad()
B.attach_grad()
with mx.autograd.record():
C = npx.pick(A, B)
assert C.shape == (INT_OVERFLOW, )
assert C[0] == 0
assert C[-1] == 3
C.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert B.grad.shape == (INT_OVERFLOW, )
Expand Down

0 comments on commit 9268f89

Please sign in to comment.