Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new broadcast API #348

Closed
wants to merge 17 commits into from
Closed
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 72 additions & 39 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,77 @@
## broadcast! ##
################

import Base.Broadcast:
if VERSION < v"0.7.0-DEV.2638"
## Old Broadcast API ##
import Base.Broadcast:
_containertype, promote_containertype, broadcast_indices,
broadcast_c, broadcast_c!

# Add StaticArray as a new output type in Base.Broadcast promotion machinery.
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
# Array, StaticArray -> Array
#
# We could be more precise about the latter, but this isn't really possible
# without using Array{N} rather than Array in Base's promote_containertype.
#
# Base also has broadcast with tuple + Array, but while implementing this would
# be consistent with Base, it's not exactly clear it's a good idea when you can
# just use an SVector instead?
promote_containertype(::Type{StaticArray}, ::Type{Any}) = StaticArray
promote_containertype(::Type{Any}, ::Type{StaticArray}) = StaticArray

broadcast_indices(::Type{StaticArray}, A) = indices(A)


# Override for when output type is deduced to be a StaticArray.
@inline function broadcast_c(f, ::Type{StaticArray}, as...)
_broadcast(f, broadcast_sizes(as...), as...)
# Add StaticArray as a new output type in Base.Broadcast promotion machinery.
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
# Array, StaticArray -> Array
#
# We could be more precise about the latter, but this isn't really possible
# without using Array{N} rather than Array in Base's promote_containertype.
#
# Base also has broadcast with tuple + Array, but while implementing this would
# be consistent with Base, it's not exactly clear it's a good idea when you can
# just use an SVector instead?
promote_containertype(::Type{StaticArray}, ::Type{Any}) = StaticArray
promote_containertype(::Type{Any}, ::Type{StaticArray}) = StaticArray

# Override for when output type is deduced to be a StaticArray.
@inline function broadcast_c(f, ::Type{StaticArray}, as...)
_broadcast(f, broadcast_sizes(as...), as...)
end

# TODO: This signature could be relaxed to (::Any, ::Type{StaticArray}, ::Type, ...), though
# we'd need to rework how _broadcast!() and broadcast_sizes() interact with normal AbstractArray.
@inline function broadcast_c!(f, ::Type{StaticArray}, ::Type{StaticArray}, dest, as...)
_broadcast!(f, Size(dest), dest, broadcast_sizes(as...), as...)
end
else
## New Broadcast API ##
import Base.Broadcast:
BroadcastStyle, AbstractArrayStyle, broadcast

# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
# A constructor that changes the style parameter N (array dimension) is also required
struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()

BroadcastStyle(::Type{<:StaticArray{D, T, N}}) where {D, T, N} = StaticArrayStyle{N}()

# Fix Precedence: Make StaticArray - Array -> Array
BroadcastStyle(::StaticArrayStyle{M}, ::Broadcast.VectorStyle) where M = Broadcast.Unknown()
BroadcastStyle(::StaticArrayStyle{M}, ::Broadcast.MatrixStyle) where M = Broadcast.Unknown()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VectorStyle and MatrixStyle are unfortunate (and hopefully temporary) objects because we don't trust that sparse broadcasting generalizes correctly beyond Array (see JuliaLang/julia#23939 (comment)). However, this suggests we'll have to support them throughout julia 1.x. 😢

Long-term the one we want to implement would be with DefaultArrayStyle. Instead of returning Unknown, it might be best to do something like this (untested):

BroadcastStyle(::StaticArrayStyle{M}, ::Broadcast.DefaultArrayStyle{N}) where {M,N} =
    DefaultArrayStyle(Broadcast._max(Val(M), Val(N))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Tim, thanks for reviewing! Your designs are a pleasure to work with.

So you suggest including this rule, in addition to the VectorStyle and MatrixStyle ones, to take over once the latter are removed, yes? Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Also the DefaultArrayStyle is necessary even now for dimensions higher than 2.

Thanks for the kind words about the API. I think it's also fair to say that unfortunately they don't help StaticArrays that much; they're more oriented around improving the API for mutable arrays. But at least it's documented, and hopefully that's worth something.

Really grateful to you for tackling this!


# Add a broadcast method that calls the @generated routine
@inline function broadcast(f, ::StaticArrayStyle, ::Void, ::Void, As...)
_broadcast(f, broadcast_sizes(As...), As...)
end

# Add a specialized broadcast! method that overrides the Base fallback and calls the old routine
# TODO: This signature could be relaxed to (::AbstractArray, ::Vararg{StaticArray,N}), ::Type, ...), though
# we'd need to rework how _broadcast!() and broadcast_sizes() interact with normal AbstractArray.
@inline function broadcast!(f, C::StaticArray, As::Vararg{StaticArray,N}) where {N}
_broadcast!(f, Size(C), C, broadcast_sizes(As...), As...)
end
end


##############################################
## Old broadcast machinery for StaticArrays ##
##############################################

broadcast_indices(A::StaticArray) = indices(A)

@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
Expand Down Expand Up @@ -113,21 +153,14 @@ end

if VERSION < v"0.7.0-DEV"
# Workaround for #329
@inline function Base.broadcast(f, ::Type{T}, a::StaticArray) where {T}
map(x->f(T,x), a)
end
end

################
## broadcast! ##
################

# TODO: This signature could be relaxed to (::Any, ::Type{StaticArray}, ::Type, ...), though
# we'd need to rework how _broadcast!() and broadcast_sizes() interact with normal AbstractArray.
@inline function broadcast_c!(f, ::Type{StaticArray}, ::Type{StaticArray}, dest, as...)
_broadcast!(f, Size(dest), dest, broadcast_sizes(as...), as...)
@inline function Base.broadcast(f, ::Type{T}, a::StaticArray) where {T}
map(x->f(T,x), a)
end
end

###############################################
## Old broadcast! machinery for StaticArrays ##
###############################################

@generated function _broadcast!(f, ::Size{newsize}, dest::StaticArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
sizes = [sz.parameters[1] for sz ∈ s.parameters]
Expand Down