From 9ea4536f1a3693fd016133a7a9a611b08802f66b Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 15 May 2024 20:41:21 +0200 Subject: [PATCH] Fix `kron` ambiguity for triangular/hermitian matrices (#54413) --- stdlib/LinearAlgebra/src/symmetric.jl | 16 +++++++--------- stdlib/LinearAlgebra/src/triangular.jl | 14 ++++++-------- stdlib/LinearAlgebra/test/symmetric.jl | 9 +++++++++ stdlib/LinearAlgebra/test/triangular.jl | 8 ++++++++ 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 32faef6df3448..b8b733cb63186 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -525,19 +525,18 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni end end -function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}} +function kron(A::Hermitian{<:Union{Real,Complex},<:StridedMatrix}, B::Hermitian{<:Union{Real,Complex},<:StridedMatrix}) resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L - C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo) + C = Hermitian(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)), resultuplo) return kron!(C, A, B) end - -function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number} +function kron(A::Symmetric{<:Number,<:StridedMatrix}, B::Symmetric{<:Number,<:StridedMatrix}) resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L - C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo) + C = Symmetric(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)), resultuplo) return kron!(C, A, B) end -function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}}) +function kron!(C::Hermitian{<:Union{Real,Complex},<:StridedMatrix}, A::Hermitian{<:Union{Real,Complex},<:StridedMatrix}, B::Hermitian{<:Union{Real,Complex},<:StridedMatrix}) size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L') throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)")) @@ -545,8 +544,7 @@ function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Co _hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo) return C end - -function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number}) +function kron!(C::Symmetric{<:Number,<:StridedMatrix}, A::Symmetric{<:Number,<:StridedMatrix}, B::Symmetric{<:Number,<:StridedMatrix}) size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L') throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)")) @@ -555,7 +553,7 @@ function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Nu return C end -function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR} +function _hermkron!(C, A, B, conj, real, Auplo, Buplo) n_A = size(A, 1) n_B = size(B, 1) @inbounds if Auplo == 'U' && Buplo == 'U' diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index f052585d93285..469411eb77d40 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -763,23 +763,21 @@ for op in (:+, :-) end end -function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number} - C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B))) +function kron(A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}) + C = UpperTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B))) return kron!(C, A, B) end - -function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number} - C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B))) +function kron(A::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}) + C = LowerTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B))) return kron!(C, A, B) end -function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number}) +function kron!(C::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}) size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) _triukron!(C.data, A.data, B.data) return C end - -function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number}) +function kron!(C::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, A::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}) size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) _trilkron!(C.data, A.data, B.data) return C diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index f9da7d004822b..218171cab2192 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -580,6 +580,15 @@ end @test kron(A, B) ≈ kron(Symmetric(A), Symmetric(B)) end +@testset "kron with symmetric/hermitian matrices of matrices" begin + M = fill(ones(2,2), 2, 2) + for W in (Symmetric, Hermitian) + for (t1, t2) in ((W(M, :U), W(M, :U)), (W(M, :U), W(M, :L)), (W(M, :L), W(M, :L))) + @test kron(t1, t2) ≈ kron(Matrix(t1), Matrix(t2)) + end + end +end + #Issue #7647: test xsyevr, xheevr, xstevr drivers. @testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in (Symmetric(diagm(0 => 1.0:3.0)), diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 40319b644b3cf..4ec6b37ea2a37 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -1052,6 +1052,14 @@ end end end +@testset "kron with triangular matrices of matrices" begin + for T in (UpperTriangular, LowerTriangular) + t = T(fill(ones(2,2), 2, 2)) + m = Matrix(t) + @test kron(t, t) ≈ kron(m, m) + end +end + @testset "copyto! with aliasing (#39460)" begin M = Matrix(reshape(1:36, 6, 6)) @testset for T in (UpperTriangular, LowerTriangular)