Skip to content

Commit

Permalink
Add causal lm preprocessor for the Llama backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel committed Mar 13, 2024
1 parent ba5913a commit 0738f3e
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 0 deletions.
185 changes: 185 additions & 0 deletions keras_nlp/models/llama/llama_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor")
class LlamaCausalLMPreprocessor(LlamaPreprocessor):
"""Llama Causal LM preprocessor.
This preprocessing layer is meant for use with
`keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence.
For use with generation, the layer also exposes two methods
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods
will be called implicitly in `generate()`. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).
Args:
tokenizer: A `keras_nlp.models.LlamaTokenizer` instance.
sequence_length: The length of the packed inputs.
add_start_token: If `True`, the preprocessor will prepend the tokenizer
start token to each input sequence. Default is `True`.
add_end_token: If `True`, the preprocessor will append the tokenizer
end token to each input sequence. Default is `False`.
Call arguments:
x: A string, `tf.Tensor` or list of python strings.
y: Label data. Should always be `None` as the layer generates labels.
sample_weight: Label weights. Should always be `None` as the layer
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset(
"llama_base_en"
)
# Tokenize and pack a single sentence.
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("League of legends")
# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["Taco tuesday", "Fish taco please!"])
# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def call(
self,
x,
y=None,
sample_weight=None,
sequence_length=None,
):
if y is not None or sample_weight is not None:
logging.warning(
"`LlamaCausalLMPreprocessor` generates `y` and "
"`sample_weight` based on your input data, but your data "
"already contains `y` or `sample_weight`. Your `y` and "
"`sample_weight` will be ignored."
)
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
"""Convert strings to integer token input for generation.
Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.
Unlike calling the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
if not self.built:
self.build(None)

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
"""Convert integer token output to strings for generation.
This method reverses `generate_preprocess()`, by first removing all
padding and start/end tokens, and then converting the integer sequence
back to a string.
"""
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# Convert the inputs to numpy arrays if they aren't a tensor already.
if not isinstance(token_ids, tf.Tensor):
token_ids = ops.convert_to_numpy(token_ids)
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = token_ids.astype("int32")
if not isinstance(padding_mask, tf.Tensor):
padding_mask = ops.convert_to_numpy(padding_mask)
padding_mask = padding_mask.astype("bool")
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
padding_mask = padding_mask & (
token_ids != self.tokenizer.start_token_id
)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
90 changes: 90 additions & 0 deletions keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

from keras_nlp.models.llama.llama_causal_lm_preprocessor import (
LlamaCausalLMPreprocessor,
)
from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer
from keras_nlp.tests.test_case import TestCase


class LlamaCausalLMPreprocessorTest(TestCase):
def setUp(self):
self.tokenizer = LlamaTokenizer(
# Generated using create_llama_test_proto.py
proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm")
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = (["the quick brown fox"],)

def test_preprocessor_basics(self):
self.run_preprocessor_test(
cls=LlamaCausalLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]],
},
[[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights.
),
)

def test_no_start_end_token(self):
input_data = ["the quick brown fox"] * 4

preprocessor = LlamaCausalLMPreprocessor(
**self.init_kwargs,
add_start_token=False,
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4)

def test_generate_preprocess(self):
input_data = "the quick brown fox"
preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0])

def test_generate_postprocess(self):
input_data = {
"token_ids": [1, 3, 8, 4, 6, 0, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0],
}
preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "the quick brown fox")

@pytest.mark.extra_large
def test_all_presets(self):
for preset in LlamaCausalLMPreprocessor.presets:
self.run_preset_test(
cls=LlamaCausalLMPreprocessor,
preset=preset,
input_data=self.input_data,
)

0 comments on commit 0738f3e

Please sign in to comment.