From e62cb63a10be84aa50337c346a734e619bfebd39 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 17 Mar 2024 10:22:30 +0100 Subject: [PATCH] follow `@functor` -> `@layer` changes in Flux (#62) * adapt_structure * adapt compat * fixes * doc fixes * cleanup --- Project.toml | 2 +- docs/src/trainer.md | 1 + src/Tsunami.jl | 4 ++-- src/fluxmodule.jl | 12 +++++++++--- src/show.jl | 7 ++++++- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index eccb58d..55c4ab5 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Adapt = "3" +Adapt = "3, 4" BSON = "0.3.6" ChainRulesCore = "1" Crayons = "4" diff --git a/docs/src/trainer.md b/docs/src/trainer.md index 205b2ef..8c4f762 100644 --- a/docs/src/trainer.md +++ b/docs/src/trainer.md @@ -14,4 +14,5 @@ Tsunami.fit! Tsunami.FitState Tsunami.test Tsunami.validate +Tsunami.Foil ``` diff --git a/src/Tsunami.jl b/src/Tsunami.jl index 937969b..49ce18d 100644 --- a/src/Tsunami.jl +++ b/src/Tsunami.jl @@ -30,8 +30,6 @@ include("utils.jl") include("stats.jl") # export Stats -include("show.jl") - include("fluxmodule.jl") export FluxModule # train_step, @@ -40,6 +38,8 @@ export FluxModule # predict_step, # configure_optimizers +include("show.jl") + include("hooks.jl") # export on_before_update, # on_before_backprop, diff --git a/src/fluxmodule.jl b/src/fluxmodule.jl index c5f8cf0..957786d 100644 --- a/src/fluxmodule.jl +++ b/src/fluxmodule.jl @@ -4,7 +4,10 @@ An abstract type for Flux models. A `FluxModule` helps orgainising you code and provides a standard interface for training. -A `FluxModule` comes with `functor` already implemented. +A `FluxModule` comes with the functionality provided by `Flux.@layer` +(cpu/gpu movement, parameter management, etc.) and the ability to interact with +[`Trainer`](@ref) and `Optimisers.jl`. + You can change the trainables by implementing `Optimisers.trainables`. Types inheriting from `FluxModule` have to be mutable. They also @@ -62,13 +65,16 @@ model, fit_state = Tsunami.fit(model, trainer, train_dataloader) """ abstract type FluxModule end -function Functors.functor(::Type{<:FluxModule}, m::T) where T + +function Functors.functor(::Type{T}, m) where {T<:FluxModule} childr = (; (f => getfield(m, f) for f in fieldnames(T))...) Tstripped = Base.typename(T).wrapper # remove all parameters. From https://discourse.julialang.org/t/stripping-parameter-from-parametric-types/8293/16 - re = x -> Tstripped(x...) + re = Base.splat(Tstripped) return childr, re end +Adapt.adapt_structure(to, m::FluxModule) = Functors.fmap(Adapt.adapt(to), m) + Base.show(io::IO, mime::MIME"text/plain", m::FluxModule) = fluxshow(io, mime, m) Base.show(io::IO, m::FluxModule) = shortshow(io, m) diff --git a/src/show.jl b/src/show.jl index 79a6512..a4466bd 100644 --- a/src/show.jl +++ b/src/show.jl @@ -1,4 +1,6 @@ -function fluxshow(io::IO, m::MIME"text/plain", x::T) where T +# Show methods that Flux defines through `@layer` +# https://github.com/FluxML/Flux.jl/blob/master/src/layers/show.jl#L4 +function fluxshow(io::IO, m::MIME"text/plain", x) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL Flux._big_show(io, x) elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix @@ -7,6 +9,9 @@ function fluxshow(io::IO, m::MIME"text/plain", x::T) where T show(io, x) end end +# Don't show Chain(Tuple(...)), always splat that. And ignore Recur's non-trainable state: +Flux._show_children(x::FluxModule) = Flux._flat_children(trainable(x)) + function shortshow(io::IO, x::T) where T str = string(T.name.name)