From cb9bb307b30b9111617e5b28ea1c5398ef3693f3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 1 Jan 2025 01:55:30 +0100 Subject: [PATCH] fix test enzyme (#2563) * fix test enzyme * fix * don't test enzyme on nightly * remove enzyme from test project * update --- test/Project.toml | 2 - test/ext_enzyme/enzyme.jl | 121 +++----------------------------------- test/runtests.jl | 18 ++++-- test/test_utils.jl | 34 ++++++++++- test/train.jl | 11 +--- 5 files changed, 53 insertions(+), 133 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index eb2fe438d0..f4161e8af3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,7 +2,6 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -22,7 +21,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Enzyme = "0.13" FiniteDifferences = "0.12" GPUArraysCore = "0.1" GPUCompiler = "0.27" diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 5d91770363..7b772fcca4 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,100 +1,8 @@ -using Test -using Flux -import Zygote - -using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal - -using Functors -using FiniteDifferences - - -function gradient_fd(f, x...) - f = f |> f64 - x = [cpu(x) for x in x] - ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x] - ps = [f64(x[1]) for x in ps_and_res] - res = [x[2] for x in ps_and_res] - fdm = FiniteDifferences.central_fdm(5, 1) - gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...) - return ((re(g) for (re, g) in zip(res, gs))...,) -end - -function gradient_ez(f, x...) - args = [] - for x in x - if x isa Number - push!(args, Active(x)) - else - push!(args, Duplicated(x, make_zero(x))) - end - end - ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) - g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) - return g -end - -function test_grad(g1, g2; broken=false) - fmap_with_path(g1, g2) do kp, x, y - :state ∈ kp && return # ignore RNN and LSTM state - if x isa AbstractArray{<:Number} - # @show kp - @test x ≈ y rtol=1e-2 atol=1e-6 broken=broken - end - return x - end -end - -function test_enzyme_grad(loss, model, x) - Flux.trainmode!(model) - l = loss(model, x) - @test loss(model, x) == l # Check loss doesn't change with multiple runs - - grads_fd = gradient_fd(loss, model, x) |> cpu - grads_flux = Flux.gradient(loss, model, x) |> cpu - grads_enzyme = gradient_ez(loss, model, x) |> cpu - - # test_grad(grads_flux, grads_enzyme) - test_grad(grads_fd, grads_enzyme) -end - -@testset "gradient_ez" begin - @testset "number and arrays" begin - f(x, y) = sum(x.^2) + y^3 - x = Float32[1, 2, 3] - y = 3f0 - g = gradient_ez(f, x, y) - @test g[1] isa Array{Float32} - @test g[2] isa Float32 - @test g[1] ≈ 2x - @test g[2] ≈ 3*y^2 - end - - @testset "struct" begin - struct SimpleDense{W, B, F} - weight::W - bias::B - σ::F - end - SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ) - (m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias) - - model = SimpleDense(2, 4) - x = randn(Float32, 2) - loss(model, x) = sum(model(x)) - - g = gradient_ez(loss, model, x) - @test g[1] isa SimpleDense - @test g[2] isa Array{Float32} - @test g[1].weight isa Array{Float32} - @test g[1].bias isa Array{Float32} - @test g[1].weight ≈ ones(Float32, 4, 1) .* x' - @test g[1].bias ≈ ones(Float32, 4) - end -end +using Enzyme: Enzyme, Duplicated, Const, Active @testset "Models" begin function loss(model, x) - sum(model(x)) + mean(model(x)) end models_xs = [ @@ -117,27 +25,14 @@ end for (model, x, name) in models_xs @testset "Enzyme grad check $name" begin println("testing $name with Enzyme") - test_enzyme_grad(loss, model, x) + test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true) end end end -@testset "Recurrence Tests" begin +@testset "Recurrent Layers" begin function loss(model, x) - for i in 1:3 - x = model(x) - end - return sum(x) - end - - struct LSTMChain - rnn1 - rnn2 - end - function (m::LSTMChain)(x) - st = m.rnn1(x) - st = m.rnn2(st[1]) - return st[1] + mean(model(x)) end models_xs = [ @@ -145,13 +40,13 @@ end # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - # (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"), + # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs @testset "check grad $name" begin println("testing $name") - test_enzyme_grad(loss, model, x) + test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true) end end end @@ -219,7 +114,7 @@ end z = _duplicated(zeros32(3)) @test_broken Flux.gradient(sum ∘ LayerNorm(3), z)[1] ≈ [0.0, 0.0, 0.0] # Constant memory is stored (or returned) to a differentiable variable @test Flux.gradient(|>, z, _duplicated(sum ∘ LayerNorm(3)))[1] ≈ [0.0, 0.0, 0.0] - @test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing + @test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing broken=VERSION >= v"1.11" @test_broken Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any} @test_broken Flux.withgradient(|>, z, _duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0] diff --git a/test/runtests.jl b/test/runtests.jl index d3476cc4d1..205523ec4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle +import Optimisers using Zygote const gradient = Flux.gradient # both Flux & Zygote export this on 0.15 @@ -21,6 +22,12 @@ using Functors: fmapstructure_with_path # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" # ENV["FLUX_TEST_ENZYME"] = "false" +const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true" +if FLUX_TEST_ENZYME + Pkg.add("Enzyme") + using Enzyme: Enzyme +end + include("test_utils.jl") # for test_gradients Random.seed!(0) @@ -28,11 +35,11 @@ Random.seed!(0) include("testsuite/normalization.jl") function flux_testsuite(dev) - @testset "Flux Test Suite" begin - @testset "Normalization" begin - normalization_testsuite(dev) - end + @testset "Flux Test Suite" begin + @testset "Normalization" begin + normalization_testsuite(dev) end + end end @testset verbose=true "Flux.jl" begin @@ -157,9 +164,8 @@ end @info "Skipping Distributed tests, set FLUX_TEST_DISTRIBUTED_MPI or FLUX_TEST_DISTRIBUTED_NCCL=true to run them." end - if get(ENV, "FLUX_TEST_ENZYME", "true") == "true" + if FLUX_TEST_ENZYME @testset "Enzyme" begin - import Enzyme include("ext_enzyme/enzyme.jl") end else diff --git a/test/test_utils.jl b/test/test_utils.jl index c736943f1c..6a548173a0 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -19,6 +19,22 @@ function finitediff_withgradient(f, x...) return y, FiniteDifferences.grad(fdm, f, x...) end +function enzyme_withgradient(f, x...) + args = [] + for x in x + if x isa Number + push!(args, Enzyme.Active(x)) + else + push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x))) + end + end + ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal) + ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...) + g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) + return ret[2], g +end + + function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) fmapstructure_with_path(a, b) do kp, x, y if x isa AbstractArray @@ -37,12 +53,12 @@ function test_gradients( test_grad_f = true, test_grad_x = true, compare_finite_diff = true, + compare_enzyme = false, loss = (f, xs...) -> mean(f(xs...)), ) - if !test_gpu && !compare_finite_diff - error("You should either compare finite diff vs CPU AD \ - or CPU AD vs GPU AD.") + if !test_gpu && !compare_finite_diff && !compare_enzyme + error("You should either compare numerical gradients methods or CPU vs GPU.") end ## Let's make sure first that the forward pass works. @@ -70,6 +86,12 @@ function test_gradients( check_equal_leaves(g, g_fd; rtol, atol) end + if compare_enzyme + y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...) + @test y ≈ y_ez rtol=rtol atol=atol + check_equal_leaves(g, g_ez; rtol, atol) + end + if test_gpu # Zygote gradient with respect to input on GPU. y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...) @@ -93,6 +115,12 @@ function test_gradients( check_equal_leaves(g, g_fd; rtol, atol) end + if compare_enzyme + y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f) + @test y ≈ y_ez rtol=rtol atol=atol + check_equal_leaves(g, g_ez; rtol, atol) + end + if test_gpu # Zygote gradient with respect to f on GPU. y_gpu, g_gpu = Zygote.withgradient(f -> loss(f, xs_gpu...), f_gpu) diff --git a/test/train.jl b/test/train.jl index 4f75d247ab..45f4f84319 100644 --- a/test/train.jl +++ b/test/train.jl @@ -1,10 +1,3 @@ -using Flux -# using Flux.Train -import Optimisers - -using Test -using Random -import Enzyme function train_enzyme!(fn, model, args...; kwargs...) Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...) @@ -12,7 +5,7 @@ end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) - if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false") + if name == "Enzyme" && FLUX_TEST_ENZYME continue end @@ -50,7 +43,7 @@ end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) # TODO reinstate Enzyme name == "Enzyme" && continue - # if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false") + # if name == "Enzyme" && FLUX_TEST_ENZYME # continue # end