Skip to content

Commit

Permalink
Add RandomSampler to Samplers (#952)
Browse files Browse the repository at this point in the history
* Add RandomSampler to Samplers

* Fix the example of docstrings of all samplers

* Keep the alphabetical order

* Remove mixing commits

* Edit docstring and code format
  • Loading branch information
abuelnasr0 authored Apr 4, 2023
1 parent 48c649e commit 9dc8b58
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 0 deletions.
2 changes: 2 additions & 0 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.random_sampler import RandomSampler
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler
Expand All @@ -30,6 +31,7 @@ def deserialize(config, custom_objects=None):
all_classes = {
"beam": BeamSampler,
"greedy": GreedySampler,
"random": RandomSampler,
"top_k": TopKSampler,
"top_p": TopPSampler,
}
Expand Down
82 changes: 82 additions & 0 deletions keras_nlp/samplers/random_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Random Sampler."""

import tensorflow as tf

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.RandomSampler")
class RandomSampler(Sampler):
"""Random Sampler class.
This sampler implements random sampling. Briefly, random sampler randomly
selects a token from the entire distribution of the tokens, with selection
chance determined by the probability of each token.
Args:
seed: int, defaults to None. The random seed.
Call Args:
{{call_args}}
Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
def next(prompt, state, index):
# A uniform distribution over our alphabet.
logits = tf.ones((batch_size, vocab_size))
return logits, state
output = keras_nlp.samplers.RandomSampler()(
next=next,
prompt=tf.fill((batch_size, length,), char_lookup['z']),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzcpnjqij']
```
"""

def __init__(
self,
seed=None,
):
super().__init__()
self.seed = seed

def get_next_token(self, probabilities):
# Sample the next token from the probability distribution.
next_token_id = tf.random.categorical(
tf.math.log(probabilities), 1, seed=self.seed, dtype="int32"
)
return tf.squeeze(next_token_id, axis=-1)

def get_config(self):
config = super().get_config()
config.update(
{
"seed": self.seed,
}
)
return config
95 changes: 95 additions & 0 deletions keras_nlp/samplers/random_sampler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Random sampler."""

import numpy as np
import tensorflow as tf
from absl.testing import parameterized

from keras_nlp.samplers.random_sampler import RandomSampler


class RandomSamplerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Use a simple alphabet of lowercase characters to [0, 25].
self.int_lookup = {i: chr(i + ord("a")) for i in range(26)}
self.char_lookup = {v: k for k, v in self.int_lookup.items()}
self.batch_size = 1
self.length = 12
self.vocab_size = len(self.int_lookup)

def next(prompt, state, index):
# Return a distribution favoring the next char in state.
logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9
return logits, state

self.next = next
self.sampler = RandomSampler()

def join_as_string(self, x):
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]

def test_stateless_call(self):
def next(prompt, state, index):
# Return a distribution favoring the first token in the vocab.
logits = np.zeros((self.batch_size, self.vocab_size))
logits[:, 0] = 1e9
return tf.constant(logits), state

prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler(
next=next,
prompt=prompt,
index=5,
)
self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"])

def test_stateful_call(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler(
next=self.next,
prompt=prompt,
state=state,
)
self.assertEqual(self.join_as_string(output), ["sequentially"])

def test_early_stopping(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler(
next=self.next,
prompt=prompt,
state=state,
end_token_id=self.char_lookup["t"],
)
self.assertEqual(self.join_as_string(output), ["sequentzzzzz"])

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_compilation(self, jit_compile):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])

@tf.function(jit_compile=jit_compile)
def generate(prompt, state):
return self.sampler(self.next, prompt=prompt, state=state)

output = generate(prompt, state)
self.assertEqual(self.join_as_string(output), ["sequentially"])

0 comments on commit 9dc8b58

Please sign in to comment.