From ed5a883beccf2b11a49be716335c9a89ae2e39fa Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Tue, 28 Jul 2020 11:48:48 +0000 Subject: [PATCH] adding error message when attempting to use Large tensor with linalg_syevd --- 3rdparty/mshadow/mshadow/base.h | 2 ++ src/operator/tensor/la_op.h | 3 +++ tests/nightly/test_large_array.py | 13 ++++++++++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index 9f538574f093..6eb794f005f7 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -335,6 +335,8 @@ const float kPi = 3.1415926f; #else typedef int32_t index_t; #endif + /*! \brief maximum signed integer limit used to check integer overflow */ +const index_t kInt32Limit = (int64_t{1} << 31) - 1; #ifdef _WIN32 /*! \brief openmp index for windows */ diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index e15390ecde5a..f67c302e4de3 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -470,11 +470,14 @@ inline bool DetType(const nnvm::NodeAttrs& attrs, inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, mxnet::ShapeVector* out_attrs) { + using namespace mshadow; CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 2); const mxnet::TShape& in_a = (*in_attrs)[0]; const mxnet::TShape& out_u = (*out_attrs)[0]; const mxnet::TShape& out_l = (*out_attrs)[1]; + CHECK_LE(in_a.Size(), kInt32Limit) + << "Large tensors are not supported by Linear Algebra operator syevd"; if ( in_a.ndim() >= 2 ) { // Forward shape inference. const int ndim(in_a.ndim()); diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 306c827bab9f..de36cc4f118d 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -27,7 +27,8 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch from mxnet import gluon, nd -from common import with_seed, with_post_test_cleanup +from common import with_seed, with_post_test_cleanup, assertRaises +from mxnet.base import MXNetError from nose.tools import with_setup import unittest @@ -1350,6 +1351,16 @@ def run_trsm(inp): check_batch_trsm() +def test_linalg_errors(): + def check_syevd_error(): + A = get_identity_mat(LARGE_SQ_X) + for i in range(LARGE_SQ_X): + A[i,i] = 1 + assertRaises(MXNetError, mx.nd.linalg.syevd, A) + + check_syevd_error() + + def test_basic(): def check_elementwise(): a = nd.ones(shape=(LARGE_X, SMALL_Y))