-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathtraining.jl
304 lines (239 loc) · 10.4 KB
/
training.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
module Training
using ADTypes: AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote
using Compat: @compat
using ConcreteStructs: @concrete
using FastClosures: @closure
using Optimisers: Optimisers
using ..Lux: Lux
using LuxCore: LuxCore, AbstractLuxLayer
"""
TrainState
Training State containing:
- `model`: `Lux` model.
- `parameters`: Trainable Variables of the `model`.
- `states`: Non-trainable Variables of the `model`.
- `optimizer`: Optimizer from `Optimisers.jl`.
- `optimizer_state`: Optimizer State.
- `step`: Number of updates of the parameters made.
Internal fields:
- `cache`: Cached values. Implementations are free to use this for whatever they want.
- `objective_function`: Objective function might be cached.
!!! warning
Constructing this object directly shouldn't be considered a stable API. Use the
version with the Optimisers API.
"""
@concrete struct TrainState
cache
objective_function
model
parameters
states
optimizer
optimizer_state
step::Int
end
"""
TrainState(model::Lux.AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
Constructor for [`TrainState`](@ref).
## Arguments
- `rng`: Random Number Generator.
- `ps`: Parameters of the model.
- `st`: States of the model.
- `model`: `Lux` model.
- `optimizer`: Optimizer from `Optimisers.jl`.
- `transform_variables`: Function to transform the variables of the model. Typically used
to transfer variables to GPU / CPU.
## Returns
[`TrainState`](@ref) object.
"""
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
st_opt = Optimisers.setup(optimizer, ps)
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end
@concrete struct TrainingBackendCache{backend, first_try}
dparameters
extras
end
training_backend(::TrainingBackendCache{backend}) where {backend} = backend
function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
println(io, "TrainState")
println(io, " model: ", ts.model)
println(io, " # of parameters: ", LuxCore.parameterlength(ts.parameters))
println(io, " # of states: ", LuxCore.statelength(ts.states))
println(io, " optimizer: ", ts.optimizer)
print(io, " step: ", ts.step)
if ts.cache !== nothing
if ts.cache isa TrainingBackendCache
print(io,
"\n cache: $(nameof(typeof(ts.cache))){$(training_backend(ts.cache))}")
else
print(io, "\n cache: $(nameof(typeof(ts.cache)))")
end
end
ts.objective_function !== nothing &&
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end
const APPLY_GRAD_DOCSTRING = """
## Arguments
- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.
## Returns
Updated [`TrainState`](@ref) object.
"""
"""
apply_gradients(ts::TrainState, grads)
Update the parameters stored in `ts` using the gradients `grads`.
$(APPLY_GRAD_DOCSTRING)
"""
function apply_gradients(ts::TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return TrainState(ts.cache, ts.objective_function, ts.model, ps,
ts.states, ts.optimizer, optimizer_state, ts.step + 1)
end
"""
apply_gradients!(ts::TrainState, grads)
Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version
of [`apply_gradients`](@ref).
$(APPLY_GRAD_DOCSTRING)
"""
function apply_gradients!(ts::TrainState, grads)
Optimisers.update!(ts.optimizer_state, ts.parameters, grads)
return TrainState(ts.cache, ts.objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer, ts.optimizer_state, ts.step + 1)
end
"""
compute_gradients(ad::AbstractADType, objective_function::Function, data,
ts::TrainState)
Compute the gradients of the objective function wrt parameters stored in `ts`.
## Backends & AD Packages
| Supported Backends | Packages Needed |
|:---------------------------- |:---------------- |
| `AutoZygote` | `Zygote.jl` |
| `AutoReverseDiff(; compile)` | `ReverseDiff.jl` |
| `AutoTracker` | `Tracker.jl` |
| `AutoEnzyme` | `Enzyme.jl` |
## Arguments
- `ad`: Backend (from [ADTypes.jl](https://github.com/SciML/ADTypes.jl)) used to compute
the gradients.
- `objective_function`: Objective function. The function must take 4 inputs -- model,
parameters, states and data. The function must return 3 values -- loss, updated_state,
and any computed statistics.
- `data`: Data used to compute the gradients.
- `ts`: Current Training State. See [`TrainState`](@ref).
## Return
A 4-Tuple containing:
- `grads`: Computed Gradients.
- `loss`: Loss from the objective function.
- `stats`: Any computed statistics from the objective function.
- `ts`: Updated Training State.
## Known Limitations
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with non-empty state
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these
issues in most cases and throw an error.
!!! danger "Aliased Gradients"
`grads` returned by this function might be aliased by the implementation of the gradient
backend. For example, if you cache the `grads` from step `i`, the new gradients
returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
"""
function compute_gradients(ad::AbstractADType, ::F, _, ::TrainState) where {F}
return check_if_compute_gradients_implemented(ad)
end
function check_if_compute_gradients_implemented(::T) where {T <: AbstractADType}
throw(ArgumentError("Support for AD backend $(nameof(T)) has not been implemented \
yet!"))
end
for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
adtype = Symbol(:Auto, package)
msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \
function!"
@eval function check_if_compute_gradients_implemented(::$(adtype))
throw(ArgumentError($msg))
end
end
function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F}
@warn "Detected function wrapper generation with function being updated between calls. \
This will generate type-unstable code. A possible reason for this is \
`TrainState` was compiled (first call to `compute_gradients`) with function \
`foo` and is being called with `bar`. A common pattern for this would be \
passing an anonymous function as `objective_function` inside a loop." maxlog=1
return Ref{Any}(), Ref{NamedTuple}()
end
# Run the code when trying to compile the function for the first time.
function generate_wrappers(objective_function::F, m, ps, st, data, ::Val{true}) where {F}
_, stₙ, statsₙ = objective_function(m, ps, st, data)
return Ref{typeof(stₙ)}(stₙ), Ref{typeof(statsₙ)}(statsₙ)
end
function wrap_objective_function(
objective_function::F, m, ps, st, data, first_try::Val) where {F}
st_updated, stats = generate_wrappers(objective_function, m, ps, st, data, first_try)
wrapped_objective_function = @closure (model, ps, st, data) -> begin
loss, st_, stats_ = objective_function(model, ps, st, data)
Lux.Utils.set_refval!(st_updated, st_)
Lux.Utils.set_refval!(stats, stats_)
return loss
end
return wrapped_objective_function, st_updated, stats
end
"""
single_train_step!(backend, obj_fn::F, data, ts::TrainState)
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
## Return
Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step! end
"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
In most cases you should use [`single_train_step!`](@ref) instead of this function.
## Return
Returned values are the same as [`compute_gradients`](@ref).
"""
function single_train_step end
for inplace in ("!", "")
step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace)
@eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F}
grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts)
ts = $(apply_fn)(ts, grads)
return grads, loss, stats, ts
end
end
# Simple extension to the `adjust!` API
function Optimisers.adjust!(ts::TrainState, eta::Real)
st_opt = ts.optimizer_state
Optimisers.adjust!(st_opt, eta)
optimizer = Optimisers.adjust(ts.optimizer, eta)
return TrainState(ts.cache, ts.objective_function, ts.model,
ts.parameters, ts.states, optimizer, st_opt, ts.step)
end
function Optimisers.adjust!(ts::TrainState; kwargs...)
st_opt = ts.optimizer_state
Optimisers.adjust!(st_opt; kwargs...)
optimizer = Optimisers.adjust(ts.optimizer; kwargs...)
return TrainState(ts.cache, ts.objective_function, ts.model,
ts.parameters, ts.states, optimizer, st_opt, ts.step)
end
function Optimisers.adjust(ts::TrainState, eta::Real)
st_opt = Optimisers.adjust(ts.optimizer_state, eta)
optimizer = Optimisers.adjust(ts.optimizer, eta)
return TrainState(ts.cache, ts.objective_function, ts.model,
ts.parameters, ts.states, optimizer, st_opt, ts.step)
end
function Optimisers.adjust(ts::TrainState; kwargs...)
st_opt = Optimisers.adjust(ts.optimizer_state; kwargs...)
optimizer = Optimisers.adjust(ts.optimizer; kwargs...)
return TrainState(ts.cache, ts.objective_function, ts.model,
ts.parameters, ts.states, optimizer, st_opt, ts.step)
end
@compat(public,
(TrainState, apply_gradients, apply_gradients!,
compute_gradients, single_train_step, single_train_step!))
export AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote
end