diff --git a/.gitignore b/.gitignore index 45f0c4f21..a5b907cfc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ Manifest.toml docs/build docs/site +.idea/* \ No newline at end of file diff --git a/Project.toml b/Project.toml index 206a3143e..2006cd512 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.3.3" +version = "0.3.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,12 +11,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.6.1" +ChainRulesTestUtils = "0.1.3" FiniteDifferences = "^0.7" Reexport = "0.2" Requires = "0.5.2, 1" julia = "^1.0" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -24,4 +26,4 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] diff --git a/test/runtests.jl b/test/runtests.jl index b3683de7d..81539c0b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using Base.Broadcast: broadcastable using ChainRules using ChainRulesCore +using ChainRulesTestUtils +using ChainRulesTestUtils: _fdm using FiniteDifferences using LinearAlgebra using LinearAlgebra.BLAS @@ -11,8 +13,6 @@ using Test Random.seed!(1) # Set seed that all testsets should reset to. -include("test_util.jl") - println("Testing ChainRules.jl") @testset "ChainRules" begin include("helper_functions.jl") diff --git a/test/test_util.jl b/test/test_util.jl deleted file mode 100644 index ba5b938dd..000000000 --- a/test/test_util.jl +++ /dev/null @@ -1,184 +0,0 @@ -using FiniteDifferences, Test -using FiniteDifferences: jvp, j′vp -using ChainRules -using ChainRulesCore: AbstractDifferential - -const _fdm = central_fdm(5, 1) - -# Useful for LinearAlgebra tests -function generate_well_conditioned_matrix(rng, N) - A = randn(rng, N, N) - return A * A' + I -end - -""" - test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) - -Given a function `f` with scalar input an scalar output, perform finite differencing checks, -at input point `x` to confirm that there are correct ChainRules provided. - -# Arguments -- `f`: Function for which the `frule` and `rrule` should be tested. -- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). - -All keyword arguments except for `fdm` is passed to `isapprox`. -""" -function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) - ensure_not_running_on_functor(f, "test_scalar") - - r_res = rrule(f, x) - f_res = frule(f, x, Zero(), 1) - @test r_res !== f_res !== nothing # Check the rule was defined - r_fx, prop_rule = r_res - f_fx, f_∂x = f_res - @testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ( - (rrule, r_fx, prop_rule(1)), - (frule, f_fx, f_∂x) - ) - @test fx == f(x) # Check we still get the normal value, right - - if rule == rrule - ∂self, ∂x = ∂x - @test ∂self === NO_FIELDS - end - @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) - end -end - -function ensure_not_running_on_functor(f, name) - # if x itself is a Type, then it is a constructor, thus not a functor. - # This also catchs UnionAll constructors which have a `:var` and `:body` fields - f isa Type && return - - if fieldcount(typeof(f)) > 0 - throw(ArgumentError( - "$name cannot be used on closures/functors (such as $f)" - )) - end -end - -""" - frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) - -# Arguments -- `f`: Function for which the `frule` should be tested. -- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). -- `ẋ`: differential w.r.t. `x` (should generally be set randomly). - -All keyword arguments except for `fdm` are passed to `isapprox`. -""" -function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) - return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...) -end - -function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) - ensure_not_running_on_functor(f, "frule_test") - xs, ẋs = collect(zip(xẋs...)) - dself = Zero() - Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...) - @test f(xs...) == Ω - - # Correctness testing via finite differencing. - dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) - @test isapprox( - collect(extern.(dΩ_ad)), # Use collect so can use vector equality - collect(dΩ_fd); - rtol=rtol, - atol=atol, - kwargs... - ) -end - - -""" - rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) - -# Arguments -- `f`: Function to which rule should be applied. -- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly). - Should be same structure as `f(x)` (so if multiple returns should be a tuple) -- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). -- `x̄`: currently accumulated adjoint (should generally be set randomly). - -All keyword arguments except for `fdm` are passed to `isapprox`. -""" -function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) - ensure_not_running_on_functor(f, "rrule_test") - - # Check correctness of evaluation. - fx, pullback = ChainRules.rrule(f, x) - @test collect(fx) ≈ collect(f(x)) # use collect so can do vector equality - (∂self, x̄_ad) = if fx isa Tuple - # If the function returned multiple values, - # then it must have multiple seeds for propagating backwards - pullback(ȳ...) - else - pullback(ȳ) - end - - @test ∂self === NO_FIELDS # No internal fields - # Correctness testing via finite differencing. - x̄_fd = j′vp(fdm, f, ȳ, x) - @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) -end - -function _make_fdm_call(fdm, f, ȳ, xs, ignores) - sig = Expr(:tuple) - call = Expr(:call, f) - newxs = Any[] - arginds = Int[] - i = 1 - for (x, ignore) in zip(xs, ignores) - if ignore - push!(call.args, x) - else - push!(call.args, Symbol(:x, i)) - push!(sig.args, Symbol(:x, i)) - push!(newxs, x) - push!(arginds, i) - end - i += 1 - end - fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...))) - fd = eval(fdexpr) - fd isa Tuple || (fd = (fd,)) - args = Any[nothing for _ in 1:length(xs)] - for (dx, ind) in zip(fd, arginds) - args[ind] = dx - end - return (args...,) -end - -function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) - ensure_not_running_on_functor(f, "rrule_test") - - # Check correctness of evaluation. - xs, x̄s = collect(zip(xx̄s...)) - y, pullback = rrule(f, xs...) - @test f(xs...) == y - - @assert !(isa(ȳ, Thunk)) - ∂s = pullback(ȳ) - ∂self = ∂s[1] - x̄s_ad = ∂s[2:end] - @test ∂self === NO_FIELDS - - # Correctness testing via finite differencing. - x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) - for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd) - if x̄_fd === nothing - # The way we've structured the above, this tests that the rule is a DoesNotExistRule - @test x̄_ad isa DoesNotExist - else - @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) - end - end -end - -function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) - error("Tried to differentiate w.r.t. a `DoesNotExist`") -end - -function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) - return isapprox(extern(d_ad), d_fd; kwargs...) -end