From 19578d6a2814fbcdbefd78ec5e14f1e8d271753e Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 24 Dec 2020 13:40:41 +0100 Subject: [PATCH 01/22] widen type of norm pullback, add inplace + tests --- src/rulesets/LinearAlgebra/norm.jl | 90 +++++++++++++++-------------- test/rulesets/LinearAlgebra/norm.jl | 28 +++++++++ 2 files changed, 76 insertions(+), 42 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 969702826..d222c85fe 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -1,3 +1,6 @@ + +rrule(f::Function, args...; kwargs...) = (@error "no rrule defined!" f args ; nothing) + ##### ##### `norm` ##### @@ -17,20 +20,25 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real) return y, ∂y end -function rrule( - ::typeof(norm), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, - p::Real, -) +function rrule(::typeof(norm), x::AbstractArray, p::Real) y = LinearAlgebra.norm(x, p) - function norm_pullback(Δy) + function norm_pullback_p(Δy) ∂x = Thunk() do return if isempty(x) || p == 0 - zero.(x) .* (zero(y) * zero(real(Δy))) + InplaceableThunk( + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))), + dx -> dx .= zero(eltype(dx)), + ) elseif p == 2 - _norm2_back(x, y, Δy) + InplaceableThunk( + @thunk(_norm2_back(x, y, Δy)), + dx -> _norm2_back!(dx, x, y, Δy), + ) elseif p == 1 - _norm1_back(x, y, Δy) + InplaceableThunk( + @thunk(_norm1_back(x, y, Δy)), + dx -> _norm1_back!(dx, x, y, Δy), + ) elseif p == Inf _normInf_back(x, y, Δy) elseif p == -Inf @@ -42,24 +50,24 @@ function rrule( ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end - norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) - return y, norm_pullback + norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) + return y, norm_pullback_p end -function rrule( - ::typeof(norm), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, -) +function rrule(::typeof(norm), x::AbstractArray) y = LinearAlgebra.norm(x) - function norm_pullback(Δy) + function norm_pullback_2(Δy) ∂x = if isempty(x) zero.(x) .* (zero(y) * zero(real(Δy))) else - _norm2_back(x, y, Δy) + InplaceableThunk( + @thunk(_norm2_back(x, y, Δy)), + dx -> _norm2_back!(dx, x, y, Δy), + ) end return (NO_FIELDS, ∂x) end - norm_pullback(::Zero) = (NO_FIELDS, Zero()) - return y, norm_pullback + norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) + return y, norm_pullback_2 end function rrule( ::typeof(norm), @@ -94,11 +102,7 @@ end ##### `normp` ##### -function rrule( - ::typeof(LinearAlgebra.normp), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, - p, -) +function rrule(::typeof(LinearAlgebra.normp),x::AbstractArray, p) y = LinearAlgebra.normp(x, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) @@ -135,20 +139,14 @@ end ##### `normMinusInf`/`normInf` ##### -function rrule( - ::typeof(LinearAlgebra.normMinusInf), - x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal}, -) +function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray) y = LinearAlgebra.normMinusInf(x) normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero()) return y, normMinusInf_pullback end -function rrule( - ::typeof(LinearAlgebra.normInf), - x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal}, -) +function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray) y = LinearAlgebra.normInf(x) normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normInf_pullback(::Zero) = (NO_FIELDS, Zero()) @@ -172,12 +170,12 @@ end ##### `norm1` ##### -function rrule( - ::typeof(LinearAlgebra.norm1), - x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal}, -) +function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray) y = LinearAlgebra.norm1(x) - norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy)) + norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + @thunk(_norm1_back(x, y, Δy)), + dx -> _norm1_back!(dx, x, y, Δy), + )) norm1_pullback(::Zero) = (NO_FIELDS, Zero()) return y, norm1_pullback end @@ -187,6 +185,10 @@ function _norm1_back(x, y, Δy) ∂x .= sign.(x) .* real(Δy) return ∂x end +function _norm1_back!(∂x, x, y, Δy) + ∂x .+= sign.(x) .* real(Δy) + return ∂x +end ##### ##### `norm2` @@ -197,12 +199,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x) return y, _norm2_forward(x, Δx, y) end -function rrule( - ::typeof(LinearAlgebra.norm2), - x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal}, -) +function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray) y = LinearAlgebra.norm2(x) - norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy)) + norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + @thunk(_norm2_back(x, y, Δy)), + dx -> _norm2_back!(dx, x, y, Δy), + )) norm2_pullback(::Zero) = (NO_FIELDS, Zero()) return y, norm2_pullback end @@ -216,6 +218,10 @@ function _norm2_back(x, y, Δy) ∂x .= x .* (real(Δy) * pinv(y)) return ∂x end +function _norm2_back!(∂x, x, y, Δy) + ∂x .+= x .* (real(Δy) * pinv(y)) + return ∂x # must return after mutating +end ##### ##### `normalize` diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index cc2a4b034..cf436cb8f 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -36,6 +36,20 @@ @test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) ≈ zero(x) @test rrule(fnorm, x)[2](Zero())[2] isa Zero end + ndims(x) > 1 && @testset "non-strided" begin + xp = if x isa Matrix + view(x, [1,2,3], 1:3) + elseif x isa Array{T,3} + PermutedDimsArray(x, (1,2,3)) + end + @test !(xp isa StridedArray) + y = fnorm(x) + # ẋ = rand(T, size(xp)) # rand_tangent(xp) + x̄ = rand(T, size(xp)) # rand_tangent(xp) + ȳ = rand_tangent(y) + # frule_test(fnorm, (xp, ẋ)) + rrule_test(fnorm, ȳ, (xp, x̄)) + end end @testset "norm(x::Array{$T,$(length(sz))})" for T in (Float64, ComplexF64), @@ -62,6 +76,20 @@ @test extern(rrule(norm, zero(x))[2](ȳ)[2]) ≈ zero(x) @test rrule(norm, x)[2](Zero())[2] isa Zero end + ndims(x) > 1 && @testset "non-strided" begin + xp = if x isa Matrix + view(x, [1,2,3], 1:3) + elseif x isa Array{T,3} + PermutedDimsArray(x, (1,2,3)) + end + @test !(xp isa StridedArray) + y = norm(x) + ẋ = rand(T, size(xp)) # rand_tangent(xp) + x̄ = rand(T, size(xp)) # rand_tangent(xp) + ȳ = rand_tangent(y) + frule_test(norm, (xp, ẋ)) + rrule_test(norm, ȳ, (xp, x̄)) + end end @testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for fnorm in (norm, LinearAlgebra.normp), From 99939795d47dbf01344d6f8c3da5953827efd970 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 26 Dec 2020 17:58:28 +0100 Subject: [PATCH 02/22] tidy --- src/rulesets/LinearAlgebra/norm.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index d222c85fe..c40a41444 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -1,6 +1,3 @@ - -rrule(f::Function, args...; kwargs...) = (@error "no rrule defined!" f args ; nothing) - ##### ##### `norm` ##### @@ -69,11 +66,7 @@ function rrule(::typeof(norm), x::AbstractArray) norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) return y, norm_pullback_2 end -function rrule( - ::typeof(norm), - x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec}, - p::Real, -) +function rrule(::typeof(norm), x::Union{LinearAlgebra.AdjOrTransAbsVec}, p::Real) y, inner_pullback = rrule(norm, parent(x), p) function norm_pullback(Δy) (∂self, ∂x′, ∂p) = inner_pullback(Δy) From 6b6487d825bcbb161d8cdfd48b3c16e77b966f07 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 26 Dec 2020 18:43:50 +0100 Subject: [PATCH 03/22] allow, and test, integer x --- src/rulesets/LinearAlgebra/norm.jl | 9 +++------ test/rulesets/LinearAlgebra/norm.jl | 12 ++++++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index c40a41444..688c16a8b 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -108,8 +108,7 @@ end function _normp_back_x(x, p, y, Δy) c = real(Δy) / y - ∂x = similar(x) - broadcast!(∂x, x) do xi + ∂x = broadcast(x) do xi a = norm(xi) ∂xi = xi * ((a / y)^(p - 2) * c) return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) @@ -174,8 +173,7 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray) end function _norm1_back(x, y, Δy) - ∂x = similar(x) - ∂x .= sign.(x) .* real(Δy) + ∂x = sign.(x) .* real(Δy) return ∂x end function _norm1_back!(∂x, x, y, Δy) @@ -207,8 +205,7 @@ function _norm2_forward(x, Δx, y) return ∂y end function _norm2_back(x, y, Δy) - ∂x = similar(x) - ∂x .= x .* (real(Δy) * pinv(y)) + ∂x = x .* (real(Δy) * pinv(y)) return ∂x end function _norm2_back!(∂x, x, y, Δy) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index cf436cb8f..d77b10aac 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -50,6 +50,12 @@ # frule_test(fnorm, (xp, ẋ)) rrule_test(fnorm, ȳ, (xp, x̄)) end + T == Float64 && ndims(x) == 1 && @testset "Integer input" begin + x = [1,2,3] + _, int_back = rrule(fnorm, x) + _, float_back = rrule(fnorm, float(x)) + @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) + end end @testset "norm(x::Array{$T,$(length(sz))})" for T in (Float64, ComplexF64), @@ -127,6 +133,12 @@ ȳ = rand_tangent(fnorm(x, p)) @test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) ≈ zero(x) @test rrule(fnorm, x, p)[2](Zero())[2] isa Zero + T == Float64 && sz == (3,) && @testset "Integer input, p=$p" begin + x = [1,2,3] + _, int_back = rrule(fnorm, x, p) + _, float_back = rrule(fnorm, float(x), p) + @test unthunk(unthunk(int_back(1.0)[2])) ≈ unthunk(unthunk(float_back(1.0)[2])) + end end @testset "norm($fdual(::Vector{$T}), p)" for T in (Float64, ComplexF64), From b03b83dba11aab3efd327c13b24f131ea73e5ba4 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 27 Dec 2020 17:52:23 +0100 Subject: [PATCH 04/22] =?UTF-8?q?don't=20let=20broadcast=20allocate=20?= =?UTF-8?q?=E2=88=82x=20because=20it's=20weird?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rulesets/LinearAlgebra/norm.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 688c16a8b..575e587fc 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -173,7 +173,11 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray) end function _norm1_back(x, y, Δy) - ∂x = sign.(x) .* real(Δy) + T = promote_type(eltype(x), real(eltype(Δy))) + ∂x = similar(x, T) + # The reason not to let broadcast allocate ∂x is that NaN .* Diagonal(ones(3)) isa Matrix, + # while pi .* Diagonal(ones(3)) isa Diagonal, hence this would be type-unstable. + ∂x .= sign.(x) .* real(Δy) return ∂x end function _norm1_back!(∂x, x, y, Δy) @@ -205,7 +209,9 @@ function _norm2_forward(x, Δx, y) return ∂y end function _norm2_back(x, y, Δy) - ∂x = x .* (real(Δy) * pinv(y)) + T = typeof(one(eltype(x)) / one(real(eltype(Δy)))) + ∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability. + ∂x .= x .* (real(Δy) * pinv(y)) return ∂x end function _norm2_back!(∂x, x, y, Δy) From b666c37006ee587e0e5d85c186fef95a18365aad Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 27 Dec 2020 18:18:50 +0100 Subject: [PATCH 05/22] one more --- src/rulesets/LinearAlgebra/norm.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 575e587fc..37222a31a 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -108,7 +108,9 @@ end function _normp_back_x(x, p, y, Δy) c = real(Δy) / y - ∂x = broadcast(x) do xi + T = promote_type(eltype(x), typeof(c)) + ∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability. + ∂x = broadcast!(∂x, x) do xi a = norm(xi) ∂xi = xi * ((a / y)^(p - 2) * c) return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) From 972b6dd86221652ed9e575a0e2e07ff352e0dbec Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 27 Dec 2020 23:44:48 +0100 Subject: [PATCH 06/22] Apply suggestions from code review Co-authored-by: Simeon Schaub --- src/rulesets/LinearAlgebra/norm.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 37222a31a..b8c93a203 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -95,7 +95,7 @@ end ##### `normp` ##### -function rrule(::typeof(LinearAlgebra.normp),x::AbstractArray, p) +function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray, p) y = LinearAlgebra.normp(x, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) @@ -110,7 +110,7 @@ function _normp_back_x(x, p, y, Δy) c = real(Δy) / y T = promote_type(eltype(x), typeof(c)) ∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability. - ∂x = broadcast!(∂x, x) do xi + map!(∂x, x) do xi a = norm(xi) ∂xi = xi * ((a / y)^(p - 2) * c) return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) From c12b82972c07bfc179cf246578859c8b7c5a872f Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 28 Dec 2020 00:29:39 +0100 Subject: [PATCH 07/22] + tests --- test/rulesets/LinearAlgebra/norm.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index d77b10aac..542798e90 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -52,8 +52,9 @@ end T == Float64 && ndims(x) == 1 && @testset "Integer input" begin x = [1,2,3] - _, int_back = rrule(fnorm, x) - _, float_back = rrule(fnorm, float(x)) + int_fwd, int_back = rrule(fnorm, x) + float_fwd, float_back = rrule(fnorm, float(x)) + @test int_fwd ≈ float_fwd @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) end end @@ -135,8 +136,9 @@ @test rrule(fnorm, x, p)[2](Zero())[2] isa Zero T == Float64 && sz == (3,) && @testset "Integer input, p=$p" begin x = [1,2,3] - _, int_back = rrule(fnorm, x, p) - _, float_back = rrule(fnorm, float(x), p) + int_fwd, int_back = rrule(fnorm, x, p) + float_fwd, float_back = rrule(fnorm, float(x), p) + @test int_fwd ≈ float_fwd @test unthunk(unthunk(int_back(1.0)[2])) ≈ unthunk(unthunk(float_back(1.0)[2])) end end From e4d60f24dc64369c1d1f55d01eec7d907484a935 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 4 Jan 2021 22:58:25 +0100 Subject: [PATCH 08/22] a type-stability fix via parent & re-wrap --- src/rulesets/LinearAlgebra/norm.jl | 32 ++++++++++++++++++---------- src/rulesets/LinearAlgebra/utils.jl | 33 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index b8c93a203..341f667bf 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -108,15 +108,22 @@ end function _normp_back_x(x, p, y, Δy) c = real(Δy) / y - T = promote_type(eltype(x), typeof(c)) - ∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability. - map!(∂x, x) do xi + ∂x = map(x) do xi a = norm(xi) ∂xi = xi * ((a / y)^(p - 2) * c) return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) end return ∂x end +function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc. + c = real(Δy) / y + ∂x_data = map(parent(x)) do xi + a = norm(xi) + ∂xi = xi * ((a / y)^(p - 2) * c) + return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) + end + return withsomezeros_rewrap(x, ∂x_data) +end function _normp_back_p(x, p, y, Δy) y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p) @@ -175,13 +182,13 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray) end function _norm1_back(x, y, Δy) - T = promote_type(eltype(x), real(eltype(Δy))) - ∂x = similar(x, T) - # The reason not to let broadcast allocate ∂x is that NaN .* Diagonal(ones(3)) isa Matrix, - # while pi .* Diagonal(ones(3)) isa Diagonal, hence this would be type-unstable. - ∂x .= sign.(x) .* real(Δy) + ∂x = sign.(x) .* real(Δy) return ∂x end +function _norm1_back(x::WithSomeZeros, y, Δy) + ∂x_data = sign.(parent(x)) .* real(Δy) + return withsomezeros_rewrap(x, ∂x_data) +end function _norm1_back!(∂x, x, y, Δy) ∂x .+= sign.(x) .* real(Δy) return ∂x @@ -211,11 +218,14 @@ function _norm2_forward(x, Δx, y) return ∂y end function _norm2_back(x, y, Δy) - T = typeof(one(eltype(x)) / one(real(eltype(Δy)))) - ∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability. - ∂x .= x .* (real(Δy) * pinv(y)) + ∂x = x .* (real(Δy) * pinv(y)) return ∂x end +function _norm2_back(x::WithSomeZeros, y, Δy) + T = typeof(one(eltype(x)) / one(real(eltype(Δy)))) + ∂x_data = parent(x) .* (real(Δy) * pinv(y)) + return withsomezeros_rewrap(x, ∂x_data) +end function _norm2_back!(∂x, x, y, Δy) ∂x .+= x .* (real(Δy) * pinv(y)) return ∂x # must return after mutating diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 9aa7b5257..9a5a4ce0b 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -43,3 +43,36 @@ Symmetric ```` """ _unionall_wrapper(::Type{T}) where {T} = T.name.wrapper + +""" + WithSomeZeros{T} + +This is a union of LinearAlgebra types, all of which are partly structral zeros, +with a simple backing array given by `parent(x)`. All have methods of `_rewrap` +to re-create. + +This exists to solve a type instability, as broadcasting for instance +`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`. +But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable. +""" +WithSomeZeros{T} = Union{ + Diagonal{T}, + UpperTriangular{T}, + UnitUpperTriangular{T}, + UpperHessenberg{T}, + LowerTriangular{T}, + UnitLowerTriangular{T}, +} +for S in [ + :Diagonal, + :UpperTriangular, + :UnitUpperTriangular, + :UpperHessenberg, + :LowerTriangular, + :UnitLowerTriangular, +] + @eval withsomezeros_rewrap(::$S, x) = $S(x) +end + +# Bidiagonal, Tridiagonal have more complicated storage. +# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent())) From 96eafd6f8dca849156a4dcbd96c881b623d2d94b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 10:47:10 -0400 Subject: [PATCH 09/22] comments from review --- src/rulesets/LinearAlgebra/norm.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 341f667bf..910e37e1f 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -20,11 +20,10 @@ end function rrule(::typeof(norm), x::AbstractArray, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) - ∂x = Thunk() do - return if isempty(x) || p == 0 + ∂x = if isempty(x) || p == 0 InplaceableThunk( @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))), - dx -> dx .= zero(eltype(dx)), + identity, ) elseif p == 2 InplaceableThunk( @@ -43,7 +42,6 @@ function rrule(::typeof(norm), x::AbstractArray, p::Real) else _normp_back_x(x, p, y, Δy) end - end ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end From 07eb77b75bb754be8e47a30d4628ba69d9b2dba9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 10:51:04 -0400 Subject: [PATCH 10/22] restrict all arrays to eltype Number --- src/rulesets/LinearAlgebra/norm.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 910e37e1f..abd422387 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -17,7 +17,7 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real) return y, ∂y end -function rrule(::typeof(norm), x::AbstractArray, p::Real) +function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) ∂x = if isempty(x) || p == 0 @@ -48,7 +48,7 @@ function rrule(::typeof(norm), x::AbstractArray, p::Real) norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, norm_pullback_p end -function rrule(::typeof(norm), x::AbstractArray) +function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) ∂x = if isempty(x) @@ -64,7 +64,7 @@ function rrule(::typeof(norm), x::AbstractArray) norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) return y, norm_pullback_2 end -function rrule(::typeof(norm), x::Union{LinearAlgebra.AdjOrTransAbsVec}, p::Real) +function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real) y, inner_pullback = rrule(norm, parent(x), p) function norm_pullback(Δy) (∂self, ∂x′, ∂p) = inner_pullback(Δy) @@ -93,7 +93,7 @@ end ##### `normp` ##### -function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray, p) +function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) y = LinearAlgebra.normp(x, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) @@ -138,14 +138,14 @@ end ##### `normMinusInf`/`normInf` ##### -function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero()) return y, normMinusInf_pullback end -function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normInf_pullback(::Zero) = (NO_FIELDS, Zero()) @@ -169,7 +169,7 @@ end ##### `norm1` ##### -function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) y = LinearAlgebra.norm1(x) norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( @thunk(_norm1_back(x, y, Δy)), @@ -201,7 +201,7 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x) return y, _norm2_forward(x, Δx, y) end -function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) y = LinearAlgebra.norm2(x) norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( @thunk(_norm2_back(x, y, Δy)), @@ -233,7 +233,7 @@ end ##### `normalize` ##### -function rrule(::typeof(normalize), x::AbstractVector, p::Real) +function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) nrm, inner_pullback = rrule(norm, x, p) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) @@ -248,7 +248,7 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real) normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, normalize_pullback end -function rrule(::typeof(normalize), x::AbstractVector) +function rrule(::typeof(normalize), x::AbstractVector{<:Number}) nrm = LinearAlgebra.norm2(x) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) From d3833fc9c3315a7d0ef0df27d6bd0d20e6b97d09 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 11:00:17 -0400 Subject: [PATCH 11/22] skip UpperHessenberg --- src/rulesets/LinearAlgebra/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 9a5a4ce0b..2beb856b6 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -59,7 +59,7 @@ WithSomeZeros{T} = Union{ Diagonal{T}, UpperTriangular{T}, UnitUpperTriangular{T}, - UpperHessenberg{T}, + # UpperHessenberg{T}, # doesn't exist in Julia 1.0 LowerTriangular{T}, UnitLowerTriangular{T}, } @@ -67,7 +67,7 @@ for S in [ :Diagonal, :UpperTriangular, :UnitUpperTriangular, - :UpperHessenberg, + # :UpperHessenberg, :LowerTriangular, :UnitLowerTriangular, ] From 79ecf7347a5c3266c379b185dd05714b2e2743cf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 11:23:47 -0400 Subject: [PATCH 12/22] version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1c45e5a0f..d44b0ab75 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.63" +version = "0.7.64" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From f1b5ba4d1cc17d4e244bc5d3e2b10caa585abc36 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 11:26:11 -0400 Subject: [PATCH 13/22] move branches into InplaceableThunk --- src/rulesets/LinearAlgebra/norm.jl | 59 ++++++++++++++++++------------ 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index abd422387..be9ca1ac7 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -20,28 +20,36 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) - ∂x = if isempty(x) || p == 0 - InplaceableThunk( - @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))), - identity, - ) + ∂x = InplaceableThunk( + # out-of-place versions + if isempty(x) || p == 0 + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) elseif p == 2 - InplaceableThunk( - @thunk(_norm2_back(x, y, Δy)), - dx -> _norm2_back!(dx, x, y, Δy), - ) + @thunk(_norm2_back(x, y, Δy)) elseif p == 1 - InplaceableThunk( - @thunk(_norm1_back(x, y, Δy)), - dx -> _norm1_back!(dx, x, y, Δy), - ) + @thunk(_norm1_back(x, y, Δy)) elseif p == Inf - _normInf_back(x, y, Δy) + @thunk(_normInf_back(x, y, Δy)) elseif p == -Inf - _normInf_back(x, y, Δy) + @thunk(_normInf_back(x, y, Δy)) else - _normp_back_x(x, p, y, Δy) + @thunk(_normp_back_x(x, p, y, Δy)) + end, + # in-place versions + if isempty(x) || p == 0 + identity + elseif p == 2 + dx -> _norm2_back!(dx, x, y, Δy) + elseif p == 1 + dx -> _norm1_back!(dx, x, y, Δy) + elseif p == Inf + dx -> dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved + elseif p == -Inf + dx -> dx .+= _normInf_back(x, y, Δy) + else + dx -> dx .+= _normp_back_x(x, p, y, Δy) end + ) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end @@ -51,14 +59,19 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) - ∂x = if isempty(x) - zero.(x) .* (zero(y) * zero(real(Δy))) - else - InplaceableThunk( - @thunk(_norm2_back(x, y, Δy)), - dx -> _norm2_back!(dx, x, y, Δy), + ∂x = InplaceableThunk( + if isempty(x) + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) + else + @thunk(_norm2_back(x, y, Δy)) + end + , + if isempty(x) + identity + else + dx -> _norm2_back!(dx, x, y, Δy) + end ) - end return (NO_FIELDS, ∂x) end norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) From 3f893e42b396edd96904ac1ea7c85ca77c72b05f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 14:18:14 -0400 Subject: [PATCH 14/22] give up on inplacethunk --- src/rulesets/LinearAlgebra/norm.jl | 72 +++++++++++++++--------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index be9ca1ac7..5444695d2 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -20,36 +20,36 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) - ∂x = InplaceableThunk( + # ∂x = InplaceableThunk( # out-of-place versions - if isempty(x) || p == 0 - @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) + ∂x = @thunk(if isempty(x) || p == 0 + zero.(x) .* (zero(y) * zero(real(Δy))) elseif p == 2 - @thunk(_norm2_back(x, y, Δy)) + _norm2_back(x, y, Δy) elseif p == 1 - @thunk(_norm1_back(x, y, Δy)) + _norm1_back(x, y, Δy) elseif p == Inf - @thunk(_normInf_back(x, y, Δy)) + _normInf_back(x, y, Δy) elseif p == -Inf - @thunk(_normInf_back(x, y, Δy)) + _normInf_back(x, y, Δy) else - @thunk(_normp_back_x(x, p, y, Δy)) - end, - # in-place versions - if isempty(x) || p == 0 - identity - elseif p == 2 - dx -> _norm2_back!(dx, x, y, Δy) - elseif p == 1 - dx -> _norm1_back!(dx, x, y, Δy) - elseif p == Inf - dx -> dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved - elseif p == -Inf - dx -> dx .+= _normInf_back(x, y, Δy) - else - dx -> dx .+= _normp_back_x(x, p, y, Δy) - end - ) + _normp_back_x(x, p, y, Δy) + end) + # , # in-place versions -- can be fixed when actually useful? + # dx -> if isempty(x) || p == 0 + # dx + # elseif p == 2 + # _norm2_back!(dx, x, y, Δy) + # elseif p == 1 + # _norm1_back!(dx, x, y, Δy) + # elseif p == Inf + # dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved + # elseif p == -Inf + # dx .+= _normInf_back(x, y, Δy) + # else + # dx .+= _normp_back_x(x, p, y, Δy) + # end + # ) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end @@ -59,19 +59,19 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) - ∂x = InplaceableThunk( - if isempty(x) - @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) - else - @thunk(_norm2_back(x, y, Δy)) - end - , - if isempty(x) - identity + # ∂x = InplaceableThunk( + ∂x = @thunk(if isempty(x) + zero.(x) .* (zero(y) * zero(real(Δy))) else - dx -> _norm2_back!(dx, x, y, Δy) - end - ) + _norm2_back(x, y, Δy) + end) + # , + # dx -> if isempty(x) + # dx + # else + # _norm2_back!(dx, x, y, Δy) + # end + # ) return (NO_FIELDS, ∂x) end norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) From 46bf4c598bc92ca9962fb4f7dd9d2edeb520b9a1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 15:05:47 -0400 Subject: [PATCH 15/22] try to find out where we are --- test/rulesets/LinearAlgebra/norm.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 542798e90..b497e2745 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -8,6 +8,7 @@ ), T in (Float64, ComplexF64), sz in [(3,), (3, 3), (3, 2, 1)] +println("starting unexported fnorm=$fnorm, T=$T, sz=$sz") x = randn(T, sz) # finite differences is unstable if maxabs (minabs) values are not well @@ -37,6 +38,7 @@ @test rrule(fnorm, x)[2](Zero())[2] isa Zero end ndims(x) > 1 && @testset "non-strided" begin +println("... non-strided") xp = if x isa Matrix view(x, [1,2,3], 1:3) elseif x isa Array{T,3} @@ -51,6 +53,7 @@ rrule_test(fnorm, ȳ, (xp, x̄)) end T == Float64 && ndims(x) == 1 && @testset "Integer input" begin +println("... integer") x = [1,2,3] int_fwd, int_back = rrule(fnorm, x) float_fwd, float_back = rrule(fnorm, float(x)) @@ -61,6 +64,7 @@ @testset "norm(x::Array{$T,$(length(sz))})" for T in (Float64, ComplexF64), sz in [(0,), (3,), (3, 3), (3, 2, 1)] +println("starting exported norm T=$T, sz=$sz") x = randn(T, sz) @@ -84,6 +88,7 @@ @test rrule(norm, x)[2](Zero())[2] isa Zero end ndims(x) > 1 && @testset "non-strided" begin +println("... non-strided'") xp = if x isa Matrix view(x, [1,2,3], 1:3) elseif x isa Array{T,3} @@ -103,6 +108,7 @@ p in (1.0, 2.0, Inf, -Inf, 2.5), T in (Float64, ComplexF64), sz in (fnorm === norm ? [(0,), (3,), (3, 3), (3, 2, 1)] : [(3,), (3, 3), (3, 2, 1)]) +println("starting p-norm p=$p, T=$T, sz=$sz") x = randn(T, sz) # finite differences is unstable if maxabs (minabs) values are not well @@ -145,6 +151,7 @@ @testset "norm($fdual(::Vector{$T}), p)" for T in (Float64, ComplexF64), fdual in (adjoint, transpose) +println("starting $fdual norm T=$T") x = fdual(randn(T, 3)) p = 2.5 @@ -155,6 +162,7 @@ end @testset "norm(x::$T, p)" for T in (Float64, ComplexF64) @testset "p = $p" for p in (-1.0, 2.0, 2.5) +println("starting scalar p-norm tests, p=$p, T=$T") test_frule(norm, randn(T), p) test_rrule(norm, randn(T), p) @@ -162,6 +170,7 @@ @test back(Zero()) == (NO_FIELDS, Zero(), Zero()) end @testset "p = 0" begin +println("starting 0-norm tests, T=$T") p = 0.0 x = randn(T) y = norm(x, p) From 71d4ac9cae7f3793fd1a40f2da4ed1f18fba2471 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 17:32:46 -0400 Subject: [PATCH 16/22] don't count the thunks --- test/rulesets/LinearAlgebra/norm.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index b497e2745..8b3828fd0 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -1,3 +1,5 @@ +@eval ChainRulesTestUtils check_thunking_is_appropriate(_) = nothing + @testset "norm functions" begin @testset "$fnorm(x::Array{$T,$(length(sz))})" for fnorm in ( From e400bd523449265b76d52a744d3aad1905f4946c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 18:00:36 -0400 Subject: [PATCH 17/22] rm nested unthunking --- test/rulesets/LinearAlgebra/norm.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 8b3828fd0..f4081b295 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -147,7 +147,7 @@ println("starting p-norm p=$p, T=$T, sz=$sz") int_fwd, int_back = rrule(fnorm, x, p) float_fwd, float_back = rrule(fnorm, float(x), p) @test int_fwd ≈ float_fwd - @test unthunk(unthunk(int_back(1.0)[2])) ≈ unthunk(unthunk(float_back(1.0)[2])) + @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) end end @testset "norm($fdual(::Vector{$T}), p)" for From 81f85b7faef44dfd9837435eb6fbbf3d004cf6cf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 4 May 2021 20:59:42 -0400 Subject: [PATCH 18/22] make tests pass --- src/rulesets/LinearAlgebra/norm.jl | 58 ++++++++++++++++------------- test/rulesets/LinearAlgebra/norm.jl | 10 +---- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 5444695d2..41792ad4d 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x) y = norm(x) return y, _norm2_forward(x, Δx, norm(x)) end + function frule((_, Δx), ::typeof(norm), x::Number, p::Real) y = norm(x, p) ∂y = if iszero(Δx) || iszero(p) @@ -20,9 +21,9 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) - # ∂x = InplaceableThunk( + ∂x = InplaceableThunk( # out-of-place versions - ∂x = @thunk(if isempty(x) || p == 0 + @thunk(if isempty(x) || p == 0 zero.(x) .* (zero(y) * zero(real(Δy))) elseif p == 2 _norm2_back(x, y, Δy) @@ -35,48 +36,50 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) else _normp_back_x(x, p, y, Δy) end) - # , # in-place versions -- can be fixed when actually useful? - # dx -> if isempty(x) || p == 0 - # dx - # elseif p == 2 - # _norm2_back!(dx, x, y, Δy) - # elseif p == 1 - # _norm1_back!(dx, x, y, Δy) - # elseif p == Inf - # dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved - # elseif p == -Inf - # dx .+= _normInf_back(x, y, Δy) - # else - # dx .+= _normp_back_x(x, p, y, Δy) - # end - # ) + , # in-place versions -- can be fixed when actually useful? + dx -> if isempty(x) || p == 0 + dx + elseif p == 2 + _norm2_back!(dx, x, y, Δy) + elseif p == 1 + _norm1_back!(dx, x, y, Δy) + elseif p == Inf + dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved + elseif p == -Inf + dx .+= _normInf_back(x, y, Δy) + else + dx .+= _normp_back_x(x, p, y, Δy) + end + ) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, norm_pullback_p end + function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) - # ∂x = InplaceableThunk( - ∂x = @thunk(if isempty(x) + ∂x = InplaceableThunk( + @thunk(if isempty(x) zero.(x) .* (zero(y) * zero(real(Δy))) else _norm2_back(x, y, Δy) end) - # , - # dx -> if isempty(x) - # dx - # else - # _norm2_back!(dx, x, y, Δy) - # end - # ) + , + dx -> if isempty(x) + dx + else + _norm2_back!(dx, x, y, Δy) + end + ) return (NO_FIELDS, ∂x) end norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) return y, norm_pullback_2 end + function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real) y, inner_pullback = rrule(norm, parent(x), p) function norm_pullback(Δy) @@ -87,6 +90,7 @@ function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::R end return y, norm_pullback end + function rrule(::typeof(norm), x::Number, p::Real) y = norm(x, p) function norm_pullback(Δy) @@ -126,6 +130,7 @@ function _normp_back_x(x, p, y, Δy) end return ∂x end + function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc. c = real(Δy) / y ∂x_data = map(parent(x)) do xi @@ -261,6 +266,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, normalize_pullback end + function rrule(::typeof(normalize), x::AbstractVector{<:Number}) nrm = LinearAlgebra.norm2(x) Ty = typeof(first(x) / nrm) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index f4081b295..9ea072812 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -1,5 +1,3 @@ -@eval ChainRulesTestUtils check_thunking_is_appropriate(_) = nothing - @testset "norm functions" begin @testset "$fnorm(x::Array{$T,$(length(sz))})" for fnorm in ( @@ -80,8 +78,6 @@ println("starting exported norm T=$T, sz=$sz") @testset "rrule" begin test_rrule(norm, x) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - # we don't check inference on older julia versions. Improvements to - # inference mean on 1.5+ it works, and that is good enough test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5") end @@ -130,13 +126,9 @@ println("starting p-norm p=$p, T=$T, sz=$sz") kwargs = NamedTuple() end - test_rrule(fnorm, x, p; kwargs...) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - test_rrule(fnorm, MT(x), p; - #Don't check inference on old julia, what matters is that works on new - check_inferred=VERSION>=v"1.5", kwargs... - ) + test_rrule(fnorm, MT(x), p; kwargs..., check_inferred=VERSION>=v"1.5") end ȳ = rand_tangent(fnorm(x, p)) From 43a275678d90290432f3bf94b27f63e39000b4b2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 4 May 2021 20:59:54 -0400 Subject: [PATCH 19/22] comments --- test/rulesets/LinearAlgebra/norm.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 9ea072812..0085d8921 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -1,4 +1,8 @@ @testset "norm functions" begin + + # First test the un-exported functions which norm(A,p) calls + # ========================================================== + @testset "$fnorm(x::Array{$T,$(length(sz))})" for fnorm in ( LinearAlgebra.norm1, @@ -61,6 +65,10 @@ println("... integer") @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) end end + + # Next test norm(x, p=2) -- two methods + # ===================================== + @testset "norm(x::Array{$T,$(length(sz))})" for T in (Float64, ComplexF64), sz in [(0,), (3,), (3, 3), (3, 2, 1)] @@ -142,7 +150,8 @@ println("starting p-norm p=$p, T=$T, sz=$sz") @test unthunk(int_back(1.0)[2]) ≈ unthunk(float_back(1.0)[2]) end end - @testset "norm($fdual(::Vector{$T}), p)" for + # Extra test for norm(adjoint vector, p) + @testset "norm($fdual(::Vector{$T}), 2.5)" for T in (Float64, ComplexF64), fdual in (adjoint, transpose) println("starting $fdual norm T=$T") @@ -154,6 +163,10 @@ println("starting $fdual norm T=$T") ȳ = rand_tangent(norm(x, p)) @test extern(rrule(norm, x, p)[2](ȳ)[2]) isa typeof(x) end + + # Scalar norm(x, p) + # ================= + @testset "norm(x::$T, p)" for T in (Float64, ComplexF64) @testset "p = $p" for p in (-1.0, 2.0, 2.5) println("starting scalar p-norm tests, p=$p, T=$T") @@ -181,6 +194,9 @@ println("starting 0-norm tests, T=$T") end end +# normalise(x, p) and normalise(A, p) +# =================================== + @testset "normalize" begin @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) From 0331eda7f71cafa01fd8902d9146e966b7f298d2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 4 May 2021 21:18:11 -0400 Subject: [PATCH 20/22] update notation --- test/rulesets/LinearAlgebra/norm.jl | 57 ++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 0085d8921..07089b991 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -28,8 +28,10 @@ println("starting unexported fnorm=$fnorm, T=$T, sz=$sz") kwargs = NamedTuple() end - fnorm === LinearAlgebra.norm2 && @testset "frule" begin - test_frule(fnorm, x) + if fnorm === LinearAlgebra.norm2 + @testset "frule" begin + test_frule(fnorm, x) + end end @testset "rrule" begin test_rrule(fnorm, x; kwargs...) @@ -49,12 +51,37 @@ println("... non-strided") PermutedDimsArray(x, (1,2,3)) end @test !(xp isa StridedArray) - y = fnorm(x) - # ẋ = rand(T, size(xp)) # rand_tangent(xp) - x̄ = rand(T, size(xp)) # rand_tangent(xp) - ȳ = rand_tangent(y) - # frule_test(fnorm, (xp, ẋ)) - rrule_test(fnorm, ȳ, (xp, x̄)) + # y = fnorm(x) + # # ẋ = rand(T, size(xp)) # rand_tangent(xp) + # x̄ = rand(T, size(xp)) # rand_tangent(xp) + # ȳ = rand_tangent(y) + # # frule_test(fnorm, (xp, ẋ)) + # rrule_test(fnorm, ȳ, (xp, x̄)) # old notation, gives a depwarn +#= +┌ Warning: `rrule_test(f, ȳ, inputs::Tuple{Any, Any}...; kwargs...)` is deprecated, use `test_rrule(f, (x ⊢ dx for (x, dx) = inputs)...; output_tangent = ȳ, kwargs...)` instead. +│ caller = macro expansion at norm.jl:57 [inlined] +└ @ Core ~/.julia/dev/ChainRules/test/rulesets/LinearAlgebra/norm.jl:57 +=# + # @show typeof(xp) + # test_rrule(fnorm, xp) # new notation, gives a spectacular failure: +#= +typeof(xp) = SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false} +test_rrule: norm1 at ([0.2972879845354616 -0.01044524463737564 2.2950878238373105; 0.3823959677906078 -0.839026854388764 -2.2670863488005306; -0.5976344767282311 0.31111133849833383 0.5299655761667461],): Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/bDd51/src/testers.jl:168 + Got exception outside of a @test + MethodError: no method matching +(::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, ::Matrix{Float64}) + Closest candidates are: + +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560 + +(::Composite{P, T} where T, ::Composite{P, T} where T) where P at /Users/me/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:167 + +(::Composite, ::AbstractThunk) at /Users/me/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:161 + ... + Stacktrace: + [1] +(a::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, b::InplaceableThunk{Thunk{ChainRules.var"#1798#1801"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}, ChainRules.var"#1799#1802"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}) + @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:161 + [2] add!!(x::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, t::InplaceableThunk{Thunk{ChainRules.var"#1798#1801"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}, ChainRules.var"#1799#1802"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}) + @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1qau5/src/accumulation.jl:23 +=# + test_rrule(fnorm, xp ⊢ rand(T, size(xp))) # ok, this passes! + end T == Float64 && ndims(x) == 1 && @testset "Integer input" begin println("... integer") @@ -101,12 +128,14 @@ println("... non-strided'") PermutedDimsArray(x, (1,2,3)) end @test !(xp isa StridedArray) - y = norm(x) - ẋ = rand(T, size(xp)) # rand_tangent(xp) - x̄ = rand(T, size(xp)) # rand_tangent(xp) - ȳ = rand_tangent(y) - frule_test(norm, (xp, ẋ)) - rrule_test(norm, ȳ, (xp, x̄)) + # y = norm(x) + # ẋ = rand(T, size(xp)) # rand_tangent(xp) + # x̄ = rand(T, size(xp)) # rand_tangent(xp) + # ȳ = rand_tangent(y) + # frule_test(norm, (xp, ẋ)) + # rrule_test(norm, ȳ, (xp, x̄)) + test_frule(norm, xp ⊢ rand(T, size(xp))) + test_rrule(norm, xp ⊢ rand(T, size(xp))) # rand_tangent does not work here end end @testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for From 69a92dcddeb08a94c610304eaf6c7656c191dc72 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 4 May 2021 21:49:15 -0400 Subject: [PATCH 21/22] tidy --- test/rulesets/LinearAlgebra/norm.jl | 51 ++--------------------------- 1 file changed, 3 insertions(+), 48 deletions(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 07089b991..612d04a8c 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -12,7 +12,6 @@ ), T in (Float64, ComplexF64), sz in [(3,), (3, 3), (3, 2, 1)] -println("starting unexported fnorm=$fnorm, T=$T, sz=$sz") x = randn(T, sz) # finite differences is unstable if maxabs (minabs) values are not well @@ -44,47 +43,15 @@ println("starting unexported fnorm=$fnorm, T=$T, sz=$sz") @test rrule(fnorm, x)[2](Zero())[2] isa Zero end ndims(x) > 1 && @testset "non-strided" begin -println("... non-strided") xp = if x isa Matrix view(x, [1,2,3], 1:3) elseif x isa Array{T,3} PermutedDimsArray(x, (1,2,3)) end @test !(xp isa StridedArray) - # y = fnorm(x) - # # ẋ = rand(T, size(xp)) # rand_tangent(xp) - # x̄ = rand(T, size(xp)) # rand_tangent(xp) - # ȳ = rand_tangent(y) - # # frule_test(fnorm, (xp, ẋ)) - # rrule_test(fnorm, ȳ, (xp, x̄)) # old notation, gives a depwarn -#= -┌ Warning: `rrule_test(f, ȳ, inputs::Tuple{Any, Any}...; kwargs...)` is deprecated, use `test_rrule(f, (x ⊢ dx for (x, dx) = inputs)...; output_tangent = ȳ, kwargs...)` instead. -│ caller = macro expansion at norm.jl:57 [inlined] -└ @ Core ~/.julia/dev/ChainRules/test/rulesets/LinearAlgebra/norm.jl:57 -=# - # @show typeof(xp) - # test_rrule(fnorm, xp) # new notation, gives a spectacular failure: -#= -typeof(xp) = SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false} -test_rrule: norm1 at ([0.2972879845354616 -0.01044524463737564 2.2950878238373105; 0.3823959677906078 -0.839026854388764 -2.2670863488005306; -0.5976344767282311 0.31111133849833383 0.5299655761667461],): Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/bDd51/src/testers.jl:168 - Got exception outside of a @test - MethodError: no method matching +(::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, ::Matrix{Float64}) - Closest candidates are: - +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560 - +(::Composite{P, T} where T, ::Composite{P, T} where T) where P at /Users/me/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:167 - +(::Composite, ::AbstractThunk) at /Users/me/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:161 - ... - Stacktrace: - [1] +(a::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, b::InplaceableThunk{Thunk{ChainRules.var"#1798#1801"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}, ChainRules.var"#1799#1802"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}) - @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:161 - [2] add!!(x::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, t::InplaceableThunk{Thunk{ChainRules.var"#1798#1801"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}, ChainRules.var"#1799#1802"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}) - @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1qau5/src/accumulation.jl:23 -=# - test_rrule(fnorm, xp ⊢ rand(T, size(xp))) # ok, this passes! - + test_rrule(fnorm, xp ⊢ rand(T, size(xp))) end T == Float64 && ndims(x) == 1 && @testset "Integer input" begin -println("... integer") x = [1,2,3] int_fwd, int_back = rrule(fnorm, x) float_fwd, float_back = rrule(fnorm, float(x)) @@ -93,13 +60,12 @@ println("... integer") end end - # Next test norm(x, p=2) -- two methods + # Next test norm(A, p=2) -- two methods # ===================================== @testset "norm(x::Array{$T,$(length(sz))})" for T in (Float64, ComplexF64), sz in [(0,), (3,), (3, 3), (3, 2, 1)] -println("starting exported norm T=$T, sz=$sz") x = randn(T, sz) @@ -121,19 +87,12 @@ println("starting exported norm T=$T, sz=$sz") @test rrule(norm, x)[2](Zero())[2] isa Zero end ndims(x) > 1 && @testset "non-strided" begin -println("... non-strided'") xp = if x isa Matrix view(x, [1,2,3], 1:3) elseif x isa Array{T,3} PermutedDimsArray(x, (1,2,3)) end @test !(xp isa StridedArray) - # y = norm(x) - # ẋ = rand(T, size(xp)) # rand_tangent(xp) - # x̄ = rand(T, size(xp)) # rand_tangent(xp) - # ȳ = rand_tangent(y) - # frule_test(norm, (xp, ẋ)) - # rrule_test(norm, ȳ, (xp, x̄)) test_frule(norm, xp ⊢ rand(T, size(xp))) test_rrule(norm, xp ⊢ rand(T, size(xp))) # rand_tangent does not work here end @@ -143,11 +102,10 @@ println("... non-strided'") p in (1.0, 2.0, Inf, -Inf, 2.5), T in (Float64, ComplexF64), sz in (fnorm === norm ? [(0,), (3,), (3, 3), (3, 2, 1)] : [(3,), (3, 3), (3, 2, 1)]) -println("starting p-norm p=$p, T=$T, sz=$sz") x = randn(T, sz) # finite differences is unstable if maxabs (minabs) values are not well - # separated from other values + # separated from other values (same as above) if p == Inf if !isempty(x) x[end] = 1000rand(T) @@ -183,7 +141,6 @@ println("starting p-norm p=$p, T=$T, sz=$sz") @testset "norm($fdual(::Vector{$T}), 2.5)" for T in (Float64, ComplexF64), fdual in (adjoint, transpose) -println("starting $fdual norm T=$T") x = fdual(randn(T, 3)) p = 2.5 @@ -198,7 +155,6 @@ println("starting $fdual norm T=$T") @testset "norm(x::$T, p)" for T in (Float64, ComplexF64) @testset "p = $p" for p in (-1.0, 2.0, 2.5) -println("starting scalar p-norm tests, p=$p, T=$T") test_frule(norm, randn(T), p) test_rrule(norm, randn(T), p) @@ -206,7 +162,6 @@ println("starting scalar p-norm tests, p=$p, T=$T") @test back(Zero()) == (NO_FIELDS, Zero(), Zero()) end @testset "p = 0" begin -println("starting 0-norm tests, T=$T") p = 0.0 x = randn(T) y = norm(x, p) From 87e431344e61e2e16c94e12acd4928b2b6b93e24 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 7 May 2021 15:03:56 -0400 Subject: [PATCH 22/22] integer comment Co-authored-by: Lyndon White --- test/rulesets/LinearAlgebra/norm.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 612d04a8c..bb2d20bfd 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -94,7 +94,7 @@ end @test !(xp isa StridedArray) test_frule(norm, xp ⊢ rand(T, size(xp))) - test_rrule(norm, xp ⊢ rand(T, size(xp))) # rand_tangent does not work here + test_rrule(norm, xp ⊢ rand(T, size(xp))) # rand_tangent does not work here because eltype(xp)==Int end end @testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for