Skip to content

Commit

Permalink
Add reactant forward and reverse pass tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 5, 2025
1 parent 101106e commit 8f623e4
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
51 changes: 51 additions & 0 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,36 @@ using Enzyme: Enzyme, Duplicated, Const, Active
end
end

@testset "Reactant Models" begin
function loss(model, x)
mean(model(x))
end

models_xs = [
(Dense(2=>4), randn(Float32, 2), "Dense"),
(Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"),
(f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"),
(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
(Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # Passes on 1.10, fails on 1.11 with MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
(first LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # AssertionError: Base.isconcretetype(typ)
(first MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # AssertionError: Base.isconcretetype(typ)
]

for (model, x, name) in models_xs
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true, test_reactant=true)
end
end
end

@testset "Recurrent Layers" begin
function loss(model, x)
mean(model(x))
Expand All @@ -51,6 +81,27 @@ end
end
end

@testset "Reactant Recurrent Layers" begin
function loss(model, x)
mean(model(x))
end

models_xs = [
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true, test_reactant=true)
end
end
end

@testset "gradient, withgradient, Duplicated" begin
# Tests above are about how Enzyme digests Flux layers.
# Tests here are just the interface Flux.gradient(f, Duplicated(model)) etc.
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using Pkg
using FiniteDifferences: FiniteDifferences
using Functors: fmapstructure_with_path

using Reactant

## Uncomment below to change the default test settings
# ENV["FLUX_TEST_AMDGPU"] = "true"
# ENV["FLUX_TEST_CUDA"] = "true"
Expand Down
24 changes: 24 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function test_gradients(
xs...;
rtol=1e-4, atol=1e-4,
test_gpu = false,
test_reactant = false,
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
Expand All @@ -73,6 +74,15 @@ function test_gradients(
@test l_gpu isa Number
end

if test_reactant
reactant_dev = MLDataDevices.reactant_device(force=true)
cpu_dev = cpu_device()
xs_re = xs |> reactant_dev
f_re = f |> reactant_dev
l_re = Reactant.@jit loss(f_re, xs_re...)
@test l_re isa Reactant.ConcreteRNumber
end

if test_grad_x
# Zygote gradient with respect to input.
y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)
Expand All @@ -99,6 +109,13 @@ function test_gradients(
@test y_gpu y rtol=rtol atol=atol
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
end

if test_reactant
# Enzyme gradient with respect to input on Reactant.
y_re, g_re = Reactant.@jit enzyme_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
@test y y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
end

if test_grad_f
Expand Down Expand Up @@ -128,6 +145,13 @@ function test_gradients(
@test y_gpu y rtol=rtol atol=atol
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
end

if test_reactant
# Enzyme gradient with respect to input on Reactant.
y_re, g_re = Reactant.@jit enzyme_withgradient(f -> loss(f, xs_re...), f_re)
@test y y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
end
return true
end

0 comments on commit 8f623e4

Please sign in to comment.