Skip to content

Commit

Permalink
fix test enzyme (#2563)
Browse files Browse the repository at this point in the history
* fix test enzyme

* fix

* don't test enzyme on nightly

* remove enzyme from test project

* update
  • Loading branch information
CarloLucibello authored Jan 1, 2025
1 parent b205266 commit cb9bb30
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 133 deletions.
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -22,7 +21,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Enzyme = "0.13"
FiniteDifferences = "0.12"
GPUArraysCore = "0.1"
GPUCompiler = "0.27"
Expand Down
121 changes: 8 additions & 113 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,100 +1,8 @@
using Test
using Flux
import Zygote

using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal

using Functors
using FiniteDifferences


function gradient_fd(f, x...)
f = f |> f64
x = [cpu(x) for x in x]
ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x]
ps = [f64(x[1]) for x in ps_and_res]
res = [x[2] for x in ps_and_res]
fdm = FiniteDifferences.central_fdm(5, 1)
gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...)
return ((re(g) for (re, g) in zip(res, gs))...,)
end

function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end

function test_grad(g1, g2; broken=false)
fmap_with_path(g1, g2) do kp, x, y
:state kp && return # ignore RNN and LSTM state
if x isa AbstractArray{<:Number}
# @show kp
@test x y rtol=1e-2 atol=1e-6 broken=broken
end
return x
end
end

function test_enzyme_grad(loss, model, x)
Flux.trainmode!(model)
l = loss(model, x)
@test loss(model, x) == l # Check loss doesn't change with multiple runs

grads_fd = gradient_fd(loss, model, x) |> cpu
grads_flux = Flux.gradient(loss, model, x) |> cpu
grads_enzyme = gradient_ez(loss, model, x) |> cpu

# test_grad(grads_flux, grads_enzyme)
test_grad(grads_fd, grads_enzyme)
end

@testset "gradient_ez" begin
@testset "number and arrays" begin
f(x, y) = sum(x.^2) + y^3
x = Float32[1, 2, 3]
y = 3f0
g = gradient_ez(f, x, y)
@test g[1] isa Array{Float32}
@test g[2] isa Float32
@test g[1] 2x
@test g[2] 3*y^2
end

@testset "struct" begin
struct SimpleDense{W, B, F}
weight::W
bias::B
σ::F
end
SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ)
(m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias)

model = SimpleDense(2, 4)
x = randn(Float32, 2)
loss(model, x) = sum(model(x))

g = gradient_ez(loss, model, x)
@test g[1] isa SimpleDense
@test g[2] isa Array{Float32}
@test g[1].weight isa Array{Float32}
@test g[1].bias isa Array{Float32}
@test g[1].weight ones(Float32, 4, 1) .* x'
@test g[1].bias ones(Float32, 4)
end
end
using Enzyme: Enzyme, Duplicated, Const, Active

@testset "Models" begin
function loss(model, x)
sum(model(x))
mean(model(x))
end

models_xs = [
Expand All @@ -117,41 +25,28 @@ end
for (model, x, name) in models_xs
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_enzyme_grad(loss, model, x)
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
end
end
end

@testset "Recurrence Tests" begin
@testset "Recurrent Layers" begin
function loss(model, x)
for i in 1:3
x = model(x)
end
return sum(x)
end

struct LSTMChain
rnn1
rnn2
end
function (m::LSTMChain)(x)
st = m.rnn1(x)
st = m.rnn2(st[1])
return st[1]
mean(model(x))
end

models_xs = [
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
# (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"),
# (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
test_enzyme_grad(loss, model, x)
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
end
end
end
Expand Down Expand Up @@ -219,7 +114,7 @@ end
z = _duplicated(zeros32(3))
@test_broken Flux.gradient(sum LayerNorm(3), z)[1] [0.0, 0.0, 0.0] # Constant memory is stored (or returned) to a differentiable variable
@test Flux.gradient(|>, z, _duplicated(sum LayerNorm(3)))[1] [0.0, 0.0, 0.0]
@test Flux.gradient(|>, z, Const(sum LayerNorm(3)))[2] === nothing
@test Flux.gradient(|>, z, Const(sum LayerNorm(3)))[2] === nothing broken=VERSION >= v"1.11"

@test_broken Flux.withgradient(sum LayerNorm(3), z).grad[1] [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
@test_broken Flux.withgradient(|>, z, _duplicated(sum LayerNorm(3))).grad[1] [0.0, 0.0, 0.0]
Expand Down
18 changes: 12 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
import Optimisers

using Zygote
const gradient = Flux.gradient # both Flux & Zygote export this on 0.15
Expand All @@ -21,18 +22,24 @@ using Functors: fmapstructure_with_path
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
# ENV["FLUX_TEST_ENZYME"] = "false"

const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
if FLUX_TEST_ENZYME
Pkg.add("Enzyme")
using Enzyme: Enzyme
end

include("test_utils.jl") # for test_gradients

Random.seed!(0)

include("testsuite/normalization.jl")

function flux_testsuite(dev)
@testset "Flux Test Suite" begin
@testset "Normalization" begin
normalization_testsuite(dev)
end
@testset "Flux Test Suite" begin
@testset "Normalization" begin
normalization_testsuite(dev)
end
end
end

@testset verbose=true "Flux.jl" begin
Expand Down Expand Up @@ -157,9 +164,8 @@ end
@info "Skipping Distributed tests, set FLUX_TEST_DISTRIBUTED_MPI or FLUX_TEST_DISTRIBUTED_NCCL=true to run them."
end

if get(ENV, "FLUX_TEST_ENZYME", "true") == "true"
if FLUX_TEST_ENZYME
@testset "Enzyme" begin
import Enzyme
include("ext_enzyme/enzyme.jl")
end
else
Expand Down
34 changes: 31 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ function finitediff_withgradient(f, x...)
return y, FiniteDifferences.grad(fdm, f, x...)
end

function enzyme_withgradient(f, x...)
args = []
for x in x
if x isa Number
push!(args, Enzyme.Active(x))
else
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
end
end
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return ret[2], g
end


function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
if x isa AbstractArray
Expand All @@ -37,12 +53,12 @@ function test_gradients(
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
compare_enzyme = false,
loss = (f, xs...) -> mean(f(xs...)),
)

if !test_gpu && !compare_finite_diff
error("You should either compare finite diff vs CPU AD \
or CPU AD vs GPU AD.")
if !test_gpu && !compare_finite_diff && !compare_enzyme
error("You should either compare numerical gradients methods or CPU vs GPU.")
end

## Let's make sure first that the forward pass works.
Expand Down Expand Up @@ -70,6 +86,12 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if compare_enzyme
y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...)
@test y y_ez rtol=rtol atol=atol
check_equal_leaves(g, g_ez; rtol, atol)
end

if test_gpu
# Zygote gradient with respect to input on GPU.
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...)
Expand All @@ -93,6 +115,12 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if compare_enzyme
y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f)
@test y y_ez rtol=rtol atol=atol
check_equal_leaves(g, g_ez; rtol, atol)
end

if test_gpu
# Zygote gradient with respect to f on GPU.
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f, xs_gpu...), f_gpu)
Expand Down
11 changes: 2 additions & 9 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
using Flux
# using Flux.Train
import Optimisers

using Test
using Random
import Enzyme

function train_enzyme!(fn, model, args...; kwargs...)
Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
end

for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))

if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
if name == "Enzyme" && FLUX_TEST_ENZYME
continue
end

Expand Down Expand Up @@ -50,7 +43,7 @@ end
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
# TODO reinstate Enzyme
name == "Enzyme" && continue
# if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
# if name == "Enzyme" && FLUX_TEST_ENZYME
# continue
# end

Expand Down

0 comments on commit cb9bb30

Please sign in to comment.