diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 306c827bab9f..b90c2a31ecaf 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1963,6 +1963,15 @@ def check_minimum(): check_maximum() check_minimum() +def test_linalg_ops(): + def check_linalg_gemm2(): + a= mx.nd.ones(shape=(SMALL_Y, LARGE_X)) + b= mx.nd.ones(shape=(LARGE_X, SMALL_Y)) + res=nd.linalg_gemm2(a, b) + res.shape == (SMALL_Y, SMALL_Y) + assert res.asnumpy()[0][0] == LARGE_X + assert res.asnumpy()[-1][-1] == LARGE_X + check_linalg_gemm2() def test_sparse_dot(): shape = (2, VLARGE_X)