forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvae_mnist.jl
178 lines (150 loc) · 5.17 KB
/
vae_mnist.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Variational Autoencoder(VAE)
#
# Auto-Encoding Variational Bayes
# Diederik P Kingma, Max Welling
# https://arxiv.org/abs/1312.6114
using BSON
using CUDA
using DrWatson: struct2dict
using Flux
using Flux: @functor, chunk
using Flux.Losses: logitbinarycrossentropy
using Flux.Data: DataLoader
using Images
using Logging: with_logger
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using TensorBoardLogger: TBLogger, tb_overwrite
using Random
# load MNIST images and return loader
function get_data(batch_size)
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
xtrain = reshape(xtrain, 28^2, :)
DataLoader((xtrain, ytrain), batchsize=batch_size, shuffle=true)
end
struct Encoder
linear
μ
logσ
end
@functor Encoder
Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Encoder(
Dense(input_dim, hidden_dim, tanh), # linear
Dense(hidden_dim, latent_dim), # μ
Dense(hidden_dim, latent_dim), # logσ
)
function (encoder::Encoder)(x)
h = encoder.linear(x)
encoder.μ(h), encoder.logσ(h)
end
Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Chain(
Dense(latent_dim, hidden_dim, tanh),
Dense(hidden_dim, input_dim)
)
function reconstuct(encoder, decoder, x, device)
μ, logσ = encoder(x)
z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
μ, logσ, decoder(z)
end
function model_loss(encoder, decoder, λ, x, device)
μ, logσ, decoder_z = reconstuct(encoder, decoder, x, device)
len = size(x)[end]
# KL-divergence
kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len
logp_x_z = -logitbinarycrossentropy(decoder_z, x, agg=sum) / len
# regularization
reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
-logp_x_z + kl_q_p + reg
end
function convert_to_image(x, y_size)
Gray.(permutedims(vcat(reshape.(chunk(x |> cpu, y_size), 28, :)...), (2, 1)))
end
# arguments for the `train` function
@with_kw mutable struct Args
η = 1e-3 # learning rate
λ = 0.01f0 # regularization paramater
batch_size = 128 # batch size
sample_size = 10 # sampling size for output
epochs = 20 # number of epochs
seed = 0 # random seed
cuda = true # use GPU
input_dim = 28^2 # image size
latent_dim = 2 # latent dimension
hidden_dim = 500 # hidden dimension
verbose_freq = 10 # logging for every verbose_freq iterations
tblogger = false # log training with tensorboard
save_path = "output" # results path
end
function train(; kws...)
# load hyperparamters
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)
# GPU config
if args.cuda && CUDA.has_cuda()
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end
# load MNIST images
loader = get_data(args.batch_size)
# initialize encoder and decoder
encoder = Encoder(args.input_dim, args.latent_dim, args.hidden_dim) |> device
decoder = Decoder(args.input_dim, args.latent_dim, args.hidden_dim) |> device
# ADAM optimizer
opt = ADAM(args.η)
# parameters
ps = Flux.params(encoder.linear, encoder.μ, encoder.logσ, decoder)
!ispath(args.save_path) && mkpath(args.save_path)
# logging by TensorBoard.jl
if args.tblogger
tblogger = TBLogger(args.save_path, tb_overwrite)
end
# fixed input
original, _ = first(get_data(args.sample_size^2))
original = original |> device
image = convert_to_image(original, args.sample_size)
image_path = joinpath(args.save_path, "original.png")
save(image_path, image)
# training
train_steps = 0
@info "Start Training, total $(args.epochs) epochs"
for epoch = 1:args.epochs
@info "Epoch $(epoch)"
progress = Progress(length(loader))
for (x, _) in loader
loss, back = Flux.pullback(ps) do
model_loss(encoder, decoder, args.λ, x |> device, device)
end
grad = back(1f0)
Flux.Optimise.update!(opt, ps, grad)
# progress meter
next!(progress; showvalues=[(:loss, loss)])
# logging with TensorBoard
if args.tblogger && train_steps % args.verbose_freq == 0
with_logger(tblogger) do
@info "train" loss=loss
end
end
train_steps += 1
end
# save image
_, _, rec_original = reconstuct(encoder, decoder, original, device)
rec_original = sigmoid.(rec_original)
image = convert_to_image(rec_original, args.sample_size)
image_path = joinpath(args.save_path, "epoch_$(epoch).png")
save(image_path, image)
@info "Image saved: $(image_path)"
end
# save model
model_path = joinpath(args.save_path, "model.bson")
let encoder = cpu(encoder), decoder = cpu(decoder), args=struct2dict(args)
BSON.@save model_path encoder decoder args
@info "Model saved: $(model_path)"
end
end
if abspath(PROGRAM_FILE) == @__FILE__
train()
end