-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RandomSampler to Samplers (#952)
* Add RandomSampler to Samplers * Fix the example of docstrings of all samplers * Keep the alphabetical order * Remove mixing commits * Edit docstring and code format
- Loading branch information
1 parent
48c649e
commit 9dc8b58
Showing
3 changed files
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Random Sampler.""" | ||
|
||
import tensorflow as tf | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.samplers.sampler import Sampler | ||
from keras_nlp.samplers.sampler import call_args_docstring | ||
from keras_nlp.utils.python_utils import format_docstring | ||
|
||
|
||
@format_docstring(call_args=call_args_docstring) | ||
@keras_nlp_export("keras_nlp.samplers.RandomSampler") | ||
class RandomSampler(Sampler): | ||
"""Random Sampler class. | ||
This sampler implements random sampling. Briefly, random sampler randomly | ||
selects a token from the entire distribution of the tokens, with selection | ||
chance determined by the probability of each token. | ||
Args: | ||
seed: int, defaults to None. The random seed. | ||
Call Args: | ||
{{call_args}} | ||
Examples: | ||
```python | ||
# Use a simple alphabet of lowercase characters with ids in range [0, 25]. | ||
int_lookup = {i: chr(i + ord('a')) for i in range(26)} | ||
char_lookup = {v: k for k, v in int_lookup.items()} | ||
batch_size, length, vocab_size = 1, 12, len(int_lookup) | ||
def next(prompt, state, index): | ||
# A uniform distribution over our alphabet. | ||
logits = tf.ones((batch_size, vocab_size)) | ||
return logits, state | ||
output = keras_nlp.samplers.RandomSampler()( | ||
next=next, | ||
prompt=tf.fill((batch_size, length,), char_lookup['z']), | ||
index=5, | ||
) | ||
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) | ||
# >>> ['zzzzzcpnjqij'] | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed=None, | ||
): | ||
super().__init__() | ||
self.seed = seed | ||
|
||
def get_next_token(self, probabilities): | ||
# Sample the next token from the probability distribution. | ||
next_token_id = tf.random.categorical( | ||
tf.math.log(probabilities), 1, seed=self.seed, dtype="int32" | ||
) | ||
return tf.squeeze(next_token_id, axis=-1) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"seed": self.seed, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for Random sampler.""" | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from absl.testing import parameterized | ||
|
||
from keras_nlp.samplers.random_sampler import RandomSampler | ||
|
||
|
||
class RandomSamplerTest(tf.test.TestCase, parameterized.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
# Use a simple alphabet of lowercase characters to [0, 25]. | ||
self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} | ||
self.char_lookup = {v: k for k, v in self.int_lookup.items()} | ||
self.batch_size = 1 | ||
self.length = 12 | ||
self.vocab_size = len(self.int_lookup) | ||
|
||
def next(prompt, state, index): | ||
# Return a distribution favoring the next char in state. | ||
logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 | ||
return logits, state | ||
|
||
self.next = next | ||
self.sampler = RandomSampler() | ||
|
||
def join_as_string(self, x): | ||
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] | ||
|
||
def test_stateless_call(self): | ||
def next(prompt, state, index): | ||
# Return a distribution favoring the first token in the vocab. | ||
logits = np.zeros((self.batch_size, self.vocab_size)) | ||
logits[:, 0] = 1e9 | ||
return tf.constant(logits), state | ||
|
||
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) | ||
output = self.sampler( | ||
next=next, | ||
prompt=prompt, | ||
index=5, | ||
) | ||
self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) | ||
|
||
def test_stateful_call(self): | ||
state_chars = list("sequentially") | ||
state = tf.constant([[self.char_lookup[c] for c in state_chars]]) | ||
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) | ||
output = self.sampler( | ||
next=self.next, | ||
prompt=prompt, | ||
state=state, | ||
) | ||
self.assertEqual(self.join_as_string(output), ["sequentially"]) | ||
|
||
def test_early_stopping(self): | ||
state_chars = list("sequentially") | ||
state = tf.constant([[self.char_lookup[c] for c in state_chars]]) | ||
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) | ||
output = self.sampler( | ||
next=self.next, | ||
prompt=prompt, | ||
state=state, | ||
end_token_id=self.char_lookup["t"], | ||
) | ||
self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) | ||
|
||
@parameterized.named_parameters( | ||
("jit_compile_false", False), ("jit_compile_true", True) | ||
) | ||
def test_compilation(self, jit_compile): | ||
state_chars = list("sequentially") | ||
state = tf.constant([[self.char_lookup[c] for c in state_chars]]) | ||
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) | ||
|
||
@tf.function(jit_compile=jit_compile) | ||
def generate(prompt, state): | ||
return self.sampler(self.next, prompt=prompt, state=state) | ||
|
||
output = generate(prompt, state) | ||
self.assertEqual(self.join_as_string(output), ["sequentially"]) |