From 9d6bf45ab4975c354ac6cd01efdb162c9119434a Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 19 Oct 2022 21:15:03 +0000 Subject: [PATCH 1/4] initial commit --- keras_nlp/models/bert/bert_models.py | 2 ++ keras_nlp/models/bert/bert_models_test.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/bert/bert_models.py b/keras_nlp/models/bert/bert_models.py index 654f050c2b..49c8aaa373 100644 --- a/keras_nlp/models/bert/bert_models.py +++ b/keras_nlp/models/bert/bert_models.py @@ -98,6 +98,7 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=2, + name="encoder", **kwargs, ): @@ -183,6 +184,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + name=name, **kwargs, ) # All references to `self` below this line diff --git a/keras_nlp/models/bert/bert_models_test.py b/keras_nlp/models/bert/bert_models_test.py index 803d4b7be1..eca86df2a9 100644 --- a/keras_nlp/models/bert/bert_models_test.py +++ b/keras_nlp/models/bert/bert_models_test.py @@ -31,7 +31,6 @@ def setUp(self): hidden_dim=64, intermediate_dim=128, max_sequence_length=128, - name="encoder", ) self.batch_size = 8 self.input_batch = { @@ -52,6 +51,7 @@ def setUp(self): def test_valid_call_bert(self): self.model(self.input_batch) + # Check default name passed through self.assertEqual(self.model.name, "encoder") def test_variable_sequence_length_call_bert(self): @@ -95,7 +95,6 @@ def test_unknown_preset_error(self): Bert.from_preset( "bert_base_uncased_clowntown", load_weights=False, - name="encoder", ) def test_preset_mutability(self): From 68f3e8a83a00630732580248d0f38c508ef1afa3 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 20 Oct 2022 18:37:04 +0000 Subject: [PATCH 2/4] Address comments --- keras_nlp/models/bert/bert_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/bert/bert_models.py b/keras_nlp/models/bert/bert_models.py index 49c8aaa373..3853675f97 100644 --- a/keras_nlp/models/bert/bert_models.py +++ b/keras_nlp/models/bert/bert_models.py @@ -98,7 +98,7 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=2, - name="encoder", + name="backbone", **kwargs, ): From 99b3a63600c9e668fa92c8c41b77c255ebe936d7 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 20 Oct 2022 18:39:16 +0000 Subject: [PATCH 3/4] fix test --- keras_nlp/models/bert/bert_models_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/bert/bert_models_test.py b/keras_nlp/models/bert/bert_models_test.py index eca86df2a9..0ecb7418c8 100644 --- a/keras_nlp/models/bert/bert_models_test.py +++ b/keras_nlp/models/bert/bert_models_test.py @@ -52,7 +52,7 @@ def setUp(self): def test_valid_call_bert(self): self.model(self.input_batch) # Check default name passed through - self.assertEqual(self.model.name, "encoder") + self.assertEqual(self.model.name, "backbone") def test_variable_sequence_length_call_bert(self): for seq_length in (25, 50, 75): From 131940cb9ea9581d11696c628645b9d59daf6485 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 20 Oct 2022 22:44:13 +0000 Subject: [PATCH 4/4] Keep `name` out of arg list for consistency --- keras_nlp/models/bert/bert_models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/bert/bert_models.py b/keras_nlp/models/bert/bert_models.py index 3853675f97..3d9f820d05 100644 --- a/keras_nlp/models/bert/bert_models.py +++ b/keras_nlp/models/bert/bert_models.py @@ -82,7 +82,6 @@ class Bert(keras.Model): hidden_dim=768, intermediate_dim=3072, max_sequence_length=12, - name="encoder", ) output = model(input_data) ``` @@ -98,7 +97,6 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=2, - name="backbone", **kwargs, ): @@ -173,6 +171,10 @@ def __init__( name="pooled_dense", )(x[:, cls_token_index, :]) + # Set default for `name` if none given + if "name" not in kwargs: + kwargs["name"] = "backbone" + # Instantiate using Functional API Model constructor super().__init__( inputs={ @@ -184,7 +186,6 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, - name=name, **kwargs, ) # All references to `self` below this line