Skip to content

Commit

Permalink
Fix kron ambiguity for triangular/hermitian matrices (#54413)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored May 15, 2024
1 parent 28aaafc commit 9ea4536
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
16 changes: 7 additions & 9 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,28 +525,26 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
end
end

function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}}
function kron(A::Hermitian{<:Union{Real,Complex},<:StridedMatrix}, B::Hermitian{<:Union{Real,Complex},<:StridedMatrix})
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
C = Hermitian(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number}
function kron(A::Symmetric{<:Number,<:StridedMatrix}, B::Symmetric{<:Number,<:StridedMatrix})
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
C = Symmetric(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}})
function kron!(C::Hermitian{<:Union{Real,Complex},<:StridedMatrix}, A::Hermitian{<:Union{Real,Complex},<:StridedMatrix}, B::Hermitian{<:Union{Real,Complex},<:StridedMatrix})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo)
return C
end

function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number})
function kron!(C::Symmetric{<:Number,<:StridedMatrix}, A::Symmetric{<:Number,<:StridedMatrix}, B::Symmetric{<:Number,<:StridedMatrix})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
Expand All @@ -555,7 +553,7 @@ function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Nu
return C
end

function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR}
function _hermkron!(C, A, B, conj, real, Auplo, Buplo)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds if Auplo == 'U' && Buplo == 'U'
Expand Down
14 changes: 6 additions & 8 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -763,23 +763,21 @@ for op in (:+, :-)
end
end

function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number}
C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
function kron(A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
C = UpperTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number}
C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
function kron(A::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
C = LowerTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number})
function kron!(C::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_triukron!(C.data, A.data, B.data)
return C
end

function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number})
function kron!(C::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, A::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::LowerTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_trilkron!(C.data, A.data, B.data)
return C
Expand Down
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,15 @@ end
@test kron(A, B) kron(Symmetric(A), Symmetric(B))
end

@testset "kron with symmetric/hermitian matrices of matrices" begin
M = fill(ones(2,2), 2, 2)
for W in (Symmetric, Hermitian)
for (t1, t2) in ((W(M, :U), W(M, :U)), (W(M, :U), W(M, :L)), (W(M, :L), W(M, :L)))
@test kron(t1, t2) kron(Matrix(t1), Matrix(t2))
end
end
end

#Issue #7647: test xsyevr, xheevr, xstevr drivers.
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
(Symmetric(diagm(0 => 1.0:3.0)),
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,14 @@ end
end
end

@testset "kron with triangular matrices of matrices" begin
for T in (UpperTriangular, LowerTriangular)
t = T(fill(ones(2,2), 2, 2))
m = Matrix(t)
@test kron(t, t) kron(m, m)
end
end

@testset "copyto! with aliasing (#39460)" begin
M = Matrix(reshape(1:36, 6, 6))
@testset for T in (UpperTriangular, LowerTriangular)
Expand Down

0 comments on commit 9ea4536

Please sign in to comment.