From 34d1d30611efed490886f5eef9944e95e4563fcf Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Mon, 29 Feb 2016 06:09:45 -0600 Subject: [PATCH] Check stride on preallocated output for matmul (fixes #15286) --- base/linalg/matmul.jl | 6 +++--- test/linalg/matmul.jl | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/base/linalg/matmul.jl b/base/linalg/matmul.jl index ca97c342ffbc6..7f16a8af885b9 100644 --- a/base/linalg/matmul.jl +++ b/base/linalg/matmul.jl @@ -256,7 +256,7 @@ function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVe return matmul3x3!(C,tA,tAt,A,A) end - if stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) + if stride(A, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(C, 2) >= size(C, 1) return copytri!(BLAS.syrk!('U', tA, one(T), A, zero(T), C), 'U') end return generic_matmatmul!(C, tA, tAt, A, A) @@ -287,7 +287,7 @@ function herk_wrapper!{T<:BlasReal}(C::Union{StridedMatrix{T}, StridedMatrix{Com # Result array does not need to be initialized as long as beta==0 # C = Array(T, mA, mA) - if stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) + if stride(A, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(C, 2) >= size(C, 1) return copytri!(BLAS.herk!('U', tA, one(T), A, zero(T), C), 'U', true) end return generic_matmatmul!(C,tA, tAt, A, A) @@ -330,7 +330,7 @@ function gemm_wrapper!{T<:BlasFloat}(C::StridedVecOrMat{T}, tA::Char, tB::Char, return matmul3x3!(C,tA,tB,A,B) end - if stride(A, 1) == stride(B, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) + if stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && stride(C, 2) >= size(C, 1) return BLAS.gemm!(tA, tB, one(T), A, B, zero(T), C) end generic_matmatmul!(C, tA, tB, A, B) diff --git a/test/linalg/matmul.jl b/test/linalg/matmul.jl index ea4b0f2847da8..e7d8dbe9e6dfd 100644 --- a/test/linalg/matmul.jl +++ b/test/linalg/matmul.jl @@ -109,6 +109,16 @@ Aref = Ai[1:2:end,1:2:end] Asub = sub(Ai, 1:2:5, 1:2:4) @test Ac_mul_B(Asub, Asub) == Ac_mul_B(Aref, Aref) @test A_mul_Bc(Asub, Asub) == A_mul_Bc(Aref, Aref) +# issue #15286 +let C = zeros(8, 8), sC = sub(C, 1:2:8, 1:2:8), B = reshape(map(Float64,-9:10),5,4) + @test At_mul_B!(sC, A, A) == A'*A + @test At_mul_B!(sC, A, B) == A'*B +end +let Aim = A .- im, C = zeros(Complex128,8,8), sC = sub(C, 1:2:8, 1:2:8), B = reshape(map(Float64,-9:10),5,4) .+ im + @test Ac_mul_B!(sC, Aim, Aim) == Aim'*Aim + @test Ac_mul_B!(sC, Aim, B) == Aim'*B +end + # syrk & herk A = reshape(1:1503, 501, 3).-750.0