Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding error message when attempting to use Large tensor with linalg_…
Browse files Browse the repository at this point in the history
…syevd
  • Loading branch information
Rohit Kumar Srivastava committed Jul 28, 2020
1 parent d009345 commit ed5a883
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
2 changes: 2 additions & 0 deletions 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ const float kPi = 3.1415926f;
#else
typedef int32_t index_t;
#endif
/*! \brief maximum signed integer limit used to check integer overflow */
const index_t kInt32Limit = (int64_t{1} << 31) - 1;

#ifdef _WIN32
/*! \brief openmp index for windows */
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/la_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,14 @@ inline bool DetType(const nnvm::NodeAttrs& attrs,
inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
using namespace mshadow;
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 2);
const mxnet::TShape& in_a = (*in_attrs)[0];
const mxnet::TShape& out_u = (*out_attrs)[0];
const mxnet::TShape& out_l = (*out_attrs)[1];
CHECK_LE(in_a.Size(), kInt32Limit)
<< "Large tensors are not supported by Linear Algebra operator syevd";
if ( in_a.ndim() >= 2 ) {
// Forward shape inference.
const int ndim(in_a.ndim());
Expand Down
13 changes: 12 additions & 1 deletion tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch
from mxnet import gluon, nd
from common import with_seed, with_post_test_cleanup
from common import with_seed, with_post_test_cleanup, assertRaises
from mxnet.base import MXNetError
from nose.tools import with_setup
import unittest

Expand Down Expand Up @@ -1350,6 +1351,16 @@ def run_trsm(inp):
check_batch_trsm()


def test_linalg_errors():
def check_syevd_error():
A = get_identity_mat(LARGE_SQ_X)
for i in range(LARGE_SQ_X):
A[i,i] = 1
assertRaises(MXNetError, mx.nd.linalg.syevd, A)

check_syevd_error()


def test_basic():
def check_elementwise():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
Expand Down

0 comments on commit ed5a883

Please sign in to comment.