From f151982e6b3e01818d72236b62a0ab2fb4fde37d Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Mon, 13 Jun 2022 07:22:52 +0530 Subject: [PATCH 01/14] Add rough BLEU Score implementation --- keras_nlp/metrics/bleu.py | 176 +++++++++++++++++++++++++++++++ keras_nlp/metrics/ngram_utils.py | 27 +++++ 2 files changed, 203 insertions(+) create mode 100644 keras_nlp/metrics/bleu.py create mode 100644 keras_nlp/metrics/ngram_utils.py diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py new file mode 100644 index 0000000000..c05adb7417 --- /dev/null +++ b/keras_nlp/metrics/bleu.py @@ -0,0 +1,176 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BLEU score implementation based on `keras.metrics.Metric`.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.metrics.ngram_utils import get_ngram_count + + +class Bleu(keras.metrics.Metric): + def __init__( + self, + max_order=2, + smooth=False, + dtype=None, + name="bleu", + **kwargs, + ): + super().__init__(name=name, dtype=dtype, **kwargs) + + if not tf.as_dtype(self.dtype).is_floating: + raise ValueError( + "`dtype` must be a floating point type. " + f"Received: dtype={dtype}" + ) + + self.max_order = max_order + self.smooth = smooth + + self._bleu = self.add_weight( + name="bleu", + initializer="zeros", + dtype=self.dtype, + ) + self._number_of_samples = self.add_weight( + name="number_of_samples", initializer="zeros", dtype=self.dtype + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + batch_size = tf.shape(y_true)[0] + + # Tokenise the strings (we will later replace this with a more + # complicated tokeniser) + y_true = tf.strings.split(y_true) + y_pred = tf.strings.split(y_pred) + + agg_reference_len = tf.cast(0, dtype=self.dtype) + agg_translation_len = tf.cast(0, dtype=self.dtype) + p_log_sum = tf.cast(0, dtype=self.dtype) + for idx in range(batch_size): + reference = y_true[idx] + translation = y_pred[idx] + agg_reference_len += tf.cast( + tf.shape(reference)[0], dtype=self.dtype + ) + agg_translation_len += tf.cast( + tf.shape(translation)[0], dtype=self.dtype + ) + + min_precision = tf.cast(1, dtype=self.dtype) + for order in range(1, self.max_order + 1): + matches = tf.cast(0, dtype=self.dtype) + possible_matches = tf.cast(0, dtype=self.dtype) + + for idx in range(batch_size): + reference = y_true[idx] + translation = y_pred[idx] + translation_len = tf.cast( + tf.shape(translation)[0], dtype=self.dtype + ) + + # Get n-grams and ngram count. + reference_ngrams, reference_ngram_freq = get_ngram_count( + reference, order + ) + translation_ngrams, translation_ngram_freq = get_ngram_count( + translation, order + ) + + # Get the intersection of the two ngram tensors. + common_ngrams = tf.sets.intersection( + reference_ngrams[tf.newaxis, :], + translation_ngrams[tf.newaxis, :], + ).values + + common_reference_ngram_freq = tf.gather( + reference_ngram_freq, + tf.argmax( + (reference_ngrams[:, None] == common_ngrams), axis=0 + ), + ) + common_translation_ngram_freq = tf.gather( + translation_ngram_freq, + tf.argmax( + (translation_ngrams[:, None] == common_ngrams), axis=0 + ), + ) + + # Compute number of ngram matches. + matches += tf.cast( + tf.reduce_sum( + tf.minimum( + common_reference_ngram_freq, + common_translation_ngram_freq, + ) + ), + dtype=self.dtype, + ) + if translation_len - order + 1 > 0: + possible_matches += translation_len + + if self.smooth: + precision = (matches + tf.cast(1, dtype=self.dtype)) / ( + possible_matches + tf.cast(1, dtype=self.dtype) + ) + else: + if possible_matches > 0: + precision = matches / possible_matches + else: + precision = tf.cast(0, dtype=self.dtype) + + if precision > 0: + p_log_sum += ( + tf.cast(1, dtype=self.dtype) + / tf.cast(self.max_order, dtype=self.dtype) + ) * tf.math.log(precision) + min_precision = tf.minimum(min_precision, precision) + + if min_precision > 0: + geo_mean = tf.exp(p_log_sum) + else: + geo_mean = tf.cast(0, dtype=self.dtype) + + # Compute the brevity penalty. + ratio = agg_translation_len / agg_reference_len + if ratio > 1: + bp = tf.cast(1, dtype=self.dtype) + else: + bp = tf.exp( + tf.cast(1, dtype=self.dtype) + - tf.cast(1, dtype=self.dtype) / ratio + ) + + self._bleu.assign_add(geo_mean * bp) + self._number_of_samples.assign_add( + tf.cast(batch_size, dtype=tf.float32) + ) + + def result(self): + if self._number_of_samples == 0: + return 0.0 + bleu = self._bleu / self._number_of_samples + + return bleu + + def reset_state(self): + self._bleu.assign(0.0) + self._number_of_samples.assign(0.0) + + def get_config(self): + config = super().get_config() + config.update({"max_order": self.max_order, "smooth": self.smooth}) + return config diff --git a/keras_nlp/metrics/ngram_utils.py b/keras_nlp/metrics/ngram_utils.py new file mode 100644 index 0000000000..f856e97eb0 --- /dev/null +++ b/keras_nlp/metrics/ngram_utils.py @@ -0,0 +1,27 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import tensorflow_text as tf_text + + +def get_ngram_count(tokenized_string, order): + ngrams = tf_text.ngrams( + data=tokenized_string, + width=order, + axis=-1, + reduction_type=tf_text.Reduction.STRING_JOIN, + ) + unique_ngrams, _, ngram_freq = tf.unique_with_counts(ngrams) + return unique_ngrams, ngram_freq From 0f757d5e524afce90ae9593690b1b667e94373cc Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 29 Jun 2022 00:03:32 +0530 Subject: [PATCH 02/14] Add BLEU score class --- keras_nlp/metrics/bleu.py | 347 +++++++++++++++++++++---------- keras_nlp/metrics/ngram_utils.py | 27 --- 2 files changed, 241 insertions(+), 133 deletions(-) delete mode 100644 keras_nlp/metrics/ngram_utils.py diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index c05adb7417..710fb2c4e2 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -12,18 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""BLEU score implementation based on `keras.metrics.Metric`.""" +import collections +import math import tensorflow as tf from tensorflow import keras -from keras_nlp.metrics.ngram_utils import get_ngram_count +from keras_nlp.utils.tensor_utils import tensor_to_string_list + +REPLACE_SUBSTRINGS = [ + ("", ""), + ("-\n", ""), + ("\n", " "), + (""", '"'), + ("&", "&"), + ("<", "<"), + (">", ">"), +] + + +REGEX_PATTERNS = [ + # language-dependent part (assuming Western languages) + (r"([\{-\~\[-\` -\&\(-\+\:-\@\/])", r" \1 "), + # tokenize period and comma unless preceded by a digit + (r"([^0-9])([\.,])", r"\1 \2 "), + # tokenize period and comma unless followed by a digit + (r"([\.,])([^0-9])", r" \1 \2"), + # tokenize dash when preceded by a digit + (r"([0-9])(-)", r"\1 \2 "), + # If last character is "." or ",", add space. + (r"[\.,]$", r" \0 \1"), + # one space only between words + (r"\s+", r" "), +] class Bleu(keras.metrics.Metric): def __init__( self, - max_order=2, + tokenizer=None, + max_order=4, smooth=False, dtype=None, name="bleu", @@ -37,140 +65,247 @@ def __init__( f"Received: dtype={dtype}" ) + def default_tokenizer(inputs): + """ + Default tokenizer. Replicates the behaviour of SacreBLEU's + default tokenizer, namely, `tokenizer_13a`. + """ + for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS: + inputs = tf.strings.regex_replace( + input=inputs, + pattern=pattern, + rewrite=replacement, + replace_global=True, + name=None, + ) + inputs = tf.strings.split(inputs) + return inputs + + if tokenizer is None: + self.tokenizer = default_tokenizer + else: + self.tokenizer = tokenizer self.max_order = max_order self.smooth = smooth + self._matches = self.add_weight( + shape=(self.max_order,), + name="bleu_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._possible_matches = self.add_weight( + shape=(self.max_order,), + name="bleu_possible_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._translation_length = self.add_weight( + name="bleu_translation_length", + initializer="zeros", + dtype=self.dtype, + ) + self._reference_length = self.add_weight( + name="bleu_reference_length", initializer="zeros", dtype=self.dtype + ) + self._bleu = self.add_weight( name="bleu", initializer="zeros", dtype=self.dtype, ) - self._number_of_samples = self.add_weight( - name="number_of_samples", initializer="zeros", dtype=self.dtype - ) def update_state(self, y_true, y_pred, sample_weight=None): - batch_size = tf.shape(y_true)[0] - - # Tokenise the strings (we will later replace this with a more - # complicated tokeniser) - y_true = tf.strings.split(y_true) - y_pred = tf.strings.split(y_pred) - - agg_reference_len = tf.cast(0, dtype=self.dtype) - agg_translation_len = tf.cast(0, dtype=self.dtype) - p_log_sum = tf.cast(0, dtype=self.dtype) - for idx in range(batch_size): - reference = y_true[idx] - translation = y_pred[idx] - agg_reference_len += tf.cast( - tf.shape(reference)[0], dtype=self.dtype - ) - agg_translation_len += tf.cast( - tf.shape(translation)[0], dtype=self.dtype - ) - - min_precision = tf.cast(1, dtype=self.dtype) - for order in range(1, self.max_order + 1): - matches = tf.cast(0, dtype=self.dtype) - possible_matches = tf.cast(0, dtype=self.dtype) + def validate_and_fix_rank(inputs, tensor_name, base_rank=0): + if not isinstance(inputs, tf.Tensor): + inputs = tf.convert_to_tensor(inputs) - for idx in range(batch_size): - reference = y_true[idx] - translation = y_pred[idx] - translation_len = tf.cast( - tf.shape(translation)[0], dtype=self.dtype + if inputs.shape.rank == base_rank: + return inputs[tf.newaxis] + elif inputs.shape.rank == base_rank + 1: + return inputs + else: + raise ValueError( + f"{tensor_name} must be of rank {base_rank} or {base_rank+1}. " + f"Found rank: {inputs.shape.rank}" ) - # Get n-grams and ngram count. - reference_ngrams, reference_ngram_freq = get_ngram_count( - reference, order - ) - translation_ngrams, translation_ngram_freq = get_ngram_count( - translation, order - ) + def _get_ngrams(segment, max_order): + """Extracts all n-grams upto a given maximum order from an input + segment. Uses Python ops. Inspired from + https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. - # Get the intersection of the two ngram tensors. - common_ngrams = tf.sets.intersection( - reference_ngrams[tf.newaxis, :], - translation_ngrams[tf.newaxis, :], - ).values - - common_reference_ngram_freq = tf.gather( - reference_ngram_freq, - tf.argmax( - (reference_ngrams[:, None] == common_ngrams), axis=0 - ), - ) - common_translation_ngram_freq = tf.gather( - translation_ngram_freq, - tf.argmax( - (translation_ngrams[:, None] == common_ngrams), axis=0 - ), - ) + Args: + segment: string. Text segment from which n-grams will be + extracted. + max_order: int. Maximum length in tokens of the n-grams returned + by this methods. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i : i + order]) + ngram_counts[ngram] += 1 + return ngram_counts + + def compute_bleu( + reference_corpus, + translation_corpus, + matches_by_order, + possible_matches_by_order, + translation_length, + reference_length, + max_order=4, + smooth=False, + ): + """Computes BLEU score of translated segments against one or more + references. Uses Python ops. Inspired from + https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. + + Args: + reference_corpus: list of lists of references for each + translation. Each reference should be tokenized into a list + of tokens. + translation_corpus: list of translations to score. Each + translation should be tokenized into a list of tokens. + max_order: int. Maximum n-gram order to use when computing + BLEU score. + smooth: boolean. Whether or not to apply Lin et al. 2004 + smoothing. + """ + for (references, translation) in zip( + reference_corpus, translation_corpus + ): + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + for reference in references: + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram) - 1] += overlap[ngram] + for order in range(1, max_order + 1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order - 1] += possible_matches - # Compute number of ngram matches. - matches += tf.cast( - tf.reduce_sum( - tf.minimum( - common_reference_ngram_freq, - common_translation_ngram_freq, + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = (matches_by_order[i] + 1.0) / ( + possible_matches_by_order[i] + 1.0 + ) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = ( + float(matches_by_order[i]) + / possible_matches_by_order[i] ) - ), - dtype=self.dtype, - ) - if translation_len - order + 1 > 0: - possible_matches += translation_len + else: + precisions[i] = 0.0 - if self.smooth: - precision = (matches + tf.cast(1, dtype=self.dtype)) / ( - possible_matches + tf.cast(1, dtype=self.dtype) + if min(precisions) > 0: + p_log_sum = sum( + (1.0 / max_order) * math.log(p) for p in precisions ) + geo_mean = math.exp(p_log_sum) else: - if possible_matches > 0: - precision = matches / possible_matches - else: - precision = tf.cast(0, dtype=self.dtype) + geo_mean = 0 - if precision > 0: - p_log_sum += ( - tf.cast(1, dtype=self.dtype) - / tf.cast(self.max_order, dtype=self.dtype) - ) * tf.math.log(precision) - min_precision = tf.minimum(min_precision, precision) + ratio = float(translation_length) / reference_length - if min_precision > 0: - geo_mean = tf.exp(p_log_sum) - else: - geo_mean = tf.cast(0, dtype=self.dtype) + if ratio > 1.0: + bp = 1.0 + else: + bp = math.exp(1 - 1.0 / ratio) - # Compute the brevity penalty. - ratio = agg_translation_len / agg_reference_len - if ratio > 1: - bp = tf.cast(1, dtype=self.dtype) - else: - bp = tf.exp( - tf.cast(1, dtype=self.dtype) - - tf.cast(1, dtype=self.dtype) / ratio + bleu = geo_mean * bp + + return ( + bleu, + matches_by_order, + possible_matches_by_order, + translation_length, + reference_length, + ) + + def calculate_bleu_score(references, translation): + references = tensor_to_string_list(references) + translation = tensor_to_string_list(translation) + + matches = self._matches.numpy().tolist() + possible_matches = self._possible_matches.numpy().tolist() + translation_length = self._translation_length.numpy() + reference_length = self._reference_length.numpy() + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = compute_bleu( + reference_corpus=references, + translation_corpus=translation, + matches_by_order=matches, + possible_matches_by_order=possible_matches, + translation_length=translation_length, + reference_length=reference_length, + max_order=self.max_order, + smooth=self.smooth, + ) + return ( + tf.constant(bleu_score, dtype=self.dtype), + tf.constant(matches, dtype=self.dtype), + tf.constant(possible_matches, dtype=self.dtype), + tf.constant(translation_length, dtype=self.dtype), + tf.constant(reference_length, dtype=self.dtype), ) - self._bleu.assign_add(geo_mean * bp) - self._number_of_samples.assign_add( - tf.cast(batch_size, dtype=tf.float32) + y_true = validate_and_fix_rank(y_true, "y_true", 1) + y_pred = validate_and_fix_rank(y_pred, "y_pred", 0) + + # Tokenize the inputs. + y_true = self.tokenizer(y_true) + y_pred = self.tokenizer(y_pred) + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = tf.py_function( + func=calculate_bleu_score, + inp=[y_true, y_pred], + Tout=[self.dtype, self.dtype, self.dtype, self.dtype, self.dtype], ) - def result(self): - if self._number_of_samples == 0: - return 0.0 - bleu = self._bleu / self._number_of_samples + self._matches.assign(matches) + self._possible_matches.assign(possible_matches) + self._translation_length.assign(translation_length) + self._reference_length.assign(reference_length) + self._bleu.assign(bleu_score) - return bleu + def result(self): + return self._bleu def reset_state(self): + self._matches.assign(0.0) + self._possible_matches.assign(0.0) + self._translation_length.assign(0.0) + self._reference_length.assign(0.0) self._bleu.assign(0.0) - self._number_of_samples.assign(0.0) def get_config(self): config = super().get_config() - config.update({"max_order": self.max_order, "smooth": self.smooth}) + config.update( + { + "tokenizer": self.tokenizer, + "max_order": self.max_order, + "smooth": self.smooth, + } + ) return config diff --git a/keras_nlp/metrics/ngram_utils.py b/keras_nlp/metrics/ngram_utils.py deleted file mode 100644 index f856e97eb0..0000000000 --- a/keras_nlp/metrics/ngram_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2022 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -import tensorflow_text as tf_text - - -def get_ngram_count(tokenized_string, order): - ngrams = tf_text.ngrams( - data=tokenized_string, - width=order, - axis=-1, - reduction_type=tf_text.Reduction.STRING_JOIN, - ) - unique_ngrams, _, ngram_freq = tf.unique_with_counts(ngrams) - return unique_ngrams, ngram_freq From eface1e48d829581d3850d5ab57d5cc2f2de0955 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 29 Jun 2022 01:20:18 +0530 Subject: [PATCH 03/14] Add arg for corpus BLEU and sentence BLEU --- keras_nlp/metrics/bleu.py | 207 ++++++++++++++++++++++++++------------ 1 file changed, 140 insertions(+), 67 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 710fb2c4e2..4622397165 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -53,6 +53,7 @@ def __init__( tokenizer=None, max_order=4, smooth=False, + variant="corpus", dtype=None, name="bleu", **kwargs, @@ -65,6 +66,12 @@ def __init__( f"Received: dtype={dtype}" ) + if variant not in ("corpus_bleu", "sentence_bleu"): + raise ValueError( + "`variant` must be either 'corpus_bleu' or 'sentence_bleu'. " + f"Received: variant={variant}" + ) + def default_tokenizer(inputs): """ Default tokenizer. Replicates the behaviour of SacreBLEU's @@ -87,27 +94,37 @@ def default_tokenizer(inputs): self.tokenizer = tokenizer self.max_order = max_order self.smooth = smooth - - self._matches = self.add_weight( - shape=(self.max_order,), - name="bleu_matches", - initializer="zeros", - dtype=self.dtype, - ) - self._possible_matches = self.add_weight( - shape=(self.max_order,), - name="bleu_possible_matches", - initializer="zeros", - dtype=self.dtype, - ) - self._translation_length = self.add_weight( - name="bleu_translation_length", - initializer="zeros", - dtype=self.dtype, - ) - self._reference_length = self.add_weight( - name="bleu_reference_length", initializer="zeros", dtype=self.dtype - ) + self.variant = variant + + if variant == "corpus_bleu": + self._matches = self.add_weight( + shape=(self.max_order,), + name="bleu_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._possible_matches = self.add_weight( + shape=(self.max_order,), + name="bleu_possible_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._translation_length = self.add_weight( + name="bleu_translation_length", + initializer="zeros", + dtype=self.dtype, + ) + self._reference_length = self.add_weight( + name="bleu_reference_length", + initializer="zeros", + dtype=self.dtype, + ) + else: + self._number_of_samples = self.add_weight( + name="number_of_samples", + initializer="zeros", + dtype=self.dtype, + ) self._bleu = self.add_weight( name="bleu", @@ -148,7 +165,7 @@ def _get_ngrams(segment, max_order): ngram_counts[ngram] += 1 return ngram_counts - def compute_bleu( + def corpus_bleu( reference_corpus, translation_corpus, matches_by_order, @@ -231,66 +248,121 @@ def compute_bleu( reference_length, ) + def sentence_bleu( + reference_corpus, + translation_corpus, + max_order=4, + smooth=False, + ): + bleu_score = 0.0 + for references, translation in zip( + reference_corpus, translation_corpus + ): + bleu_score += corpus_bleu( + reference_corpus=[references], + translation_corpus=translation, + matches_by_order=[0] * max_order, + possible_matches_by_order=[0] * max_order, + translation_length=0, + reference_length=0, + max_order=max_order, + smooth=smooth, + ) + return bleu_score + def calculate_bleu_score(references, translation): references = tensor_to_string_list(references) translation = tensor_to_string_list(translation) - matches = self._matches.numpy().tolist() - possible_matches = self._possible_matches.numpy().tolist() - translation_length = self._translation_length.numpy() - reference_length = self._reference_length.numpy() - - ( - bleu_score, - matches, - possible_matches, - translation_length, - reference_length, - ) = compute_bleu( - reference_corpus=references, - translation_corpus=translation, - matches_by_order=matches, - possible_matches_by_order=possible_matches, - translation_length=translation_length, - reference_length=reference_length, - max_order=self.max_order, - smooth=self.smooth, - ) - return ( - tf.constant(bleu_score, dtype=self.dtype), - tf.constant(matches, dtype=self.dtype), - tf.constant(possible_matches, dtype=self.dtype), - tf.constant(translation_length, dtype=self.dtype), - tf.constant(reference_length, dtype=self.dtype), - ) + if self.variant == "corpus_bleu": + matches = self._matches.numpy().tolist() + possible_matches = self._possible_matches.numpy().tolist() + translation_length = self._translation_length.numpy() + reference_length = self._reference_length.numpy() + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = corpus_bleu( + reference_corpus=references, + translation_corpus=translation, + matches_by_order=matches, + possible_matches_by_order=possible_matches, + translation_length=translation_length, + reference_length=reference_length, + max_order=self.max_order, + smooth=self.smooth, + ) + return ( + tf.constant(bleu_score, dtype=self.dtype), + tf.constant(matches, dtype=self.dtype), + tf.constant(possible_matches, dtype=self.dtype), + tf.constant(translation_length, dtype=self.dtype), + tf.constant(reference_length, dtype=self.dtype), + ) + else: + bleu_score = sentence_bleu( + reference_corpus=references, + translation_corpus=translation, + max_order=self.max_order, + smooth=self.smooth, + ) + return tf.constant(bleu_score, dtype=self.dtype) y_true = validate_and_fix_rank(y_true, "y_true", 1) y_pred = validate_and_fix_rank(y_pred, "y_pred", 0) + if self.variant == "sentence_bleu": + batch_size = tf.cast(tf.shape(y_true)[0], dtype=self.dtype) + self._number_of_samples.assign_add(batch_size) + # Tokenize the inputs. y_true = self.tokenizer(y_true) y_pred = self.tokenizer(y_pred) - ( - bleu_score, - matches, - possible_matches, - translation_length, - reference_length, - ) = tf.py_function( - func=calculate_bleu_score, - inp=[y_true, y_pred], - Tout=[self.dtype, self.dtype, self.dtype, self.dtype, self.dtype], - ) + if self.variant == "corpus_bleu": + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = tf.py_function( + func=calculate_bleu_score, + inp=[y_true, y_pred], + Tout=[ + self.dtype, + self.dtype, + self.dtype, + self.dtype, + self.dtype, + ], + ) - self._matches.assign(matches) - self._possible_matches.assign(possible_matches) - self._translation_length.assign(translation_length) - self._reference_length.assign(reference_length) - self._bleu.assign(bleu_score) + self._matches.assign(matches) + self._possible_matches.assign(possible_matches) + self._translation_length.assign(translation_length) + self._reference_length.assign(reference_length) + self._bleu.assign(bleu_score) + else: + bleu_score = tf.py_function( + func=calculate_bleu_score, + inp=[y_true, y_pred], + Tout=self.dtype, + ) + self._bleu.assign_add(bleu_score) def result(self): - return self._bleu + if self.variant == "corpus_bleu": + return self._bleu + else: + if self._number_of_samples == 0: + return 0.0 + else: + return self._bleu / self._number_of_samples def reset_state(self): self._matches.assign(0.0) @@ -306,6 +378,7 @@ def get_config(self): "tokenizer": self.tokenizer, "max_order": self.max_order, "smooth": self.smooth, + "variant": self.variant, } ) return config From dc2110ef6e14ffdd0bd2621da8f2bcae633b29a0 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 29 Jun 2022 01:33:43 +0530 Subject: [PATCH 04/14] Minor bug fixes --- keras_nlp/metrics/bleu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 4622397165..8c329c6a03 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -53,7 +53,7 @@ def __init__( tokenizer=None, max_order=4, smooth=False, - variant="corpus", + variant="corpus_bleu", dtype=None, name="bleu", **kwargs, @@ -267,7 +267,7 @@ def sentence_bleu( reference_length=0, max_order=max_order, smooth=smooth, - ) + )[0] return bleu_score def calculate_bleu_score(references, translation): From e18cc508dfa6d8c7f2c7f0fabe49036256341cf4 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 29 Jun 2022 01:47:11 +0530 Subject: [PATCH 05/14] More bug fixes --- keras_nlp/metrics/bleu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 8c329c6a03..878432288e 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -248,7 +248,7 @@ def corpus_bleu( reference_length, ) - def sentence_bleu( + def aggregate_sentence_bleu( reference_corpus, translation_corpus, max_order=4, @@ -260,7 +260,7 @@ def sentence_bleu( ): bleu_score += corpus_bleu( reference_corpus=[references], - translation_corpus=translation, + translation_corpus=[translation], matches_by_order=[0] * max_order, possible_matches_by_order=[0] * max_order, translation_length=0, @@ -304,7 +304,7 @@ def calculate_bleu_score(references, translation): tf.constant(reference_length, dtype=self.dtype), ) else: - bleu_score = sentence_bleu( + bleu_score = aggregate_sentence_bleu( reference_corpus=references, translation_corpus=translation, max_order=self.max_order, From d59058b353303bb5be99608dbd2842660a11f9d3 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 29 Jun 2022 02:12:48 +0530 Subject: [PATCH 06/14] Add doc-strings --- keras_nlp/metrics/__init__.py | 1 + keras_nlp/metrics/bleu.py | 61 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/keras_nlp/metrics/__init__.py b/keras_nlp/metrics/__init__.py index 55ade6dc8a..dee54aca02 100644 --- a/keras_nlp/metrics/__init__.py +++ b/keras_nlp/metrics/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.metrics.bleu import Bleu from keras_nlp.metrics.perplexity import Perplexity from keras_nlp.metrics.rouge_l import RougeL from keras_nlp.metrics.rouge_n import RougeN diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 878432288e..f4145abbde 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -48,6 +48,45 @@ class Bleu(keras.metrics.Metric): + """BLEU metric. + + This class implements the BLEU metric. BLEU is generally used to evaluate + machine translation systems. Succinctly put, in BLEU score, we count the + number of matching n-grams in the candidate translation to n-grams in the + reference text. We find the "clipped count" of matching n-grams so as to not + give a high score to a reference, prediction pair with repeated tokens. + Secondly, BLEU score tends to reward shorter predictions more, which is why + a brevity penalty is applied to penalise short predictions. + + Note on input shapes: + For `y_true` and `y_pred`, this class supports scalar values and batch + inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. + + Args: + tokenizer: callable. A function that takes a string `tf.Tensor` (of + any shape), and tokenizes the strings in the tensor. This function + should use TensorFlow graph ops. If the tokenizer is not specified, + the default tokenizer (`"tokenizer_13a"` present in the SacreBLEU + package) will be used. + max_order: int. The maximum n-gram order to use. For example, if + `max_order` is set to 3, unigrams, bigrams, and trigrams will be + considered. Defaults to 4. + smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU + score. Defaults to False. + variant: string. Either `"corpus_bleu"` or `"sentence_bleu"`. The former + computes the micro-average precision, which is equivalent to + passing all samples (across batches) all at once. In other words, + summing the numerators and denominators for each + hypothesis-reference(s) pairs before the division (in order to + calculate the precision). The latter is the macro-average BLEU score + , which means that it computes the per sample BLEU score and + averages it. Defaults to `"corpus_bleu"`. + dtype: string or tf.dtypes.Dtype. Precision of metric computation. If + not specified, it defaults to tf.float32. + name: string. Name of the metric instance. + **kwargs: Other keyword arguments. + """ + def __init__( self, tokenizer=None, @@ -185,6 +224,14 @@ def corpus_bleu( of tokens. translation_corpus: list of translations to score. Each translation should be tokenized into a list of tokens. + matches_by_order: list of floats containing the initial number + of matches for each order. + possible_matches_by_order: list of floats containing the initial + number of possible matches for each order. + translation_length: float. Initial number of tokens in all the + translations. + reference_length: float. Initial number of tokens in all the + references. max_order: int. Maximum n-gram order to use when computing BLEU score. smooth: boolean. Whether or not to apply Lin et al. 2004 @@ -254,6 +301,20 @@ def aggregate_sentence_bleu( max_order=4, smooth=False, ): + """Computes the per-sample BLEU score and returns the aggregate of + all samples. Uses Python ops. + + Args: + reference_corpus: list of lists of references for each + translation. Each reference should be tokenized into a list + of tokens. + translation_corpus: list of translations to score. Each + translation should be tokenized into a list of tokens. + max_order: int. Maximum n-gram order to use when computing + BLEU score. + smooth: boolean. Whether or not to apply Lin et al. 2004 + smoothing. + """ bleu_score = 0.0 for references, translation in zip( reference_corpus, translation_corpus From 5ddcfa74106e207301b2abaf092904f505063b5e Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 29 Jun 2022 08:15:17 +0530 Subject: [PATCH 07/14] Add references --- keras_nlp/metrics/bleu.py | 40 ++++++++++++++++++--------------- keras_nlp/metrics/rouge_base.py | 3 +++ keras_nlp/metrics/rouge_l.py | 3 +++ keras_nlp/metrics/rouge_n.py | 3 +++ 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index f4145abbde..e4e5b67a84 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -52,39 +52,43 @@ class Bleu(keras.metrics.Metric): This class implements the BLEU metric. BLEU is generally used to evaluate machine translation systems. Succinctly put, in BLEU score, we count the - number of matching n-grams in the candidate translation to n-grams in the - reference text. We find the "clipped count" of matching n-grams so as to not - give a high score to a reference, prediction pair with repeated tokens. - Secondly, BLEU score tends to reward shorter predictions more, which is why - a brevity penalty is applied to penalise short predictions. + number of matching n-grams in the candidate translation and the reference + text. We find the "clipped count" of matching n-grams so as to not + give a high score to a (reference, prediction) pair with redundant, repeated + tokens. Secondly, BLEU score tends to reward shorter predictions more, which + is why a brevity penalty is applied to penalise short predictions. Note on input shapes: For `y_true` and `y_pred`, this class supports scalar values and batch inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. Args: - tokenizer: callable. A function that takes a string `tf.Tensor` (of - any shape), and tokenizes the strings in the tensor. This function - should use TensorFlow graph ops. If the tokenizer is not specified, - the default tokenizer (`"tokenizer_13a"` present in the SacreBLEU - package) will be used. + tokenizer: callable. A function that takes a string `tf.RaggedTensor` + (of any shape), and tokenizes the strings in the tensor. This + function should use TensorFlow graph ops. If the tokenizer is not + specified, the default tokenizer is used. The default tokenizer + replicates the behaviour of SacreBLEU's `"tokenizer_13a"` tokenizer + (https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py). max_order: int. The maximum n-gram order to use. For example, if `max_order` is set to 3, unigrams, bigrams, and trigrams will be considered. Defaults to 4. smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU score. Defaults to False. variant: string. Either `"corpus_bleu"` or `"sentence_bleu"`. The former - computes the micro-average precision, which is equivalent to - passing all samples (across batches) all at once. In other words, - summing the numerators and denominators for each - hypothesis-reference(s) pairs before the division (in order to - calculate the precision). The latter is the macro-average BLEU score - , which means that it computes the per sample BLEU score and - averages it. Defaults to `"corpus_bleu"`. + computes micro-average precision, which is equivalent to passing all + samples (across batches) all at once. In other words, summing the + numerators and denominators for each hypothesis-reference(s) pairs + before the division (in order to calculate precision). The latter is + the macro-averaged BLEU score which means that it computes the BLEU + score for every sample separately and averages over these scores. + Defaults to `"corpus_bleu"`. dtype: string or tf.dtypes.Dtype. Precision of metric computation. If not specified, it defaults to tf.float32. name: string. Name of the metric instance. **kwargs: Other keyword arguments. + + References: + - [Papineni et al., 2002](https://aclanthology.org/P02-1040/) """ def __init__( @@ -302,7 +306,7 @@ def aggregate_sentence_bleu( smooth=False, ): """Computes the per-sample BLEU score and returns the aggregate of - all samples. Uses Python ops. + BLEU scores over all samples. Uses Python ops. Args: reference_corpus: list of lists of references for each diff --git a/keras_nlp/metrics/rouge_base.py b/keras_nlp/metrics/rouge_base.py index 22d4adf3b8..9df0408f17 100644 --- a/keras_nlp/metrics/rouge_base.py +++ b/keras_nlp/metrics/rouge_base.py @@ -48,6 +48,9 @@ class RougeBase(keras.metrics.Metric): not specified, it defaults to tf.float32. name: string. Name of the metric instance. **kwargs: Other keyword arguments. + + References: + - [Lin et al., 2004](https://aclanthology.org/W04-1013/) """ def __init__( diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py index f6969a85f6..0d55dc4741 100644 --- a/keras_nlp/metrics/rouge_l.py +++ b/keras_nlp/metrics/rouge_l.py @@ -38,6 +38,9 @@ class RougeL(RougeBase): name: string. Name of the metric instance. **kwargs: Other keyword arguments. + References: + - [Lin et al., 2004](https://aclanthology.org/W04-1013/) + Examples: 1. Various Input Types. diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py index 4bfe532ee2..180765fadd 100644 --- a/keras_nlp/metrics/rouge_n.py +++ b/keras_nlp/metrics/rouge_n.py @@ -40,6 +40,9 @@ class RougeN(RougeBase): name: string. Name of the metric instance. **kwargs: Other keyword arguments. + References: + - [Lin et al., 2004](https://aclanthology.org/W04-1013/) + Examples: 1. Various Input Types. From 45136fb446a7b5954681b54dd0c035528baf204a Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Fri, 1 Jul 2022 21:49:24 +0530 Subject: [PATCH 08/14] Address reew comments - I --- keras_nlp/metrics/bleu.py | 325 ++++++++++++++++++++------------------ 1 file changed, 171 insertions(+), 154 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index e4e5b67a84..316f40a957 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""BLEU metric implementation.""" + import collections import math @@ -51,12 +53,16 @@ class Bleu(keras.metrics.Metric): """BLEU metric. This class implements the BLEU metric. BLEU is generally used to evaluate - machine translation systems. Succinctly put, in BLEU score, we count the - number of matching n-grams in the candidate translation and the reference - text. We find the "clipped count" of matching n-grams so as to not - give a high score to a (reference, prediction) pair with redundant, repeated - tokens. Secondly, BLEU score tends to reward shorter predictions more, which - is why a brevity penalty is applied to penalise short predictions. + machine translation systems. by default, this implementation replicates + SacreBLEU, but user-defined tokenizers can be passed to deal with other + languages. + + For BLEU score, we count the number of matching n-grams in the candidate + translation and the reference text. We find the "clipped count" of matching + n-grams so as to not give a high score to a (reference, prediction) pair + with redundant, repeated tokens. Secondly, BLEU score tends to reward + shorter predictions more, which is why a brevity penalty is applied to + penalise short predictions. Note on input shapes: For `y_true` and `y_pred`, this class supports scalar values and batch @@ -73,7 +79,9 @@ class Bleu(keras.metrics.Metric): `max_order` is set to 3, unigrams, bigrams, and trigrams will be considered. Defaults to 4. smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU - score. Defaults to False. + score. Adds 1 to the matched n-gram count (i.e., numerator) and 1 + to the total n-gram count (i.e., denominator) for every order while + calculating precision. Defaults to False. variant: string. Either `"corpus_bleu"` or `"sentence_bleu"`. The former computes micro-average precision, which is equivalent to passing all samples (across batches) all at once. In other words, summing the @@ -89,6 +97,8 @@ class Bleu(keras.metrics.Metric): References: - [Papineni et al., 2002](https://aclanthology.org/P02-1040/) + - [SacreBLEU](https://github.com/mjpost/sacrebleu) + - [Lin et al., 2004](https://aclanthology.org/P04-1077/) """ def __init__( @@ -175,6 +185,158 @@ def default_tokenizer(inputs): dtype=self.dtype, ) + def _get_ngrams(self, segment, max_order): + """Extracts all n-grams upto a given maximum order from an input segment. + + Uses Python ops. Inspired from + https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. + + Args: + segment: string. Text segment from which n-grams will be + extracted. + max_order: int. Maximum length in tokens of the n-grams returned + by this methods. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i : i + order]) + ngram_counts[ngram] += 1 + return ngram_counts + + def _corpus_bleu( + self, + reference_corpus, + translation_corpus, + matches_by_order, + possible_matches_by_order, + translation_length, + reference_length, + max_order=4, + smooth=False, + ): + """Corpus BLEU implementation using Python ops. + + Computes BLEU score of translated segments against one or more + references. Inspired from + https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. + + Args: + reference_corpus: list of lists of references for each + translation. Each reference should be tokenized into a list + of tokens. + translation_corpus: list of translations to score. Each + translation should be tokenized into a list of tokens. + matches_by_order: list of floats containing the initial number + of matches for each order. + possible_matches_by_order: list of floats containing the initial + number of possible matches for each order. + translation_length: float. Initial number of tokens in all the + translations. + reference_length: float. Initial number of tokens in all the + references. + max_order: int. Maximum n-gram order to use when computing + BLEU score. + smooth: boolean. Whether or not to apply Lin et al. 2004 + smoothing. + """ + for (references, translation) in zip( + reference_corpus, translation_corpus + ): + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + for reference in references: + merged_ref_ngram_counts |= self._get_ngrams( + reference, max_order + ) + translation_ngram_counts = self._get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram) - 1] += overlap[ngram] + for order in range(1, max_order + 1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order - 1] += possible_matches + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = (matches_by_order[i] + 1.0) / ( + possible_matches_by_order[i] + 1.0 + ) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = ( + float(matches_by_order[i]) + / possible_matches_by_order[i] + ) + else: + precisions[i] = 0.0 + + if min(precisions) > 0: + p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1.0 + else: + bp = math.exp(1 - 1.0 / ratio) + + bleu = geo_mean * bp + + return ( + bleu, + matches_by_order, + possible_matches_by_order, + translation_length, + reference_length, + ) + + def _aggregate_sentence_bleu( + self, + reference_corpus, + translation_corpus, + max_order=4, + smooth=False, + ): + """Aggregate Sentence BLEU implementation using Python ops. + + Computes the per-sample BLEU score and returns the aggregate of BLEU + scores over all samples. + + Args: + reference_corpus: list of lists of references for each + translation. Each reference should be tokenized into a list + of tokens. + translation_corpus: list of translations to score. Each + translation should be tokenized into a list of tokens. + max_order: int. Maximum n-gram order to use when computing + BLEU score. + smooth: boolean. Whether or not to apply Lin et al. 2004 + smoothing. + """ + bleu_score = 0.0 + for references, translation in zip( + reference_corpus, translation_corpus + ): + bleu_score += self._corpus_bleu( + reference_corpus=[references], + translation_corpus=[translation], + matches_by_order=[0] * max_order, + possible_matches_by_order=[0] * max_order, + translation_length=0, + reference_length=0, + max_order=max_order, + smooth=smooth, + )[0] + return bleu_score + def update_state(self, y_true, y_pred, sample_weight=None): def validate_and_fix_rank(inputs, tensor_name, base_rank=0): if not isinstance(inputs, tf.Tensor): @@ -190,151 +352,6 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0): f"Found rank: {inputs.shape.rank}" ) - def _get_ngrams(segment, max_order): - """Extracts all n-grams upto a given maximum order from an input - segment. Uses Python ops. Inspired from - https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. - - Args: - segment: string. Text segment from which n-grams will be - extracted. - max_order: int. Maximum length in tokens of the n-grams returned - by this methods. - """ - ngram_counts = collections.Counter() - for order in range(1, max_order + 1): - for i in range(0, len(segment) - order + 1): - ngram = tuple(segment[i : i + order]) - ngram_counts[ngram] += 1 - return ngram_counts - - def corpus_bleu( - reference_corpus, - translation_corpus, - matches_by_order, - possible_matches_by_order, - translation_length, - reference_length, - max_order=4, - smooth=False, - ): - """Computes BLEU score of translated segments against one or more - references. Uses Python ops. Inspired from - https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. - - Args: - reference_corpus: list of lists of references for each - translation. Each reference should be tokenized into a list - of tokens. - translation_corpus: list of translations to score. Each - translation should be tokenized into a list of tokens. - matches_by_order: list of floats containing the initial number - of matches for each order. - possible_matches_by_order: list of floats containing the initial - number of possible matches for each order. - translation_length: float. Initial number of tokens in all the - translations. - reference_length: float. Initial number of tokens in all the - references. - max_order: int. Maximum n-gram order to use when computing - BLEU score. - smooth: boolean. Whether or not to apply Lin et al. 2004 - smoothing. - """ - for (references, translation) in zip( - reference_corpus, translation_corpus - ): - reference_length += min(len(r) for r in references) - translation_length += len(translation) - - merged_ref_ngram_counts = collections.Counter() - for reference in references: - merged_ref_ngram_counts |= _get_ngrams(reference, max_order) - translation_ngram_counts = _get_ngrams(translation, max_order) - overlap = translation_ngram_counts & merged_ref_ngram_counts - for ngram in overlap: - matches_by_order[len(ngram) - 1] += overlap[ngram] - for order in range(1, max_order + 1): - possible_matches = len(translation) - order + 1 - if possible_matches > 0: - possible_matches_by_order[order - 1] += possible_matches - - precisions = [0] * max_order - for i in range(0, max_order): - if smooth: - precisions[i] = (matches_by_order[i] + 1.0) / ( - possible_matches_by_order[i] + 1.0 - ) - else: - if possible_matches_by_order[i] > 0: - precisions[i] = ( - float(matches_by_order[i]) - / possible_matches_by_order[i] - ) - else: - precisions[i] = 0.0 - - if min(precisions) > 0: - p_log_sum = sum( - (1.0 / max_order) * math.log(p) for p in precisions - ) - geo_mean = math.exp(p_log_sum) - else: - geo_mean = 0 - - ratio = float(translation_length) / reference_length - - if ratio > 1.0: - bp = 1.0 - else: - bp = math.exp(1 - 1.0 / ratio) - - bleu = geo_mean * bp - - return ( - bleu, - matches_by_order, - possible_matches_by_order, - translation_length, - reference_length, - ) - - def aggregate_sentence_bleu( - reference_corpus, - translation_corpus, - max_order=4, - smooth=False, - ): - """Computes the per-sample BLEU score and returns the aggregate of - BLEU scores over all samples. Uses Python ops. - - Args: - reference_corpus: list of lists of references for each - translation. Each reference should be tokenized into a list - of tokens. - translation_corpus: list of translations to score. Each - translation should be tokenized into a list of tokens. - max_order: int. Maximum n-gram order to use when computing - BLEU score. - smooth: boolean. Whether or not to apply Lin et al. 2004 - smoothing. - """ - bleu_score = 0.0 - for references, translation in zip( - reference_corpus, translation_corpus - ): - bleu_score += corpus_bleu( - reference_corpus=[references], - translation_corpus=[translation], - matches_by_order=[0] * max_order, - possible_matches_by_order=[0] * max_order, - translation_length=0, - reference_length=0, - max_order=max_order, - smooth=smooth, - )[0] - return bleu_score - def calculate_bleu_score(references, translation): references = tensor_to_string_list(references) translation = tensor_to_string_list(translation) @@ -351,7 +368,7 @@ def calculate_bleu_score(references, translation): possible_matches, translation_length, reference_length, - ) = corpus_bleu( + ) = self._corpus_bleu( reference_corpus=references, translation_corpus=translation, matches_by_order=matches, @@ -369,7 +386,7 @@ def calculate_bleu_score(references, translation): tf.constant(reference_length, dtype=self.dtype), ) else: - bleu_score = aggregate_sentence_bleu( + bleu_score = self._aggregate_sentence_bleu( reference_corpus=references, translation_corpus=translation, max_order=self.max_order, From 0217b71f0c0c352bf043739cf6d2844c61ea5214 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 6 Jul 2022 23:41:51 +0530 Subject: [PATCH 09/14] Add UTs, allow dtypes other than string tensors, remove Sentence BLEU --- keras_nlp/metrics/bleu.py | 261 +++++++++------------------ keras_nlp/metrics/bleu_test.py | 210 +++++++++++++++++++++ keras_nlp/utils/tensor_utils.py | 24 ++- keras_nlp/utils/tensor_utils_test.py | 18 ++ 4 files changed, 335 insertions(+), 178 deletions(-) create mode 100644 keras_nlp/metrics/bleu_test.py diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 316f40a957..12c6c28f43 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -20,6 +20,7 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.utils.tensor_utils import tensor_to_list from keras_nlp.utils.tensor_utils import tensor_to_string_list REPLACE_SUBSTRINGS = [ @@ -65,8 +66,10 @@ class Bleu(keras.metrics.Metric): penalise short predictions. Note on input shapes: - For `y_true` and `y_pred`, this class supports scalar values and batch - inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. + For `y_true` and `y_pred`, this class supports the following shapes: + If `y_pred` is a scalar value, `y_true` has to be a 1D dense tensor. + For batched inputs, if `y_pred` is a 1D dense tensor, `y_true` has to be + a dense/ragged tensor with shape `(batch_size, None)`. Args: tokenizer: callable. A function that takes a string `tf.RaggedTensor` @@ -82,14 +85,6 @@ class Bleu(keras.metrics.Metric): score. Adds 1 to the matched n-gram count (i.e., numerator) and 1 to the total n-gram count (i.e., denominator) for every order while calculating precision. Defaults to False. - variant: string. Either `"corpus_bleu"` or `"sentence_bleu"`. The former - computes micro-average precision, which is equivalent to passing all - samples (across batches) all at once. In other words, summing the - numerators and denominators for each hypothesis-reference(s) pairs - before the division (in order to calculate precision). The latter is - the macro-averaged BLEU score which means that it computes the BLEU - score for every sample separately and averages over these scores. - Defaults to `"corpus_bleu"`. dtype: string or tf.dtypes.Dtype. Precision of metric computation. If not specified, it defaults to tf.float32. name: string. Name of the metric instance. @@ -106,7 +101,6 @@ def __init__( tokenizer=None, max_order=4, smooth=False, - variant="corpus_bleu", dtype=None, name="bleu", **kwargs, @@ -119,12 +113,6 @@ def __init__( f"Received: dtype={dtype}" ) - if variant not in ("corpus_bleu", "sentence_bleu"): - raise ValueError( - "`variant` must be either 'corpus_bleu' or 'sentence_bleu'. " - f"Received: variant={variant}" - ) - def default_tokenizer(inputs): """ Default tokenizer. Replicates the behaviour of SacreBLEU's @@ -147,38 +135,29 @@ def default_tokenizer(inputs): self.tokenizer = tokenizer self.max_order = max_order self.smooth = smooth - self.variant = variant - - if variant == "corpus_bleu": - self._matches = self.add_weight( - shape=(self.max_order,), - name="bleu_matches", - initializer="zeros", - dtype=self.dtype, - ) - self._possible_matches = self.add_weight( - shape=(self.max_order,), - name="bleu_possible_matches", - initializer="zeros", - dtype=self.dtype, - ) - self._translation_length = self.add_weight( - name="bleu_translation_length", - initializer="zeros", - dtype=self.dtype, - ) - self._reference_length = self.add_weight( - name="bleu_reference_length", - initializer="zeros", - dtype=self.dtype, - ) - else: - self._number_of_samples = self.add_weight( - name="number_of_samples", - initializer="zeros", - dtype=self.dtype, - ) + self._matches = self.add_weight( + shape=(self.max_order,), + name="bleu_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._possible_matches = self.add_weight( + shape=(self.max_order,), + name="bleu_possible_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._translation_length = self.add_weight( + name="bleu_translation_length", + initializer="zeros", + dtype=self.dtype, + ) + self._reference_length = self.add_weight( + name="bleu_reference_length", + initializer="zeros", + dtype=self.dtype, + ) self._bleu = self.add_weight( name="bleu", initializer="zeros", @@ -298,48 +277,9 @@ def _corpus_bleu( reference_length, ) - def _aggregate_sentence_bleu( - self, - reference_corpus, - translation_corpus, - max_order=4, - smooth=False, - ): - """Aggregate Sentence BLEU implementation using Python ops. - - Computes the per-sample BLEU score and returns the aggregate of BLEU - scores over all samples. - - Args: - reference_corpus: list of lists of references for each - translation. Each reference should be tokenized into a list - of tokens. - translation_corpus: list of translations to score. Each - translation should be tokenized into a list of tokens. - max_order: int. Maximum n-gram order to use when computing - BLEU score. - smooth: boolean. Whether or not to apply Lin et al. 2004 - smoothing. - """ - bleu_score = 0.0 - for references, translation in zip( - reference_corpus, translation_corpus - ): - bleu_score += self._corpus_bleu( - reference_corpus=[references], - translation_corpus=[translation], - matches_by_order=[0] * max_order, - possible_matches_by_order=[0] * max_order, - translation_length=0, - reference_length=0, - max_order=max_order, - smooth=smooth, - )[0] - return bleu_score - def update_state(self, y_true, y_pred, sample_weight=None): def validate_and_fix_rank(inputs, tensor_name, base_rank=0): - if not isinstance(inputs, tf.Tensor): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): inputs = tf.convert_to_tensor(inputs) if inputs.shape.rank == base_rank: @@ -353,102 +293,83 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0): ) def calculate_bleu_score(references, translation): - references = tensor_to_string_list(references) - translation = tensor_to_string_list(translation) - - if self.variant == "corpus_bleu": - matches = self._matches.numpy().tolist() - possible_matches = self._possible_matches.numpy().tolist() - translation_length = self._translation_length.numpy() - reference_length = self._reference_length.numpy() - - ( - bleu_score, - matches, - possible_matches, - translation_length, - reference_length, - ) = self._corpus_bleu( - reference_corpus=references, - translation_corpus=translation, - matches_by_order=matches, - possible_matches_by_order=possible_matches, - translation_length=translation_length, - reference_length=reference_length, - max_order=self.max_order, - smooth=self.smooth, - ) - return ( - tf.constant(bleu_score, dtype=self.dtype), - tf.constant(matches, dtype=self.dtype), - tf.constant(possible_matches, dtype=self.dtype), - tf.constant(translation_length, dtype=self.dtype), - tf.constant(reference_length, dtype=self.dtype), - ) + if references.dtype == tf.string: + references = tensor_to_string_list(references) + translation = tensor_to_string_list(translation) else: - bleu_score = self._aggregate_sentence_bleu( - reference_corpus=references, - translation_corpus=translation, - max_order=self.max_order, - smooth=self.smooth, - ) - return tf.constant(bleu_score, dtype=self.dtype) + references = tensor_to_list(references) + translation = tensor_to_list(translation) - y_true = validate_and_fix_rank(y_true, "y_true", 1) - y_pred = validate_and_fix_rank(y_pred, "y_pred", 0) + matches = self._matches.numpy().tolist() + possible_matches = self._possible_matches.numpy().tolist() + translation_length = self._translation_length.numpy() + reference_length = self._reference_length.numpy() - if self.variant == "sentence_bleu": - batch_size = tf.cast(tf.shape(y_true)[0], dtype=self.dtype) - self._number_of_samples.assign_add(batch_size) - - # Tokenize the inputs. - y_true = self.tokenizer(y_true) - y_pred = self.tokenizer(y_pred) - - if self.variant == "corpus_bleu": ( bleu_score, matches, possible_matches, translation_length, reference_length, - ) = tf.py_function( - func=calculate_bleu_score, - inp=[y_true, y_pred], - Tout=[ - self.dtype, - self.dtype, - self.dtype, - self.dtype, - self.dtype, - ], + ) = self._corpus_bleu( + reference_corpus=references, + translation_corpus=translation, + matches_by_order=matches, + possible_matches_by_order=possible_matches, + translation_length=translation_length, + reference_length=reference_length, + max_order=self.max_order, + smooth=self.smooth, ) - - self._matches.assign(matches) - self._possible_matches.assign(possible_matches) - self._translation_length.assign(translation_length) - self._reference_length.assign(reference_length) - self._bleu.assign(bleu_score) - else: - bleu_score = tf.py_function( - func=calculate_bleu_score, - inp=[y_true, y_pred], - Tout=self.dtype, + return ( + tf.constant(bleu_score, dtype=self.dtype), + tf.constant(matches, dtype=self.dtype), + tf.constant(possible_matches, dtype=self.dtype), + tf.constant(translation_length, dtype=self.dtype), + tf.constant(reference_length, dtype=self.dtype), ) - self._bleu.assign_add(bleu_score) + + y_true = validate_and_fix_rank(y_true, "y_true", 1) + y_pred = validate_and_fix_rank(y_pred, "y_pred", 0) + + # Tokenize the inputs. + y_true = self.tokenizer(y_true) + y_pred = self.tokenizer(y_pred) + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = tf.py_function( + func=calculate_bleu_score, + inp=[y_true, y_pred], + Tout=[ + self.dtype, + self.dtype, + self.dtype, + self.dtype, + self.dtype, + ], + ) + + self._matches.assign(matches) + self._possible_matches.assign(possible_matches) + self._translation_length.assign(translation_length) + self._reference_length.assign(reference_length) + self._bleu.assign(bleu_score) def result(self): - if self.variant == "corpus_bleu": - return self._bleu - else: - if self._number_of_samples == 0: - return 0.0 - else: - return self._bleu / self._number_of_samples + return self._bleu def reset_state(self): - self._matches.assign(0.0) - self._possible_matches.assign(0.0) + self._matches.assign( + tf.zeros(shape=(self.max_order,), dtype=self.dtype) + ) + self._possible_matches.assign( + tf.zeros(shape=(self.max_order,), dtype=self.dtype) + ) self._translation_length.assign(0.0) self._reference_length.assign(0.0) self._bleu.assign(0.0) @@ -457,10 +378,8 @@ def get_config(self): config = super().get_config() config.update( { - "tokenizer": self.tokenizer, "max_order": self.max_order, "smooth": self.smooth, - "variant": self.variant, } ) return config diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py new file mode 100644 index 0000000000..ae612980ff --- /dev/null +++ b/keras_nlp/metrics/bleu_test.py @@ -0,0 +1,210 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Bleu.""" + +import tensorflow as tf + +from keras_nlp.metrics import Bleu +from keras_nlp.tokenizers import ByteTokenizer + + +class BleuTest(tf.test.TestCase): + def test_initialization(self): + bleu = Bleu() + result = bleu.result() + + self.assertEqual(result, 0.0) + + def test_scalar_input(self): + bleu = Bleu(smooth=True) + y_true = [ + "He eats a sweet apple.", + "He is eating a tasty apple, isn't he?", + ] + y_pred = "He He He eats sweet apple which is a fruit." + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.212, delta=1e-3) + + def test_1d_list_input(self): + bleu = Bleu() + y_true = [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + y_pred = [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + def test_1d_tensor_input(self): + bleu = Bleu() + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + def test_custom_tokenizer(self): + byte_tokenizer = ByteTokenizer() + bleu = Bleu(tokenizer=byte_tokenizer) + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.609, delta=1e-3) + + def test_different_order(self): + bleu = Bleu(max_order=5) + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.188, delta=1e-3) + + def test_reset_state(self): + bleu = Bleu() + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu.update_state(y_true, y_pred) + bleu_val = bleu.result() + self.assertNotEqual(bleu_val.numpy(), 0.0) + + bleu.reset_state() + bleu_val = bleu.result() + self.assertEqual(bleu_val, 0.0) + + def test_update_state(self): + bleu = Bleu() + y_true_1 = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred_1 = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu.update_state(y_true_1, y_pred_1) + bleu_val = bleu.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + y_true_2 = tf.constant(["Virat Kohli is the GOAT."]) + y_pred_2 = tf.constant("Virat Kohli is the greatest of all time!") + + bleu.update_state(y_true_2, y_pred_2) + bleu_val = bleu.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.26, delta=1e-3) + + def test_merge_state_normalize(self): + bleu_1 = Bleu(smooth=True) + bleu_2 = Bleu(smooth=True) + + y_true_1 = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred_1 = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + y_true_2 = tf.constant(["Virat Kohli is the GOAT."]) + y_pred_2 = tf.constant("Virat Kohli is the greatest of all time!") + + y_true_3 = tf.constant([["Watching Test cricket is so much fun."]]) + y_pred_3 = tf.constant(["Test is the best format in cricket."]) + + bleu_1.update_state(y_true_1, y_pred_1) + bleu_1.update_state(y_true_2, y_pred_2) + bleu_val = bleu_1.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.293, delta=1e-3) + + bleu_2.update_state(y_true_3, y_pred_3) + bleu_val = bleu_2.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.202, delta=1e-3) + + merged_bleu = Bleu(smooth=True) + merged_bleu.merge_state([bleu_1, bleu_2]) + bleu_val = merged_bleu.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.495, delta=1e-3) + + def test_get_config(self): + bleu = Bleu( + tokenizer=None, + max_order=8, + smooth=True, + dtype=tf.float64, + name="bleu_test", + ) + + config = bleu.get_config() + expected_config_subset = { + "max_order": 8, + "smooth": True, + } + self.assertEqual(config, {**config, **expected_config_subset}) diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index 26fc815f11..30d2288f10 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -25,16 +25,11 @@ def _decode_strings_to_utf8(inputs): return [_decode_strings_to_utf8(x) for x in inputs] -def tensor_to_string_list(inputs): - """Detokenize and convert tensor to nested lists of python strings. - - This is a convenience method which converts each byte string to a python - string. +def tensor_to_list(inputs): + """Converts a tensor to nested lists. Args: inputs: Input tensor, or dict/list/tuple of input tensors. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. """ if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): inputs = tf.convert_to_tensor(inputs) @@ -44,4 +39,19 @@ def tensor_to_string_list(inputs): list_outputs = inputs.numpy() if inputs.shape.rank != 0: list_outputs = list_outputs.tolist() + return list_outputs + + +def tensor_to_string_list(inputs): + """Detokenize and convert tensor to nested lists of python strings. + + This is a convenience method which converts each byte string to a python + string. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + list_outputs = tensor_to_list(inputs) return _decode_strings_to_utf8(list_outputs) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tensor_utils_test.py index d9941f750f..bdb9d728c4 100644 --- a/keras_nlp/utils/tensor_utils_test.py +++ b/keras_nlp/utils/tensor_utils_test.py @@ -14,9 +14,27 @@ import tensorflow as tf +from keras_nlp.utils.tensor_utils import tensor_to_list from keras_nlp.utils.tensor_utils import tensor_to_string_list +class TensorToListTest(tf.test.TestCase): + def test_ragged_input(self): + input_data = tf.ragged.constant([[1, 2], [4, 5, 6]]) + list_output = tensor_to_list(input_data) + self.assertAllEqual(list_output, [[1, 2], [4, 5, 6]]) + + def test_dense_input(self): + input_data = tf.constant([[1, 2], [3, 4]]) + list_output = tensor_to_list(input_data) + self.assertAllEqual(list_output, [[1, 2], [3, 4]]) + + def test_scalar_input(self): + input_data = tf.constant(1) + list_output = tensor_to_list(input_data) + self.assertEqual(list_output, 1) + + class TensorToStringListTest(tf.test.TestCase): def test_detokenize_to_strings_for_ragged(self): input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) From 0b6ebfafe2a819bf39061d07f6382d4f0727d55e Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Thu, 7 Jul 2022 17:52:46 +0530 Subject: [PATCH 10/14] Address review comments-II, make shape changes --- keras_nlp/metrics/bleu.py | 27 ++++++++++++++++++--------- keras_nlp/metrics/bleu_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 12c6c28f43..fee9063e99 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -54,7 +54,7 @@ class Bleu(keras.metrics.Metric): """BLEU metric. This class implements the BLEU metric. BLEU is generally used to evaluate - machine translation systems. by default, this implementation replicates + machine translation systems. By default, this implementation replicates SacreBLEU, but user-defined tokenizers can be passed to deal with other languages. @@ -63,13 +63,15 @@ class Bleu(keras.metrics.Metric): n-grams so as to not give a high score to a (reference, prediction) pair with redundant, repeated tokens. Secondly, BLEU score tends to reward shorter predictions more, which is why a brevity penalty is applied to - penalise short predictions. + penalise short predictions. For more details, see the following article: + https://cloud.google.com/translate/automl/docs/evaluate#bleu. Note on input shapes: - For `y_true` and `y_pred`, this class supports the following shapes: - If `y_pred` is a scalar value, `y_true` has to be a 1D dense tensor. - For batched inputs, if `y_pred` is a 1D dense tensor, `y_true` has to be - a dense/ragged tensor with shape `(batch_size, None)`. + `y_pred` can be a scalar (of shape `()`), or a dense tensor of shape + `(batch_size,)` or `(batch_size, 1)`. `y_true` can either be a dense tensor + of shape `(num_references,)`, or a ragged tensor of shapes + `(batch_size, None)` or `(batch_size, None, 1)`. This is because every + sample can have multiple references. Args: tokenizer: callable. A function that takes a string `tf.RaggedTensor` @@ -171,7 +173,7 @@ def _get_ngrams(self, segment, max_order): https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. Args: - segment: string. Text segment from which n-grams will be + segment: list. Text segment from which n-grams will be extracted. max_order: int. Maximum length in tokens of the n-grams returned by this methods. @@ -286,10 +288,17 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0): return inputs[tf.newaxis] elif inputs.shape.rank == base_rank + 1: return inputs + elif inputs.shape.rank == base_rank + 2: + if tf.shape(inputs)[-1] != 1: + raise ValueError( + f"{tensor_name} is of rank {input.shape.rank}. The " + f"last dimension must be of size 1." + ) + return tf.squeeze(inputs, axis=-1) else: raise ValueError( - f"{tensor_name} must be of rank {base_rank} or {base_rank+1}. " - f"Found rank: {inputs.shape.rank}" + f"{tensor_name} must be of rank {base_rank}, {base_rank+1} " + f"or {base_rank+2}. Found rank: {inputs.shape.rank}" ) def calculate_bleu_score(references, translation): diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py index ae612980ff..fd832b9fb0 100644 --- a/keras_nlp/metrics/bleu_test.py +++ b/keras_nlp/metrics/bleu_test.py @@ -52,6 +52,20 @@ def test_1d_list_input(self): bleu_val = bleu(y_true, y_pred) self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + def test_2d_list_input(self): + bleu = Bleu() + y_true = [ + [["He eats a sweet apple."]], + [["Silicon Valley is one of my favourite shows!"]], + ] + y_pred = [ + ["He He He eats sweet apple which is a fruit."], + ["I love Silicon Valley, it's one of my favourite shows."], + ] + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + def test_1d_tensor_input(self): bleu = Bleu() y_true = tf.ragged.constant( @@ -70,6 +84,24 @@ def test_1d_tensor_input(self): bleu_val = bleu(y_true, y_pred) self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + def test_2d_tensor_input(self): + bleu = Bleu() + y_true = tf.constant( + [ + [["He eats a sweet apple."]], + [["Silicon Valley is one of my favourite shows!"]], + ] + ) + y_pred = tf.constant( + [ + ["He He He eats sweet apple which is a fruit."], + ["I love Silicon Valley, it's one of my favourite shows."], + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + def test_custom_tokenizer(self): byte_tokenizer = ByteTokenizer() bleu = Bleu(tokenizer=byte_tokenizer) From be897dd952137cc6f5d20fed56fe17a541f027b9 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Fri, 8 Jul 2022 19:15:27 +0530 Subject: [PATCH 11/14] Address review comments-III --- keras_nlp/metrics/bleu.py | 130 +++++++++++++++----------------- keras_nlp/metrics/bleu_test.py | 27 ++++++- keras_nlp/utils/tensor_utils.py | 2 - 3 files changed, 88 insertions(+), 71 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index fee9063e99..3aed002c3d 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -115,26 +115,7 @@ def __init__( f"Received: dtype={dtype}" ) - def default_tokenizer(inputs): - """ - Default tokenizer. Replicates the behaviour of SacreBLEU's - default tokenizer, namely, `tokenizer_13a`. - """ - for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS: - inputs = tf.strings.regex_replace( - input=inputs, - pattern=pattern, - rewrite=replacement, - replace_global=True, - name=None, - ) - inputs = tf.strings.split(inputs) - return inputs - - if tokenizer is None: - self.tokenizer = default_tokenizer - else: - self.tokenizer = tokenizer + self.tokenizer = tokenizer self.max_order = max_order self.smooth = smooth @@ -166,6 +147,25 @@ def default_tokenizer(inputs): dtype=self.dtype, ) + def _tokenizer(self, inputs): + """ + Tokenizes the input strings. By default, replicates the behaviour of + SacreBLEU's default tokenizer, namely, `tokenizer_13a`. + """ + if self.tokenizer: + return self.tokenizer(inputs) + + for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS: + inputs = tf.strings.regex_replace( + input=inputs, + pattern=pattern, + rewrite=replacement, + replace_global=True, + name=None, + ) + inputs = tf.strings.split(inputs) + return inputs + def _get_ngrams(self, segment, max_order): """Extracts all n-grams upto a given maximum order from an input segment. @@ -176,7 +176,7 @@ def _get_ngrams(self, segment, max_order): segment: list. Text segment from which n-grams will be extracted. max_order: int. Maximum length in tokens of the n-grams returned - by this methods. + by this method. """ ngram_counts = collections.Counter() for order in range(1, max_order + 1): @@ -279,6 +279,43 @@ def _corpus_bleu( reference_length, ) + def _calculate_bleu_score(self, references, translation): + if references.dtype == tf.string: + references = tensor_to_string_list(references) + translation = tensor_to_string_list(translation) + else: + references = tensor_to_list(references) + translation = tensor_to_list(translation) + + matches = self._matches.numpy() + possible_matches = self._possible_matches.numpy() + translation_length = self._translation_length.numpy() + reference_length = self._reference_length.numpy() + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = self._corpus_bleu( + reference_corpus=references, + translation_corpus=translation, + matches_by_order=matches, + possible_matches_by_order=possible_matches, + translation_length=translation_length, + reference_length=reference_length, + max_order=self.max_order, + smooth=self.smooth, + ) + return ( + tf.constant(bleu_score, dtype=self.dtype), + tf.constant(matches, dtype=self.dtype), + tf.constant(possible_matches, dtype=self.dtype), + tf.constant(translation_length, dtype=self.dtype), + tf.constant(reference_length, dtype=self.dtype), + ) + def update_state(self, y_true, y_pred, sample_weight=None): def validate_and_fix_rank(inputs, tensor_name, base_rank=0): if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): @@ -301,49 +338,12 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0): f"or {base_rank+2}. Found rank: {inputs.shape.rank}" ) - def calculate_bleu_score(references, translation): - if references.dtype == tf.string: - references = tensor_to_string_list(references) - translation = tensor_to_string_list(translation) - else: - references = tensor_to_list(references) - translation = tensor_to_list(translation) - - matches = self._matches.numpy().tolist() - possible_matches = self._possible_matches.numpy().tolist() - translation_length = self._translation_length.numpy() - reference_length = self._reference_length.numpy() - - ( - bleu_score, - matches, - possible_matches, - translation_length, - reference_length, - ) = self._corpus_bleu( - reference_corpus=references, - translation_corpus=translation, - matches_by_order=matches, - possible_matches_by_order=possible_matches, - translation_length=translation_length, - reference_length=reference_length, - max_order=self.max_order, - smooth=self.smooth, - ) - return ( - tf.constant(bleu_score, dtype=self.dtype), - tf.constant(matches, dtype=self.dtype), - tf.constant(possible_matches, dtype=self.dtype), - tf.constant(translation_length, dtype=self.dtype), - tf.constant(reference_length, dtype=self.dtype), - ) - y_true = validate_and_fix_rank(y_true, "y_true", 1) y_pred = validate_and_fix_rank(y_pred, "y_pred", 0) # Tokenize the inputs. - y_true = self.tokenizer(y_true) - y_pred = self.tokenizer(y_pred) + y_true = self._tokenizer(y_true) + y_pred = self._tokenizer(y_pred) ( bleu_score, @@ -352,15 +352,9 @@ def calculate_bleu_score(references, translation): translation_length, reference_length, ) = tf.py_function( - func=calculate_bleu_score, + func=self._calculate_bleu_score, inp=[y_true, y_pred], - Tout=[ - self.dtype, - self.dtype, - self.dtype, - self.dtype, - self.dtype, - ], + Tout=[self.dtype, self.dtype, self.dtype, self.dtype, self.dtype], ) self._matches.assign(matches) diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py index fd832b9fb0..09196e910b 100644 --- a/keras_nlp/metrics/bleu_test.py +++ b/keras_nlp/metrics/bleu_test.py @@ -15,6 +15,7 @@ """Tests for Bleu.""" import tensorflow as tf +from tensorflow import keras from keras_nlp.metrics import Bleu from keras_nlp.tokenizers import ByteTokenizer @@ -139,6 +140,29 @@ def test_different_order(self): bleu_val = bleu(y_true, y_pred) self.assertAlmostEqual(bleu_val.numpy(), 0.188, delta=1e-3) + def test_model_compile(self): + inputs = keras.Input(shape=(), dtype="string") + outputs = tf.identity(inputs) + model = keras.Model(inputs, outputs) + + model.compile(metrics=[Bleu()]) + + x = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + y = tf.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + + output = model.evaluate(x, y, return_dict=True) + self.assertAlmostEqual(output["bleu"], 0.243, delta=1e-3) + def test_reset_state(self): bleu = Bleu() y_true = tf.ragged.constant( @@ -226,8 +250,9 @@ def test_merge_state_normalize(self): self.assertAlmostEqual(bleu_val.numpy(), 0.495, delta=1e-3) def test_get_config(self): + byte_tokenizer = ByteTokenizer() bleu = Bleu( - tokenizer=None, + tokenizer=byte_tokenizer.tokenize, max_order=8, smooth=True, dtype=tf.float64, diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index 30d2288f10..bcab0f24d8 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -50,8 +50,6 @@ def tensor_to_string_list(inputs): Args: inputs: Input tensor, or dict/list/tuple of input tensors. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. """ list_outputs = tensor_to_list(inputs) return _decode_strings_to_utf8(list_outputs) From 363da3a515512999853fdb6c1ba1581917d7ca5e Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Fri, 8 Jul 2022 19:39:49 +0530 Subject: [PATCH 12/14] Serialise tokenizer --- keras_nlp/metrics/bleu.py | 10 ++++++++++ keras_nlp/metrics/bleu_test.py | 8 +++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index 3aed002c3d..ef1aa2cd6e 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -116,6 +116,13 @@ def __init__( ) self.tokenizer = tokenizer + try: + self.tokenizer = keras.utils.register_keras_serializable( + package="keras_nlp.metrics.Bleu", name="tokenizer" + )(self.tokenizer) + except: + pass + self.max_order = max_order self.smooth = smooth @@ -381,6 +388,9 @@ def get_config(self): config = super().get_config() config.update( { + "tokenizer": None + if self.tokenizer is None + else keras.utils.serialize_keras_object(self.tokenizer), "max_order": self.max_order, "smooth": self.smooth, } diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py index 09196e910b..8f3f086415 100644 --- a/keras_nlp/metrics/bleu_test.py +++ b/keras_nlp/metrics/bleu_test.py @@ -252,7 +252,7 @@ def test_merge_state_normalize(self): def test_get_config(self): byte_tokenizer = ByteTokenizer() bleu = Bleu( - tokenizer=byte_tokenizer.tokenize, + tokenizer=byte_tokenizer, max_order=8, smooth=True, dtype=tf.float64, @@ -260,6 +260,12 @@ def test_get_config(self): ) config = bleu.get_config() + self.assertIsInstance( + keras.utils.deserialize_keras_object( + identifier=config.pop("tokenizer"), module_objects=globals() + ), + ByteTokenizer, + ) expected_config_subset = { "max_order": 8, "smooth": True, From fa2c65818ac12634c492d3f80203e417d9cf6187 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Fri, 8 Jul 2022 23:53:28 +0530 Subject: [PATCH 13/14] Small fixes --- keras_nlp/metrics/bleu.py | 39 ++++++++++++---------------------- keras_nlp/metrics/bleu_test.py | 7 +----- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index ef1aa2cd6e..fcd11c65c4 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -21,7 +21,6 @@ from tensorflow import keras from keras_nlp.utils.tensor_utils import tensor_to_list -from keras_nlp.utils.tensor_utils import tensor_to_string_list REPLACE_SUBSTRINGS = [ ("", ""), @@ -67,18 +66,19 @@ class Bleu(keras.metrics.Metric): https://cloud.google.com/translate/automl/docs/evaluate#bleu. Note on input shapes: - `y_pred` can be a scalar (of shape `()`), or a dense tensor of shape - `(batch_size,)` or `(batch_size, 1)`. `y_true` can either be a dense tensor - of shape `(num_references,)`, or a ragged tensor of shapes - `(batch_size, None)` or `(batch_size, None, 1)`. This is because every - sample can have multiple references. + For unbatched inputs, `y_pred` should be a tensor of shape `()`, and + `y_true` should be a tensor of shape `(num_references,)`. For batched + inputs, `y_pred` should be a tensor of shape `(batch_size,)`, + and `y_true` should be a tensor of shape `(batch_size, num_references)`. In + case of batched inputs, `y_true` can also be of shape `(batch_size, None)` + in case different samples have different number of references. Args: tokenizer: callable. A function that takes a string `tf.RaggedTensor` - (of any shape), and tokenizes the strings in the tensor. This - function should use TensorFlow graph ops. If the tokenizer is not - specified, the default tokenizer is used. The default tokenizer - replicates the behaviour of SacreBLEU's `"tokenizer_13a"` tokenizer + (of any shape), and tokenizes the strings in the tensor. If the + tokenizer is not specified, the default tokenizer is used. The + default tokenizer replicates the behaviour of SacreBLEU's + `"tokenizer_13a"` tokenizer (https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py). max_order: int. The maximum n-gram order to use. For example, if `max_order` is set to 3, unigrams, bigrams, and trigrams will be @@ -116,13 +116,6 @@ def __init__( ) self.tokenizer = tokenizer - try: - self.tokenizer = keras.utils.register_keras_serializable( - package="keras_nlp.metrics.Bleu", name="tokenizer" - )(self.tokenizer) - except: - pass - self.max_order = max_order self.smooth = smooth @@ -287,12 +280,8 @@ def _corpus_bleu( ) def _calculate_bleu_score(self, references, translation): - if references.dtype == tf.string: - references = tensor_to_string_list(references) - translation = tensor_to_string_list(translation) - else: - references = tensor_to_list(references) - translation = tensor_to_list(translation) + references = tensor_to_list(references) + translation = tensor_to_list(translation) matches = self._matches.numpy() possible_matches = self._possible_matches.numpy() @@ -388,9 +377,7 @@ def get_config(self): config = super().get_config() config.update( { - "tokenizer": None - if self.tokenizer is None - else keras.utils.serialize_keras_object(self.tokenizer), + "tokenizer": self.tokenizer, "max_order": self.max_order, "smooth": self.smooth, } diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py index 8f3f086415..a2d314c639 100644 --- a/keras_nlp/metrics/bleu_test.py +++ b/keras_nlp/metrics/bleu_test.py @@ -260,13 +260,8 @@ def test_get_config(self): ) config = bleu.get_config() - self.assertIsInstance( - keras.utils.deserialize_keras_object( - identifier=config.pop("tokenizer"), module_objects=globals() - ), - ByteTokenizer, - ) expected_config_subset = { + "tokenizer": byte_tokenizer, "max_order": 8, "smooth": True, } From 0acaae5ece1f46fec6f89cdc568265c700ab0ae5 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Sat, 9 Jul 2022 07:26:12 +0530 Subject: [PATCH 14/14] Doc-string changes --- keras_nlp/metrics/bleu.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py index fcd11c65c4..1626762361 100644 --- a/keras_nlp/metrics/bleu.py +++ b/keras_nlp/metrics/bleu.py @@ -70,8 +70,9 @@ class Bleu(keras.metrics.Metric): `y_true` should be a tensor of shape `(num_references,)`. For batched inputs, `y_pred` should be a tensor of shape `(batch_size,)`, and `y_true` should be a tensor of shape `(batch_size, num_references)`. In - case of batched inputs, `y_true` can also be of shape `(batch_size, None)` - in case different samples have different number of references. + case of batched inputs, `y_true` can also be a ragged tensor of shape + `(batch_size, None)` if different samples have different number of + references. Args: tokenizer: callable. A function that takes a string `tf.RaggedTensor` @@ -167,7 +168,7 @@ def _tokenizer(self, inputs): return inputs def _get_ngrams(self, segment, max_order): - """Extracts all n-grams upto a given maximum order from an input segment. + """Extracts all n-grams up to a given maximum order from an input segment. Uses Python ops. Inspired from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.