Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix length calculation of broadcast over tuples #23887

Merged
merged 1 commit into from
Nov 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ end
end

Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
# `(x,)`, where `x` is a scalar, broadcasts the same way as `[x]` or `x`
Base.@propagate_inbounds _broadcast_getindex(::Type{Tuple}, A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
Expand Down Expand Up @@ -333,13 +335,32 @@ end
end
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
tuplebroadcast(f, tuplebroadcast_maxtuple(A, Bs...), A, Bs...)
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N))
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N))
first_tuple(A::Tuple, Bs...) = A
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
# When the result of broadcast is a tuple it can only come from mixing n-tuples
# of the same length with scalars and 1-tuples. So, in order to have a
# type-stable broadcast, we need to find a tuple of maximum length (except when
# there are only scalars, empty tuples and 1-tuples, in which case the
# returned value will be an empty tuple).
# The following methods compare broadcast arguments pairwise to determine the
# length of the final tuple.
tuplebroadcast_maxtuple(A, B) =
_tuplebroadcast_maxtuple(containertype(A), containertype(B), A, B)
@inline tuplebroadcast_maxtuple(A, Bs...) =
tuplebroadcast_maxtuple(A, tuplebroadcast_maxtuple(Bs...))
tuplebroadcast_maxtuple(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} = A
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a _tuplebroadcast_maxtuple method, for example with the testcase broadcast(+, (1, 2), (3, 4), 5).

Copy link
Contributor Author

@pabloferz pabloferz Oct 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That example works with this PR as is. I had it like this to reduce the compiler/inference work in cases where you have a bunch of tuples of the same length and a scalar element (which I imagined would be common), e.g. (1,2) .+ (2,3) .+ (3,4) .+ (4,5) .+ 1. But I can make it a _tuplebroadcast_maxtuple for consistency if preferred.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I screwed up the order of my example. It should be broadcast(+, 5, (1, 2), (3, 4)) which I believe would not currently work.

Copy link
Contributor Author

@pabloferz pabloferz Nov 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that example (or broadcast(+, (1, 2), 5, (3, 4))) works fine.

tuplebroadcast_maxtuple will call _tuplebroadcast_maxtuple only when the number of arguments to tuplebroadcast_maxtuple is two and both are not tuples of the same length (if they are, dispatch would select tuplebroadcast_maxtuple(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} instead).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I understand. Thanks for the correction.

# Here we use the containertype trait to easier disambiguate between methods
_tuplebroadcast_maxtuple(::Type{Any}, ::Type{Any}, A, B) = (nothing,)
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Any}, A, B) = A
_tuplebroadcast_maxtuple(::Type{Any}, ::Type{Tuple}, A, B) = B
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A, B::Tuple{Any}) = A
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A::Tuple{Any}, B) = B
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A::Tuple{Any}, ::Tuple{Any}) = A
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A, B) =
throw(DimensionMismatch("tuples could not be broadcast to a common size"))
tuplebroadcast_getargs(::Tuple{}, k) = ()
@inline tuplebroadcast_getargs(As, k) =
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)
Expand Down
16 changes: 12 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,16 @@ end
Nullable("hello"))
end

# Issue #21291
let t = (0, 1, 2)
o = 1
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
@testset "broadcast resulting in tuples" begin
# Issue #21291
let t = (0, 1, 2)
o = 1
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
end

# Issue #23647
@test (1, 2, 3) .+ (1,) == (1,) .+ (1, 2, 3) == (2, 3, 4)
@test (1,) .+ () == () .+ (1,) == () .+ () == ()
@test (1, 2) .+ (1, 2) == (2, 4)
@test_throws DimensionMismatch (1, 2) .+ (1, 2, 3)
end