Skip to content

Commit

Permalink
add train_autodiff macro
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 10, 2022
1 parent 7aa74e0 commit e6f7b9e
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion src/train/Train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap

export train!, update!, adjust!, FluxState,
export train!, update!, adjust!, FluxState, @train_autodiff,
Descent, Adam, Momentum, Nesterov, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #,
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand Down Expand Up @@ -108,6 +108,48 @@ include("implicit_train.jl") # Params etc, Zygote only

explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor

"""
@train_autodiff Diffractor
@train_autodiff Yota
@train_autodiff Zygote
This macro allows the use of `train!` with various automatic differentiation packages,
instead of the default Zygote.jl. You should load the package, then call this macro.
Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`,
since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific.
!!! note
Experimental
"""
macro train_autodiff(pkg)
if pkg == :Diffractor
return quote
Diffractor.gradient(sin, 0.0)[1] 1.0 # ensures an error if not loaded
function Flux.Train.explicit_withgradient(f, args...)
y, back = Diffractor.∂⃖¹(f, args...)
dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors
return (; value = y, gradient = Base.tail(back(dy1)))
end
end |> esc
elseif pkg == :Yota
return quote
Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
function Flux.Train.explicit_withgradient(f, args...)
value, (_, gradient...) = Yota.grad(f, args...)
return (; value, gradient)
end
end |> esc
elseif pkg == :Zygote
return quote
Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
end |> esc
else
throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.")
end
end


### Misc. related utilities

"""
Expand Down

0 comments on commit e6f7b9e

Please sign in to comment.