From 30b34d3bdec20b2eb36af8067b8c857c1bc433cd Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:19:23 -0700 Subject: [PATCH] Fix some testing on the latest version of keras (#1663) --- .../whisper_audio_feature_extractor.py | 10 ---- .../src/tokenizers/byte_tokenizer_test.py | 17 +++---- .../unicode_codepoint_tokenizer_test.py | 46 +++---------------- 3 files changed, 12 insertions(+), 61 deletions(-) diff --git a/keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py index aec7dd3c07..f524fd7d44 100644 --- a/keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py @@ -80,16 +80,6 @@ def __init__( max_audio_length=30, **kwargs, ): - # Check dtype and provide a default. - if "dtype" not in kwargs or kwargs["dtype"] is None: - kwargs["dtype"] = "float32" - else: - dtype = tf.dtypes.as_dtype(kwargs["dtype"]) - if not dtype.is_floating: - raise ValueError( - f"dtype must be a floating type. Received: dtype={dtype}" - ) - super().__init__(**kwargs) self._convert_input_args = False diff --git a/keras_nlp/src/tokenizers/byte_tokenizer_test.py b/keras_nlp/src/tokenizers/byte_tokenizer_test.py index b2444f36a0..0ae09688df 100644 --- a/keras_nlp/src/tokenizers/byte_tokenizer_test.py +++ b/keras_nlp/src/tokenizers/byte_tokenizer_test.py @@ -208,6 +208,7 @@ def test_load_model_with_config(self): ) def test_config(self): + input_data = ["hello", "fun", "▀▁▂▃", "haha"] tokenizer = ByteTokenizer( name="byte_tokenizer_config_test", lowercase=False, @@ -216,14 +217,8 @@ def test_config(self): errors="ignore", replacement_char=0, ) - exp_config = { - "dtype": "int32", - "errors": "ignore", - "lowercase": False, - "name": "byte_tokenizer_config_test", - "normalization_form": "NFC", - "replacement_char": 0, - "sequence_length": 8, - "trainable": True, - } - self.assertEqual(tokenizer.get_config(), exp_config) + cloned_tokenizer = ByteTokenizer.from_config(tokenizer.get_config()) + self.assertAllEqual( + tokenizer(input_data), + cloned_tokenizer(input_data), + ) diff --git a/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py b/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py index 773733bb5b..4f324da15a 100644 --- a/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py +++ b/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py @@ -263,6 +263,7 @@ def test_load_model_with_config(self): ) def test_config(self): + input_data = ["ninja", "samurai", "▀▁▂▃"] tokenizer = UnicodeCodepointTokenizer( name="unicode_character_tokenizer_config_gen", lowercase=False, @@ -272,45 +273,10 @@ def test_config(self): replacement_char=0, vocabulary_size=100, ) - exp_config = { - "dtype": "int32", - "errors": "ignore", - "lowercase": False, - "name": "unicode_character_tokenizer_config_gen", - "normalization_form": "NFC", - "replacement_char": 0, - "sequence_length": 8, - "input_encoding": "UTF-8", - "output_encoding": "UTF-8", - "trainable": True, - "vocabulary_size": 100, - } - self.assertEqual(tokenizer.get_config(), exp_config) - - tokenize_different_encoding = UnicodeCodepointTokenizer( - name="unicode_character_tokenizer_config_gen", - lowercase=False, - sequence_length=8, - errors="ignore", - replacement_char=0, - input_encoding="UTF-16", - output_encoding="UTF-16", - vocabulary_size=None, + cloned_tokenizer = UnicodeCodepointTokenizer.from_config( + tokenizer.get_config() ) - exp_config_different_encoding = { - "dtype": "int32", - "errors": "ignore", - "lowercase": False, - "name": "unicode_character_tokenizer_config_gen", - "normalization_form": None, - "replacement_char": 0, - "sequence_length": 8, - "input_encoding": "UTF-16", - "output_encoding": "UTF-16", - "trainable": True, - "vocabulary_size": None, - } - self.assertEqual( - tokenize_different_encoding.get_config(), - exp_config_different_encoding, + self.assertAllEqual( + tokenizer(input_data), + cloned_tokenizer(input_data), )