From e6f7b9ebc94cba5cb0ef6fe3418e7b068ec708d4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 10 Aug 2022 09:28:31 -0700 Subject: [PATCH] add train_autodiff macro --- src/train/Train.jl | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/train/Train.jl b/src/train/Train.jl index ac5724dbc6..4ca7861ccc 100644 --- a/src/train/Train.jl +++ b/src/train/Train.jl @@ -4,7 +4,7 @@ using LinearAlgebra using Optimisers: Optimisers using Functors: fmap -export train!, update!, adjust!, FluxState, +export train!, update!, adjust!, FluxState, @train_autodiff, Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #, # InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, @@ -108,6 +108,48 @@ include("implicit_train.jl") # Params etc, Zygote only explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor +""" + @train_autodiff Diffractor + @train_autodiff Yota + @train_autodiff Zygote + +This macro allows the use of `train!` with various automatic differentiation packages, +instead of the default Zygote.jl. You should load the package, then call this macro. + +Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`, +since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific. + +!!! note + Experimental +""" +macro train_autodiff(pkg) + if pkg == :Diffractor + return quote + Diffractor.gradient(sin, 0.0)[1] ≈ 1.0 # ensures an error if not loaded + function Flux.Train.explicit_withgradient(f, args...) + y, back = Diffractor.∂⃖¹(f, args...) + dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors + return (; value = y, gradient = Base.tail(back(dy1))) + end + end |> esc + elseif pkg == :Yota + return quote + Yota.grad(sin, 0.0) # [2][1] ≈ 1.0 + function Flux.Train.explicit_withgradient(f, args...) + value, (_, gradient...) = Yota.grad(f, args...) + return (; value, gradient) + end + end |> esc + elseif pkg == :Zygote + return quote + Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...) + end |> esc + else + throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.") + end +end + + ### Misc. related utilities """