From fef6aa821e871db5df11061312ba02929b7fa16d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 15 Aug 2022 12:22:09 -0700 Subject: [PATCH] Remove `NTuple` from `unbroadcast` (#661) * remove NTuple * spaces Co-authored-by: Frames Catherine White Co-authored-by: Frames Catherine White --- Project.toml | 2 +- src/rulesets/Base/broadcast.jl | 4 ++-- test/rulesets/Base/broadcast.jl | 4 ++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index a420d3ed9..15ce2ddb5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.1" +version = "1.44.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index be11eb76a..ddf4dc426 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -328,13 +328,13 @@ end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} - val = if length(x) == length(dx) + val = if N == length(dx) dx else sum(dx; dims=2:ndims(dx)) end eltype(val) <: AbstractZero && return NoTangent() - return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent + return ProjectTo(x)(Tuple{Vararg{Any,N}}(val)) # Tangent end unbroadcast(x::Tuple, dx::AbstractZero) = dx diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 68d47a7d4..219b45a71 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -173,4 +173,8 @@ BT1 = Broadcast.BroadcastStyle(Tuple) test_rrule(copy∘broadcasted, complex, rand()) end end + + @testset "bugs" begin + @test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type + end end \ No newline at end of file