diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index ea9aa6a7cf65..c4e3dae35ab9 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1630,7 +1630,7 @@ template struct pick { template MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a, - const IType *idx, index_t M, int stride, + const IType *idx, index_t M, index_t stride, mshadow::Shape bshape, mshadow::Shape sshape) { using namespace mxnet_op; @@ -1652,7 +1652,7 @@ template struct pick_grad { template MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd, - const IType *idx, index_t M, int stride, + const IType *idx, index_t M, index_t stride, mshadow::Shape bshape, mshadow::Shape sshape) { using namespace mxnet_op; @@ -1717,7 +1717,7 @@ void PickOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = inputs[0].shape_; index_t axis = CheckAxis(param.axis.value(), ishape.ndim()); - int leading = 1, trailing = 1, M = ishape[axis]; + index_t leading = 1, trailing = 1, M = ishape[axis]; for (index_t i = 0; i < axis; ++i) leading *= ishape[i]; for (index_t i = axis+1; i < ishape.ndim(); ++i) trailing *= ishape[i]; @@ -1764,7 +1764,7 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = outputs[0].shape_; const index_t axis = CheckAxis(param.axis.value(), ishape.ndim()); - int leading = 1, trailing = 1, M = ishape[axis]; + index_t leading = 1, trailing = 1, M = ishape[axis]; for (index_t i = 0; i < axis; ++i) leading *= ishape[i]; for (index_t i = axis+1; i < ishape.ndim(); ++i) trailing *= ishape[i]; diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py index 90fc58f0f27b..d2b653b500cd 100644 --- a/tests/nightly/test_np_large_array.py +++ b/tests/nightly/test_np_large_array.py @@ -709,16 +709,18 @@ def test_one_hot(): assert A.grad[0] == 0 @use_np -@pytest.mark.skip(reason='backward value broken on large tensor') def test_pick(): + INT_OVERFLOW = 2**31 A = np.zeros((INT_OVERFLOW, 2)) B = np.zeros((INT_OVERFLOW)) + A[-1, 0] = 3 A.attach_grad() B.attach_grad() with mx.autograd.record(): C = npx.pick(A, B) assert C.shape == (INT_OVERFLOW, ) assert C[0] == 0 + assert C[-1] == 3 C.backward() assert A.grad.shape == (INT_OVERFLOW, 2) assert B.grad.shape == (INT_OVERFLOW, )