Skip to content

Commit

Permalink
Extend sparse broadcast to VecOrMats
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Jan 18, 2017
1 parent b8971ea commit d6e85e8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 39 deletions.
24 changes: 15 additions & 9 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Base: linearindices, tail, OneTo, to_shape,
import Base: broadcast, broadcast!
export broadcast_getindex, broadcast_setindex!, dotview

immutable OneOrTwoD end
typealias ArrayType Union{Type{AbstractArray}, Type{OneOrTwoD}}
typealias ScalarType Union{Type{Any}, Type{Nullable}}

## Broadcasting utilities ##
Expand All @@ -24,16 +26,19 @@ broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X)
_containertype(::Type) = Any
_containertype{T<:Ptr}(::Type{T}) = Any
_containertype{T<:Tuple}(::Type{T}) = Tuple
_containertype{T<:Ref}(::Type{T}) = Array
_containertype{T<:AbstractArray}(::Type{T}) = Array
_containertype{T<:Ref}(::Type{T}) = AbstractArray
_containertype{T<:AbstractArray}(::Type{T}) = AbstractArray
_containertype{T<:AbstractVecOrMat}(::Type{T}) = OneOrTwoD
_containertype{T<:Nullable}(::Type{T}) = Nullable
containertype(x) = _containertype(typeof(x))
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))

promote_containertype(::Type{Array}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ct) = Array
promote_containertype(ct, ::Type{Array}) = Array
promote_containertype(ct1, ct2) = AbstractArray
promote_containertype(::ArrayType, ::Type{Tuple}) = AbstractArray
promote_containertype(::Type{Tuple}, ::ArrayType) = AbstractArray
promote_containertype(::ArrayType, ::ScalarType) = AbstractArray
promote_containertype(::ScalarType, ::ArrayType) = AbstractArray
promote_containertype(::Type{Tuple}, ::ScalarType) = Tuple
promote_containertype(::ScalarType, ::Type{Tuple}) = Tuple
promote_containertype(::Type{Any}, ::Type{Nullable}) = Nullable
Expand All @@ -46,8 +51,8 @@ broadcast_indices() = ()
broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::ScalarType, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A::Ref) = ()
broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::Type{AbstractArray}, A::Ref) = ()
broadcast_indices(::ArrayType, A) = indices(A)
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -125,7 +130,7 @@ end
end

Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::Type{AbstractArray}, A::Ref, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]

Expand Down Expand Up @@ -283,7 +288,8 @@ eltypestuple(a, b...) = (Base.@_pure_meta; Tuple{eltypestuple(a).types..., eltyp
_broadcast_eltype(f, A, Bs...) = Base._return_type(f, eltypestuple(A, Bs...))

# broadcast methods that dispatch on the type of the final container
@inline function broadcast_c(f, ::Type{Array}, A, Bs...)
@inline broadcast_c(f, ::Type{OneOrTwoD}, A, Bs...) = broadcast_c(f, AbstractArray, A, Bs...)
@inline function broadcast_c(f, ::Type{AbstractArray}, A, Bs...)
T = _broadcast_eltype(f, A, Bs...)
shape = broadcast_indices(A, Bs...)
iter = CartesianRange(shape)
Expand Down
51 changes: 24 additions & 27 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ module HigherOrderFns
# This module provides higher order functions specialized for sparse arrays,
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, broadcast!
import Base.Broadcast: _containertype, promote_containertype,
broadcast_indices, broadcast_c, broadcast_c!
import Base.Broadcast: ScalarType, OneOrTwoD, _containertype,
promote_containertype, broadcast_indices, broadcast_c, broadcast_c!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
Expand Down Expand Up @@ -848,21 +848,23 @@ broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
# broadcast container type promotion for combinations of sparse arrays and other types
_containertype{T<:SparseVecOrMat}(::Type{T}) = AbstractSparseArray
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::ScalarType) = AbstractSparseArray
promote_containertype(::ScalarType, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{OneOrTwoD}) = AbstractSparseArray
promote_containertype(::Type{OneOrTwoD}, ::Type{AbstractSparseArray}) = AbstractSparseArray
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = Array
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
promote_containertype(::Type{AbstractSparseArray}, ::Type{Array}) = Array
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array
promote_containertype(::Type{AbstractSparseArray}, ::Type{AbstractArray}) = AbstractArray
promote_containertype(::Type{AbstractArray}, ::Type{AbstractSparseArray}) = AbstractArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = AbstractArray
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = AbstractArray

# broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N})
parevalf, passedargstup = capturescalars(f, mixedargs)
parevalf, passedargstup = capturescalars(f, map(_sparsifystructured, mixedargs))
return broadcast(parevalf, passedargstup...)
end
@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
parevalf, passedsrcargstup = capturescalars(f, map(_sparsifystructured, mixedsrcargs))
return broadcast!(parevalf, dest, passedsrcargstup...)
end
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
Expand Down Expand Up @@ -900,32 +902,27 @@ _containertype{T<:Diagonal}(::Type{T}) = StructuredArray
_containertype{T<:Bidiagonal}(::Type{T}) = StructuredArray
_containertype{T<:Tridiagonal}(::Type{T}) = StructuredArray
_containertype{T<:SymTridiagonal}(::Type{T}) = StructuredArray
promote_containertype(::Type{StructuredArray}, ::Type{StructuredArray}) = StructuredArray
# combinations involving sparse arrays continue in the structured array funnel
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = StructuredArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = StructuredArray
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = AbstractSparseArray
promote_containertype(::Type{StructuredArray}, ::Type{OneOrTwoD}) = AbstractSparseArray
promote_containertype(::Type{OneOrTwoD}, ::Type{StructuredArray}) = AbstractSparseArray
# combinations involving scalars continue in the structured array funnel
promote_containertype(::Type{StructuredArray}, ::Type{Any}) = StructuredArray
promote_containertype(::Type{Any}, ::Type{StructuredArray}) = StructuredArray
promote_containertype(::Type{StructuredArray}, ::ScalarType) = AbstractSparseArray
promote_containertype(::ScalarType, ::Type{StructuredArray}) = AbstractSparseArray
# combinations involving arrays divert to the generic array code
promote_containertype(::Type{StructuredArray}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ::Type{StructuredArray}) = Array
promote_containertype(::Type{StructuredArray}, ::Type{AbstractArray}) = AbstractArray
promote_containertype(::Type{AbstractArray}, ::Type{StructuredArray}) = AbstractArray
# combinations involving tuples divert to the generic array code
promote_containertype(::Type{StructuredArray}, ::Type{Tuple}) = Array
promote_containertype(::Type{Tuple}, ::Type{StructuredArray}) = Array
promote_containertype(::Type{StructuredArray}, ::Type{Tuple}) = AbstractArray
promote_containertype(::Type{Tuple}, ::Type{StructuredArray}) = AbstractArray

# for combinations involving sparse/structured arrays and scalars only,
# promote all structured arguments to sparse and then rebroadcast
@inline broadcast_c{N}(f, ::Type{StructuredArray}, As::Vararg{Any,N}) =
broadcast(f, map(_sparsifystructured, As)...)
@inline broadcast_c!{N}(f, ::Type{AbstractSparseArray}, ::Type{StructuredArray}, C, B, As::Vararg{Any,N}) =
broadcast!(f, C, _sparsifystructured(B), map(_sparsifystructured, As)...)
@inline broadcast_c!{N}(f, CT::Type, ::Type{StructuredArray}, C, B, As::Vararg{Any,N}) =
broadcast_c!(f, CT, Array, C, B, As...)
@inline _sparsifystructured(S::SymTridiagonal) = SparseMatrixCSC(S)
@inline _sparsifystructured(T::Tridiagonal) = SparseMatrixCSC(T)
@inline _sparsifystructured(B::Bidiagonal) = SparseMatrixCSC(B)
@inline _sparsifystructured(D::Diagonal) = SparseMatrixCSC(D)
@inline _sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
@inline _sparsifystructured(V::AbstractVector) = SparseVector(V)
@inline _sparsifystructured(A::AbstractSparseArray) = A
@inline _sparsifystructured(x) = x

Expand Down
6 changes: 3 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,10 @@ Base.size(A::Array19745) = size(A.data)

Base.Broadcast._containertype{T<:Array19745}(::Type{T}) = Array19745

# This way of defining promote_containertype methods is discouraged. The recommended
# way is by defining definitions for combinations of tight containers types
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745

Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A)
Expand All @@ -406,7 +406,7 @@ getfield19745(x::Array19745) = x.data
getfield19745(x) = x

Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...))
Array19745(Base.Broadcast.broadcast_c(f, AbstractArray, getfield19745(A), map(getfield19745, Bs)...))

@testset "broadcasting for custom AbstractArray" begin
a = randn(10)
Expand Down

0 comments on commit d6e85e8

Please sign in to comment.