From a006232590f8040ce0c62620b829a5d777f5bf4f Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Sat, 16 Apr 2022 07:19:52 +0530 Subject: [PATCH] Adding a UnicodeCharacterTokenizer (#100) * Debugging * Debugging * Fixed Sequence Length Issue * Sequence Length Changes * Removed _ From Class Attributes * Fixed Null Bytes in Detokenization * Testing regex_replace * Testing * Helper Function and Debug Statements * Testing Regex Replace New Ordering * Added Checks for Errors and Normalization Form * Doc String Completed * Ran lint/format * New Tests and Decoding Changes * Changes * Minor Tweak * Tweaking Detokenizer * Added Tests and Updated Docstrings * Ran format.sh and lint.sh * Refactoring and Removing Unused Lines * Fixed Some Broken Tests * Fixed All Tests * Testing Decode * Testing * Debug * Fixes + Replaced Regex with BooleanMask * Added Debug Lines * Added Debug Line for .numpy() * Testing Byte Tokenizer Approach * Testing With Unicode_transcode * Listing Methods of Object * Testing _numpy * Added Decode Call * Checking Methods post _numpy() * Removed Debug Statements and Improved Docstring * Fixed Failing Test * Ran format/lint * Fixed Docstring and Improved Examples * Ran format and lint * Copy edits * Copy edits Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com> --- keras_nlp/tokenizers/__init__.py | 3 + keras_nlp/tokenizers/byte_tokenizer.py | 4 +- .../tokenizers/unicode_character_tokenizer.py | 272 +++++++++++++++ .../unicode_character_tokenizer_test.py | 317 ++++++++++++++++++ 4 files changed, 594 insertions(+), 2 deletions(-) create mode 100644 keras_nlp/tokenizers/unicode_character_tokenizer.py create mode 100644 keras_nlp/tokenizers/unicode_character_tokenizer_test.py diff --git a/keras_nlp/tokenizers/__init__.py b/keras_nlp/tokenizers/__init__.py index 7206410a14..36c530ee70 100644 --- a/keras_nlp/tokenizers/__init__.py +++ b/keras_nlp/tokenizers/__init__.py @@ -14,4 +14,7 @@ from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer from keras_nlp.tokenizers.tokenizer import Tokenizer +from keras_nlp.tokenizers.unicode_character_tokenizer import ( + UnicodeCharacterTokenizer, +) from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py index 033699d31e..1a106cd91c 100644 --- a/keras_nlp/tokenizers/byte_tokenizer.py +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -97,7 +97,7 @@ class ByteTokenizer(tokenizer.Tokenizer): >>> ds.take(1).get_single_element() - Batch up the inputs and then tokenize. + Batch the inputs and then tokenize. >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() >>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"]) >>> ds = ds.batch(2).map(tokenizer) @@ -114,7 +114,7 @@ class ByteTokenizer(tokenizer.Tokenizer): array([[104, 101, 108, 108, 111], [102, 117, 110, 0, 0]], dtype=int32)> - Batch up the inputs and then tokenize (`sequence_length` provided). + Batch the inputs and then tokenize (`sequence_length` provided). >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer(sequence_length=5) >>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"]) >>> ds = ds.batch(2).map(tokenizer) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py new file mode 100644 index 0000000000..ade085e399 --- /dev/null +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -0,0 +1,272 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict + +import tensorflow as tf +import tensorflow_text as tf_text + +from keras_nlp.tokenizers import tokenizer + + +class UnicodeCharacterTokenizer(tokenizer.Tokenizer): + """A unicode character tokenizer layer. + + This tokenizer is a vocabulary free tokenizer which tokenizes text as + unicode characters codepoints. + + Args: + lowercase: If true, the input text will be first lowered before + tokenization. + sequence_length: If set, the output will be converted to a dense + tensor and padded/trimmed so all outputs are of sequence_length. + normalization_form: One of the following string values (None, 'NFC', + 'NFKC', 'NFD', 'NFKD'). If set will normalize unicode to the given + form before tokenizing. + errors: One of ('replace', 'remove', 'strict'). Specifies the + `detokenize()` behavior when an invalid codepoint is encountered. + (same behavior as + https://www.tensorflow.org/api_docs/python/tf/strings/unicode_transcode) + replacement_char: The unicode codepoint to use in place of invalid + codepoints. Defaults to 65533 (U+FFFD). + input_encoding: One of ("UTF-8", "UTF-16-BE", or "UTF-32-BE"). + One of The encoding of the input text. Defaults to "UTF-8". + output_encoding: One of ("UTF-8", "UTF-16-BE", or "UTF-32-BE"). + The encoding of the output text. Defaults to "UTF-8". + + Examples: + + Basic Usage. + >>> inputs = "Unicode Tokenizer" + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() + >>> tokenizer(inputs) + + + Ragged outputs. + >>> inputs = ["Book", "पुस्तक", "کتاب"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() + >>> tokenizer(inputs) + + + Dense outputs. + >>> inputs = ["Book", "पुस्तक", "کتاب"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... sequence_length=8) + >>> tokenizer(inputs) + + + Tokenize first, then batch the dataset. + >>> inputs = ["Book", "पुस्तक", "کتاب"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() + >>> ds = tf.data.Dataset.from_tensor_slices(inputs) + >>> ds = ds.map(tokenizer) + >>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(3)) + >>> ds.take(1).get_single_element() + + + Batch the inputs and then tokenize. + >>> inputs = ["Book", "पुस्तक", "کتاب"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() + >>> ds = tf.data.Dataset.from_tensor_slices(inputs) + >>> ds = ds.batch(3).map(tokenizer) + >>> ds.take(1).get_single_element() + + + Tokenize first, then batch for dense outputs (`sequence_length` provided). + >>> inputs = ["Book", "पुस्तक", "کتاب"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... sequence_length=5) + >>> ds = tf.data.Dataset.from_tensor_slices(inputs) + >>> ds = ds.map(tokenizer) + >>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(3)) + >>> ds.take(1).get_single_element() + + + Batch first, then tokenize for dense outputs (`sequence_length` provided). + (`sequence_length` provided). + >>> inputs = ["Book", "पुस्तक", "کتاب"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... sequence_length=5) + >>> ds = tf.data.Dataset.from_tensor_slices(inputs) + >>> ds = ds.batch(3).map(tokenizer) + >>> ds.take(1).get_single_element() + + + Tokenization showcasing truncation of long sequences. + >>> inputs = ["I Like to Travel a Lot", "मैं किताबें पढ़ना पसंद करता हूं"] + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... sequence_length=5) + >>> tokenizer(inputs) + + + Detokenization. + >>> inputs = tf.constant([110, 105, 110, 106, 97], dtype=tf.int32) + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() + >>> tokenizer.detokenize(inputs) + + + Detokenization while showcasing padded characters being removed + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... sequence_length=7) + >>> dataset = tf.data.Dataset.from_tensor_slices(["a b c", "b c", "a"]) + >>> dataset = dataset.map(tokenizer) + >>> dataset.take(1).get_single_element() + + >>> detokunbatched = dataset.map(tokenizer.detokenize) + >>> detokunbatched = dataset.map(tokenizer.detokenize) + >>> detokunbatched.take(1).get_single_element() + + + Detokenization with invalid bytes. + >>> # The 10000000 in the inputs tensor below is an invalid value + >>> # Hence it replaces to the replacement_char 75 which represents 'K' + >>> inputs = tf.constant([110, 105, 10000000, 110, 106, 97]) + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... errors="replace", replacement_char=75) + >>> tokenizer.detokenize(inputs).numpy().decode('utf-8') + 'niKnja' + """ + + def __init__( + self, + sequence_length: int = None, + lowercase: bool = True, + normalization_form: str = None, + errors: str = "replace", + replacement_char: int = 65533, + input_encoding: str = "UTF-8", + output_encoding: str = "UTF-8", + **kwargs, + ) -> None: + # Check dtype and provide a default. + if "dtype" not in kwargs or kwargs["dtype"] is None: + kwargs["dtype"] = tf.int32 + else: + dtype = tf.dtypes.as_dtype(kwargs["dtype"]) + if not dtype.is_integer and dtype != tf.string: + raise ValueError( + "Output dtype must be an integer type of a string. " + f"Received: dtype={dtype}" + ) + + # Check normalization_form. + if normalization_form not in [None, "NFC", "NFKC", "NFD", "NFKD"]: + raise ValueError( + '`normalization_form` must be one of None, "NFC", "NFKC", ' + '"NFD", "NFKD". Received: normalization_form=' + f"{normalization_form}" + ) + + # Check errors. + if errors not in ["strict", "replace", "ignore"]: + raise ValueError( + '`errors` must be one of "strict", "replace", "ignore" ' + f"Received: errors={errors}" + ) + + # Check normalization_form matches input_encoding. + if normalization_form: + if input_encoding != "UTF-8": + raise ValueError( + """Normalization Forms are Only Supported for Input Encoding + UTF-8""" + ) + + super().__init__(**kwargs) + + self.sequence_length = sequence_length + self.lowercase = lowercase + self.normalization_form = normalization_form + self.errors = errors + self.replacement_char = replacement_char + self.input_encoding = input_encoding + self.output_encoding = output_encoding + + def get_config(self) -> Dict[str, Any]: + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "lowercase": self.lowercase, + "normalization_form": self.normalization_form, + "errors": self.errors, + "replacement_char": self.replacement_char, + "input_encoding": self.input_encoding, + "output_encoding": self.output_encoding, + } + ) + return config + + def tokenize(self, inputs): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + # Optionally lowercase the text + if self.lowercase: + inputs = tf_text.case_fold_utf8(inputs) + + # Optionally normalize the text to a given form + if self.normalization_form: + inputs = tf_text.normalize_utf8(inputs, self.normalization_form) + + tokens = tf.strings.unicode_decode( + inputs, + errors=self.errors, + replacement_char=self.replacement_char, + input_encoding=self.input_encoding, + ) + + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + if scalar_input: + tokens = tf.squeeze(tokens, 0) + return tokens + + def detokenize(self, inputs): + inputs = tf.ragged.boolean_mask(inputs, tf.not_equal(inputs, 0)) + encoded_string = tf.strings.unicode_encode( + inputs, + errors=self.errors, + replacement_char=self.replacement_char, + output_encoding=self.output_encoding, + ) + return encoded_string diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py new file mode 100644 index 0000000000..34df6a5094 --- /dev/null +++ b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py @@ -0,0 +1,317 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.tokenizers.unicode_character_tokenizer import ( + UnicodeCharacterTokenizer, +) + + +class UnicodeCharacterTokenizerTest(tf.test.TestCase): + def test_tokenize(self): + input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) + tokenizer = UnicodeCharacterTokenizer() + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + self.assertIsInstance(call_output, tf.RaggedTensor) + exp_outputs = [ + [110, 105, 110, 106, 97], + [115, 97, 109, 117, 114, 97, 105], + [9600, 9601, 9602, 9603], + ] + for i in range(call_output.shape[0]): + self.assertAllEqual(call_output[i], exp_outputs[i]) + self.assertAllEqual(tokenize_output[i], exp_outputs[i]) + + def test_tokenize_scalar(self): + input_data = "ninja" + tokenizer = UnicodeCharacterTokenizer() + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + + self.assertAllEqual(call_output, [110, 105, 110, 106, 97]) + self.assertAllEqual(tokenize_output, [110, 105, 110, 106, 97]) + + def test_dense_output(self): + input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) + tokenizer = UnicodeCharacterTokenizer(sequence_length=10) + call_output = tokenizer(input_data) + self.assertIsInstance(call_output, tf.Tensor) + self.assertAllEqual( + call_output, + [ + [110, 105, 110, 106, 97, 0, 0, 0, 0, 0], + [115, 97, 109, 117, 114, 97, 105, 0, 0, 0], + [9600, 9601, 9602, 9603, 0, 0, 0, 0, 0, 0], + ], + ) + + def test_detokenize(self): + input_data = tf.ragged.constant( + [ + [110, 105, 110, 106, 97], + [115, 97, 109, 117, 114, 97, 105], + [9600, 9601, 9602, 9603], + ] + ) + + tokenizer = UnicodeCharacterTokenizer() + detokenize_output = tokenizer.detokenize(input_data) + self.assertAllEqual( + detokenize_output, + [ + b"ninja", + b"samurai", + b"\xe2\x96\x80\xe2\x96\x81\xe2\x96\x82\xe2\x96\x83", + ], + ) + + def test_detokenize_replace_error(self): + # 10000000 is an invalid value + input_data = tf.ragged.constant([[110, 105, 10000000, 110, 106, 97]]) + tokenizer = UnicodeCharacterTokenizer( + errors="replace", replacement_char=75 + ) + detokenize_output = tokenizer.detokenize(input_data) + self.assertAllEqual(detokenize_output, [b"niKnja"]) + + def test_detokenize_ignore_error(self): + input_data = tf.ragged.constant([[110, 105, 10000000, 110, 106, 97]]) + tokenizer = UnicodeCharacterTokenizer(errors="ignore") + detokenize_output = tokenizer.detokenize(input_data) + self.assertAllEqual(detokenize_output, [b"ninja"]) + + def test_detokenize_strict_error(self): + input_data = tf.ragged.constant([[110, 105, 10000000, 110, 106, 97]]) + tokenizer = UnicodeCharacterTokenizer(errors="strict") + with self.assertRaises(tf.errors.InvalidArgumentError): + _ = tokenizer.detokenize(input_data) + + def test_normalization_without_UTF8_valueerror(self): + with self.assertRaises(ValueError): + _ = UnicodeCharacterTokenizer( + errors="strict", + input_encoding="UTF-16", + normalization_form="NFC", + ) + + def test_lowercase(self): + input_data = tf.constant(["NiNJaS"]) + tokenizer = UnicodeCharacterTokenizer() + call_output = tokenizer(input_data) + self.assertAllEqual( + call_output, + [[110, 105, 110, 106, 97, 115]], + ) + + def test_skip_lowercase(self): + input_data = tf.constant(["NiNJaS"]) + tokenizer = UnicodeCharacterTokenizer(lowercase=False) + call_output = tokenizer(input_data) + self.assertAllEqual( + call_output, + [[78, 105, 78, 74, 97, 83]], + ) + + def test_tokenize_first_batch_second(self): + tokenizer = UnicodeCharacterTokenizer() + + ds = tf.data.Dataset.from_tensor_slices( + ["ninja", "samurai", "▀▁▂▃", "keras", "tensorflow"] + ) + ds = ds.map(tokenizer) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(5)) + output = ds.take(1).get_single_element() + + exp_output = [ + [110, 105, 110, 106, 97], + [115, 97, 109, 117, 114, 97, 105], + [9600, 9601, 9602, 9603], + [107, 101, 114, 97, 115], + [116, 101, 110, 115, 111, 114, 102, 108, 111, 119], + ] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_tokenize_first_batch_second_with_sequence_length(self): + tokenizer = UnicodeCharacterTokenizer(sequence_length=10) + + ds = tf.data.Dataset.from_tensor_slices( + ["ninja", "samurai", "▀▁▂▃", "keras", "tensorflow"] + ) + ds = ds.map(tokenizer) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(5)) + output = ds.take(1).get_single_element() + + exp_output = [ + [110, 105, 110, 106, 97, 0, 0, 0, 0, 0], + [115, 97, 109, 117, 114, 97, 105, 0, 0, 0], + [9600, 9601, 9602, 9603, 0, 0, 0, 0, 0, 0], + [107, 101, 114, 97, 115, 0, 0, 0, 0, 0], + [116, 101, 110, 115, 111, 114, 102, 108, 111, 119], + ] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_batch_first_tokenize_second(self): + tokenizer = UnicodeCharacterTokenizer() + + ds = tf.data.Dataset.from_tensor_slices( + ["ninja", "samurai", "▀▁▂▃", "keras", "tensorflow"] + ) + ds = ds.batch(5).map(tokenizer) + output = ds.take(1).get_single_element() + + exp_output = [ + [110, 105, 110, 106, 97], + [115, 97, 109, 117, 114, 97, 105], + [9600, 9601, 9602, 9603], + [107, 101, 114, 97, 115], + [116, 101, 110, 115, 111, 114, 102, 108, 111, 119], + ] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_batch_first_tokenize_second_with_sequence_length(self): + tokenizer = UnicodeCharacterTokenizer(sequence_length=10) + + ds = tf.data.Dataset.from_tensor_slices( + ["ninja", "samurai", "▀▁▂▃", "keras", "tensorflow"] + ) + ds = ds.batch(5).map(tokenizer) + output = ds.take(1).get_single_element() + + exp_output = [ + [110, 105, 110, 106, 97, 0, 0, 0, 0, 0], + [115, 97, 109, 117, 114, 97, 105, 0, 0, 0], + [9600, 9601, 9602, 9603, 0, 0, 0, 0, 0, 0], + [107, 101, 114, 97, 115, 0, 0, 0, 0, 0], + [116, 101, 110, 115, 111, 114, 102, 108, 111, 119], + ] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_functional_model(self): + input_data = tf.constant( + ["ninja", "samurai", "▀▁▂▃", "keras", "tensorflow"] + ) + tokenizer = UnicodeCharacterTokenizer() + inputs = tf.keras.Input(dtype="string", shape=()) + outputs = tokenizer.detokenize(tokenizer.tokenize(inputs)) + model = tf.keras.Model(inputs, outputs) + model_output = model(input_data) + self.assertAllEqual( + model_output, + [ + b"ninja", + b"samurai", + b"\xe2\x96\x80\xe2\x96\x81\xe2\x96\x82\xe2\x96\x83", + b"keras", + b"tensorflow", + ], + ) + + def test_load_model_with_config(self): + input_data = tf.constant(["hello"]) + + original_tokenizer = UnicodeCharacterTokenizer( + lowercase=False, + sequence_length=11, + normalization_form="NFC", + errors="strict", + ) + cloned_tokenizer = UnicodeCharacterTokenizer.from_config( + original_tokenizer.get_config() + ) + self.assertAllEqual( + original_tokenizer(input_data), + cloned_tokenizer(input_data), + ) + + decoded_input = [107, 101, 114, 97, 115] + self.assertAllEqual( + original_tokenizer.detokenize(decoded_input), + cloned_tokenizer.detokenize(decoded_input), + ) + + def test_config(self): + tokenizer = UnicodeCharacterTokenizer( + name="unicode_character_tokenizer_config_gen", + lowercase=False, + sequence_length=8, + normalization_form="NFC", + errors="ignore", + replacement_char=0, + ) + exp_config = { + "dtype": "int32", + "errors": "ignore", + "lowercase": False, + "name": "unicode_character_tokenizer_config_gen", + "normalization_form": "NFC", + "replacement_char": 0, + "sequence_length": 8, + "input_encoding": "UTF-8", + "output_encoding": "UTF-8", + "trainable": True, + } + self.assertEqual(tokenizer.get_config(), exp_config) + + tokenize_different_encoding = UnicodeCharacterTokenizer( + name="unicode_character_tokenizer_config_gen", + lowercase=False, + sequence_length=8, + errors="ignore", + replacement_char=0, + input_encoding="UTF-16", + output_encoding="UTF-16", + ) + exp_config_different_encoding = { + "dtype": "int32", + "errors": "ignore", + "lowercase": False, + "name": "unicode_character_tokenizer_config_gen", + "normalization_form": None, + "replacement_char": 0, + "sequence_length": 8, + "input_encoding": "UTF-16", + "output_encoding": "UTF-16", + "trainable": True, + } + self.assertEqual( + tokenize_different_encoding.get_config(), + exp_config_different_encoding, + ) + + def test_saving(self): + input_data = tf.constant(["ninjas and samurais", "time travel"]) + + tokenizer = UnicodeCharacterTokenizer( + name="unicode_character_tokenizer_config_gen", + lowercase=False, + sequence_length=20, + normalization_form="NFKC", + errors="replace", + ) + inputs = keras.Input(dtype="string", shape=()) + outputs = tokenizer(inputs) + model = keras.Model(inputs, outputs) + model.save(self.get_temp_dir()) + restored_model = keras.models.load_model(self.get_temp_dir()) + self.assertAllEqual( + model(input_data), + restored_model(input_data), + )