Skip to content

Commit

Permalink
Fix length calculation of broadcast over tuples (#23887)
Browse files Browse the repository at this point in the history
(cherry picked from commit 698ef27)
  • Loading branch information
pabloferz authored and ararslan committed Nov 14, 2017
1 parent 9de323f commit a2a45bb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
27 changes: 24 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ end
end

Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
# `(x,)`, where `x` is a scalar, broadcasts the same way as `[x]` or `x`
Base.@propagate_inbounds _broadcast_getindex(::Type{Tuple}, A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
Expand Down Expand Up @@ -334,13 +336,32 @@ end
end
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
tuplebroadcast(f, tuplebroadcast_maxtuple(A, Bs...), A, Bs...)
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val{N})
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val{N})
first_tuple(A::Tuple, Bs...) = A
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
# When the result of broadcast is a tuple it can only come from mixing n-tuples
# of the same length with scalars and 1-tuples. So, in order to have a
# type-stable broadcast, we need to find a tuple of maximum length (except when
# there are only scalars, empty tuples and 1-tuples, in which case the
# returned value will be an empty tuple).
# The following methods compare broadcast arguments pairwise to determine the
# length of the final tuple.
tuplebroadcast_maxtuple(A, B) =
_tuplebroadcast_maxtuple(containertype(A), containertype(B), A, B)
@inline tuplebroadcast_maxtuple(A, Bs...) =
tuplebroadcast_maxtuple(A, tuplebroadcast_maxtuple(Bs...))
tuplebroadcast_maxtuple(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} = A
# Here we use the containertype trait to easier disambiguate between methods
_tuplebroadcast_maxtuple(::Type{Any}, ::Type{Any}, A, B) = (nothing,)
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Any}, A, B) = A
_tuplebroadcast_maxtuple(::Type{Any}, ::Type{Tuple}, A, B) = B
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A, B::Tuple{Any}) = A
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A::Tuple{Any}, B) = B
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A::Tuple{Any}, ::Tuple{Any}) = A
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A, B) =
throw(DimensionMismatch("tuples could not be broadcast to a common size"))
tuplebroadcast_getargs(::Tuple{}, k) = ()
@inline tuplebroadcast_getargs(As, k) =
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)
Expand Down
16 changes: 12 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,16 @@ end
Nullable("hello"))
end

# Issue #21291
let t = (0, 1, 2)
o = 1
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
@testset "broadcast resulting in tuples" begin
# Issue #21291
let t = (0, 1, 2)
o = 1
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
end

# Issue #23647
@test (1, 2, 3) .+ (1,) == (1,) .+ (1, 2, 3) == (2, 3, 4)
@test (1,) .+ () == () .+ (1,) == () .+ () == ()
@test (1, 2) .+ (1, 2) == (2, 4)
@test_throws DimensionMismatch (1, 2) .+ (1, 2, 3)
end

0 comments on commit a2a45bb

Please sign in to comment.