diff --git a/Project.toml b/Project.toml index 1c45e5a0f..d44b0ab75 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.63" +version = "0.7.64" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 969702826..41792ad4d 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x) y = norm(x) return y, _norm2_forward(x, Δx, norm(x)) end + function frule((_, Δx), ::typeof(norm), x::Number, p::Real) y = norm(x, p) ∂y = if iszero(Δx) || iszero(p) @@ -17,15 +18,12 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real) return y, ∂y end -function rrule( - ::typeof(norm), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, - p::Real, -) +function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) - function norm_pullback(Δy) - ∂x = Thunk() do - return if isempty(x) || p == 0 + function norm_pullback_p(Δy) + ∂x = InplaceableThunk( + # out-of-place versions + @thunk(if isempty(x) || p == 0 zero.(x) .* (zero(y) * zero(real(Δy))) elseif p == 2 _norm2_back(x, y, Δy) @@ -37,35 +35,52 @@ function rrule( _normInf_back(x, y, Δy) else _normp_back_x(x, p, y, Δy) + end) + , # in-place versions -- can be fixed when actually useful? + dx -> if isempty(x) || p == 0 + dx + elseif p == 2 + _norm2_back!(dx, x, y, Δy) + elseif p == 1 + _norm1_back!(dx, x, y, Δy) + elseif p == Inf + dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved + elseif p == -Inf + dx .+= _normInf_back(x, y, Δy) + else + dx .+= _normp_back_x(x, p, y, Δy) end - end + ) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end - norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) - return y, norm_pullback + norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) + return y, norm_pullback_p end -function rrule( - ::typeof(norm), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, -) + +function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) - function norm_pullback(Δy) - ∂x = if isempty(x) - zero.(x) .* (zero(y) * zero(real(Δy))) - else - _norm2_back(x, y, Δy) - end + function norm_pullback_2(Δy) + ∂x = InplaceableThunk( + @thunk(if isempty(x) + zero.(x) .* (zero(y) * zero(real(Δy))) + else + _norm2_back(x, y, Δy) + end) + , + dx -> if isempty(x) + dx + else + _norm2_back!(dx, x, y, Δy) + end + ) return (NO_FIELDS, ∂x) end - norm_pullback(::Zero) = (NO_FIELDS, Zero()) - return y, norm_pullback + norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) + return y, norm_pullback_2 end -function rrule( - ::typeof(norm), - x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec}, - p::Real, -) + +function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real) y, inner_pullback = rrule(norm, parent(x), p) function norm_pullback(Δy) (∂self, ∂x′, ∂p) = inner_pullback(Δy) @@ -75,6 +90,7 @@ function rrule( end return y, norm_pullback end + function rrule(::typeof(norm), x::Number, p::Real) y = norm(x, p) function norm_pullback(Δy) @@ -94,11 +110,7 @@ end ##### `normp` ##### -function rrule( - ::typeof(LinearAlgebra.normp), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, - p, -) +function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) y = LinearAlgebra.normp(x, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) @@ -111,8 +123,7 @@ end function _normp_back_x(x, p, y, Δy) c = real(Δy) / y - ∂x = similar(x) - broadcast!(∂x, x) do xi + ∂x = map(x) do xi a = norm(xi) ∂xi = xi * ((a / y)^(p - 2) * c) return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) @@ -120,6 +131,16 @@ function _normp_back_x(x, p, y, Δy) return ∂x end +function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc. + c = real(Δy) / y + ∂x_data = map(parent(x)) do xi + a = norm(xi) + ∂xi = xi * ((a / y)^(p - 2) * c) + return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) + end + return withsomezeros_rewrap(x, ∂x_data) +end + function _normp_back_p(x, p, y, Δy) y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p) s = sum(x) do xi @@ -135,20 +156,14 @@ end ##### `normMinusInf`/`normInf` ##### -function rrule( - ::typeof(LinearAlgebra.normMinusInf), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, -) +function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero()) return y, normMinusInf_pullback end -function rrule( - ::typeof(LinearAlgebra.normInf), - x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal}, -) +function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normInf_pullback(::Zero) = (NO_FIELDS, Zero()) @@ -172,19 +187,26 @@ end ##### `norm1` ##### -function rrule( - ::typeof(LinearAlgebra.norm1), - x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal}, -) +function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) y = LinearAlgebra.norm1(x) - norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy)) + norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + @thunk(_norm1_back(x, y, Δy)), + dx -> _norm1_back!(dx, x, y, Δy), + )) norm1_pullback(::Zero) = (NO_FIELDS, Zero()) return y, norm1_pullback end function _norm1_back(x, y, Δy) - ∂x = similar(x) - ∂x .= sign.(x) .* real(Δy) + ∂x = sign.(x) .* real(Δy) + return ∂x +end +function _norm1_back(x::WithSomeZeros, y, Δy) + ∂x_data = sign.(parent(x)) .* real(Δy) + return withsomezeros_rewrap(x, ∂x_data) +end +function _norm1_back!(∂x, x, y, Δy) + ∂x .+= sign.(x) .* real(Δy) return ∂x end @@ -197,12 +219,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x) return y, _norm2_forward(x, Δx, y) end -function rrule( - ::typeof(LinearAlgebra.norm2), - x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal}, -) +function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) y = LinearAlgebra.norm2(x) - norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy)) + norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + @thunk(_norm2_back(x, y, Δy)), + dx -> _norm2_back!(dx, x, y, Δy), + )) norm2_pullback(::Zero) = (NO_FIELDS, Zero()) return y, norm2_pullback end @@ -212,16 +234,24 @@ function _norm2_forward(x, Δx, y) return ∂y end function _norm2_back(x, y, Δy) - ∂x = similar(x) - ∂x .= x .* (real(Δy) * pinv(y)) + ∂x = x .* (real(Δy) * pinv(y)) return ∂x end +function _norm2_back(x::WithSomeZeros, y, Δy) + T = typeof(one(eltype(x)) / one(real(eltype(Δy)))) + ∂x_data = parent(x) .* (real(Δy) * pinv(y)) + return withsomezeros_rewrap(x, ∂x_data) +end +function _norm2_back!(∂x, x, y, Δy) + ∂x .+= x .* (real(Δy) * pinv(y)) + return ∂x # must return after mutating +end ##### ##### `normalize` ##### -function rrule(::typeof(normalize), x::AbstractVector, p::Real) +function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) nrm, inner_pullback = rrule(norm, x, p) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) @@ -236,7 +266,8 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real) normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, normalize_pullback end -function rrule(::typeof(normalize), x::AbstractVector) + +function rrule(::typeof(normalize), x::AbstractVector{<:Number}) nrm = LinearAlgebra.norm2(x) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 9aa7b5257..2beb856b6 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -43,3 +43,36 @@ Symmetric ```` """ _unionall_wrapper(::Type{T}) where {T} = T.name.wrapper + +""" + WithSomeZeros{T} + +This is a union of LinearAlgebra types, all of which are partly structral zeros, +with a simple backing array given by `parent(x)`. All have methods of `_rewrap` +to re-create. + +This exists to solve a type instability, as broadcasting for instance +`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`. +But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable. +""" +WithSomeZeros{T} = Union{ + Diagonal{T}, + UpperTriangular{T}, + UnitUpperTriangular{T}, + # UpperHessenberg{T}, # doesn't exist in Julia 1.0 + LowerTriangular{T}, + UnitLowerTriangular{T}, +} +for S in [ + :Diagonal, + :UpperTriangular, + :UnitUpperTriangular, + # :UpperHessenberg, + :LowerTriangular, + :UnitLowerTriangular, +] + @eval withsomezeros_rewrap(::$S, x) = $S(x) +end + +# Bidiagonal, Tridiagonal have more complicated storage. +# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent())) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index cc2a4b034..bb2d20bfd 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -1,4 +1,8 @@ @testset "norm functions" begin + + # First test the un-exported functions which norm(A,p) calls + # ========================================================== + @testset "$fnorm(x::Array{$T,$(length(sz))})" for fnorm in ( LinearAlgebra.norm1, @@ -23,8 +27,10 @@ kwargs = NamedTuple() end - fnorm === LinearAlgebra.norm2 && @testset "frule" begin - test_frule(fnorm, x) + if fnorm === LinearAlgebra.norm2 + @testset "frule" begin + test_frule(fnorm, x) + end end @testset "rrule" begin test_rrule(fnorm, x; kwargs...) @@ -36,7 +42,27 @@ @test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) ≈ zero(x) @test rrule(fnorm, x)[2](Zero())[2] isa Zero end + ndims(x) > 1 && @testset "non-strided" begin + xp = if x isa Matrix + view(x, [1,2,3], 1:3) + elseif x isa Array{T,3} + PermutedDimsArray(x, (1,2,3)) + end + @test !(xp isa StridedArray) + test_rrule(fnorm, xp ⊢ rand(T, size(xp))) + end + T == Float64 && ndims(x) == 1 && @testset "Integer input" begin + x = [1,2,3] + int_fwd, int_back = rrule(fnorm, x) + float_fwd, float_back = rrule(fnorm, float(x)) + @test int_fwd ≈ float_fwd + @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) + end end + + # Next test norm(A, p=2) -- two methods + # ===================================== + @testset "norm(x::Array{$T,$(length(sz))})" for T in (Float64, ComplexF64), sz in [(0,), (3,), (3, 3), (3, 2, 1)] @@ -53,8 +79,6 @@ @testset "rrule" begin test_rrule(norm, x) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - # we don't check inference on older julia versions. Improvements to - # inference mean on 1.5+ it works, and that is good enough test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5") end @@ -62,6 +86,16 @@ @test extern(rrule(norm, zero(x))[2](ȳ)[2]) ≈ zero(x) @test rrule(norm, x)[2](Zero())[2] isa Zero end + ndims(x) > 1 && @testset "non-strided" begin + xp = if x isa Matrix + view(x, [1,2,3], 1:3) + elseif x isa Array{T,3} + PermutedDimsArray(x, (1,2,3)) + end + @test !(xp isa StridedArray) + test_frule(norm, xp ⊢ rand(T, size(xp))) + test_rrule(norm, xp ⊢ rand(T, size(xp))) # rand_tangent does not work here because eltype(xp)==Int + end end @testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for fnorm in (norm, LinearAlgebra.normp), @@ -71,7 +105,7 @@ x = randn(T, sz) # finite differences is unstable if maxabs (minabs) values are not well - # separated from other values + # separated from other values (same as above) if p == Inf if !isempty(x) x[end] = 1000rand(T) @@ -87,20 +121,24 @@ kwargs = NamedTuple() end - test_rrule(fnorm, x, p; kwargs...) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - test_rrule(fnorm, MT(x), p; - #Don't check inference on old julia, what matters is that works on new - check_inferred=VERSION>=v"1.5", kwargs... - ) + test_rrule(fnorm, MT(x), p; kwargs..., check_inferred=VERSION>=v"1.5") end ȳ = rand_tangent(fnorm(x, p)) @test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) ≈ zero(x) @test rrule(fnorm, x, p)[2](Zero())[2] isa Zero + T == Float64 && sz == (3,) && @testset "Integer input, p=$p" begin + x = [1,2,3] + int_fwd, int_back = rrule(fnorm, x, p) + float_fwd, float_back = rrule(fnorm, float(x), p) + @test int_fwd ≈ float_fwd + @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) + end end - @testset "norm($fdual(::Vector{$T}), p)" for + # Extra test for norm(adjoint vector, p) + @testset "norm($fdual(::Vector{$T}), 2.5)" for T in (Float64, ComplexF64), fdual in (adjoint, transpose) @@ -111,6 +149,10 @@ ȳ = rand_tangent(norm(x, p)) @test extern(rrule(norm, x, p)[2](ȳ)[2]) isa typeof(x) end + + # Scalar norm(x, p) + # ================= + @testset "norm(x::$T, p)" for T in (Float64, ComplexF64) @testset "p = $p" for p in (-1.0, 2.0, 2.5) test_frule(norm, randn(T), p) @@ -136,6 +178,9 @@ end end +# normalise(x, p) and normalise(A, p) +# =================================== + @testset "normalize" begin @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3)