-
Notifications
You must be signed in to change notification settings - Fork 261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BLEU Score #222
Add BLEU Score #222
Changes from 1 commit
f151982
7708cd9
0f757d5
e90bef3
eface1e
dc2110e
e18cc50
d59058b
5ddcfa7
b2d0822
45136fb
23f9a2f
0217b71
0b6ebfa
be897dd
363da3a
fa2c658
0acaae5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ | |
from tensorflow import keras | ||
|
||
from keras_nlp.utils.tensor_utils import tensor_to_list | ||
from keras_nlp.utils.tensor_utils import tensor_to_string_list | ||
|
||
REPLACE_SUBSTRINGS = [ | ||
("<skipped>", ""), | ||
|
@@ -67,18 +66,19 @@ class Bleu(keras.metrics.Metric): | |
https://cloud.google.com/translate/automl/docs/evaluate#bleu. | ||
|
||
Note on input shapes: | ||
`y_pred` can be a scalar (of shape `()`), or a dense tensor of shape | ||
`(batch_size,)` or `(batch_size, 1)`. `y_true` can either be a dense tensor | ||
of shape `(num_references,)`, or a ragged tensor of shapes | ||
`(batch_size, None)` or `(batch_size, None, 1)`. This is because every | ||
sample can have multiple references. | ||
For unbatched inputs, `y_pred` should be a tensor of shape `()`, and | ||
`y_true` should be a tensor of shape `(num_references,)`. For batched | ||
inputs, `y_pred` should be a tensor of shape `(batch_size,)`, | ||
and `y_true` should be a tensor of shape `(batch_size, num_references)`. In | ||
case of batched inputs, `y_true` can also be of shape `(batch_size, None)` | ||
in case different samples have different number of references. | ||
|
||
Args: | ||
tokenizer: callable. A function that takes a string `tf.RaggedTensor` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens it you pass a tokenizer layer here, will that work? Say byte tokenizer for simplicity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm, it won't work with byte tokeniser because we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should either support our tokenizers or not name this argument to something else. Tokenizer means something specific in our library now, if we use that name but don't support our tokenizer class that is a bad look. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do support our tokenisers. I've added a unit test here: https://github.com/keras-team/keras-nlp/blob/0b6ebfafe2a819bf39061d07f6382d4f0727d55e/keras_nlp/metrics/bleu_test.py#L105 |
||
(of any shape), and tokenizes the strings in the tensor. This | ||
function should use TensorFlow graph ops. If the tokenizer is not | ||
specified, the default tokenizer is used. The default tokenizer | ||
replicates the behaviour of SacreBLEU's `"tokenizer_13a"` tokenizer | ||
(of any shape), and tokenizes the strings in the tensor. If the | ||
tokenizer is not specified, the default tokenizer is used. The | ||
default tokenizer replicates the behaviour of SacreBLEU's | ||
`"tokenizer_13a"` tokenizer | ||
(https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py). | ||
max_order: int. The maximum n-gram order to use. For example, if | ||
`max_order` is set to 3, unigrams, bigrams, and trigrams will be | ||
|
@@ -116,13 +116,6 @@ def __init__( | |
) | ||
|
||
self.tokenizer = tokenizer | ||
try: | ||
self.tokenizer = keras.utils.register_keras_serializable( | ||
package="keras_nlp.metrics.Bleu", name="tokenizer" | ||
)(self.tokenizer) | ||
except: | ||
pass | ||
|
||
self.max_order = max_order | ||
self.smooth = smooth | ||
|
||
|
@@ -287,12 +280,8 @@ def _corpus_bleu( | |
) | ||
|
||
def _calculate_bleu_score(self, references, translation): | ||
if references.dtype == tf.string: | ||
references = tensor_to_string_list(references) | ||
translation = tensor_to_string_list(translation) | ||
else: | ||
references = tensor_to_list(references) | ||
translation = tensor_to_list(translation) | ||
references = tensor_to_list(references) | ||
translation = tensor_to_list(translation) | ||
|
||
matches = self._matches.numpy() | ||
possible_matches = self._possible_matches.numpy() | ||
|
@@ -388,9 +377,7 @@ def get_config(self): | |
config = super().get_config() | ||
config.update( | ||
{ | ||
"tokenizer": None | ||
if self.tokenizer is None | ||
else keras.utils.serialize_keras_object(self.tokenizer), | ||
"tokenizer": self.tokenizer, | ||
"max_order": self.max_order, | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"smooth": self.smooth, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a ragged tensor with shape
(batch_size, None)