Skip to content

Commit

Permalink
Fix length calculation of broadcast over tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Sep 27, 2017
1 parent 674e64b commit 7890971
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
19 changes: 16 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ end
end

Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
Base.@propagate_inbounds _broadcast_getindex(::Type{Tuple}, A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
Expand Down Expand Up @@ -333,13 +334,25 @@ end
end
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
tuplebroadcast(f, tuplebroadcast_length(A, Bs...), A, Bs...)
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N))
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N))
first_tuple(A::Tuple, Bs...) = A
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
tuplebroadcast_length(A, B) = (nothing,)
tuplebroadcast_length(A, ::Tuple{}) = ()
tuplebroadcast_length(::Tuple{}, A) = ()
tuplebroadcast_length(::Tuple, ::Tuple{}) = ()
tuplebroadcast_length(::Tuple{}, ::Tuple) = ()
tuplebroadcast_length(::Tuple{Any}, ::Tuple{}) = ()
tuplebroadcast_length(::Tuple{}, ::Tuple{Any}) = ()
tuplebroadcast_length(A::Tuple, ::Tuple{Any}) = A
tuplebroadcast_length(::Tuple{Any}, A::Tuple) = A
tuplebroadcast_length(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} = A
tuplebroadcast_length(::Tuple, ::Tuple) =
throw(DimensionMismatch("tuples could not be broadcast to a common size"))
@inline tuplebroadcast_length(A, Bs...) =
tuplebroadcast_length(A, tuplebroadcast_length(Bs...))
tuplebroadcast_getargs(::Tuple{}, k) = ()
@inline tuplebroadcast_getargs(As, k) =
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)
Expand Down
16 changes: 12 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,16 @@ end
Nullable("hello"))
end

# Issue #21291
let t = (0, 1, 2)
o = 1
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
@testset "broadcast resulting in tuples" begin
# Issue #21291
let t = (0, 1, 2)
o = 1
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
end

# Issue #23647
@test (1, 2, 3) .+ (1,) == (1,) .+ (1, 2, 3) == (2, 3, 4)
@test (1,) .+ () == () .+ (1,) == () .+ () == ()
@test (1, 2) .+ (1, 2) == (2, 4)
@test_throws DimensionMismatch (1, 2) .+ (1, 2, 3)
end

0 comments on commit 7890971

Please sign in to comment.