From c48875c9d81dfc620102e1f3fbff386796054533 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 1 Sep 2022 00:08:13 -0400 Subject: [PATCH] tests --- src/runtime.jl | 1 + test/runtests.jl | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/runtime.jl b/src/runtime.jl index e3c1edec..48776239 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -27,3 +27,4 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing _tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z) _tangent(::Type, ::NamedTuple{()}) = NoTangent() +_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent() diff --git a/test/runtests.jl b/test/runtests.jl index 38088d0f..cd9387c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -257,13 +257,21 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] @test tup_adj[2] isa Transpose @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal + + @test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure end @testset "broadcast, 2nd order" begin + @test gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] + @test gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2] + @test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1] + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing) - @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3.0])[1] ≈ exp.(1:3) + @test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0] @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # ERROR: (1, current_logger_for_env(std_level::Base.CoreLogging.LogLevel, group, _module) @ Base.CoreLogging logging.jl:500, :($(Expr(:meta, :noinline)))) @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] + + @test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,) end # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)