Skip to content

Commit

Permalink
Add a test harness based on keras-core's run_layer_test (#1238)
Browse files Browse the repository at this point in the history
* Add a test harness based on keras-core's `run_layer_test`

Eventually, I want to add some dtype changes similar to ->
keras-team/keras-core#805
But the nice for that PR on keras-core was I could add dtype test to a
common harness and test all layers.

So I think it's finally time to bite the bullet and add these for
keras-nlp. This ports and simplifies the `run_layer_test` code from
keras-core and applies it to our modeling layers.

I am ditching the saved model tests for our individual layers, with
the idea being that saved model tests are slow, and we get fairly
robust serialization tests now without saved model. If this is good
enough for keras-core layers, I think we can follow suit here. We still
test saving end to end through our modeling tests.

* Fix tests

* Address comments
  • Loading branch information
mattdangerw authored Sep 11, 2023
1 parent 60af93f commit 8fea705
Show file tree
Hide file tree
Showing 18 changed files with 381 additions and 695 deletions.
5 changes: 4 additions & 1 deletion keras_nlp/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,7 @@ def call(
self._combine_equation, attention_scores, value
)
attention_output = self._output_dense(attention_output)
return attention_output, cache

if cache is not None:
return attention_output, cache
return attention_output
19 changes: 15 additions & 4 deletions keras_nlp/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,21 @@


class CachedMultiHeadAttentionTest(TestCase):
def test_valid_call(self):
layer = CachedMultiHeadAttention(num_heads=2, key_dim=4)
x = ops.random.uniform(shape=(2, 2, 8))
layer(query=x, value=x)
def test_layer_behaviors(self):
self.run_layer_test(
layer_cls=CachedMultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 4,
},
input_data={
"query": ops.random.uniform(shape=(2, 4, 6)),
"value": ops.random.uniform(shape=(2, 4, 6)),
},
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
expected_num_non_trainable_variables=1,
)

def test_cache_call_is_correct(self):
batch_size = 2
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/layers/modeling/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def build(self, inputs_shape):
self._intermediate_dense.compute_output_shape(inputs_shape)
)
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
self.built = True

def call(self, inputs):
"""Forward pass of the FNetEncoder.
Expand Down
94 changes: 17 additions & 77 deletions keras_nlp/layers/modeling/f_net_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,93 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling import f_net_encoder
from keras_nlp.layers.modeling.f_net_encoder import FNetEncoder
from keras_nlp.tests.test_case import TestCase


class FNetEncoderTest(TestCase):
def test_valid_call(self):
encoder = f_net_encoder.FNetEncoder(intermediate_dim=4)
model = keras.Sequential(
[
keras.Input(shape=(4, 6)),
encoder,
]
)
input = ops.random.uniform(shape=[2, 4, 6])
model(input)

def test_get_config_and_from_config(self):
encoder = f_net_encoder.FNetEncoder(
intermediate_dim=4,
kernel_initializer="HeNormal",
bias_initializer="Zeros",
)
config = encoder.get_config()
expected_config_subset = {
"intermediate_dim": 4,
"dropout": 0,
"activation": "relu",
"layer_norm_epsilon": 1e-5,
"kernel_initializer": keras.initializers.serialize(
keras.initializers.HeNormal()
),
"bias_initializer": keras.initializers.serialize(
keras.initializers.Zeros()
),
}
self.assertEqual(config, {**config, **expected_config_subset})

restored_encoder = f_net_encoder.FNetEncoder.from_config(
config,
)
self.assertEqual(
restored_encoder.get_config(), {**config, **expected_config_subset}
def test_layer_behaviors(self):
self.run_layer_test(
layer_cls=FNetEncoder,
init_kwargs={
"intermediate_dim": 4,
"dropout": 0,
"activation": "relu",
"layer_norm_epsilon": 1e-5,
"kernel_initializer": "HeNormal",
"bias_initializer": "Zeros",
},
input_data=ops.random.uniform(shape=(2, 4, 6)),
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
expected_num_non_trainable_variables=1,
)

def test_value_error_when_invalid_kernel_initializer(self):
with self.assertRaises(ValueError):
f_net_encoder.FNetEncoder(
FNetEncoder(
intermediate_dim=4,
dropout=0.5,
kernel_initializer="Invalid",
)

def test_one_training_step_of_f_net_encoder(self):
encoder = f_net_encoder.FNetEncoder(intermediate_dim=4)
inputs = keras.Input(shape=(4, 6))
x = encoder(inputs)
x = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=x)

data = ops.random.uniform(shape=[2, 4, 6])
label = ops.random.randint(minval=0, maxval=2, shape=(2, 4, 1))

loss = keras.losses.BinaryCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam()
model.compile(loss=loss, optimizer=optimizer)
loss = model.train_on_batch(x=data, y=label)
self.assertGreater(loss, 0)

def test_saved_model(self):
model = keras.Sequential(
[
keras.Input(shape=(4, 6)),
f_net_encoder.FNetEncoder(
intermediate_dim=4,
),
]
)
data = ops.random.uniform(shape=[2, 4, 6])
model(data)
path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(path, save_format="keras_v3")
loaded_model = keras.models.load_model(path)

model_output = model(data)
loaded_model_output = loaded_model(data)
self.assertAllClose(model_output, loaded_model_output)
7 changes: 4 additions & 3 deletions keras_nlp/layers/modeling/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,11 @@ def build(self, inputs_shape, mask_positions_shape=None):
initializer=self.bias_initializer,
dtype=self.dtype,
)
self.built = True

def call(self, inputs, mask_positions):
# Avoid auto-converting numpy int arrays to float tensors.
mask_positions = ops.convert_to_tensor(mask_positions, dtype="int")
# Gather the encoded tokens at the masked indices.
mask_positions = ops.expand_dims(mask_positions, axis=-1)
x = ops.take_along_axis(inputs, mask_positions, axis=1)
Expand Down Expand Up @@ -222,6 +225,4 @@ def get_config(self):
return config

def compute_output_shape(self, inputs_shape, mask_positions_shape):
output_shape = list(mask_positions_shape)
output_shape[-1] = self.vocabulary_size
return tuple(output_shape)
return mask_positions_shape + (self.vocabulary_size,)
136 changes: 35 additions & 101 deletions keras_nlp/layers/modeling/masked_lm_head_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,80 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.tests.test_case import TestCase


class MaskedLMHeadTest(TestCase):
def test_valid_call(self):
head = MaskedLMHead(
vocabulary_size=100,
activation="softmax",
def test_layer_behaviors(self):
self.run_layer_test(
layer_cls=MaskedLMHead,
init_kwargs={
"vocabulary_size": 100,
"activation": "softmax",
"kernel_initializer": "HeNormal",
"bias_initializer": "Zeros",
},
input_data={
"inputs": ops.random.uniform(shape=(4, 10, 16)),
"mask_positions": ops.random.randint(
minval=0, maxval=10, shape=(4, 5)
),
},
expected_output_shape=(4, 5, 100),
expected_num_trainable_weights=6,
)
encoded_tokens = keras.Input(shape=(10, 16))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)

token_data = ops.random.uniform(shape=(4, 10, 16))
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5))
model((token_data, position_data))

def test_valid_call_with_token_embedding(self):
def test_layer_behaviors_with_embedding(self):
embedding = ReversibleEmbedding(100, 16)
embedding.build((4, 10))
head = MaskedLMHead(
vocabulary_size=100,
token_embedding=embedding,
activation="softmax",
)
# Use a difference "hidden dim" for the model than "embedding dim", we
# need to support this in the layer.
sequence = keras.Input(shape=(10, 32))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(sequence, mask_positions=positions)
model = keras.Model((sequence, positions), outputs)
sequence_data = ops.random.uniform(shape=(4, 10, 32))
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5))
model((sequence_data, position_data))

def test_get_config_and_from_config(self):
head = MaskedLMHead(
vocabulary_size=100,
kernel_initializer="HeNormal",
bias_initializer="Zeros",
activation="softmax",
)

config = head.get_config()

expected_params = {
"vocabulary_size": 100,
"kernel_initializer": keras.initializers.serialize(
keras.initializers.HeNormal()
),
"bias_initializer": keras.initializers.serialize(
keras.initializers.Zeros()
),
"activation": keras.activations.serialize(
keras.activations.softmax
),
}

self.assertEqual(config, {**config, **expected_params})

restored = MaskedLMHead.from_config(config)
restored_config = restored.get_config()

self.assertEqual(
restored_config, {**restored_config, **expected_params}
self.run_layer_test(
layer_cls=MaskedLMHead,
init_kwargs={
"vocabulary_size": 100,
"activation": "softmax",
"kernel_initializer": "HeNormal",
"bias_initializer": "Zeros",
"token_embedding": embedding,
},
input_data={
"inputs": ops.random.uniform(shape=(4, 10, 16)),
"mask_positions": ops.random.randint(
minval=0, maxval=10, shape=(4, 5)
),
},
expected_output_shape=(4, 5, 100),
expected_num_trainable_weights=6,
)
self.assertEqual(restored_config, config)

def test_value_error_when_neither_embedding_or_vocab_size_set(self):
with self.assertRaises(ValueError):
Expand All @@ -99,42 +72,3 @@ def test_value_error_when_vocab_size_mismatch(self):
vocabulary_size=101,
token_embedding=embedding,
)

def test_one_train_step(self):
head = MaskedLMHead(
vocabulary_size=100,
)
encoded_tokens = keras.Input(shape=(10, 16))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)

token_data = ops.random.uniform(shape=(4, 10, 16))
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5))
label_data = ops.random.randint(minval=0, maxval=2, shape=(4, 5, 1))

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam()
model.compile(loss=loss, optimizer=optimizer)
loss = model.train_on_batch(x=(token_data, position_data), y=label_data)
self.assertGreater(loss, 0)

def test_saved_model(self):
head = MaskedLMHead(
vocabulary_size=100,
activation="softmax",
)
encoded_tokens = keras.Input(shape=(10, 16))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)

token_data = ops.random.uniform(shape=(4, 10, 16))
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5))
model_output = model((token_data, position_data))
path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)

restored_output = restored_model((token_data, position_data))
self.assertAllClose(model_output, restored_output)
7 changes: 3 additions & 4 deletions keras_nlp/layers/modeling/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,15 @@ def get_config(self):
)
return config

def build(self, input_shape):
feature_size = input_shape[-1]
def build(self, inputs_shape):
feature_size = inputs_shape[-1]
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.sequence_length, feature_size],
initializer=self.initializer,
trainable=True,
)

super().build(input_shape)
self.built = True

def call(self, inputs, start_index=0):
shape = ops.shape(inputs)
Expand Down
Loading

0 comments on commit 8fea705

Please sign in to comment.