-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel_def.py
60 lines (45 loc) · 1.93 KB
/
model_def.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
This example demonstrates how to train a GAN with Determined's TF Keras API.
The Determined TF Keras API support using a subclassed `tf.keras.Model` which
defines a custom `train_step()` and `test_step()`.
"""
import tensorflow as tf
from data import get_train_dataset, get_validation_dataset
from dc_gan import DCGan
from determined.keras import InputData, TFKerasTrial, TFKerasTrialContext
class DCGanTrial(TFKerasTrial):
def __init__(self, context: TFKerasTrialContext) -> None:
self.context = context
def build_model(self) -> tf.keras.models.Model:
model = DCGan(
batch_size=self.context.get_per_slot_batch_size(),
noise_dim=self.context.get_hparam("noise_dim"),
)
# Wrap the model.
model = self.context.wrap_model(model)
# Create and wrap the optimizers.
g_optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=self.context.get_hparam("generator_lr")
)
g_optimizer = self.context.wrap_optimizer(g_optimizer)
d_optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=self.context.get_hparam("discriminator_lr")
)
d_optimizer = self.context.wrap_optimizer(d_optimizer)
model.compile(
discriminator_optimizer=d_optimizer,
generator_optimizer=g_optimizer,
)
return model
def build_training_data_loader(self) -> InputData:
ds = get_train_dataset(self.context.distributed.get_rank())
# Wrap the training dataset.
ds = self.context.wrap_dataset(ds)
ds = ds.batch(self.context.get_per_slot_batch_size())
return ds
def build_validation_data_loader(self) -> InputData:
ds = get_validation_dataset(self.context.distributed.get_rank())
# Wrap the validation dataset.
ds = self.context.wrap_dataset(ds)
ds = ds.batch(self.context.get_per_slot_batch_size())
return ds