From c37fc2798d9bb7349ff8eadb350ae68cf17cee61 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 16 Mar 2023 10:53:06 +0000 Subject: [PATCH] Add *(::Diagonal, ::Diagonal, ::Diagonal) (#49005) (#49007) --- stdlib/LinearAlgebra/src/diagonal.jl | 6 ++++++ stdlib/LinearAlgebra/test/diagonal.jl | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index df37e8935d2af..7c54ce4009e33 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -385,6 +385,12 @@ function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) return broadcast(*, Da.diag, A, permutedims(Db.diag)) end +function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal) + _muldiag_size_check(Da, Db) + _muldiag_size_check(Db, Dc) + return Diagonal(Da.diag .* Db.diag .* Dc.diag) +end + # Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat @inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = _muldiag!(out, D, V, alpha, beta) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index bcfed234a51ee..b8f3789ae2674 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1155,4 +1155,15 @@ end @test all(isone, diag(D)) end +@testset "diagonal triple multiplication (#49005)" begin + n = 10 + @test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n))) isa Diagonal + @test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n+1)))) + @test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n+1), Diagonal(ones(n+1)))) + @test_throws DimensionMismatch (*(Diagonal(ones(n+1)), Diagonal(1:n), Diagonal(ones(n)))) + + # currently falls back to two-term * + @test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal +end + end # module TestDiagonal