diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index abd422387..be9ca1ac7 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -20,28 +20,36 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) - ∂x = if isempty(x) || p == 0 - InplaceableThunk( - @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))), - identity, - ) + ∂x = InplaceableThunk( + # out-of-place versions + if isempty(x) || p == 0 + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) elseif p == 2 - InplaceableThunk( - @thunk(_norm2_back(x, y, Δy)), - dx -> _norm2_back!(dx, x, y, Δy), - ) + @thunk(_norm2_back(x, y, Δy)) elseif p == 1 - InplaceableThunk( - @thunk(_norm1_back(x, y, Δy)), - dx -> _norm1_back!(dx, x, y, Δy), - ) + @thunk(_norm1_back(x, y, Δy)) elseif p == Inf - _normInf_back(x, y, Δy) + @thunk(_normInf_back(x, y, Δy)) elseif p == -Inf - _normInf_back(x, y, Δy) + @thunk(_normInf_back(x, y, Δy)) else - _normp_back_x(x, p, y, Δy) + @thunk(_normp_back_x(x, p, y, Δy)) + end, + # in-place versions + if isempty(x) || p == 0 + identity + elseif p == 2 + dx -> _norm2_back!(dx, x, y, Δy) + elseif p == 1 + dx -> _norm1_back!(dx, x, y, Δy) + elseif p == Inf + dx -> dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved + elseif p == -Inf + dx -> dx .+= _normInf_back(x, y, Δy) + else + dx -> dx .+= _normp_back_x(x, p, y, Δy) end + ) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end @@ -51,14 +59,19 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) - ∂x = if isempty(x) - zero.(x) .* (zero(y) * zero(real(Δy))) - else - InplaceableThunk( - @thunk(_norm2_back(x, y, Δy)), - dx -> _norm2_back!(dx, x, y, Δy), + ∂x = InplaceableThunk( + if isempty(x) + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) + else + @thunk(_norm2_back(x, y, Δy)) + end + , + if isempty(x) + identity + else + dx -> _norm2_back!(dx, x, y, Δy) + end ) - end return (NO_FIELDS, ∂x) end norm_pullback_2(::Zero) = (NO_FIELDS, Zero())