From 1344a00b495f3f431e57d844a9c63661d0f3f4fb Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Mon, 8 Aug 2022 16:45:30 +0530 Subject: [PATCH] Minor fixes --- keras_nlp/layers/random_deletion.py | 2 +- keras_nlp/layers/random_deletion_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_nlp/layers/random_deletion.py b/keras_nlp/layers/random_deletion.py index 582f3918b5..d25ae7231e 100644 --- a/keras_nlp/layers/random_deletion.py +++ b/keras_nlp/layers/random_deletion.py @@ -46,7 +46,7 @@ class RandomDeletion(keras.layers.Layer): indicates that should not be considered a candidate for deletion. Unlike the `skip_fn` argument, this argument need not be tracable--it can be any python function. - seed: A seed for the rng. + seed: A seed for the random number generator. Examples: diff --git a/keras_nlp/layers/random_deletion_test.py b/keras_nlp/layers/random_deletion_test.py index edacd3f35e..b345400a7f 100644 --- a/keras_nlp/layers/random_deletion_test.py +++ b/keras_nlp/layers/random_deletion_test.py @@ -105,6 +105,8 @@ def skip_py_fn(word): output = tf.strings.reduce_join(augmented, separator=" ", axis=-1) self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) exp_output = [b"Hey like", b"Keras Tensorflow"] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) def test_get_config_and_from_config(self): augmenter = random_deletion.RandomDeletion(