-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvanilla_models_loading_training.py
106 lines (90 loc) · 3.48 KB
/
vanilla_models_loading_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Libraries
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow as tf
import lib.GAN as GAN
import lib.VAE as VAE
# VAE Loading
model_name = 'VAE64D1'
model_filepath = os.path.join(os.getcwd(), 'models', 'VAEs', model_name + '_model', model_name)
vae = VAE.load(model_filepath)
# VAE Generation Test
num_img = 7
random_latent_vectors = tf.random.normal(shape=(num_img, vae.latent_dim))
generated_images = vae.decoder(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
plt.figure(figsize=(5*num_img, 5))
print("VAE Generation")
for i in range(num_img):
img = keras.preprocessing.image.array_to_img(generated_images[i])
plt.subplot(1, num_img, i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(img)
plt.show()
# GAN Loading
model_name = 'GAN64D1'
model_filepath_fwd = os.path.join(os.getcwd(), 'models', 'GANs', model_name + '_model', 'forward', model_name)
model_filepath_bck = os.path.join(os.getcwd(), 'models', 'GANs', model_name + '_model', 'backward', model_name)
gan = GAN.load(model_filepath_fwd)
igan = GAN.load_inverse(model_filepath_bck, gan)
# GAN Generation Test
num_img = 7
random_latent_vectors = tf.random.normal(shape=(num_img, gan.latent_dim))
generated_images = gan.generator(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
plt.figure(figsize=(5*num_img, 5))
print("GAN Generation")
for i in range(num_img):
img = keras.preprocessing.image.array_to_img(generated_images[i])
plt.subplot(1, num_img, i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(img)
plt.show()
# VAE Training Example
model_name = 'VAE64D6'
input_shape = (64, 64, 3)
latent_dim = 64
parameters = {}
train_data = np.random.randint(255, size=(5, 64, 64, 3)).astype('float') #LOAD DATASET HERE (batches of images)
val_data = None # FID-based validation. Skipped if this is None
encoder, decoder = VAE.create(input_shape, latent_dim, **parameters)
info = {
'dataset': "CelebA_align",
'name': model_name
}
model_filepath = os.path.join(os.getcwd(), 'models', 'VAEs', model_name + '_model', model_name)
VAE.train(input_shape, latent_dim, train_data, model_filepath, encoder, decoder, info=info, parameters=parameters,
val_data=val_data, fid_samples=1000, epochs=1, steps_per_epoch=1)
# GAN Training Example
model_name = 'GAN64D6'
input_shape = (64, 64, 3)
latent_dim = 64
parameters = {
'recoder_args': {
'base_filters_n': 128,
'filters_multiplier': 2,
'n_layers': 3,
'stride': 2,
'kernel_size': 4
}
}
train_data = np.random.randint(255, size=(5, 64, 64, 3)).astype('float') # LOAD DATASET HERE (batches of images)
val_data = None # FID-based validation. Skipped if this is None
discriminator, generator, recoder = GAN.create(input_shape, latent_dim, **parameters)
info = {
'dataset': "CelebA_align",
'name': model_name
}
model_filepath_fwd = os.path.join(os.getcwd(), 'models', 'GANs', model_name + '_model', 'forward', model_name)
model_filepath_bck = os.path.join(os.getcwd(), 'models', 'GANs', model_name + '_model', 'backward', model_name)
GAN.train(input_shape, latent_dim, train_data, model_filepath_fwd, discriminator, generator, info=info, parameters=parameters,
val_data=val_data, fid_samples=1000, epochs=1, steps_per_epoch=1)
gan = GAN.load(model_filepath_fwd)
GAN.train_inverse(input_shape, latent_dim, model_filepath_bck, recoder, gan.generator, info=info,
parameters=parameters, epochs=1, steps_per_epoch=1)