-
Notifications
You must be signed in to change notification settings - Fork 231
/
Copy pathexample_gan.py
118 lines (97 loc) · 4.46 KB
/
example_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import matplotlib as mpl
# This line allows mpl to run with no DISPLAY defined
mpl.use('Agg')
import pandas as pd
import numpy as np
import os
from keras.layers import Reshape, Flatten, LeakyReLU, Activation
from keras.models import Sequential
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
from keras_adversarial.image_grid_callback import ImageGridCallback
from keras_adversarial import AdversarialModel, simple_gan, gan_targets
from keras_adversarial import normal_latent_sampling, AdversarialOptimizerSimultaneous
from keras_adversarial.legacy import l1l2, Dense, fit
import keras.backend as K
from mnist_utils import mnist_data
def model_generator(latent_dim, input_shape, hidden_dim=1024, reg=lambda: l1l2(1e-5, 1e-5)):
return Sequential([
Dense(int(hidden_dim / 4), name="generator_h1", input_dim=latent_dim, W_regularizer=reg()),
LeakyReLU(0.2),
Dense(int(hidden_dim / 2), name="generator_h2", W_regularizer=reg()),
LeakyReLU(0.2),
Dense(hidden_dim, name="generator_h3", W_regularizer=reg()),
LeakyReLU(0.2),
Dense(np.prod(input_shape), name="generator_x_flat", W_regularizer=reg()),
Activation('sigmoid'),
Reshape(input_shape, name="generator_x")],
name="generator")
def model_discriminator(input_shape, hidden_dim=1024, reg=lambda: l1l2(1e-5, 1e-5), output_activation="sigmoid"):
return Sequential([
Flatten(name="discriminator_flatten", input_shape=input_shape),
Dense(hidden_dim, name="discriminator_h1", W_regularizer=reg()),
LeakyReLU(0.2),
Dense(int(hidden_dim / 2), name="discriminator_h2", W_regularizer=reg()),
LeakyReLU(0.2),
Dense(int(hidden_dim / 4), name="discriminator_h3", W_regularizer=reg()),
LeakyReLU(0.2),
Dense(1, name="discriminator_y", W_regularizer=reg()),
Activation(output_activation)],
name="discriminator")
def example_gan(adversarial_optimizer, path, opt_g, opt_d, nb_epoch, generator, discriminator, latent_dim,
targets=gan_targets, loss='binary_crossentropy'):
csvpath = os.path.join(path, "history.csv")
if os.path.exists(csvpath):
print("Already exists: {}".format(csvpath))
return
print("Training: {}".format(csvpath))
# gan (x - > yfake, yreal), z generated on GPU
gan = simple_gan(generator, discriminator, normal_latent_sampling((latent_dim,)))
# print summary of models
generator.summary()
discriminator.summary()
gan.summary()
# build adversarial model
model = AdversarialModel(base_model=gan,
player_params=[generator.trainable_weights, discriminator.trainable_weights],
player_names=["generator", "discriminator"])
model.adversarial_compile(adversarial_optimizer=adversarial_optimizer,
player_optimizers=[opt_g, opt_d],
loss=loss)
# create callback to generate images
zsamples = np.random.normal(size=(10 * 10, latent_dim))
def generator_sampler():
return generator.predict(zsamples).reshape((10, 10, 28, 28))
generator_cb = ImageGridCallback(os.path.join(path, "epoch-{:03d}.png"), generator_sampler)
# train model
xtrain, xtest = mnist_data()
y = targets(xtrain.shape[0])
ytest = targets(xtest.shape[0])
callbacks = [generator_cb]
if K.backend() == "tensorflow":
callbacks.append(
TensorBoard(log_dir=os.path.join(path, 'logs'), histogram_freq=0, write_graph=True, write_images=True))
history = fit(model, x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=callbacks, nb_epoch=nb_epoch,
batch_size=32)
# save history to CSV
df = pd.DataFrame(history.history)
df.to_csv(csvpath)
# save models
generator.save(os.path.join(path, "generator.h5"))
discriminator.save(os.path.join(path, "discriminator.h5"))
def main():
# z \in R^100
latent_dim = 100
# x \in R^{28x28}
input_shape = (28, 28)
# generator (z -> x)
generator = model_generator(latent_dim, input_shape)
# discriminator (x -> y)
discriminator = model_discriminator(input_shape)
example_gan(AdversarialOptimizerSimultaneous(), "output/gan",
opt_g=Adam(1e-4, decay=1e-4),
opt_d=Adam(1e-3, decay=1e-4),
nb_epoch=100, generator=generator, discriminator=discriminator,
latent_dim=latent_dim)
if __name__ == "__main__":
main()