Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use TestItems #78

Merged
merged 5 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
using Flux, Tsunami
using Documenter

DocMeta.setdocmeta!(Tsunami, :DocTestSetup,
:(using Tsunami, Flux);
recursive = true)
DocMeta.setdocmeta!(Tsunami, :DocTestSetup, :(using Tsunami, Flux); recursive = true)

prettyurls = get(ENV, "CI", nothing) == "true"
mathengine = MathJax3()
sidebar_sitename = true
assets = ["assets/flux.css"]

makedocs(;
modules = [Tsunami],
doctest = true,
checkdocs = :exports,
format = Documenter.HTML(; mathengine, prettyurls, assets, sidebar_sitename),
sitename = "Tsunami.jl",
pages = [
"Get Started" => "index.md",

"Guides" => "guides.md",

"API Reference" => [
# This essentially collects docstrings, with a bit of introduction.
"Callbacks" => "callbacks.md",
"FluxModule" => "fluxmodule.md",
"Hooks" => "hooks.md",
"Logging" => "logging.md",
"Trainer" => "trainer.md",
"Utils" => "utils.md",
],
# "Tutorials" => [] # TODO These walk you through various tasks. It's fine if they overlap quite a lot.

])
modules = [Tsunami],
doctest = true,
checkdocs = :exports,
format = Documenter.HTML(; mathengine, prettyurls, assets, sidebar_sitename),
sitename = "Tsunami.jl",
pages = [
"Get Started" => "index.md",

"Guides" => "guides.md",

# "Tutorials" => [] # TODO These walk you through various tasks. It's fine if they overlap quite a lot.

"API Reference" => [
# This essentially collects docstrings, with a bit of introduction.
"Callbacks" => "api/callbacks.md",
"FluxModule" => "api/fluxmodule.md",
"Foil" => "api/foil.md",
"Hooks" => "api/hooks.md",
"Logging" => "api/logging.md",
"Trainer" => "api/trainer.md",
"Utils" => "api/utils.md",
],
]
)

deploydocs(repo = "github.com/CarloLucibello/Tsunami.jl.git")
4 changes: 4 additions & 0 deletions docs/src/callbacks.md → docs/src/api/callbacks.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Callbacks

Callbacks are functions that are called at certain points in the training process. They are useful for logging, early stopping, and other tasks.
Expand Down
4 changes: 4 additions & 0 deletions docs/src/fluxmodule.md → docs/src/api/fluxmodule.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# FluxModule

```@docs
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api/foil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Foil

The [`Foil`](@ref) is a minimalistic version of the [`Trainer`](@ref) that allows to make only minimal changes to your Flux code while still obtaining many of the benefits of Tsunami. This is similar to what `Lighting Fabric` is to `PyTorch Lightning`. `Foil` also resembles HuggingFace's `accelerate` library.

```@docs
Foil
Tsunami.setup
Tsunami.setup_batch
```
4 changes: 4 additions & 0 deletions docs/src/hooks.md → docs/src/api/hooks.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Hooks

Hooks are a way to extend the functionality of Tsunami. They are a way to inject custom code into the FluxModule or
Expand Down
4 changes: 4 additions & 0 deletions docs/src/logging.md → docs/src/api/logging.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Logging

```@docs
Expand Down
6 changes: 5 additions & 1 deletion docs/src/trainer.md → docs/src/api/trainer.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@meta
CollapsedDocStrings = true
```

# Trainer

The [`Trainer`](@ref) struct is the main entry point for training a model. It is responsible for managing the training loop, logging, and checkpointing. It is also responsible for managing the [`FitState`](@ref Tsunami.FitState) struct, which contains the state of the training loop.
Expand All @@ -13,5 +17,5 @@ Tsunami.fit!
Tsunami.FitState
Tsunami.test
Tsunami.validate
Tsunami.Foil
```

6 changes: 5 additions & 1 deletion docs/src/utils.md → docs/src/api/utils.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Utils
```@meta
CollapsedDocStrings = true
```

# Utils

Tsunami provides some utility functions to make your life easier.

Expand Down
1 change: 1 addition & 0 deletions src/Tsunami.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export Checkpointer, load_checkpoint

include("foil.jl")
export Foil
@compat(public, (setup, setup_batch))

include("trainer.jl")
export Trainer
Expand Down
24 changes: 20 additions & 4 deletions src/foil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,31 @@ to_precision(foil::Foil, x) = x |> foil.fprec

is_using_gpu(foil::Foil) = !(foil.device isa CPUDevice)

function setup(foil::Foil, model::FluxModule)
return model |> to_precision(foil) |> to_device(foil)
end
"""
setup(foil::Foil, model, optimisers)

Setup the model and optimisers for training sending them to the device and setting the precision.
This function is called internally by [`Tsunami.fit!`](@ref).

function setup(foil::Foil, model::FluxModule, optimisers)
See also [`Foil`](@ref).
"""
function setup(foil::Foil, model, optimisers)
model = setup(foil, model)
optimisers = optimisers |> to_precision(foil) |> to_device(foil)
return model, optimisers
end

function setup(foil::Foil, model)
return model |> to_precision(foil) |> to_device(foil)
end


"""
setup_batch(foil::Foil, batch)

Setup the batch for training sending it to the device and setting the precision.
This function is called internally by [`Tsunami.fit!`](@ref).
"""
function setup_batch(foil::Foil, batch)
return batch |> to_precision(foil) |> to_device(foil)
end
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Tsunami = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
6 changes: 4 additions & 2 deletions test/fluxmodule.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
@testset "abstract type FluxModule" begin
@testitem "abstract type FluxModule" begin
@test isabstracttype(FluxModule)
end

@testset "functor" begin
@testitem "FluxModule functor" setup=[TsunamiTest] begin
using .TsunamiTest
using Functors: Functors
m = TestModule1()
@test Functors.children(m) == (; net = m.net, tuple_field = m.tuple_field)

Expand Down
14 changes: 14 additions & 0 deletions test/foil.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testitem "Foil constructor" begin
foil = Foil(accelerator=:cpu, precision=:f32, devices=nothing)
@test foil isa Foil
end

@testitem "Tsunami.setup" setup=[TsunamiTest] begin
using .TsunamiTest
foil = Foil(accelerator=:cpu, precision=:f32, devices=nothing)
model = Chain(Dense(28^2 => 512, relu), Dense(512 => 10))
opt_state = Flux.setup(AdamW(1e-3), model)
model, opt_state = Tsunami.setup(foil, model, opt_state)
@test model isa Chain
@test opt_state.layers[1].weight isa Optimisers.Leaf{<:AdamW}
end
32 changes: 19 additions & 13 deletions test/hooks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testset "on_before_update" begin
@testitem "on_before_update" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct OnBeforeUpdateCbk end
Expand All @@ -8,15 +9,16 @@
end
trainer = SilentTrainer(max_epochs=1, callbacks=[OnBeforeUpdateCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader)
@test length(out) == 2
@test out[1] isa NamedTuple
@test out[2] isa NamedTuple
@test size(out[1].net.layers[1].weight) == (3, 4)
end

@testset "on_train_epoch_start and on_train_epoch_end" begin
@testitem "on_train_epoch_start and on_train_epoch_end" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct TrainEpochCbk end
Expand All @@ -30,13 +32,14 @@ end
end
trainer = SilentTrainer(max_epochs=2, callbacks=[TrainEpochCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader)
@test out == [1, 2, 1, 2]
end


@testset "on_val_epoch_start and on_val_epoch_end" begin
@testitem "on_val_epoch_start and on_val_epoch_end" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct ValEpochCbk end
Expand All @@ -50,15 +53,17 @@ end
end
trainer = SilentTrainer(max_epochs=2, callbacks=[ValEpochCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader, train_dataloader)
@test out == [1, 2, 1, 2, 1, 2]

Tsunami.validate(model, trainer, train_dataloader)
@test out == [1, 2, 1, 2, 1, 2, 1, 2]
end

@testset "on_test_epoch_start and on_test_epoch_end" begin
@testitem "on_test_epoch_start and on_test_epoch_end" setup=[TsunamiTest] begin
using .TsunamiTest

out = []

struct TestEpochCbk end
Expand All @@ -72,26 +77,27 @@ end
end
trainer = SilentTrainer(max_epochs=2, callbacks=[TestEpochCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.test(model, trainer, train_dataloader)
@test out == [1, 2]
end

@testset "on_before_backprop" begin
@testitem "on_before_backprop" setup=[TsunamiTest] begin
using .TsunamiTest
out = []

struct BeforePullbackCbk end

function Tsunami.on_before_backprop(::BeforePullbackCbk, model, trainer, loss)
push!(out, loss)
push!(out, 1)
end
function Tsunami.on_before_update(::BeforePullbackCbk, model, trainer, grad)
@show grad
push!(out, 2)
end

trainer = SilentTrainer(max_epochs=2, callbacks=[BeforePullbackCbk()])
model = TestModule1()
train_dataloader = make_dataloader(io_sizes(model)..., 10, 5)
train_dataloader = make_dataloader(io_sizes(model)..., n=10, bs=5)
Tsunami.fit!(model, trainer, train_dataloader)
@test out == [1, 1]
@test out == [1, 2, 1, 2, 1, 2, 1, 2]
end
32 changes: 17 additions & 15 deletions test/linear_regression.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@

@testitem "linear regression" setup=[TsunamiTest] begin
using .TsunamiTest
N = 1000
α = 0.5
λ = 1f-5 / round(Int, N * α)

N = 1000
α = 0.5
λ = 1f-5 / round(Int, N * α)
M = round(Int, N * α)
teacher = LinearModel(N)
X = randn(Float32, N, M)
y = teacher(X)

M = round(Int, N * α)
teacher = LinearModel(N)
X = randn(Float32, N, M)
y = teacher(X)

model = LinearModel(N; λ)
@test model.W isa Matrix{Float32}
@test size(model.W) == (1, N)

trainer = SilentTrainer(max_epochs=1000, devices=[1])
fit_state = Tsunami.fit!(model, trainer, [(X, y)])
@test model.W isa Matrix{Float32} # by default precision is Float32
model = LinearModel(N; λ)
@test model.W isa Matrix{Float32}
@test size(model.W) == (1, N)
trainer = SilentTrainer(max_epochs=1000, devices=[1])
fit_state = Tsunami.fit!(model, trainer, [(X, y)])
@test model.W isa Matrix{Float32} # by default precision is Float32
@test Flux.mse(model(X), y) < 1e-1
end
3 changes: 2 additions & 1 deletion test/logging.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testset "TensorBoard logging" begin
@testitem "TensorBoard logging" setup=[TsunamiTest] begin
using .TsunamiTest
batch_sizes = [3, 2]
max_epochs = 4

Expand Down
Loading
Loading