Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup Reactant and Enzyme tests #2578

Merged
merged 6 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -61,7 +60,6 @@ OneHotArrays = "0.2.4"
Optimisers = "0.4.1"
Preferences = "1"
ProgressLogging = "0.1"
Reactant = "0.2.16"
Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
22 changes: 11 additions & 11 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ using Enzyme: Enzyme, Duplicated, Const, Active
(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
(Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
# (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
# (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # Passes on 1.10, fails on 1.11 with MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
(first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
# (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # AssertionError: Base.isconcretetype(typ)
# (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # AssertionError: Base.isconcretetype(typ)
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
(first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"),
]

for (model, x, name) in models_xs
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
end
end
end
Expand All @@ -36,17 +36,17 @@ end
end

models_xs = [
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
# (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
end
end
end
Expand Down
13 changes: 13 additions & 0 deletions test/ext_reactant/test_utils_reactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# These are used only in test_utils.jl but cannot leave there
# because Reactant is only optionally loaded and the macros fail when it is not loaded.

function reactant_withgradient(f, x...)
y, g = Reactant.@jit enzyme_withgradient(f, x...)
return y, g
end

function reactant_loss(loss, x...)
l = Reactant.@jit loss(x...)
@test l isa Reactant.ConcreteRNumber
return l
end
22 changes: 14 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ using Pkg
using FiniteDifferences: FiniteDifferences
using Functors: fmapstructure_with_path

using Reactant

## Uncomment below to change the default test settings
# ENV["FLUX_TEST_AMDGPU"] = "true"
# ENV["FLUX_TEST_CUDA"] = "true"
Expand All @@ -23,20 +21,20 @@ using Reactant
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
# ENV["FLUX_TEST_ENZYME"] = "false"
# ENV["FLUX_TEST_REACTANT"] = "false"

const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"

# Reactant will automatically select a GPU backend, if available, and TPU backend, if available.
# Otherwise it will fall back to CPU.
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT",
VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"

if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT
Pkg.add("Enzyme")
using Enzyme: Enzyme
end

if FLUX_TEST_REACTANT
Pkg.add("Reactant")
using Reactant: Reactant
end

include("test_utils.jl") # for test_gradients

Random.seed!(0)
Expand Down Expand Up @@ -182,7 +180,15 @@ end
end

if FLUX_TEST_REACTANT
## This Pg.add has to be done after Pkg.add("CUDA") otherwise CUDA.jl
## will not be functional and complain with:
# ┌ Error: CUDA.jl could not find an appropriate CUDA runtime to use.
# │
# │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present.
Pkg.add("Reactant")
using Reactant: Reactant
@testset "Reactant" begin
include("ext_reactant/test_utils_reactant.jl")
include("ext_reactant/reactant.jl")
end
else
Expand Down
22 changes: 14 additions & 8 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
# @show kp
if x isa AbstractArray
@test x ≈ y rtol=rtol atol=atol
elseif x isa Number
Expand All @@ -45,23 +46,29 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
end
end

# By default, this computes the gradients on cpu using the default AD (Zygote)
# and compares them with finite differences.
# Changing the arguments, you can assume the cpu Zygote gradients as the ground truth
# and test other scenarios.
function test_gradients(
f,
xs...;
rtol=1e-4, atol=1e-4,
test_gpu = false,
test_reactant = false,
test_enzyme = false,
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
compare_enzyme = false,
loss = (f, xs...) -> mean(f(xs...)),
)

if !test_gpu && !compare_finite_diff && !compare_enzyme && !test_reactant
if !test_gpu && !compare_finite_diff && !test_enzyme && !test_reactant
error("You should either compare numerical gradients methods or CPU vs GPU.")
end

Flux.trainmode!(f) # for layers like BatchNorm

## Let's make sure first that the forward pass works.
l = loss(f, xs...)
@test l isa Number
Expand All @@ -79,8 +86,7 @@ function test_gradients(
cpu_dev = cpu_device()
xs_re = xs |> reactant_dev
f_re = f |> reactant_dev
l_re = Reactant.@jit loss(f_re, xs_re...)
@test l_re isa Reactant.ConcreteRNumber
l_re = reactant_loss(loss, f_re, xs_re...)
@test l ≈ l_re rtol=rtol atol=atol
end

Expand All @@ -97,7 +103,7 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if compare_enzyme
if test_enzyme
y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...)
@test y ≈ y_ez rtol=rtol atol=atol
check_equal_leaves(g, g_ez; rtol, atol)
Expand All @@ -113,7 +119,7 @@ function test_gradients(

if test_reactant
# Enzyme gradient with respect to input on Reactant.
y_re, g_re = Reactant.@jit enzyme_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
@test y ≈ y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
Expand All @@ -133,7 +139,7 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if compare_enzyme
if test_enzyme
y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f)
@test y ≈ y_ez rtol=rtol atol=atol
check_equal_leaves(g, g_ez; rtol, atol)
Expand All @@ -149,7 +155,7 @@ function test_gradients(

if test_reactant
# Enzyme gradient with respect to input on Reactant.
y_re, g_re = Reactant.@jit enzyme_withgradient(f -> loss(f, xs_re...), f_re)
y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re)
@test y ≈ y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
Expand Down
Loading