-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01_machine_learning.jl
442 lines (384 loc) · 12.6 KB
/
01_machine_learning.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
using DeepPumas
using DeepPumas.SimpleChains
using StableRNGs
using CairoMakie
using Distributions
using Random
#
# TABLE OF CONTENTS
#
# 1. A SIMPLE MACHINE LEARNING (ML) MODEL
#
# 1.1. Sample subjects with an obvious `true_function`
# 1.2. Model `true_function` with a linear regression
#
# 2. CAPTURING COMPLEX RELATIONSHIPS
#
# 2.1. Sample subjects with a more complex `true_function`
# 2.2. Exercise: Reason about using a linear regression to model the current `true_function`
# 2.3. Use a neural network (NN) to model `true_function`
#
# 3. BASIC UNDERFITTING AND OVERFITTING
#
# 3.1. Exercise: Investigate the impact of the number of fitting iterations in NNs
# (Hint: Train `model_ex2` on `population_ex2` for few and for many iterations.)
# 3.2. Exercise: Reason about Exercise 2.2 again (that is, using a linear regression
# to model a quadratic relationship). Is the number of iterations relevant there?
# 3.3. The impact of the NN size
#
# 4. INSPECTION OF THE VALIDATION LOSS AND REGULARIZATION
#
# 4.1. Validation loss as a proxy for generalization performance
# 4.2. Regularization to prevent overfitting
#
"""
Helper Pumas model to generate synthetic data. Subjects will have one
covariate `x` and one observation `y ~ Normal(true_function(x), σ)`.
`true_function` and `σ` have to be defined independently, and the probability
distribution of `x` has to be determined in the call to `synthetic_data`.
"""
data_model = @model begin
@covariates x
@pre x_ = x
@derived begin
y ~ @. Normal(true_function(x_), σ)
end
end
#
# 1. A SIMPLE MACHINE LEARNING (ML) MODEL
#
# 1.1. Sample subjects with an obvious `true_function`
# 1.2. Model `true_function` with a DeepPumas linear regression
#
# 1.1. Sample subjects with an obvious `true_function`
true_function = x -> x
σ = 0.25
population_ex1 = synthetic_data(
data_model;
covariates = (; x = Uniform(-1, 1)),
obstimes = [0.0],
rng = StableRNG(0), # must use `StableRNGs` until bug fix in next release
)
x = [only(subject.covariates().x) for subject in population_ex1]
y = [only(subject.observations.y) for subject in population_ex1]
begin
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
# 1.2. Model `true_function` with a linear regression
model_ex1 = @model begin
@param begin
a ∈ RealDomain()
b ∈ RealDomain()
σ ∈ RealDomain(; lower = 0.0)
end
@covariates x
@pre ŷ = a * x + b
@derived y ~ @. Normal(ŷ, σ)
end
fpm = fit(model_ex1, population_ex1, init_params(model_ex1), NaivePooled());
fpm # `true_function` is y = x (that is, a = 1 b = 0) and σ = 0.25
ŷ = [only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
begin
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
scatter!(x, ŷ, label = "prediction")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
#
# 2. CAPTURING COMPLEX RELATIONSHIPS
#
# 2.1. Sample subjects with a more complex `true_function`
# 2.2. Exercise: Reason about using a linear regression to model the current `true_function`
# 2.3. Use a neural network (NN) to model `true_function`
#
# 2.1. Sample subjects with a more complex `true_function`
true_function = x -> x^2 # the examples aim to be insightful; please, play along!
σ = 0.25
population_ex2 = synthetic_data(
data_model;
covariates = (; x = Uniform(-1, 1)),
obstimes = [0.0],
rng = StableRNG(0), # must use `StableRNGs` until bug fix in next release
)
x = [only(subject.covariates().x) for subject in population_ex2]
y = [only(subject.observations.y) for subject in population_ex2]
begin
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
# 2.2. Exercise: Reason about using a linear regression to model the current `true_function`
solution_ex22 = begin
fpm = fit(model_ex1, population_ex2, init_params(model_ex1), MAP(NaivePooled()))
ŷ_ex22 = [only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
scatter!(x, ŷ_ex22, label = "prediction (fpm)")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
# 2.3. Use a neural network (NN) to model `true_function`
model_ex2 = @model begin
@param begin
nn ∈ MLP(1, (8, tanh), (1, identity); bias = true)
σ ∈ RealDomain(; lower = 0.0)
end
@covariates x
@pre ŷ = only(nn(x))
@derived y ~ @. Normal(ŷ, σ)
end
fpm = fit(
model_ex2,
population_ex2,
init_params(model_ex2),
NaivePooled();
optim_options = (; iterations = 100),
);
fpm # try to make sense of the parameters in the NN
ŷ_ex23 = [only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
begin
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
scatter!(x, ŷ_ex23, label = "prediction")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
#
# 3. BASIC UNDERFITTING AND OVERFITTING
#
# 3.1. Exercise: Investigate the impact of the number of fitting iterations in NNs
# (Hint: Train `model_ex2` on `population_ex2` for few and for many iterations.)
# 3.2. Exercise: Reason about Exercise 2.2 again (that is, using a linear regression
# to model a quadratic relationship). Is the number of iterations relevant there?
# 3.3. The impact of the NN size
#
# 3.1. Exercise: Investigate the impact of the number of fitting iterations in NNs
# (Hint: Train `model_ex2` on `population_ex2` for few and for many iteration
solution_ex31 = begin
fpm = fit(
model_ex2,
population_ex2,
init_params(model_ex2),
NaivePooled();
optim_options = (; iterations = 10),
)
ŷ_underfit =
[only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
fpm = fit(
model_ex2,
population_ex2,
init_params(model_ex2),
NaivePooled();
optim_options = (; iterations = 5_000),
)
ŷ_overfit =
[only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
scatter!(x, ŷ_underfit, label = "prediction (10 iterations)")
scatter!(x, ŷ_ex23, label = "prediction (100 iterations)")
scatter!(x, ŷ_overfit, label = "prediction (5k iterations)")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
# 3.2. Exercise: Reason about Exercise 2.2 again (that is, using a linear regression
# to model a quadratic relationship). Is the number of iterations relevant there?
# Investigate the effect of `max_iterations`.
solution_ex32 = begin
max_iterations = 10
fpm = fit(
model_ex1,
population_ex2,
init_params(model_ex1),
NaivePooled();
optim_options = (; iterations = max_iterations),
)
ŷ = [only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
scatter!(x, ŷ, label = "prediction ($max_iterations iterations)")
scatter!(x, ŷ_ex22, label = "prediction (exercise 2.2)")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
# 3.3. The impact of the NN size
model_ex3 = @model begin
@param begin
nn ∈ MLP(1, (32, tanh), (32, tanh), (1, identity); bias = true)
σ ∈ RealDomain(; lower = 0.0)
end
@covariates x
@pre ŷ = only(nn(x))
@derived y ~ @. Normal(ŷ, σ)
end
fpm = fit(
model_ex3,
population_ex2,
init_params(model_ex3),
NaivePooled();
optim_options = (; iterations = 1000),
);
ŷ = [only(subject_prediction.pred.y) for subject_prediction in predict(fpm)]
begin
f = scatter(
x,
y;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "data (each dot is a subject)",
)
scatter!(x, ŷ, label = "prediction (32x32 units - 1k iter)")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
#
# 4. INSPECTION OF THE VALIDATION LOSS AND REGULARIZATION
#
# 4.1. Validation loss as a proxy for generalization performance
# 4.2. Regularization to prevent overfitting
#
# 4.1. Validation loss as a proxy for generalization performance
population_train = population_ex2
x_train, y_train = x, y
population_valid = synthetic_data(
data_model;
covariates = (; x = Uniform(-1, 1)),
obstimes = [0.0],
rng = StableRNG(1), # must use `StableRNGs` until bug fix in next release
)
x_valid = [only(subject.covariates().x) for subject in population_valid]
y_valid = [only(subject.observations.y) for subject in population_valid]
begin
f = scatter(
x_train,
y_train;
axis = (xlabel = "covariate x", ylabel = "observation y"),
label = "training data",
)
scatter!(x_valid, y_valid; label = "validation data")
lines!(-1:0.1:1, true_function.(-1:0.1:1); color = :gray, label = "true")
axislegend()
f
end
begin
loss_train_l, loss_valid_l = [], []
fpm = fit(
model_ex3,
population_train,
init_params(model_ex3),
NaivePooled();
optim_options = (; iterations = 10),
)
push!(loss_train_l, cost(model_ex3, population_train, coef(fpm), nothing, mse))
push!(loss_valid_l, cost(model_ex3, population_valid, coef(fpm), nothing, mse))
iteration_blocks = 100
for _ = 2:iteration_blocks
fpm = fit(
model_ex3,
population_train,
coef(fpm),
MAP(NaivePooled());
optim_options = (; iterations = 10),
)
push!(loss_train_l, cost(model_ex3, population_train, coef(fpm), nothing, mse))
push!(loss_valid_l, cost(model_ex3, population_valid, coef(fpm), nothing, mse))
end
end
begin
f, ax = scatterlines(
1:iteration_blocks,
Float32.(loss_train_l);
label = "training",
axis = (; xlabel = "Blocks of 10 iterations", ylabel = "Mean squared loss"),
)
scatterlines!(1:iteration_blocks, Float32.(loss_valid_l); label = "validation")
axislegend()
f
end
# 4.2. Regularization to prevent overfitting
model_ex4 = @model begin
@param begin
nn ∈ MLP(1, (32, tanh), (32, tanh), (1, identity); bias = true, reg = L2(1.0))
σ ∈ RealDomain(; lower = 0.0)
end
@covariates x
@pre ŷ = only(nn(x))
@derived y ~ @. Normal(ŷ, σ)
end
begin
reg_loss_train_l, reg_loss_valid_l = [], []
fpm = fit(
model_ex4,
population_train,
init_params(model_ex4),
MAP(NaivePooled());
optim_options = (; iterations = 10),
)
push!(reg_loss_train_l, cost(model_ex4, population_train, coef(fpm), nothing, mse))
push!(reg_loss_valid_l, cost(model_ex4, population_valid, coef(fpm), nothing, mse))
iteration_blocks = 100
for _ = 2:iteration_blocks
fpm = fit(
model_ex4,
population_train,
coef(fpm),
MAP(NaivePooled());
optim_options = (; iterations = 10),
)
push!(reg_loss_train_l, cost(model_ex4, population_train, coef(fpm), nothing, mse))
push!(reg_loss_valid_l, cost(model_ex4, population_valid, coef(fpm), nothing, mse))
end
end
begin
f, ax = scatterlines(
1:iteration_blocks,
Float32.(loss_train_l);
label = "training",
axis = (; xlabel = "Blocks of 10 iterations", ylabel = "Mean squared loss"),
)
scatterlines!(1:iteration_blocks, Float32.(loss_valid_l); label = "validation")
scatterlines!(1:iteration_blocks, Float32.(reg_loss_train_l); label = "training (L2)")
scatterlines!(1:iteration_blocks, Float32.(reg_loss_valid_l); label = "validation (L2)")
axislegend()
f
end