Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TopP, TopK and Beam samplers #652

Merged
Merged
13 changes: 10 additions & 3 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from tensorflow import keras

from keras_nlp.samplers.greedy import Greedy
from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.greedy_sampler import Sampler
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler


def serialize(sampler):
Expand All @@ -24,7 +28,10 @@ def serialize(sampler):
def deserialize(config, custom_objects=None):
"""Return a `Sampler` object from its config."""
all_classes = {
"greedy": Greedy,
"beam": BeamSampler,
"greedy": GreedySampler,
"top_k": TopKSampler,
"top_p": TopPSampler,
}
return keras.utils.deserialize_keras_object(
config,
Expand All @@ -46,7 +53,7 @@ def get(identifier):
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Sampler` class.

>>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}}
>>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}}
>>> sampler = keras_nlp.samplers.get(cfg)

In the case that the `identifier` is a class, this method will return a new
Expand Down
215 changes: 215 additions & 0 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam Sampler."""

import tensorflow as tf
from tensorflow import keras

from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import base_sampler_args_docstring
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.samplers.sampler import sample_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(
base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring
)
@keras.utils.register_keras_serializable(package="keras_nlp")
class BeamSampler(Sampler):
"""Beam Sampler class.

This sampler implements beam search algorithm. At each time-step, beam
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
search keeps the beams (sequences) of the top `num_beams` highest
accumulated probabilities, and uses each one of the beams to predict
candidate next tokens.

Args:
num_beams: int. The number of beams that should be kept at each
time-step. `num_beams` should be strictly positive.
{{base_sampler_args}}

Call Args:
{{call_args}}

Examples:
```python
BATCH_SIZE = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub I really don't think these arg blocks match the rest of the library's style. I don't think there's a perfect answer but I'd prefer it not to be obvious who wrote what 🥗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I checked our code base, it appears we have a few code having this style:

  • maskedLmhead: link
  • PositionEmbedding: link
  • SinePositionEmbedding: link
  • Generation utils: link

I feel like giving these numbers an explicit meaning makes the code longer but kinda easier to understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a totally valid opinion @chenmoneygithub but it's also important for the code to have a unified style rather than each contributor producing different looking code. If there's an example that's unclear without named params I'm open to trying something different but otherwise I'm hoping we can compromise!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking steps towards a more unified style (and then reflecting that in our style guide) sgtm. What are the main places this differs, besides the constants at the top?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically what I asked for in #658 was the current thought. A bit less script-like and a little more "drop this line in colab and see what we're talking about".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see us replacing the "model" with something like tf.random.uniform(shape, minval=-1, maxval=1). It is kind of weird to me that we show a whole model, that is trainable but randomly initialized (so results will be random anyway), and not even sequence aware so would never really perform even if your trained it. For a new user this seems a bit of a red herring.

Would be more concise to do something like:

    def token_probability_fn(inputs, mask):
        return tf.random.uniform(...) # Replace with a real model!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to keep the model part so that the example is closer to real use cases.

I will unify the docstring to move those hypers inline.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me the model falls into an "uncanny valley" of code examples. It's not something that will actually work, yet also not clearly a random dummy data. As a newbie I worry I would not understand, first, that results will be random, and second, that this is a "bad model" for the task.

Fine to merge as is, but I hope we can play around with some improvements down the road.

VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID = 1

# Create a dummy model to predict the next token.
model = keras.Sequential(
[
keras.Input(shape=[None]),
keras.layers.Embedding(
input_dim=VOCAB_SIZE,
output_dim=FEATURE_SIZE,
),
keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
]
)

# Define a function that outputs the next token's probability for each token
# in the input sequence.
def token_probability_fn(inputs, mask):
return model(inputs)

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

sampler = keras_nlp.samplers.BeamSampler(num_beams=3)
# Print the generated sequence (token ids).
print(sampler(prompt, token_probability_fn, 10))
```
"""

def __init__(
self,
num_beams=5,
jit_compile=True,
run_eagerly=False,
):
self.num_beams = num_beams
super().__init__(jit_compile, run_eagerly)

@format_docstring(sample_args=sample_args_docstring)
def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
):
"""Sampling logic implementation.

Args:
{{sample_args}}
"""
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
max_length = tf.cast(max_length, num_steps.dtype)
length = max_length - num_steps
dummy_preds = token_probability_fn(prompt, mask=mask)
vocab_size = tf.shape(dummy_preds)[-1]
pred_dtype = dummy_preds.dtype

num_beams = self.num_beams

# Initialize beam with shape `(batch_size, num_beams, length)`.
beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1)
# Initialize `beams_prob` with shape `(batch_size, num_beams)`.
beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype)
beams_prob = tf.concat(
[beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)],
axis=-1,
)

def one_step(beams, beams_prob, length, mask):

flattened_beams = tf.reshape(
beams, shape=[batch_size * num_beams, -1]
)
repeated_mask = tf.tile(mask, [num_beams, 1])
probs = token_probability_fn(flattened_beams, repeated_mask)
preds = tf.gather(
probs,
tf.repeat(length - 1, batch_size * num_beams),
axis=1,
batch_dims=1,
)
if from_logits:
preds = keras.activations.softmax(preds, axis=-1)
# Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`.

preds = tf.reshape(preds, shape=[batch_size, -1])

cum_probs = tf.math.log(preds) + tf.repeat(
beams_prob, repeats=vocab_size, axis=1
)

candidate_prob, candidate_indexes = tf.math.top_k(
cum_probs, k=num_beams, sorted=False
)

candidate_beam_indexes = candidate_indexes // vocab_size
next_token = candidate_indexes % vocab_size

beams = tf.gather(
beams, candidate_beam_indexes, axis=1, batch_dims=1
)

# Build a new column of updates to scatter into the beam tensor.
next_token = tf.where(
condition=mask[..., length, tf.newaxis],
x=beams[..., length],
y=next_token,
)
next_token = tf.reshape(next_token, shape=[-1])

mask = tf.tensor_scatter_nd_update(
tensor=mask,
indices=tf.stack(
(
tf.cast(tf.range(batch_size), dtype=length.dtype),
tf.repeat(length, batch_size),
),
axis=1,
),
updates=tf.repeat(True, batch_size),
)

# Generate `(batch_index, beam_index)` tuples for each beam.
beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool))
beam_indices = tf.cast(beam_indices, dtype=length.dtype)
# Build a tensor of repeated `length` values.
length_indices = tf.fill((batch_size * num_beams, 1), length)
# Concatenate to a triplet of `(batch_index, beam_index, length)`.
indices = tf.concat([beam_indices, length_indices], axis=-1)

# Update `beams[:, :, length]` with `next_token`.
beams = tf.tensor_scatter_nd_update(
tensor=beams,
indices=indices,
updates=next_token,
)

beams_prob = candidate_prob

length = tf.add(length, 1)
return beams, beams_prob, length, mask

# Run a while loop till text of length `max_length` has been generated.
beams, beams_prob, length, mask = tf.while_loop(
cond=lambda beams, beams_prob, length, mask: tf.less(
length, max_length
),
body=one_step,
loop_vars=[beams, beams_prob, length, mask],
# There is a strange issue that when `batch_size=1`, the first loop
# iteration changes `beams_prob`'s shape from [1, None] to
# [None, None], which does not happen for `batch_size>1`.
# As a workaround, we set shape invariants.
shape_invariants=[
beams.get_shape(),
tf.TensorShape([None, None]),
length.get_shape(),
mask.get_shape(),
],
)

# Get the beam with the maximum probability.
max_indexes = tf.math.argmax(beams_prob, axis=-1)
max_beams = tf.gather(
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
)

prompt = tf.squeeze(max_beams, axis=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return immediately


return prompt
137 changes: 137 additions & 0 deletions keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Beam sampler."""

import random

import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler


class BeamSamplerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.vocab_size = 10
self.feature_size = 16

# Create a dummy model to predict the next token.
model = keras.Sequential(
[
keras.Input(shape=[None]),
keras.layers.Embedding(
input_dim=self.vocab_size,
output_dim=self.feature_size,
),
keras.layers.Dense(self.vocab_size),
keras.layers.Softmax(),
]
)

def token_probability_fn(inputs, mask):
return model(inputs)

self.token_probability_fn = token_probability_fn
self.sampler = BeamSampler(num_beams=2)

def test_generate_with_1d_prompt(self):
inputs = tf.constant([1])
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [5])

def test_generate_with_2d_prompt(self):
inputs = tf.constant([[1], [1]])
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_list_prompt(self):
inputs = [[1], [1]]
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_ragged_prompt(self):
def token_probability_fn(inputs, mask):
batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

inputs = tf.ragged.constant([[1], [2, 1, 2]])
outputs = self.sampler(inputs, token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_one_beam_generation(self):
for _ in range(5):
inputs = tf.constant([random.randint(0, 9)])
beam_sampler = BeamSampler(num_beams=1)
greedy_sampler = GreedySampler()
beam_output = beam_sampler(
inputs,
self.token_probability_fn,
max_length=5,
)
greedy_output = greedy_sampler(
inputs,
self.token_probability_fn,
max_length=5,
)
self.assertAllEqual(beam_output, greedy_output)

@parameterized.named_parameters(
("xla_graph", True, False),
("non_xla_graph", False, False),
("eager", False, True),
)
def test_assert_generation_is_correct(self, jit_compile, run_eagerly):
def token_probability_fn(inputs, mask):
batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

batch_size = 10
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
max_length = 3
for i in range(1, 5):
sampler = BeamSampler(
num_beams=i,
jit_compile=jit_compile,
run_eagerly=run_eagerly,
)
outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
)
self.assertAllEqual(
outputs, 3 * tf.ones(shape=[batch_size, max_length])
)

def test_end_token_id(self):
def token_probability_fn(inputs, mask):
batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

max_length = 5
inputs = tf.constant([[0, 1], [1, 2]])
outputs = self.sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
)
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
self.assertAllEqual(outputs, expected_outputs)
Loading