-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a preprocessor for the Llama backbone * Add causal lm preprocessor for the Llama backbone
- Loading branch information
1 parent
a59a26f
commit e81daa0
Showing
4 changed files
with
523 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import tensorflow as tf | ||
from absl import logging | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.backend import ops | ||
from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor | ||
from keras_nlp.utils.keras_utils import ( | ||
convert_inputs_to_list_of_tensor_segments, | ||
) | ||
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") | ||
class LlamaCausalLMPreprocessor(LlamaPreprocessor): | ||
"""Llama Causal LM preprocessor. | ||
This preprocessing layer is meant for use with | ||
`keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of | ||
strings, and return outputs in a `(x, y, sample_weight)` format, where the | ||
`y` label is the next token id in the `x` sequence. | ||
For use with generation, the layer also exposes two methods | ||
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor | ||
is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods | ||
will be called implicitly in `generate()`. They can also be called | ||
standalone (e.g. to precompute preprocessing inputs for generation in a | ||
separate process). | ||
Args: | ||
tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. | ||
sequence_length: The length of the packed inputs. | ||
add_start_token: If `True`, the preprocessor will prepend the tokenizer | ||
start token to each input sequence. Default is `True`. | ||
add_end_token: If `True`, the preprocessor will append the tokenizer | ||
end token to each input sequence. Default is `False`. | ||
Call arguments: | ||
x: A string, `tf.Tensor` or list of python strings. | ||
y: Label data. Should always be `None` as the layer generates labels. | ||
sample_weight: Label weights. Should always be `None` as the layer | ||
generates label weights. | ||
sequence_length: Pass to override the configured `sequence_length` of | ||
the layer. | ||
Examples: | ||
```python | ||
# Load the preprocessor from a preset. | ||
preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( | ||
"llama_base_en" | ||
) | ||
# Tokenize and pack a single sentence. | ||
sentence = tf.constant("League of legends") | ||
preprocessor(sentence) | ||
# Same output. | ||
preprocessor("League of legends") | ||
# Tokenize a batch of sentences. | ||
sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) | ||
preprocessor(sentences) | ||
# Same output. | ||
preprocessor(["Taco tuesday", "Fish taco please!"]) | ||
# Map a dataset to preprocess a single sentence. | ||
features = tf.constant( | ||
[ | ||
"Avatar 2 is amazing!", | ||
"Well, I am not sure.", | ||
] | ||
) | ||
labels = tf.constant([1, 0]) | ||
ds = tf.data.Dataset.from_tensor_slices((features, labels)) | ||
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
# Map a dataset to preprocess unlabled sentences. | ||
ds = tf.data.Dataset.from_tensor_slices(features) | ||
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
``` | ||
""" | ||
|
||
def call( | ||
self, | ||
x, | ||
y=None, | ||
sample_weight=None, | ||
sequence_length=None, | ||
): | ||
if y is not None or sample_weight is not None: | ||
logging.warning( | ||
"`LlamaCausalLMPreprocessor` generates `y` and " | ||
"`sample_weight` based on your input data, but your data " | ||
"already contains `y` or `sample_weight`. Your `y` and " | ||
"`sample_weight` will be ignored." | ||
) | ||
sequence_length = sequence_length or self.sequence_length | ||
|
||
x = convert_inputs_to_list_of_tensor_segments(x)[0] | ||
x = self.tokenizer(x) | ||
# Pad with one extra token to account for the truncation below. | ||
token_ids, padding_mask = self.packer( | ||
x, | ||
sequence_length=sequence_length + 1, | ||
add_start_value=self.add_start_token, | ||
add_end_value=self.add_end_token, | ||
) | ||
# The last token does not have a next token, so we truncate it out. | ||
x = { | ||
"token_ids": token_ids[..., :-1], | ||
"padding_mask": padding_mask[..., :-1], | ||
} | ||
# Target `y` will be the next token. | ||
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] | ||
return pack_x_y_sample_weight(x, y, sample_weight) | ||
|
||
def generate_preprocess( | ||
self, | ||
x, | ||
sequence_length=None, | ||
): | ||
"""Convert strings to integer token input for generation. | ||
Similar to calling the layer for training, this method takes in strings | ||
or tensor strings, tokenizes and packs the input, and computes a padding | ||
mask masking all inputs not filled in with a padded value. | ||
Unlike calling the layer for training, this method does not compute | ||
labels and will never append a `tokenizer.end_token_id` to the end of | ||
the sequence (as generation is expected to continue at the end of the | ||
inputted prompt). | ||
""" | ||
if not self.built: | ||
self.build(None) | ||
|
||
x = convert_inputs_to_list_of_tensor_segments(x)[0] | ||
x = self.tokenizer(x) | ||
token_ids, padding_mask = self.packer( | ||
x, sequence_length=sequence_length, add_end_value=False | ||
) | ||
return { | ||
"token_ids": token_ids, | ||
"padding_mask": padding_mask, | ||
} | ||
|
||
def generate_postprocess( | ||
self, | ||
x, | ||
): | ||
"""Convert integer token output to strings for generation. | ||
This method reverses `generate_preprocess()`, by first removing all | ||
padding and start/end tokens, and then converting the integer sequence | ||
back to a string. | ||
""" | ||
token_ids, padding_mask = x["token_ids"], x["padding_mask"] | ||
# Convert the inputs to numpy arrays if they aren't a tensor already. | ||
if not isinstance(token_ids, tf.Tensor): | ||
token_ids = ops.convert_to_numpy(token_ids) | ||
# Make sure the numpy array has type `int32` since | ||
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays. | ||
token_ids = token_ids.astype("int32") | ||
if not isinstance(padding_mask, tf.Tensor): | ||
padding_mask = ops.convert_to_numpy(padding_mask) | ||
padding_mask = padding_mask.astype("bool") | ||
# Strip any special tokens during detokenization (e.g. the start and | ||
# end markers). In the future we could make this configurable. | ||
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) | ||
padding_mask = padding_mask & ( | ||
token_ids != self.tokenizer.start_token_id | ||
) | ||
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) | ||
return self.tokenizer.detokenize(token_ids) |
90 changes: 90 additions & 0 deletions
90
keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import pytest | ||
|
||
from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( | ||
LlamaCausalLMPreprocessor, | ||
) | ||
from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer | ||
from keras_nlp.tests.test_case import TestCase | ||
|
||
|
||
class LlamaCausalLMPreprocessorTest(TestCase): | ||
def setUp(self): | ||
self.tokenizer = LlamaTokenizer( | ||
# Generated using create_llama_test_proto.py | ||
proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") | ||
) | ||
self.init_kwargs = { | ||
"tokenizer": self.tokenizer, | ||
"sequence_length": 8, | ||
} | ||
self.input_data = (["the quick brown fox"],) | ||
|
||
def test_preprocessor_basics(self): | ||
self.run_preprocessor_test( | ||
cls=LlamaCausalLMPreprocessor, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output=( | ||
{ | ||
"token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], | ||
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], | ||
}, | ||
[[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. | ||
[[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. | ||
), | ||
) | ||
|
||
def test_no_start_end_token(self): | ||
input_data = ["the quick brown fox"] * 4 | ||
|
||
preprocessor = LlamaCausalLMPreprocessor( | ||
**self.init_kwargs, | ||
add_start_token=False, | ||
add_end_token=False, | ||
) | ||
x, y, sw = preprocessor(input_data) | ||
self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) | ||
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) | ||
self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) | ||
self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) | ||
|
||
def test_generate_preprocess(self): | ||
input_data = "the quick brown fox" | ||
preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) | ||
x = preprocessor.generate_preprocess(input_data) | ||
self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) | ||
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) | ||
|
||
def test_generate_postprocess(self): | ||
input_data = { | ||
"token_ids": [1, 3, 8, 4, 6, 0, 0, 0], | ||
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], | ||
} | ||
preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) | ||
x = preprocessor.generate_postprocess(input_data) | ||
self.assertAllEqual(x, "the quick brown fox") | ||
|
||
@pytest.mark.extra_large | ||
def test_all_presets(self): | ||
for preset in LlamaCausalLMPreprocessor.presets: | ||
self.run_preset_test( | ||
cls=LlamaCausalLMPreprocessor, | ||
preset=preset, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.