From 243ebc389eec78f060a47c53375b3fd48c462d76 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 7 Apr 2024 19:52:09 +0530 Subject: [PATCH] Support broadcasting over structured block matrices (#53909) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix https://github.com/JuliaLang/julia/issues/48664 After this, broadcasting over structured block matrices with matrix-valued elements works: ```julia julia> D = Diagonal([[1 2; 3 4], [5 6; 7 8]]) 2×2 Diagonal{Matrix{Int64}, Vector{Matrix{Int64}}}: [1 2; 3 4] ⋅ ⋅ [5 6; 7 8] julia> D .+ D 2×2 Diagonal{Matrix{Int64}, Vector{Matrix{Int64}}}: [2 4; 6 8] ⋅ ⋅ [10 12; 14 16] julia> cos.(D) 2×2 Matrix{Matrix{Float64}}: [0.855423 -0.110876; -0.166315 0.689109] [1.0 0.0; 0.0 1.0] [1.0 0.0; 0.0 1.0] [0.928384 -0.069963; -0.0816235 0.893403] ``` Such operations show up when using `BlockArrays`. The implementation is a bit hacky as it uses `0I` as the zero element in `fzero`, which isn't really the correct zero if the blocks are rectangular. Nonetheless, this works, as `fzero` is only used to determine if the structure is preserved. --- .../LinearAlgebra/src/structuredbroadcast.jl | 6 +- stdlib/LinearAlgebra/test/special.jl | 1 + .../LinearAlgebra/test/structuredbroadcast.jl | 58 +++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index dfc58aabc63ec..6114c9a796390 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}() StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}() -const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular} -for ST in Base.uniontypes(StructuredMatrix) +const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}} +for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular) @eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}()) end @@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However, iszerodefined(::Type) = false iszerodefined(::Type{<:Number}) = true iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T) +iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T) fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)) # Like sparse matrices, we assume that the zero-preservation property of a broadcasted @@ -144,6 +145,7 @@ fzero(::Type{T}) where T = T fzero(r::Ref) = r[] fzero(t::Tuple{Any}) = t[1] fzero(S::StructuredMatrix) = zero(eltype(S)) +fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing fzero(x) = missing function fzero(bc::Broadcast.Broadcasted) args = map(fzero, bc.args) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index 6465927db5235..a5198892ff995 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -111,6 +111,7 @@ Random.seed!(1) struct TypeWithZero end Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero() + Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x)) Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero() LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero() LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl index 02d7ead55af17..3767fc10055f2 100644 --- a/stdlib/LinearAlgebra/test/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -280,4 +280,62 @@ end # structured broadcast with function returning non-number type @test tuple.(Diagonal([1, 2])) == [(1,) (0,); (0,) (2,)] +@testset "broadcast over structured matrices with matrix elements" begin + function standardbroadcastingtests(D, T) + M = [x for x in D] + Dsum = D .+ D + @test Dsum isa T + @test Dsum == M .+ M + Dcopy = copy.(D) + @test Dcopy isa T + @test Dcopy == D + Df = float.(D) + @test Df isa T + @test Df == D + @test eltype(eltype(Df)) <: AbstractFloat + @test (x -> (x,)).(D) == (x -> (x,)).(M) + @test (x -> 1).(D) == ones(Int,size(D)) + @test all(==(2), ndims.(D)) + @test_throws MethodError size.(D) + end + @testset "Diagonal" begin + @testset "square" begin + A = [1 3; 2 4] + D = Diagonal([A, A]) + standardbroadcastingtests(D, Diagonal) + @test sincos.(D) == sincos.(Matrix{eltype(D)}(D)) + M = [x for x in D] + @test cos.(D) == cos.(M) + end + + @testset "different-sized square blocks" begin + D = Diagonal([ones(3,3), fill(3.0,2,2)]) + standardbroadcastingtests(D, Diagonal) + end + + @testset "rectangular blocks" begin + D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)]) + standardbroadcastingtests(D, Diagonal) + end + + @testset "incompatible sizes" begin + A = reshape(1:12, 4, 3) + B = reshape(1:12, 3, 4) + D1 = Diagonal(fill(A, 2)) + D2 = Diagonal(fill(B, 2)) + @test_throws DimensionMismatch D1 .+ D2 + end + end + @testset "Bidiagonal" begin + A = [1 3; 2 4] + B = Bidiagonal(fill(A,3), fill(A,2), :U) + standardbroadcastingtests(B, Bidiagonal) + end + @testset "UpperTriangular" begin + A = [1 3; 2 4] + U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3]) + standardbroadcastingtests(U, UpperTriangular) + end +end + end