diff --git a/Project.toml b/Project.toml index 14d911028..581934ac9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.4" +version = "1.44.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 5878b93b4..3673c0a43 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -415,7 +415,8 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...) function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) arr_axs = map(axes, arrs) - function add_pullback(dy) + function add_pullback(dy_raw) + dy = unthunk(dy_raw) # reshape will otherwise unthunk N times return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...) end return y, add_pullback diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index ddf4dc426..d376a64f0 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -316,7 +316,8 @@ rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) | # When sizes disagree, broadcasting gradient uses `unbroadcast` to reduce to correct shape. # It's sometimes a little wasteful to allocate a too-large `dx`, but difficult to make more efficient. -function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) +function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw) + dx = unthunk(dx_raw) N = ndims(dx) if length(x) == length(dx) ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors @@ -327,7 +328,8 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx -function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} +function unbroadcast(x::T, dx_raw) where {T<:Tuple{Vararg{Any,N}}} where {N} + dx = unthunk(dx_raw) val = if N == length(dx) dx else