Skip to content

Commit

Permalink
Fix distilbert serialization (#392)
Browse files Browse the repository at this point in the history
* Fix distilbert serialization

* fixup
  • Loading branch information
mattdangerw authored Oct 20, 2022
1 parent 0fee8c1 commit 1217543
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
29 changes: 16 additions & 13 deletions keras_nlp/models/distilbert/distilbert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def distilbert_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras.utils.register_keras_serializable(package="keras_nlp")
class DistilBertCustom(keras.Model):
"""DistilBERT encoder network with custom hyperparmeters.
Expand Down Expand Up @@ -157,19 +158,21 @@ def __init__(
self.max_sequence_length = max_sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"name": self.name,
"trainable": self.trainable,
}

@classmethod
def from_config(cls, config):
return cls(**config)


MODEL_DOCSTRING = """DistilBert "{type}" architecture.
Expand Down
11 changes: 9 additions & 2 deletions keras_nlp/models/distilbert/distilbert_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,18 @@ def test_distilbert_base_compile_batched_ds(self, jit_compile):
model.compile(jit_compile=jit_compile)
model.predict(self.input_dataset)

def test_saving_model(self):
@parameterized.named_parameters(
("save_format_tf", "tf"), ("save_format_h5", "h5")
)
def test_saving_model(self, save_format):
model_output = self.model(self.input_batch)
save_path = os.path.join(self.get_temp_dir(), "model")
self.model.save(save_path)
self.model.save(save_path, save_format)
restored_model = keras.models.load_model(save_path)

# Check we got the real object back.
self.assertIsInstance(restored_model, DistilBertCustom)

# Check that output matches.
restored_output = restored_model(self.input_batch)
self.assertAllClose(model_output, restored_output)
1 change: 1 addition & 0 deletions keras_nlp/models/distilbert/distilbert_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"""


@keras.utils.register_keras_serializable(package="keras_nlp")
class DistilBertPreprocessor(keras.layers.Layer):
def __init__(
self,
Expand Down

0 comments on commit 1217543

Please sign in to comment.