Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLaMA 3 tokenizer and preset #1584

Merged
merged 10 commits into from
May 17, 2024
Prev Previous commit
Next Next commit
Add tests for all the components
  • Loading branch information
tirthasheshpatel committed May 2, 2024
commit d407d9caeb1dbbd778de107f335b97817aff27c8
94 changes: 94 additions & 0 deletions keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Llama3CausalLMPreprocessor,
)
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.tests.test_case import TestCase


class Llama3CausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.tokenizer = Llama3Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = ["airplane at airport"]

def test_preprocessor_basics(self):
self.run_preprocessor_test(
cls=Llama3CausalLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[6, 1, 3, 4, 2, 5, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
[[1, 3, 4, 2, 5, 0, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 0, 0, 0]], # Pass through sample_weights.
),
)

def test_with_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = Llama3CausalLMPreprocessor(
**self.init_kwargs,
add_start_token=True,
add_end_token=True,
)
x, y, sw = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[6, 1, 3, 4, 2, 5, 7, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4)
self.assertAllEqual(y, [[1, 3, 4, 2, 5, 7, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)

def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])

def test_generate_postprocess(self):
input_data = {
"token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
}
preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "airplane at airport")

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Llama3CausalLMPreprocessor.presets:
self.run_preset_test(
cls=Llama3CausalLMPreprocessor,
preset=preset,
input_data=self.input_data,
)
130 changes: 130 additions & 0 deletions keras_nlp/src/models/llama3/llama3_causal_lm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

import pytest

from keras_nlp.src.backend import ops
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Llama3CausalLMPreprocessor,
)
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.tests.test_case import TestCase


class Llama3CausalLMTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.preprocessor = Llama3CausalLMPreprocessor(
Llama3Tokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=7,
)
self.backbone = Llama3Backbone(
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
hidden_dim=8,
intermediate_dim=16,
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
"backbone": self.backbone,
}
self.train_data = ([" airplane at airport", " airplane at airport"],)
self.input_data = self.preprocessor(*self.train_data)[0]

def test_causal_lm_basics(self):
self.run_task_test(
cls=Llama3CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 7, 8),
)

def test_generate(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
# String input.
prompt = " airplane at airport"
output = causal_lm.generate(" airplane at airport")
self.assertTrue(prompt in output)
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
prompt_ids["token_ids"][:, :5],
)
self.assertAllEqual(
outputs["padding_mask"][:, :5],
prompt_ids["padding_mask"][:, :5],
)

def test_early_stopping(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
call_with_cache = causal_lm.call_with_cache

def wrapper(*args, **kwargs):
"""Modify output logits to always favor end_token_id"""
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
index = self.preprocessor.tokenizer.end_token_id
update = ops.ones_like(logits)[:, :, index] * 1.0e9
update = ops.expand_dims(update, axis=-1)
logits = ops.slice_update(logits, (0, 0, index), update)
return logits, hidden_states, cache

with patch.object(causal_lm, "call_with_cache", wraps=wrapper):
prompt = [" airplane at airport", " airplane"]
output = causal_lm.generate(prompt)
# We should immediately abort and output the prompt.
self.assertEqual(prompt, output)

def test_generate_compilation(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
# Assert we do not recompile with successive calls.
causal_lm.generate(" airplane at airport")
first_fn = causal_lm.generate_function
causal_lm.generate(" airplane at airport")
second_fn = causal_lm.generate_function
self.assertEqual(first_fn, second_fn)
# Assert we do recompile after compile is called.
causal_lm.compile(sampler="greedy")
self.assertIsNone(causal_lm.generate_function)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=Llama3CausalLM,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Llama3CausalLM.presets:
self.run_preset_test(
cls=Llama3CausalLM,
preset=preset,
input_data=self.input_data,
)
12 changes: 0 additions & 12 deletions keras_nlp/src/models/llama3/llama3_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,3 @@
@keras_nlp_export("keras_nlp.models.Llama3Preprocessor")
class Llama3Preprocessor(LlamaPreprocessor):
tokenizer_cls = Llama3Tokenizer

def __init__(
self,
tokenizer,
sequence_length=1024,
add_start_token=False,
add_end_token=False,
**kwargs
):
super().__init__(
tokenizer, sequence_length, add_start_token, add_end_token, **kwargs
)
84 changes: 84 additions & 0 deletions keras_nlp/src/models/llama3/llama3_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.tests.test_case import TestCase


class Llama3PreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.tokenizer = Llama3Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = [
"airplane at airport",
]

def test_preprocessor_basics(self):
self.run_preprocessor_test(
cls=Llama3Preprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[7, 1, 3, 4, 2, 5, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
}
),
)

def test_with_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = Llama3Preprocessor(
tokenizer=Llama3Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
),
sequence_length=8,
add_start_token=True,
add_end_token=True,
)
x = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[7, 1, 3, 4, 2, 5, 6, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4)

def test_sequence_length_override(self):
input_data = "airplane at airport"
preprocessor = Llama3Preprocessor(**self.init_kwargs)
x = preprocessor(input_data, sequence_length=4)
self.assertAllEqual(x["token_ids"], [7, 1, 3, 4])

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Llama3Preprocessor.presets:
self.run_preset_test(
cls=Llama3Preprocessor,
preset=preset,
input_data=self.input_data,
)
63 changes: 63 additions & 0 deletions keras_nlp/src/models/llama3/llama3_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.tests.test_case import TestCase


class Llama3TokenizerTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
"<|begin_of_text|>airplane at airport<|end_of_text|>",
" airplane airport",
]

def test_tokenizer_basics(self):
self.run_preprocessing_layer_test(
cls=Llama3Tokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=[[7, 1, 3, 4, 2, 5, 6], [2, 3, 2, 5]],
)

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
Llama3Tokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"])

@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=Llama3Tokenizer,
preset="llama3_8b_en",
input_data=["The quick brown fox."],
expected_output=[[791, 4062, 14198, 39935, 13]],
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Llama3Tokenizer.presets:
self.run_preset_test(
cls=Llama3Tokenizer,
preset=preset,
input_data=self.input_data,
)
Loading