-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Perplexity Metric #68
Changes from 22 commits
937e68f
de41f62
52257dd
6c9438b
53f2e9c
4d68c5b
ff8e341
a5fb375
fb4f6d3
d50842f
be547be
ff1cfaa
5b19a73
ae976d1
d260090
ad9a2a5
362db98
af39af8
829b236
18372f8
9cde5bf
72421ac
7213bbb
3c9e462
ee37937
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_nlp.metrics.perplexity import Perplexity | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Perplexity metric implementation based on `keras.metrics.Metric`.""" | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
|
||
class Perplexity(keras.metrics.Metric): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please make the metric serializable by adding a |
||
"""Perplexity metric. | ||
|
||
This class implements the perplexity metric. In short, this class calculates | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a lot of returns here. Blank line after the one liner, blank line after paragraph, blank line before Args: and Examples: |
||
the cross entropy loss and takes its exponent. | ||
Note: This implementation is not suitable for fixed-size windows. | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
name: string. Name of the metric instance. | ||
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If | ||
not specified, it defaults to tf.float32. | ||
from_logits: bool. If True, `y_pred` (input to `update_state()`) should | ||
be the logits as returned by the model. Otherwise, `y_pred` is a | ||
tensor of probabilities. | ||
pad_token_id: int. Token ID of the padding token. If provided, the mask | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also prefer "mask_token_id" over "pad_token_id" |
||
is computed by this class (all padding tokens are masked while | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by -> for ? |
||
computing the cross entropy loss). Note that if this field is | ||
provided, the `sample_weight` field in `update_state()` is ignored. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This behavior is problematic; we should combine the masks, not drop one of them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think is the best way to combine the masks? Element-wise maximum or element-wise addition (if both are not None)? Or do you have something else in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just multiply the masks together. If padding token is set, that will give a mask of 1s and 0s, which could be multiplied with sample_weight. Put one way... If a padding token has sample weight There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! Done 👍🏼 |
||
**kwargs: Other keyword arguments. | ||
|
||
Examples: | ||
|
||
1. Calculate perplexity by calling update_state() and result(). | ||
1.1. `sample_weight`, and `pad_token_id` are not provided. | ||
>>> tf.random.set_seed(42) | ||
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity") | ||
>>> target = tf.random.uniform( | ||
... shape=[2, 5], maxval=10, dtype=tf.int32, seed=42) | ||
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42) | ||
>>> perplexity.update_state(target, logits) | ||
>>> perplexity.result() | ||
<tf.Tensor: shape=(), dtype=float32, numpy=11.8781595> | ||
|
||
1.2. `sample_weight` specified (masking token with ID 0). | ||
>>> tf.random.set_seed(42) | ||
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity") | ||
>>> target = tf.random.uniform( | ||
... shape=[2, 5], maxval=10, dtype=tf.int32, seed=42) | ||
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42) | ||
>>> sample_weight = tf.cast( | ||
... tf.math.logical_not(tf.equal(target, 0)), tf.float32) | ||
>>> perplexity.update_state(target, logits, sample_weight) | ||
>>> perplexity.result() | ||
<tf.Tensor: shape=(), dtype=float32, numpy=13.1128> | ||
|
||
2. Call perplexity directly. | ||
>>> tf.random.set_seed(42) | ||
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity") | ||
>>> target = tf.random.uniform( | ||
... shape=[2, 5], maxval=10, dtype=tf.int32, seed=42) | ||
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42) | ||
>>> perplexity(target, logits) | ||
<tf.Tensor: shape=(), dtype=float32, numpy=11.8781595> | ||
|
||
3. Provide the padding token ID and let the class compute the mask on its | ||
own. | ||
>>> tf.random.set_seed(42) | ||
>>> perplexity = keras_nlp.metrics.Perplexity( | ||
... name="perplexity", pad_token_id=0) | ||
>>> target = tf.random.uniform( | ||
... shape=[2, 5], maxval=10, dtype=tf.int32, seed=42) | ||
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42) | ||
>>> perplexity(target, logits) | ||
<tf.Tensor: shape=(), dtype=float32, numpy=13.1128> | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name="perplexity", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
dtype=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for dtype (which comes before name). |
||
from_logits=False, | ||
pad_token_id=None, | ||
**kwargs, | ||
): | ||
super().__init__(name=name, dtype=dtype, **kwargs) | ||
|
||
if not tf.as_dtype(self.dtype).is_floating: | ||
raise ValueError( | ||
"`dtype` must be a floating point type. " | ||
f"Received: dtype={dtype}" | ||
) | ||
|
||
self._cross_entropy = keras.losses.SparseCategoricalCrossentropy( | ||
from_logits=from_logits, reduction="sum" | ||
) | ||
|
||
self.pad_token_id = pad_token_id | ||
|
||
self._aggregate_cross_entropy = self.add_weight( | ||
name="aggregate_cross_entropy", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Spell "crossentropy" in a single word, for consistency. This applies to the weight name and also to Python variable names. |
||
initializer="zeros", | ||
dtype=self.dtype, | ||
) | ||
self._number_of_samples = self.add_weight( | ||
name="number_of_samples", initializer="zeros", dtype=self.dtype | ||
) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None): | ||
# y_true shape: (batch_size, seq_len) | ||
# y_pred shape: (batch_size, seq_len, vocab_size) | ||
y_true = tf.cast(y_true, self.dtype) | ||
y_pred = tf.cast(y_pred, self.dtype) | ||
batch_size = tf.cast(tf.shape(y_true)[0], self.dtype) | ||
|
||
if self.pad_token_id is not None: | ||
sample_weight = tf.cast( | ||
tf.math.logical_not(tf.equal(y_true, self.pad_token_id)), | ||
self.dtype, | ||
) | ||
|
||
if sample_weight is not None: | ||
sample_weight = tf.cast(sample_weight, self.dtype) | ||
|
||
# Calculate the Cross Entropy Loss. | ||
cross_entropy_value = tf.cast( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked out the source code of tf.keras.metrics.SparseCategoricalCrossentropy, and it is dong WEIGHTED_MEAN reduction (https://github.com/keras-team/keras/blob/d8fcb9d4d4dad45080ecfdd575483653028f8eda/keras/metrics.py#L583), which should automatically set the divisor as the sum over masks, could you help verify it with your unit test? Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Will try this out! Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @chenmoneygithub, this particular UT failed:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did a quick analysis on Colab. Apparently, when Let me know what the correct course of action is. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty odd, I guess we can stick to Loss function and open an issue for future investigation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Coolio :). Thanks! 👍🏼 |
||
self._cross_entropy(y_true, y_pred, sample_weight=sample_weight), | ||
self.dtype, | ||
) # scalar | ||
|
||
# Divide the loss by the number of non-masked tokens | ||
if sample_weight is not None: | ||
cross_entropy_value = cross_entropy_value / tf.reduce_sum( | ||
sample_weight | ||
) # scalar | ||
else: | ||
cross_entropy_value = cross_entropy_value / ( | ||
tf.cast(tf.shape(y_true)[0], self.dtype) | ||
* tf.cast(tf.shape(y_true)[1], self.dtype) | ||
) # scalar | ||
|
||
self._aggregate_cross_entropy.assign_add( | ||
batch_size * cross_entropy_value | ||
) | ||
self._number_of_samples.assign_add(batch_size) | ||
|
||
def result(self): | ||
if self._number_of_samples == 0: | ||
return 0.0 | ||
perplexity_score = tf.exp( | ||
self._aggregate_cross_entropy / self._number_of_samples | ||
) | ||
return perplexity_score | ||
|
||
def reset_state(self): | ||
self._aggregate_cross_entropy.assign(0.0) | ||
self._number_of_samples.assign(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add an import of metrics from the init file one directory up, otherwise the imports will not work on the exported package.