Skip to content

Commit

Permalink
Extend sparse broadcast to VecOrMats
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Jan 20, 2017
1 parent 4f5510f commit f513300
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
15 changes: 10 additions & 5 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ _containertype(::Type) = Any
_containertype{T<:Ptr}(::Type{T}) = Any
_containertype{T<:Tuple}(::Type{T}) = Tuple
_containertype{T<:Ref}(::Type{T}) = Array
_containertype{T<:AbstractVecOrMat}(::Type{T}) = VecOrMat
_containertype{T<:AbstractArray}(::Type{T}) = Array
_containertype{T<:Nullable}(::Type{T}) = Nullable
containertype(x) = _containertype(typeof(x))
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))

promote_containertype(::Type{Array}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ct) = Array
promote_containertype(ct, ::Type{Array}) = Array
promote_containertype(ct1, ct2) = Array
promote_containertype(::Type{Array}, ::Type{VecOrMat}) = Array
promote_containertype(::Type{VecOrMat}, ::Type{Array}) = Array
promote_containertype(::Type{VecOrMat}, ::Type{Tuple}) = Array
promote_containertype(::Type{Tuple}, ::Type{VecOrMat}) = Array
promote_containertype{T<:Array}(::Type{T}, ::ScalarType) = T
promote_containertype{T<:Array}(::ScalarType, ::Type{T}) = T
promote_containertype(::Type{Tuple}, ::ScalarType) = Tuple
promote_containertype(::ScalarType, ::Type{Tuple}) = Tuple
promote_containertype(::Type{Any}, ::Type{Nullable}) = Nullable
Expand All @@ -47,7 +52,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::ScalarType, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A::Ref) = ()
broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices{T<:Array}(::Type{T}, A) = indices(A)
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -283,7 +288,7 @@ eltypestuple(a, b...) = (Base.@_pure_meta; Tuple{eltypestuple(a).types..., eltyp
_broadcast_eltype(f, A, Bs...) = Base._return_type(f, eltypestuple(A, Bs...))

# broadcast methods that dispatch on the type of the final container
@inline function broadcast_c(f, ::Type{Array}, A, Bs...)
@inline function broadcast_c{T<:Array}(f, ::Type{T}, A, Bs...)
T = _broadcast_eltype(f, A, Bs...)
shape = broadcast_indices(A, Bs...)
iter = CartesianRange(shape)
Expand Down
34 changes: 17 additions & 17 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import Base.Broadcast: _containertype, promote_containertype,
broadcast_indices, broadcast_c, broadcast_c!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray,
AbstractSparseMatrix, AbstractSparseVector, indtype

# This module is organized as follows:
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
Expand Down Expand Up @@ -863,6 +864,8 @@ _containertype{T<:SparseVecOrMat}(::Type{T}) = AbstractSparseArray
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
promote_containertype(::Type{VecOrMat}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{VecOrMat}) = AbstractSparseArray
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = Array
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
Expand All @@ -871,11 +874,11 @@ promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array

# broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N})
parevalf, passedargstup = capturescalars(f, mixedargs)
parevalf, passedargstup = capturescalars(f, map(_sparsifystructured, mixedargs))
return broadcast(parevalf, passedargstup...)
end
@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
parevalf, passedsrcargstup = capturescalars(f, map(_sparsifystructured, mixedsrcargs))
return broadcast!(parevalf, dest, passedsrcargstup...)
end
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
Expand Down Expand Up @@ -913,33 +916,30 @@ _containertype{T<:Diagonal}(::Type{T}) = StructuredArray
_containertype{T<:Bidiagonal}(::Type{T}) = StructuredArray
_containertype{T<:Tridiagonal}(::Type{T}) = StructuredArray
_containertype{T<:SymTridiagonal}(::Type{T}) = StructuredArray
promote_containertype(::Type{StructuredArray}, ::Type{StructuredArray}) = StructuredArray
# combinations involving sparse arrays continue in the structured array funnel
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = StructuredArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = StructuredArray
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = AbstractSparseArray
promote_containertype(::Type{StructuredArray}, ::Type{VecOrMat}) = AbstractSparseArray
promote_containertype(::Type{VecOrMat}, ::Type{StructuredArray}) = AbstractSparseArray
# combinations involving scalars continue in the structured array funnel
promote_containertype(::Type{StructuredArray}, ::Type{Any}) = StructuredArray
promote_containertype(::Type{Any}, ::Type{StructuredArray}) = StructuredArray
promote_containertype(::Type{StructuredArray}, ::Type{Any}) = AbstractSparseArray
promote_containertype(::Type{Any}, ::Type{StructuredArray}) = AbstractSparseArray
# combinations involving arrays divert to the generic array code
promote_containertype(::Type{StructuredArray}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ::Type{StructuredArray}) = Array
# combinations involving tuples divert to the generic array code
promote_containertype(::Type{StructuredArray}, ::Type{Tuple}) = Array
promote_containertype(::Type{Tuple}, ::Type{StructuredArray}) = Array

# for combinations involving sparse/structured arrays and scalars only,
# promote all structured arguments to sparse and then rebroadcast
@inline broadcast_c{N}(f, ::Type{StructuredArray}, As::Vararg{Any,N}) =
broadcast(f, map(_sparsifystructured, As)...)
@inline broadcast_c!{N}(f, ::Type{AbstractSparseArray}, ::Type{StructuredArray}, C, B, As::Vararg{Any,N}) =
broadcast!(f, C, _sparsifystructured(B), map(_sparsifystructured, As)...)
@inline broadcast_c!{N}(f, CT::Type, ::Type{StructuredArray}, C, B, As::Vararg{Any,N}) =
broadcast_c!(f, CT, Array, C, B, As...)
@inline _sparsifystructured(S::SymTridiagonal) = SparseMatrixCSC(S)
@inline _sparsifystructured(T::Tridiagonal) = SparseMatrixCSC(T)
@inline _sparsifystructured(B::Bidiagonal) = SparseMatrixCSC(B)
@inline _sparsifystructured(D::Diagonal) = SparseMatrixCSC(D)
@inline _sparsifystructured(A::AbstractSparseArray) = A
@inline _sparsifystructured(A::AbstractMatrix) = SparseMatrixCSC(A)
@inline _sparsifystructured(A::AbstractVector) = SparseVector(A)
@inline _sparsifystructured(A::AbstractSparseMatrix) = SparseMatrixCSC(A)
@inline _sparsifystructured(A::AbstractSparseVector) = SparseVector(A)
@inline _sparsifystructured(S::SparseVecOrMat) = S
@inline _sparsifystructured(x) = x


Expand Down
4 changes: 2 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,10 @@ Base.size(A::Array19745) = size(A.data)

Base.Broadcast._containertype{T<:Array19745}(::Type{T}) = Array19745

# This way of defining promote_containertype methods is discouraged. The recommended
# way is by defining definitions for combinations of tight containers types
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745

Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A)
Expand Down

0 comments on commit f513300

Please sign in to comment.