-
-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtraining.jl
175 lines (137 loc) · 4.86 KB
/
training.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
epoch!(learner, phase[, dataiter])
Train `learner` for one epoch on `dataiter`. Iterates through
`dataiter` and [`step!`](#)s for each batch/item.
If no data iterator is passed in, use `learner.data[phasedataiter(phase)]`.
## Extending
The default implementation iterates over every batch in `dataiter`
and calls [`step!`](#) for each. This behavior can be overloaded
by implementing `epoch!(learner, ::MyPhase, dataiter)`.
If you're implementing a custom `epoch!` method, it is recommended
you make use of [`runepoch`](#) to get begin and end events as well
as proper handling of [`CancelEpochException`](#)s.
See the default implementation for reference.
"""
function epoch!(learner, phase::Phase, dataiter=learner.data[phasedataiter(phase)])
runepoch(learner, phase) do _
for batch in dataiter
step!(learner, phase, batch)
end
end
end
"""
step!(learner, phase::Phase, batch)
Run one step of training for `learner` on batch.
Behavior is customized through `phase`.
## Extending
This is a required method for custom [`Phase`](#)s to implement.
To implement `step!`, it is recommended you make use of [`runstep`](#)
to get begin and end events as well as proper handling of
[`CancelStepException`](#)s.
See the implementations of [`TrainingPhase`](#) and [`ValidationPhase`](#)
for reference.
"""
function step! end
function step!(learner, phase::TrainingPhase, batch)
xs, ys = batch
runstep(learner, phase, (; xs=xs, ys=ys)) do handle, state
state.grads = _gradient(learner.optimizer, learner.model, learner.params) do model
state.ŷs = model(state.xs)
handle(LossBegin())
state.loss = learner.lossfn(state.ŷs, state.ys)
handle(BackwardBegin())
return state.loss
end
handle(BackwardEnd())
learner.params, learner.model = _update!(
learner.optimizer, learner.params, learner.model, state.grads)
end
end
# Handle both old Flux.jl and new Optimisers.jl optimisers
_gradient(f, _, m, _) = gradient(f, m)[1]
_gradient(f, ::Flux.Optimise.AbstractOptimiser, m, ps::Params) = gradient(() -> f(m), ps)
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads)
update!(optimizer, params, grads)
return params, model
end
function _update!(_, st, model, grads)
st, model = Optimisers.update!(st, model, grads)
return st, model
end
function step!(learner, phase::ValidationPhase, batch)
xs, ys = batch
runstep(learner, phase, (;xs=xs, ys=ys)) do _, state
state.ŷs = learner.model(state.xs)
state.loss = learner.lossfn(state.ŷs, state.ys)
end
end
"""
runepoch(epochfn, learner, phase)
Run `epochfn` inside the context of an epoch. Calls `epochfn(handle)`
where `handle(e)` can be called to dispatch events.
Takes care of dispatching [`EpochBegin`](#) and [`EpochEnd`](#)
events as well as handling [`CancelEpochException`](#)s.
"""
function runepoch(epochfn, learner, phase::Phase)
handlefn(e) = handle(learner.callbacks.runner, e, phase, learner)
try
handlefn(EpochBegin())
epochfn(handlefn)
handlefn(EpochEnd())
catch e
if e isa CancelEpochException
@debug "Epoch skipped" error = e
handlefn(EpochEnd())
else
rethrow()
end
end
end
"""
runstep(stepfn, learner, phase) -> state
Run `stepfn` inside the context of a step. Calls `stepfn(handle, state)`
where `handle(e)` can be called to dispatch events and `state` is a [`PropDict`](#)
which step data, gradients and losses can be written to. Return `state`.
Takes care of dispatching [`StepBegin`](#) and [`StepEnd`](#)
events as well as handling [`CancelStepException`](#)s.
"""
function runstep(stepfn, learner, phase::Phase, initialstate = (;))
state = PropDict(pairs(initialstate))
handlefn(e) = handle(learner.callbacks.runner, e, phase, learner)
try
learner.step = state
handlefn(StepBegin())
stepfn(handlefn, state)
handlefn(StepEnd())
return state
catch e
if e isa CancelStepException
@debug "Step skipped" error = e
else
rethrow()
end
end
return state
end
# Utilities
"""
fit!(learner, nepochs)
fit!(learner, nepochs, (trainiter, validiter))
Train `learner` for `nepochs` of training and validation each. Use data
iterators that are passed in. If none are given, use `learner.data.training`
and `learner.data.validation`.
## Examples
```julia
fit!(learner, 10)
fit!(learner, 10, (traindl, valdl))
```
"""
function fit!(learner, nepochs::Int, (trainiter, validiter))
for i in 1:nepochs
epoch!(learner, TrainingPhase(), trainiter)
epoch!(learner, ValidationPhase(), validiter)
end
end
function fit!(learner, nepochs::Int)
fit!(learner, nepochs, (learner.data.training, learner.data.validation))
end