diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index d432e61d5d..766959c988 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -128,6 +128,12 @@ GPTNeoXPreprocessor, ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( + Llama3CausalLMPreprocessor, +) +from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/llama/llama_causal_lm.py b/keras_nlp/src/models/llama/llama_causal_lm.py index 7dc179347c..6ca3b3d77c 100644 --- a/keras_nlp/src/models/llama/llama_causal_lm.py +++ b/keras_nlp/src/models/llama/llama_causal_lm.py @@ -19,7 +19,6 @@ from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import ( LlamaCausalLMPreprocessor, ) -from keras_nlp.src.utils.python_utils import classproperty from keras_nlp.src.utils.tensor_utils import any_equal @@ -46,6 +45,9 @@ class LlamaCausalLM(CausalLM): should be preprocessed before calling the model. """ + backbone_cls = LlamaBackbone + preprocessor_cls = LlamaCausalLMPreprocessor + def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === self.backbone = backbone @@ -61,14 +63,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): **kwargs, ) - @classproperty - def backbone_cls(cls): - return LlamaBackbone - - @classproperty - def preprocessor_cls(cls): - return LlamaCausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/src/models/llama/llama_tokenizer.py b/keras_nlp/src/models/llama/llama_tokenizer.py index 74b2e89987..c8147c3836 100644 --- a/keras_nlp/src/models/llama/llama_tokenizer.py +++ b/keras_nlp/src/models/llama/llama_tokenizer.py @@ -11,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.llama.llama_presets import backbone_presets from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) -from keras_nlp.src.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.LlamaTokenizer") @@ -85,7 +82,3 @@ def set_proto(self, proto): self.start_token_id = None self.end_token_id = None self.pad_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/src/models/llama3/__init__.py b/keras_nlp/src/models/llama3/__init__.py new file mode 100644 index 0000000000..cc2764e7a7 --- /dev/null +++ b/keras_nlp/src/models/llama3/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_nlp.src.models.llama3.llama3_presets import backbone_presets +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_nlp.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, (Llama3Backbone, Llama3Tokenizer)) diff --git a/keras_nlp/src/models/llama3/llama3_backbone.py b/keras_nlp/src/models/llama3/llama3_backbone.py new file mode 100644 index 0000000000..90b3a42a3e --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_backbone.py @@ -0,0 +1,84 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone + + +# LLaMA 3 shares the same architecture as its predecessors +# So, we simply create an alias for API consistency +@keras_nlp_export("keras_nlp.models.Llama3Backbone") +class Llama3Backbone(LlamaBackbone): + """ + The Llama Transformer core architecture with hyperparameters. + + This network implements a Transformer-based decoder network, + Llama, as described in + ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + Llama model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling layers. + intermediate_dim (int): The output dimension of the first Dense layer in a + three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads for + each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for calculation + of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer normalization + layers in the transformer decoder. Defaults to `1e-6`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Llama decoder. + model = keras_nlp.models.Llama3Backbone.from_preset("llama3_8b_en") + model(input_data) + + # Randomly initialized Llama decoder with custom config. + model = keras_nlp.models.Llama3Backbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` + """ + + pass diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm.py b/keras_nlp/src/models/llama3/llama3_causal_lm.py new file mode 100644 index 0000000000..be61351ab2 --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_causal_lm.py @@ -0,0 +1,44 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( + Llama3CausalLMPreprocessor, +) +from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM + + +class Llama3CausalLM(LlamaCausalLM): + """An end-to-end Llama 3 model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a LLaMA 3 model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.Llama3Backbone` instance. + preprocessor: A `keras_nlp.models.Llama3CausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + backbone_cls = Llama3Backbone + preprocessor_cls = Llama3CausalLMPreprocessor diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py new file mode 100644 index 0000000000..b9ed44b527 --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.backend import ops +from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.Llama3CausalLMPreprocessor") +class Llama3CausalLMPreprocessor(Llama3Preprocessor): + """Llama 3 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.Llama3CausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.Llama3CausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.Llama3Tokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `False`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.Llama3CausalLMPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`Llama3CausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..2b79bd0d4f --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( + Llama3CausalLMPreprocessor, +) +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Llama3CausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = Llama3Tokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Llama3CausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[6, 1, 3, 4, 2, 5, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[1, 3, 4, 2, 5, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_with_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = Llama3CausalLMPreprocessor( + **self.init_kwargs, + add_start_token=True, + add_end_token=True, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[6, 1, 3, 4, 2, 5, 7, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) + self.assertAllEqual(y, [[1, 3, 4, 2, 5, 7, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [6, 1, 3, 4, 2, 5, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Llama3CausalLMPreprocessor.presets: + self.run_preset_test( + cls=Llama3CausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_test.py b/keras_nlp/src/models/llama3/llama3_causal_lm_test.py new file mode 100644 index 0000000000..d513c9390f --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_causal_lm_test.py @@ -0,0 +1,130 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from keras_nlp.src.backend import ops +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM +from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( + Llama3CausalLMPreprocessor, +) +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Llama3CausalLMTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.preprocessor = Llama3CausalLMPreprocessor( + Llama3Tokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=7, + ) + self.backbone = Llama3Backbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ([" airplane at airport", " airplane at airport"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Llama3CausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 7, 8), + ) + + def test_generate(self): + causal_lm = Llama3CausalLM(**self.init_kwargs) + # String input. + prompt = " airplane at airport" + output = causal_lm.generate(" airplane at airport") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = Llama3CausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = [" airplane at airport", " airplane"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = Llama3CausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate(" airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate(" airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Llama3CausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Llama3CausalLM.presets: + self.run_preset_test( + cls=Llama3CausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/llama3/llama3_preprocessor.py b/keras_nlp/src/models/llama3/llama3_preprocessor.py new file mode 100644 index 0000000000..cd005f0388 --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_preprocessor.py @@ -0,0 +1,21 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_nlp.src.models.llama.llama_preprocessor import LlamaPreprocessor + + +@keras_nlp_export("keras_nlp.models.Llama3Preprocessor") +class Llama3Preprocessor(LlamaPreprocessor): + tokenizer_cls = Llama3Tokenizer diff --git a/keras_nlp/src/models/llama3/llama3_preprocessor_test.py b/keras_nlp/src/models/llama3/llama3_preprocessor_test.py new file mode 100644 index 0000000000..ffbbb9c1da --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_preprocessor_test.py @@ -0,0 +1,84 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Llama3PreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = Llama3Tokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = [ + "airplane at airport", + ] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Llama3Preprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[7, 1, 3, 4, 2, 5, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + } + ), + ) + + def test_with_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = Llama3Preprocessor( + tokenizer=Llama3Tokenizer( + vocabulary=self.vocab, + merges=self.merges, + ), + sequence_length=8, + add_start_token=True, + add_end_token=True, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[7, 1, 3, 4, 2, 5, 6, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "airplane at airport" + preprocessor = Llama3Preprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [7, 1, 3, 4]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Llama3Preprocessor.presets: + self.run_preset_test( + cls=Llama3Preprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/llama3/llama3_presets.py b/keras_nlp/src/models/llama3/llama3_presets.py new file mode 100644 index 0000000000..10f7fed1a4 --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_presets.py @@ -0,0 +1,38 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama 3 model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "llama3_8b_en": { + "metadata": { + "description": "LLaMA 3 8B Base model", + "params": 8030261248, + "official_name": "LLaMA 3", + "path": "llama3", + "model_card": "https://github.com/meta-llama/llama3", + }, + "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/2", + }, + "llama3_instruct_8b_en": { + "metadata": { + "description": "LLaMA 3 8B Instruct model", + "params": 8030261248, + "official_name": "LLaMA 3", + "path": "llama3", + "model_card": "https://github.com/meta-llama/llama3", + }, + "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/2", + }, +} diff --git a/keras_nlp/src/models/llama3/llama3_tokenizer.py b/keras_nlp/src/models/llama3/llama3_tokenizer.py new file mode 100644 index 0000000000..4c841a6d42 --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_tokenizer.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_nlp_export("keras_nlp.models.Llama3Tokenizer") +class Llama3Tokenizer(BytePairTokenizer): + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + self.start_token = "<|begin_of_text|>" + self.end_token = "<|end_of_text|>" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[self.start_token, self.end_token], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + if self.end_token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{self.end_token}'` in the provided " + f"`vocabulary`. Please provide `'{self.end_token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = 0 + else: + self.end_token_id = None + self.start_token_id = None + self.pad_token_id = None + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/src/models/llama3/llama3_tokenizer_test.py b/keras_nlp/src/models/llama3/llama3_tokenizer_test.py new file mode 100644 index 0000000000..5ea6c193be --- /dev/null +++ b/keras_nlp/src/models/llama3/llama3_tokenizer_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Llama3TokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + "<|begin_of_text|>airplane at airport<|end_of_text|>", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=Llama3Tokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[7, 1, 3, 4, 2, 5, 6], [2, 3, 2, 5]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + Llama3Tokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=Llama3Tokenizer, + preset="llama3_8b_en", + input_data=["The quick brown fox."], + expected_output=[[791, 4062, 14198, 39935, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Llama3Tokenizer.presets: + self.run_preset_test( + cls=Llama3Tokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/tools/checkpoint_conversion/convert_llama3_checkpoints.py b/tools/checkpoint_conversion/convert_llama3_checkpoints.py new file mode 100644 index 0000000000..c4bfde5fbe --- /dev/null +++ b/tools/checkpoint_conversion/convert_llama3_checkpoints.py @@ -0,0 +1,299 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import traceback + +import numpy as np +import torch +from absl import app +from absl import flags +from huggingface_hub import hf_hub_download +from keras import ops +from transformers import AutoTokenizer +from transformers import LlamaForCausalLM + +from keras_nlp import upload_preset +from keras_nlp.models import Llama3Backbone +from keras_nlp.models import Llama3CausalLMPreprocessor +from keras_nlp.models import Llama3Tokenizer + +PRESET_MAP = { + "llama3_8b_en": "meta-llama/Meta-Llama-3-8B", + "llama3_instruct_8b_en": "meta-llama/Meta-Llama-3-8B-Instruct", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) + + +def convert_checkpoints(keras_nlp_model, hf_model): + config = hf_model.config + + keras_nlp_model.token_embedding.embeddings.assign( + hf_model.model.embed_tokens.weight.detach().cpu().float().numpy() + ) + + for i in range(keras_nlp_model.num_layers): + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._key_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.k_proj.weight.T.reshape( + config.hidden_size, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + ) + .detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._query_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.q_proj.weight.T.reshape( + config.hidden_size, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ) + .detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._value_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.v_proj.weight.T.reshape( + config.hidden_size, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + ) + .detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._output_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.o_proj.weight.T.reshape( + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + config.hidden_size, + ) + .detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layernorm.set_weights( + [ + hf_model.model.layers[i] + .input_layernorm.weight.detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._feedforward_intermediate_dense.set_weights( + [ + hf_model.model.layers[i] + .mlp.up_proj.weight.T.detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._feedforward_output_dense.set_weights( + [ + hf_model.model.layers[i] + .mlp.down_proj.weight.T.detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._feedforward_gate_dense.set_weights( + [ + hf_model.model.layers[i] + .mlp.gate_proj.weight.T.detach() + .cpu() + .float() + .numpy() + ] + ) + keras_nlp_model.transformer_layers[ + i + ]._feedforward_layernorm.set_weights( + [ + hf_model.model.layers[i] + .post_attention_layernorm.weight.detach() + .cpu() + .float() + .numpy() + ] + ) + + keras_nlp_model.layer_norm.set_weights( + [hf_model.model.norm.weight.detach().cpu().float().numpy()] + ) + keras_nlp_model.token_embedding.reverse_embeddings.assign( + hf_model.lm_head.weight.T.detach().cpu().float().numpy() + ) + + +def test_model( + keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_nlp_params = keras_nlp_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_nlp_params == hf_params + + # Test the outputs of both the models + hf_outputs = hf_model( + **hf_model_tokenizer(["What is Keras?"], return_tensors="pt") + ) + hf_output_logits = hf_outputs.logits.detach().cpu().numpy() + + keras_nlp_preprocessor = Llama3CausalLMPreprocessor(keras_nlp_tokenizer) + keras_nlp_output = keras_nlp_model( + keras_nlp_preprocessor(["What is Keras?"], sequence_length=5)[0] + ) + keras_nlp_logits = keras_nlp_model.token_embedding( + keras_nlp_output, reverse=True + ) + keras_nlp_logits = ops.convert_to_numpy(keras_nlp_logits) + + # High tolerence since bfloat16 is used as the default dtype for Llama + try: + np.testing.assert_allclose( + keras_nlp_logits, hf_output_logits, atol=1e-4 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_nlp_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_nlp_preprocessor = Llama3CausalLMPreprocessor(keras_nlp_tokenizer) + keras_nlp_output = keras_nlp_preprocessor( + ["What is Keras?"], sequence_length=5 + ) + keras_nlp_output = ops.convert_to_numpy(keras_nlp_output[0]["token_ids"]) + + np.testing.assert_equal(keras_nlp_output, hf_output) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + hf_model = LlamaForCausalLM.from_pretrained( + hf_preset, torch_dtype=torch.bfloat16 + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") + + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="bfloat16", + ) + keras_nlp_model = Llama3Backbone(**backbone_kwargs) + + # === Get the tokenizer from the Huggingface model === + tokenizer_path = hf_hub_download( + "meta-llama/Meta-Llama-3-8B", "tokenizer.json", token=True + ) + with open(tokenizer_path, "r") as tokenizer_file: + tokenizer_content = json.load(tokenizer_file) + vocabulary = hf_tokenizer.vocab + merges = tokenizer_content["model"]["merges"] + keras_nlp_tokenizer = Llama3Tokenizer(vocabulary, merges) + print("\n-> Keras 3 model and tokenizer loaded.") + + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + + keras_nlp_model.save_to_preset(preset) + print("\n-> Saved the model preset in float16") + + # === Save the tokenizer === + keras_nlp_tokenizer.save_to_preset(preset) + print("\n-> Saved the tokenizer") + + # === Upload the preset === + try: + uri = f"kaggle://keras/llama3/keras/{preset}" + upload_preset(uri, preset) + print("-> Uploaded the preset!") + except Exception: + print( + "-> Failed to upload the preset. Make sure you have the " + "correct premissions to upload and/or the page " + "you are pushing to exists." + ) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py index 4e127b2c7d..27be78901b 100644 --- a/tools/checkpoint_conversion/convert_llama_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -11,23 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import gc -import os -import shutil -import tempfile import traceback import numpy as np +import torch from absl import app from absl import flags from keras import ops from transformers import AutoTokenizer from transformers import LlamaForCausalLM +from keras_nlp import upload_preset from keras_nlp.models import LlamaBackbone from keras_nlp.models import LlamaCausalLMPreprocessor from keras_nlp.models import LlamaTokenizer -from keras_nlp.utils.preset_utils import save_to_preset PRESET_MAP = { "llama2_7b_en": "meta-llama/Llama-2-7b-hf", @@ -224,65 +221,53 @@ def main(_): preset = FLAGS.preset hf_preset = PRESET_MAP[preset] - # === Create the temporary save directories === - temp_dir = tempfile.mkdtemp() - - try: - # === Load the Huggingface model === - hf_model = LlamaForCausalLM.from_pretrained(hf_preset) - hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) - hf_model.eval() - print("\n-> Huggingface model and tokenizer loaded") - - # === Load the KerasNLP model === - backbone_kwargs = dict( - vocabulary_size=hf_model.config.vocab_size, - hidden_dim=hf_model.config.hidden_size, - num_layers=hf_model.config.num_hidden_layers, - num_query_heads=hf_model.config.num_attention_heads, - num_key_value_heads=hf_model.config.num_key_value_heads, - intermediate_dim=hf_model.config.intermediate_size, - layer_norm_epsilon=hf_model.config.rms_norm_eps, - rope_max_wavelength=hf_model.config.rope_theta, - dtype="float32", - ) - keras_nlp_model = LlamaBackbone(**backbone_kwargs) + # === Load the Huggingface model === + hf_model = LlamaForCausalLM.from_pretrained( + hf_preset, torch_dtype=torch.bfloat16 + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") - # === Get the tokenizer from the Huggingface model === - tokenizer_path = hf_tokenizer.vocab_file - keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path) - print("\n-> Keras 3 model and tokenizer loaded.") + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="bfloat16", + ) + keras_nlp_model = LlamaBackbone(**backbone_kwargs) - # === Port the weights === - convert_checkpoints(keras_nlp_model, hf_model) - print("\n-> Weight transfer done.") + # === Get the tokenizer from the Huggingface model === + tokenizer_path = hf_tokenizer.vocab_file + keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") - # === Check that the models and tokenizers outputs match === - test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) - test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) - print("\n-> Tests passed!") + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") - # === Save the model weights in float32 format === - keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) - print("\n-> Saved the model weights in float32") + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") - del keras_nlp_model, hf_model - gc.collect() + keras_nlp_model.save_to_preset(keras_nlp_model, preset) + print("\n-> Saved the model preset in float16") - # === Save the weights again in float16 === - backbone_kwargs["dtype"] = "float16" - keras_nlp_model = LlamaBackbone(**backbone_kwargs) - keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) - save_to_preset(keras_nlp_model, preset) - print("\n-> Saved the model preset in float16") + # === Save the tokenizer === + keras_nlp_tokenizer.save_to_preset(preset) + print("\n-> Saved the tokenizer") - # === Save the tokenizer === - save_to_preset( - keras_nlp_tokenizer, preset, config_filename="tokenizer.json" - ) - print("\n-> Saved the tokenizer") - finally: - shutil.rmtree(temp_dir) + # === Upload the preset === + uri = f"kaggle://keras/llama2/keras/{preset}" + upload_preset(uri, preset) + print("-> Uploaded the preset!") if __name__ == "__main__":