From 0a445cb8436062a15963e08f356cfddc288af71f Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 15 Jul 2020 16:46:29 +0000 Subject: [PATCH] add large tensor test for linalg_gemm2 --- tests/nightly/test_large_array.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 306c827bab9f..9c8ef2cb6a46 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1337,6 +1337,15 @@ def run_trsm(inp): assert(grad[0, 0, 0] == 0) assert(grad[1, 0, 0] == 0) + def check_linalg_gemm2(): + a = mx.nd.ones(shape=(SMALL_Y, LARGE_X)) + b = mx.nd.ones(shape=(LARGE_X, SMALL_Y)) + res = nd.linalg_gemm2(a, b) + res.shape == (SMALL_Y, SMALL_Y) + assert res.asnumpy()[0][0] == LARGE_X + assert res.asnumpy()[-1][-1] == LARGE_X + + check_linalg_gemm2() check_potrf() check_potri() check_syrk_batch()