Skip to content

Commit

Permalink
Fix gpt2 serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Oct 18, 2022
1 parent fa65faa commit 9b3fe87
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
27 changes: 14 additions & 13 deletions keras_nlp/models/gpt2/gpt2_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _gpt_2_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


@keras.utils.register_keras_serializable(package="keras_nlp")
class Gpt2Custom(keras.Model):
"""GPT-2 core network with customizable hyperparameters.
Expand Down Expand Up @@ -164,19 +165,19 @@ def __init__(
self.max_sequence_length = max_sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}

@classmethod
def from_config(cls, config):
return cls(**config)


MODEL_DOCSTRING = """GPT-2 "{type}" architecture.
Expand Down
11 changes: 9 additions & 2 deletions keras_nlp/models/gpt2/gpt2_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,18 @@ def test_gpt2_base_compile_batched_ds(self, jit_compile):
model.compile(jit_compile=jit_compile)
model.predict(self.input_dataset)

def test_saving_model(self):
@parameterized.named_parameters(
("save_format_tf", "tf"), ("save_format_h5", "h5")
)
def test_saving_model(self, save_format):
model_output = self.model(self.input_batch)
save_path = os.path.join(self.get_temp_dir(), "model")
self.model.save(save_path)
self.model.save(save_path, save_format)
restored_model = keras.models.load_model(save_path)

# Check we got the real object back.
self.assertIsInstance(restored_model, Gpt2Custom)

# Check that output matches.
restored_output = restored_model(self.input_batch)
self.assertAllClose(model_output, restored_output)

0 comments on commit 9b3fe87

Please sign in to comment.