Skip to content

Commit

Permalink
use TestItems (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 1, 2025
1 parent b01c445 commit 3753d67
Show file tree
Hide file tree
Showing 22 changed files with 219 additions and 142 deletions.
50 changes: 25 additions & 25 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
using Flux, Tsunami
using Documenter

DocMeta.setdocmeta!(Tsunami, :DocTestSetup,
:(using Tsunami, Flux);
recursive = true)
DocMeta.setdocmeta!(Tsunami, :DocTestSetup, :(using Tsunami, Flux); recursive = true)

prettyurls = get(ENV, "CI", nothing) == "true"
mathengine = MathJax3()
sidebar_sitename = true
assets = ["assets/flux.css"]

makedocs(;
modules = [Tsunami],
doctest = true,
checkdocs = :exports,
format = Documenter.HTML(; mathengine, prettyurls, assets, sidebar_sitename),
sitename = "Tsunami.jl",
pages = [
"Get Started" => "index.md",

"Guides" => "guides.md",

"API Reference" => [
# This essentially collects docstrings, with a bit of introduction.
"Callbacks" => "callbacks.md",
"FluxModule" => "fluxmodule.md",
"Hooks" => "hooks.md",
"Logging" => "logging.md",
"Trainer" => "trainer.md",
"Utils" => "utils.md",
],
# "Tutorials" => [] # TODO These walk you through various tasks. It's fine if they overlap quite a lot.

])
modules = [Tsunami],
doctest = true,
checkdocs = :exports,
format = Documenter.HTML(; mathengine, prettyurls, assets, sidebar_sitename),
sitename = "Tsunami.jl",
pages = [
"Get Started" => "index.md",

"Guides" => "guides.md",

# "Tutorials" => [] # TODO These walk you through various tasks. It's fine if they overlap quite a lot.

"API Reference" => [
# This essentially collects docstrings, with a bit of introduction.
"Callbacks" => "api/callbacks.md",
"FluxModule" => "api/fluxmodule.md",
"Foil" => "api/foil.md",
"Hooks" => "api/hooks.md",
"Logging" => "api/logging.md",
"Trainer" => "api/trainer.md",
"Utils" => "api/utils.md",
],
]
)

deploydocs(repo = "github.com/CarloLucibello/Tsunami.jl.git")
4 changes: 4 additions & 0 deletions docs/src/callbacks.md → docs/src/api/callbacks.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Callbacks

Callbacks are functions that are called at certain points in the training process. They are useful for logging, early stopping, and other tasks.
Expand Down
4 changes: 4 additions & 0 deletions docs/src/fluxmodule.md → docs/src/api/fluxmodule.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# FluxModule

```@docs
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api/foil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Foil

The [`Foil`](@ref) is a minimalistic version of the [`Trainer`](@ref) that allows to make only minimal changes to your Flux code while still obtaining many of the benefits of Tsunami. This is similar to what `Lighting Fabric` is to `PyTorch Lightning`. `Foil` also resembles HuggingFace's `accelerate` library.

```@docs
Foil
Tsunami.setup
Tsunami.setup_batch
```
4 changes: 4 additions & 0 deletions docs/src/hooks.md → docs/src/api/hooks.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Hooks

Hooks are a way to extend the functionality of Tsunami. They are a way to inject custom code into the FluxModule or
Expand Down
4 changes: 4 additions & 0 deletions docs/src/logging.md → docs/src/api/logging.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Logging

```@docs
Expand Down
6 changes: 5 additions & 1 deletion docs/src/trainer.md → docs/src/api/trainer.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Trainer

The [`Trainer`](@ref) struct is the main entry point for training a model. It is responsible for managing the training loop, logging, and checkpointing. It is also responsible for managing the [`FitState`](@ref Tsunami.FitState) struct, which contains the state of the training loop.
Expand All @@ -13,5 +17,5 @@ Tsunami.fit!
Tsunami.FitState
Tsunami.test
Tsunami.validate
Tsunami.Foil
```

6 changes: 5 additions & 1 deletion docs/src/utils.md → docs/src/api/utils.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Utils
```@meta
CollapsedDocStrings = true
```

# Utils

Tsunami provides some utility functions to make your life easier.

Expand Down
1 change: 1 addition & 0 deletions src/Tsunami.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export Checkpointer, load_checkpoint

include("foil.jl")
export Foil
@compat(public, (setup, setup_batch))

include("trainer.jl")
export Trainer
Expand Down
24 changes: 20 additions & 4 deletions src/foil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,31 @@ to_precision(foil::Foil, x) = x |> foil.fprec

is_using_gpu(foil::Foil) = !(foil.device isa CPUDevice)

function setup(foil::Foil, model::FluxModule)
return model |> to_precision(foil) |> to_device(foil)
end
"""
setup(foil::Foil, model, optimisers)
Setup the model and optimisers for training sending them to the device and setting the precision.
This function is called internally by [`Tsunami.fit!`](@ref).
function setup(foil::Foil, model::FluxModule, optimisers)
See also [`Foil`](@ref).
"""
function setup(foil::Foil, model, optimisers)
model = setup(foil, model)
optimisers = optimisers |> to_precision(foil) |> to_device(foil)
return model, optimisers
end

function setup(foil::Foil, model)
return model |> to_precision(foil) |> to_device(foil)
end


"""
setup_batch(foil::Foil, batch)
Setup the batch for training sending it to the device and setting the precision.
This function is called internally by [`Tsunami.fit!`](@ref).
"""
function setup_batch(foil::Foil, batch)
return batch |> to_precision(foil) |> to_device(foil)
end
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Tsunami = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
6 changes: 4 additions & 2 deletions test/fluxmodule.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
@testset "abstract type FluxModule" begin
@testitem "abstract type FluxModule" begin
@test isabstracttype(FluxModule)
end

@testset "functor" begin
@testitem "FluxModule functor" setup=[TsunamiTest] begin
using .TsunamiTest
using Functors: Functors
m = TestModule1()
@test Functors.children(m) == (; net = m.net, tuple_field = m.tuple_field)

Expand Down
14 changes: 14 additions & 0 deletions test/foil.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testitem "Foil constructor" begin
foil = Foil(accelerator=:cpu, precision=:f32, devices=nothing)
@test foil isa Foil
end

@testitem "Tsunami.setup" setup=[TsunamiTest] begin
using .TsunamiTest
foil = Foil(accelerator=:cpu, precision=:f32, devices=nothing)
model = Chain(Dense(28^2 => 512, relu), Dense(512 => 10))
opt_state = Flux.setup(AdamW(1e-3), model)
model, opt_state = Tsunami.setup(foil, model, opt_state)
@test model isa Chain
@test opt_state.layers[1].weight isa Optimisers.Leaf{<:AdamW}
end
32 changes: 19 additions & 13 deletions test/hooks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testset "on_before_update" begin
@testitem "on_before_update" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct OnBeforeUpdateCbk end
Expand All @@ -8,15 +9,16 @@
end
trainer = SilentTrainer(max_epochs=1, callbacks=[OnBeforeUpdateCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader)
@test length(out) == 2
@test out[1] isa NamedTuple
@test out[2] isa NamedTuple
@test size(out[1].net.layers[1].weight) == (3, 4)
end

@testset "on_train_epoch_start and on_train_epoch_end" begin
@testitem "on_train_epoch_start and on_train_epoch_end" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct TrainEpochCbk end
Expand All @@ -30,13 +32,14 @@ end
end
trainer = SilentTrainer(max_epochs=2, callbacks=[TrainEpochCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader)
@test out == [1, 2, 1, 2]
end


@testset "on_val_epoch_start and on_val_epoch_end" begin
@testitem "on_val_epoch_start and on_val_epoch_end" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct ValEpochCbk end
Expand All @@ -50,15 +53,17 @@ end
end
trainer = SilentTrainer(max_epochs=2, callbacks=[ValEpochCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader, train_dataloader)
@test out == [1, 2, 1, 2, 1, 2]

Tsunami.validate(model, trainer, train_dataloader)
@test out == [1, 2, 1, 2, 1, 2, 1, 2]
end

@testset "on_test_epoch_start and on_test_epoch_end" begin
@testitem "on_test_epoch_start and on_test_epoch_end" setup=[TsunamiTest] begin
using .TsunamiTest

out = []

struct TestEpochCbk end
Expand All @@ -72,26 +77,27 @@ end
end
trainer = SilentTrainer(max_epochs=2, callbacks=[TestEpochCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.test(model, trainer, train_dataloader)
@test out == [1, 2]
end

@testset "on_before_backprop" begin
@testitem "on_before_backprop" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct BeforePullbackCbk end

function Tsunami.on_before_backprop(::BeforePullbackCbk, model, trainer, loss)
push!(out, loss)
push!(out, 1)
end
function Tsunami.on_before_update(::BeforePullbackCbk, model, trainer, grad)
@show grad
push!(out, 2)
end

trainer = SilentTrainer(max_epochs=2, callbacks=[BeforePullbackCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader)
@test out == [1, 1]
@test out == [1, 2, 1, 2, 1, 2, 1, 2]
end
32 changes: 17 additions & 15 deletions test/linear_regression.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@

@testitem "linear regression" setup=[TsunamiTest] begin
using .TsunamiTest
N = 1000
α = 0.5
λ = 1f-5 / round(Int, N * α)

N = 1000
α = 0.5
λ = 1f-5 / round(Int, N * α)
M = round(Int, N * α)
teacher = LinearModel(N)
X = randn(Float32, N, M)
y = teacher(X)

M = round(Int, N * α)
teacher = LinearModel(N)
X = randn(Float32, N, M)
y = teacher(X)

model = LinearModel(N; λ)
@test model.W isa Matrix{Float32}
@test size(model.W) == (1, N)

trainer = SilentTrainer(max_epochs=1000, devices=[1])
fit_state = Tsunami.fit!(model, trainer, [(X, y)])
@test model.W isa Matrix{Float32} # by default precision is Float32
model = LinearModel(N; λ)
@test model.W isa Matrix{Float32}
@test size(model.W) == (1, N)
trainer = SilentTrainer(max_epochs=1000, devices=[1])
fit_state = Tsunami.fit!(model, trainer, [(X, y)])
@test model.W isa Matrix{Float32} # by default precision is Float32
@test Flux.mse(model(X), y) < 1e-1
end
3 changes: 2 additions & 1 deletion test/logging.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testset "TensorBoard logging" begin
@testitem "TensorBoard logging" setup=[TsunamiTest] begin
using .TsunamiTest
batch_sizes = [3, 2]
max_epochs = 4

Expand Down
Loading

0 comments on commit 3753d67

Please sign in to comment.