diff --git a/base/broadcast.jl b/base/broadcast.jl index 6e2809daa2a48..e07ccab2443bc 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -162,6 +162,32 @@ BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = typeof(a)(_max(Val(M),Val(N))) +# FIXME +# The following definitions are necessary to limit SparseArray broadcasting to "plain Arrays" +# (see https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382). +# They should be deleted once the sparse broadcast infrastucture is capable of handling +# arbitrary AbstractArrays. +struct VectorStyle <: AbstractArrayStyle{1} end +struct MatrixStyle <: AbstractArrayStyle{2} end +const VMStyle = Union{VectorStyle,MatrixStyle} +# These lose to DefaultArrayStyle +VectorStyle(::Val{N}) where N = DefaultArrayStyle{N}() +MatrixStyle(::Val{N}) where N = DefaultArrayStyle{N}() + +BroadcastStyle(::Type{<:Vector}) = VectorStyle() +BroadcastStyle(::Type{<:Matrix}) = MatrixStyle() + +BroadcastStyle(::MatrixStyle, ::VectorStyle) = MatrixStyle() +BroadcastStyle(a::AbstractArrayStyle{Any}, ::VectorStyle) = a +BroadcastStyle(a::AbstractArrayStyle{Any}, ::MatrixStyle) = a +BroadcastStyle(a::AbstractArrayStyle{N}, ::VectorStyle) where N = typeof(a)(_max(Val(N), Val(1))) +BroadcastStyle(a::AbstractArrayStyle{N}, ::MatrixStyle) where N = typeof(a)(_max(Val(N), Val(2))) +BroadcastStyle(::VectorStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(1))) +BroadcastStyle(::MatrixStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(2))) +# to avoid the VectorStyle(::Val) constructor we also need the following +BroadcastStyle(::VectorStyle, ::MatrixStyle) = MatrixStyle() +# end FIXME + ## Allocating the output container """ broadcast_similar(f, ::BroadcastStyle, ::Type{ElType}, inds, As...) @@ -181,6 +207,17 @@ broadcast_similar(f, ::ArrayConflict, ::Type{ElType}, inds::Indices, As...) wher broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) = similar(BitArray, inds) +# FIXME: delete when we get rid of VectorStyle and MatrixStyle +broadcast_similar(f, ::VectorStyle, ::Type{ElType}, inds::Indices{1}, As...) where ElType = + similar(Vector{ElType}, inds) +broadcast_similar(f, ::MatrixStyle, ::Type{ElType}, inds::Indices{2}, As...) where ElType = + similar(Matrix{ElType}, inds) +broadcast_similar(f, ::VectorStyle, ::Type{Bool}, inds::Indices{1}, As...) = + similar(BitArray, inds) +broadcast_similar(f, ::MatrixStyle, ::Type{Bool}, inds::Indices{2}, As...) = + similar(BitArray, inds) +# end FIXME + ## Computing the result's indices. Most types probably won't need to specialize this. broadcast_indices() = () broadcast_indices(::Type{T}) where T = () @@ -582,7 +619,7 @@ Nullable{Complex{Float64}}() broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...), A, Bs...) -const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict} +const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict,VectorStyle,MatrixStyle} @inline function broadcast(f, s::NonleafHandlingTypes, ::Type{ElType}, inds::Indices, As...) where ElType if !Base._isleaftype(ElType) diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 1c551fe446aba..447d247746838 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -989,9 +989,18 @@ PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() +# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle +# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse() +# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() +# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() +BroadcastStyle(::Type{<:Base.RowVector{T,<:Vector}}) where T = Broadcast.MatrixStyle() # RowVector not yet defined when broadcast.jl loaded +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse() +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse() +Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N = + Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1))) +Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N = + Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2))) +# end FIXME broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} = broadcast(f, map(_sparsifystructured, As)...)