Skip to content

Commit

Permalink
Better mapreduce for Broadcasted
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Feb 9, 2019
1 parent 51cd56f commit f4c68ee
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
14 changes: 13 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
# methods that instead specialize on `BroadcastStyle`,
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})

struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple}
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
f::F
args::Args
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
Expand Down Expand Up @@ -219,6 +219,18 @@ argtype(bc::Broadcasted) = argtype(typeof(bc))
_eachindex(t::Tuple{Any}) = t[1]
_eachindex(t::Tuple) = CartesianIndices(t)

_combined_indexstyle(::Type{Tuple{A}}) where A = IndexStyle(A)
_combined_indexstyle(::Type{TT}) where TT =
IndexStyle(IndexStyle(Base.tuple_type_head(TT)),
_combined_indexstyle(Base.tuple_type_tail(TT)))

Base.IndexStyle(bc::Broadcasted) = IndexStyle(typeof(bc))
Base.IndexStyle(::Type{<:Broadcasted{<:Any,Nothing,<:Any,Args}}) where {Args <: Tuple} =
_combined_indexstyle(Args)
Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian()

Base.LinearIndices(bc::Broadcasted) = LinearIndices(eachindex(bc))

Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N

Expand Down
20 changes: 12 additions & 8 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ else
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
end

abstract type AbstractBroadcasted end
const AbstractArrayOrBroadcasted = Union{AbstractArray, AbstractBroadcasted}

"""
Base.add_sum(x, y)
Expand Down Expand Up @@ -152,7 +155,8 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)

# This is a generic implementation of `mapreduce_impl()`,
# certain `op` (e.g. `min` and `max`) may have their own specialized versions.
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int)
@noinline function mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted,
ifirst::Integer, ilast::Integer, blksize::Int)
if ifirst == ilast
@inbounds a1 = A[ifirst]
return mapreduce_first(f, op, a1)
Expand All @@ -175,7 +179,7 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
end
end

mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer) =
mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted, ifirst::Integer, ilast::Integer) =
mapreduce_impl(f, op, A, ifirst, ilast, pairwise_blocksize(f, op))

"""
Expand Down Expand Up @@ -296,13 +300,13 @@ The default is `reduce_first(op, f(x))`.
"""
mapreduce_first(f, op, x) = reduce_first(op, f(x))

_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
_mapreduce(f, op, A::AbstractArrayOrBroadcasted) = _mapreduce(f, op, IndexStyle(A), A)

function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
function _mapreduce(f, op, ::IndexLinear, A::AbstractArrayOrBroadcasted)
inds = LinearIndices(A)
n = length(inds)
if n == 0
return mapreduce_empty(f, op, T)
return mapreduce_empty_iter(f, op, A, IteratorEltype(A))
elseif n == 1
@inbounds a1 = A[first(inds)]
return mapreduce_first(f, op, a1)
Expand All @@ -323,7 +327,7 @@ end

mapreduce(f, op, a::Number) = mapreduce_first(f, op, a)

_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A)
_mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted) = mapfoldl(f, op, A)

"""
reduce(op, itr; [init])
Expand Down Expand Up @@ -473,7 +477,7 @@ isgoodzero(::typeof(max), x) = isbadzero(min, x)
isgoodzero(::typeof(min), x) = isbadzero(max, x)

function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
A::AbstractArray, first::Int, last::Int)
A::AbstractArrayOrBroadcasted, first::Int, last::Int)
a1 = @inbounds A[first]
v1 = mapreduce_first(f, op, a1)
v2 = v3 = v4 = v1
Expand Down Expand Up @@ -746,7 +750,7 @@ function count(pred, itr)
end
return n
end
function count(pred, a::AbstractArray)
function count(pred, a::AbstractArrayOrBroadcasted)
n = 0
for i in eachindex(a)
@inbounds n += pred(a[i])::Bool
Expand Down
13 changes: 8 additions & 5 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,19 @@ julia> mapreduce(isodd, |, a, dims=1)
1 1 1 1
```
"""
mapreduce(f, op, A::AbstractArray; dims=:, kw...) = _mapreduce_dim(f, op, kw.data, A, dims)
mapreduce(f, op, A::AbstractArrayOrBroadcasted; dims=:, kw...) =
_mapreduce_dim(f, op, kw.data, A, dims)

_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, ::Colon) = mapfoldl(f, op, A; nt...)
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, ::Colon) =
mapfoldl(f, op, A; nt...)

_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, ::Colon) = _mapreduce(f, op, IndexStyle(A), A)
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, ::Colon) =
_mapreduce(f, op, IndexStyle(A), A)

_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, dims) =
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, dims) =
mapreducedim!(f, op, reducedim_initarray(A, dims, nt.init), A)

_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, dims) =
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, dims) =
mapreducedim!(f, op, reducedim_init(f, op, A, dims), A)

"""
Expand Down

0 comments on commit f4c68ee

Please sign in to comment.