diff --git a/docs/src/utils.md b/docs/src/utils.md index 3adb1d4c1..fbe8a57cc 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -26,6 +26,7 @@ Zygote.hook Zygote.Buffer Zygote.forwarddiff Zygote.checkpointed +Zygote.eager_update! ``` `Params` and `Grads` can be copied to and from arrays using the `copy!` function. diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 6b9002f73..92be71d34 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -27,6 +27,46 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs... return y, pullback_checkpointed end + + +""" + + eager_update!(state, model, update!) + +Eagerly updates the model parameters, discarding the updated gradients to save memory. +`model` stores the parameters to be updated, `state` is the optimization state (eg. from Optimisers.jl) matching your model component, and +`update!` is the function that updates the parameters (eg. from `Optimisers.jl`), usually called as `update!(state, model, grads)`. + +If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with: + +```julia +h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) +``` + +or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with: + +```julia +h = Zygote.checkpointed(f, eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) +``` + +If `model.layers[i]` itself is callable, we can use the above by first wrapping it: + +```julia +f(model, xs...) = model(xs...) +h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) +``` + +!!! warning + If different layers share trainable parameters, then `eager_update!` will likely give wrong results. +""" +function eager_update!(state, model, update!) + function update_hook(dmodel) + update!(state, model, dmodel) + return nothing + end + return Zygote.hook(update_hook, model) +end + """ hessian(f, x)