Skip to content

Commit

Permalink
Extend sparse broadcast to VecOrMats
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Mar 9, 2017
1 parent 8436f45 commit 3258952
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 36 deletions.
19 changes: 12 additions & 7 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@ _containertype(::Type) = Any
_containertype(::Type{<:Ptr}) = Any
_containertype(::Type{<:Tuple}) = Tuple
_containertype(::Type{<:Ref}) = Array
_containertype(::Type{<:AbstractVecOrMat}) = VecOrMat
_containertype(::Type{<:AbstractArray}) = Array
_containertype(::Type{<:Nullable}) = Nullable
containertype(x) = _containertype(typeof(x))
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))

promote_containertype(::Type{Array}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ct) = Array
promote_containertype(ct, ::Type{Array}) = Array
promote_containertype(ct1, ct2) = Array
promote_containertype{T<:AbstractArray}(::Type{T}, ::Type{Tuple}) = Array
promote_containertype{T<:AbstractArray}(::Type{Tuple}, ::Type{T}) = Array
promote_containertype(::Type{Array}, ::Type{VecOrMat}) = Array
promote_containertype(::Type{VecOrMat}, ::Type{Array}) = Array
promote_containertype(::Type{Tuple}, ::ScalarType) = Tuple
promote_containertype(::ScalarType, ::Type{Tuple}) = Tuple
promote_containertype(::Type{Any}, ::Type{Nullable}) = Nullable
promote_containertype(::Type{Nullable}, ::Type{Any}) = Nullable
promote_containertype(T::Type, ::ScalarType) = T
promote_containertype(::ScalarType, T::Type) = T
promote_containertype{T}(::Type{T}, ::Type{T}) = T

## Calculate the broadcast indices of the arguments, or error if incompatible
Expand All @@ -49,7 +54,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::ScalarType, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A::Ref) = ()
broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::Type{<:AbstractArray}, A) = indices(A)
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -201,8 +206,8 @@ arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N}) =
broadcast_c!(f, containertype(C), containertype(A, Bs...), C, A, Bs...)
@inline function broadcast_c!{N}(f, ::Type, ::Type, C, A, Bs::Vararg{Any,N})
broadcast_c!(f, containertype(C, A, Bs...), C, A, Bs...)
@inline function broadcast_c!{N}(f, ::Type, C, A, Bs::Vararg{Any,N})
shape = indices(C)
@boundscheck check_broadcast_indices(shape, A, Bs...)
keeps, Idefaults = map_newindexer(shape, A, Bs)
Expand Down Expand Up @@ -285,7 +290,7 @@ eltypestuple(a, b...) = (Base.@_pure_meta; Tuple{eltypestuple(a).types..., eltyp
_broadcast_eltype(f, A, Bs...) = Base._return_type(f, eltypestuple(A, Bs...))

# broadcast methods that dispatch on the type of the final container
@inline function broadcast_c(f, ::Type{Array}, A, Bs...)
@inline function broadcast_c{T<:Array}(f, ::Type{T}, A, Bs...)
T = _broadcast_eltype(f, A, Bs...)
shape = broadcast_indices(A, Bs...)
iter = CartesianRange(shape)
Expand Down
44 changes: 27 additions & 17 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ module HigherOrderFns
# This module provides higher order functions specialized for sparse arrays,
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, broadcast!
import Base.Broadcast: _containertype, promote_containertype,
broadcast_indices, broadcast_c, broadcast_c!
import Base.Broadcast: _containertype, promote_containertype, broadcast_c, broadcast_c!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray,
AbstractSparseMatrix, AbstractSparseVector, indtype

# This module is organized as follows:
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
Expand Down Expand Up @@ -940,32 +940,30 @@ end

# (10) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices

# broadcast shape promotion for combinations of sparse arrays and other types
broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
# broadcast container type promotion for combinations of sparse arrays and other types
_containertype(::Type{<:SparseVecOrMat}) = AbstractSparseArray
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
# combinations of sparse arrays with tuples should divert to the generic AbstractArray broadcast code
# (we handle combinations involving dense vectors/matrices below)
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array
promote_containertype{T<:AbstractSparseArray}(::Type{VecOrMat}, ::Type{T}) = AbstractSparseArray
promote_containertype{T<:AbstractSparseArray}(::Type{T}, ::Type{VecOrMat}) = AbstractSparseArray
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
promote_containertype{T<:AbstractSparseArray}(::Type{Array}, ::Type{T}) = Array
promote_containertype{T<:AbstractSparseArray}(::Type{T}, ::Type{Array}) = Array

# broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N})
parevalf, passedargstup = capturescalars(f, mixedargs)
@inline function broadcast_c{T<:AbstractSparseArray,N}(f, ::Type{T}, mixedargs::Vararg{Any,N})
parevalf, passedargstup = capturescalars(f, map(_sparsifystructured, mixedargs))
return broadcast(parevalf, passedargstup...)
end
@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
@inline function broadcast_c!{T<:AbstractSparseArray,N}(f, ::Type{T}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
parevalf, passedsrcargstup = capturescalars(f, map(_sparsifystructured, mixedsrcargs))
return broadcast!(parevalf, dest, passedsrcargstup...)
end
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
# vectors/matrices in mixedargs in their orginal order, and such that the result of
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
@inline capturescalars{N}(f, mixedargs::Vararg{SparseVecOrMat,N}) = (f, mixedargs)
@inline capturescalars(f, mixedargs) =
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
# Recursion cases for capturescalars
Expand Down Expand Up @@ -1002,6 +1000,15 @@ broadcast{Tf,T}(f::Tf, A::SparseMatrixCSC, ::Type{T}) = broadcast(x -> f(x, T),
# (spbroadcast_c[!], spcontainertype, promote_spcontainertype), and then promote
# arguments to sparse as appropriate and rebroadcast.

# structured array container type promotion
struct StructuredArray{Tv,Ti,N} <: AbstractSparseArray{Tv,Ti,N} end
_containertype{T<:Diagonal}(::Type{T}) = StructuredArray
_containertype{T<:Bidiagonal}(::Type{T}) = StructuredArray
_containertype{T<:Tridiagonal}(::Type{T}) = StructuredArray
_containertype{T<:SymTridiagonal}(::Type{T}) = StructuredArray
# combinations involving sparse arrays continue in the structured array funnel
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = AbstractSparseArray

# first (Broadcast containertype) dispatch layer's promotion logic
struct PromoteToSparse end
Expand Down Expand Up @@ -1078,8 +1085,11 @@ promote_spcontainertype(::FunnelToSparseBC, ::FunnelToSparseBC) = PromoteToSpars

@inline _sparsifystructured(A::AbstractSparseArray) = A
@inline _sparsifystructured(M::StructuredMatrix) = SparseMatrixCSC(M)
@inline _sparsifystructured(M::Matrix) = SparseMatrixCSC(M)
@inline _sparsifystructured(V::Vector) = SparseVector(V)
@inline _sparsifystructured(A::AbstractMatrix) = SparseMatrixCSC(A)
@inline _sparsifystructured(A::AbstractVector) = SparseVector(A)
@inline _sparsifystructured(A::AbstractSparseMatrix) = SparseMatrixCSC(A)
@inline _sparsifystructured(A::AbstractSparseVector) = SparseVector(A)
@inline _sparsifystructured(S::SparseVecOrMat) = S
@inline _sparsifystructured(x) = x


Expand Down
21 changes: 9 additions & 12 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,12 @@ Base.size(A::Array19745) = size(A.data)

Base.Broadcast._containertype{T<:Array19745}(::Type{T}) = Array19745

Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745

Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A)
Base.Broadcast.broadcast_indices(::Type{Array19745}, A::Ref) = ()
# Only define promote_containertype methods with tight container types
# (scalars are properly handled by default)
Base.Broadcast.promote_containertype{T<:AbstractArray}(::Type{Array19745}, ::Type{T}) = Array19745
Base.Broadcast.promote_containertype{T<:AbstractArray}(::Type{T}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Tuple}) = Array19745
Base.Broadcast.promote_containertype(::Type{Tuple}, ::Type{Array19745}) = Array19745

getfield19745(x::Array19745) = x.data
getfield19745(x) = x
Expand All @@ -441,10 +439,9 @@ Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
@testset "broadcasting for custom AbstractArray" begin
a = randn(10)
aa = Array19745(a)
@test a .+ 1 == @inferred(aa .+ 1)
@test a .* a' == @inferred(aa .* aa')
@test isa(aa .+ 1, Array19745)
@test isa(aa .* aa', Array19745)
@test a .+ 1 == @inferred(aa .+ 1)::Array19745
@test a .* a' == @inferred(aa .* aa')::Array19745
@test (aa .+ [1])::Array19745 == (aa .+ (1,))::Array19745
end

# broadcast should only "peel off" one container layer
Expand Down

0 comments on commit 3258952

Please sign in to comment.