From 294304b94f8ab9355beeeae6965e26f6c4fc9286 Mon Sep 17 00:00:00 2001
From: Tirth Patel <tirthasheshpatel@gmail.com>
Date: Fri, 17 May 2024 14:39:50 -0700
Subject: [PATCH] Add LLaMA 3 tokenizer and preset (#1584)

* Add LLaMA 3 tokenizer and preset

* Add a LLaMA 3 backbone and correct presets

* Add docs for LLaMA 3 backbone

[skip ci]

* Fix lint failures

* Fix the checkpointing scripts

* Add tests for all the components

* Run shell/api_gen.sh

* Address review comments; run api_gen.sh
---
 keras_nlp/api/models/__init__.py              |   6 +
 keras_nlp/src/models/llama/llama_causal_lm.py |  12 +-
 keras_nlp/src/models/llama/llama_tokenizer.py |   7 -
 keras_nlp/src/models/llama3/__init__.py       |  20 ++
 .../src/models/llama3/llama3_backbone.py      |  84 +++++
 .../src/models/llama3/llama3_causal_lm.py     |  44 +++
 .../llama3/llama3_causal_lm_preprocessor.py   | 185 +++++++++++
 .../llama3_causal_lm_preprocessor_test.py     |  94 ++++++
 .../models/llama3/llama3_causal_lm_test.py    | 130 ++++++++
 .../src/models/llama3/llama3_preprocessor.py  |  21 ++
 .../models/llama3/llama3_preprocessor_test.py |  84 +++++
 keras_nlp/src/models/llama3/llama3_presets.py |  38 +++
 .../src/models/llama3/llama3_tokenizer.py     |  63 ++++
 .../models/llama3/llama3_tokenizer_test.py    |  63 ++++
 .../convert_llama3_checkpoints.py             | 299 ++++++++++++++++++
 .../convert_llama_checkpoints.py              |  99 +++---
 16 files changed, 1176 insertions(+), 73 deletions(-)
 create mode 100644 keras_nlp/src/models/llama3/__init__.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_backbone.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_causal_lm.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_causal_lm_test.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_preprocessor.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_preprocessor_test.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_presets.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_tokenizer.py
 create mode 100644 keras_nlp/src/models/llama3/llama3_tokenizer_test.py
 create mode 100644 tools/checkpoint_conversion/convert_llama3_checkpoints.py

diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py
index d432e61d5d..766959c988 100644
--- a/keras_nlp/api/models/__init__.py
+++ b/keras_nlp/api/models/__init__.py
@@ -128,6 +128,12 @@
     GPTNeoXPreprocessor,
 )
 from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
+from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
+from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
+    Llama3CausalLMPreprocessor,
+)
+from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
 from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone
 from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM
 from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import (
diff --git a/keras_nlp/src/models/llama/llama_causal_lm.py b/keras_nlp/src/models/llama/llama_causal_lm.py
index 7dc179347c..6ca3b3d77c 100644
--- a/keras_nlp/src/models/llama/llama_causal_lm.py
+++ b/keras_nlp/src/models/llama/llama_causal_lm.py
@@ -19,7 +19,6 @@
 from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import (
     LlamaCausalLMPreprocessor,
 )
-from keras_nlp.src.utils.python_utils import classproperty
 from keras_nlp.src.utils.tensor_utils import any_equal
 
 
@@ -46,6 +45,9 @@ class LlamaCausalLM(CausalLM):
             should be preprocessed before calling the model.
     """
 
+    backbone_cls = LlamaBackbone
+    preprocessor_cls = LlamaCausalLMPreprocessor
+
     def __init__(self, backbone, preprocessor=None, **kwargs):
         # === Layers ===
         self.backbone = backbone
@@ -61,14 +63,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
             **kwargs,
         )
 
-    @classproperty
-    def backbone_cls(cls):
-        return LlamaBackbone
-
-    @classproperty
-    def preprocessor_cls(cls):
-        return LlamaCausalLMPreprocessor
-
     def call_with_cache(
         self,
         token_ids,
diff --git a/keras_nlp/src/models/llama/llama_tokenizer.py b/keras_nlp/src/models/llama/llama_tokenizer.py
index 74b2e89987..c8147c3836 100644
--- a/keras_nlp/src/models/llama/llama_tokenizer.py
+++ b/keras_nlp/src/models/llama/llama_tokenizer.py
@@ -11,14 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import copy
 
 from keras_nlp.src.api_export import keras_nlp_export
-from keras_nlp.src.models.llama.llama_presets import backbone_presets
 from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
     SentencePieceTokenizer,
 )
-from keras_nlp.src.utils.python_utils import classproperty
 
 
 @keras_nlp_export("keras_nlp.models.LlamaTokenizer")
@@ -85,7 +82,3 @@ def set_proto(self, proto):
             self.start_token_id = None
             self.end_token_id = None
             self.pad_token_id = None
-
-    @classproperty
-    def presets(cls):
-        return copy.deepcopy(backbone_presets)
diff --git a/keras_nlp/src/models/llama3/__init__.py b/keras_nlp/src/models/llama3/__init__.py
new file mode 100644
index 0000000000..cc2764e7a7
--- /dev/null
+++ b/keras_nlp/src/models/llama3/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
+from keras_nlp.src.models.llama3.llama3_presets import backbone_presets
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
+from keras_nlp.src.utils.preset_utils import register_presets
+
+register_presets(backbone_presets, (Llama3Backbone, Llama3Tokenizer))
diff --git a/keras_nlp/src/models/llama3/llama3_backbone.py b/keras_nlp/src/models/llama3/llama3_backbone.py
new file mode 100644
index 0000000000..90b3a42a3e
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_backbone.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from keras_nlp.src.api_export import keras_nlp_export
+from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone
+
+
+# LLaMA 3 shares the same architecture as its predecessors
+# So, we simply create an alias for API consistency
+@keras_nlp_export("keras_nlp.models.Llama3Backbone")
+class Llama3Backbone(LlamaBackbone):
+    """
+    The Llama Transformer core architecture with hyperparameters.
+
+    This network implements a Transformer-based decoder network,
+    Llama, as described in
+    ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf).
+    It includes the embedding lookups and transformer layers.
+
+    The default constructor gives a fully customizable, randomly initialized
+    Llama model with any number of layers, heads, and embedding
+    dimensions. To load preset architectures and weights, use the `from_preset`
+    constructor.
+
+    Args:
+        vocabulary_size (int): The size of the token vocabulary.
+        num_layers (int): The number of transformer layers.
+        num_query_heads (int): The number of query attention heads for
+            each transformer.
+        hidden_dim (int): The size of the transformer encoding and pooling layers.
+        intermediate_dim (int): The output dimension of the first Dense layer in a
+            three-layer feedforward network for each transformer.
+        num_key_value_heads (int): The number of key and value attention heads for
+            each transformer.
+        rope_max_wavelength (int, optional): The maximum angular wavelength of the
+            sine/cosine curves, for rotary embeddings. Defaults to `10000`.
+        rope_scaling_factor (float, optional): The scaling factor for calculation
+            of roatary embedding. Defaults to `1.0`.
+        layer_norm_epsilon (float, optional): Epsilon for the layer normalization
+            layers in the transformer decoder. Defaults to `1e-6`.
+        dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
+            for model computations and weights. Note that some computations,
+            such as softmax and layer normalization, will always be done at
+            float32 precision regardless of dtype.
+
+    Examples:
+
+    ```python
+    input_data = {
+        "token_ids": np.ones(shape=(1, 12), dtype="int32"),
+        "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
+    }
+
+    # Pretrained Llama decoder.
+    model = keras_nlp.models.Llama3Backbone.from_preset("llama3_8b_en")
+    model(input_data)
+
+    # Randomly initialized Llama decoder with custom config.
+    model = keras_nlp.models.Llama3Backbone(
+        vocabulary_size=10,
+        hidden_dim=512,
+        num_layers=2,
+        num_query_heads=32,
+        num_key_value_heads=8,
+        intermediate_dim=1024,
+        layer_norm_epsilon=1e-6,
+        dtype="float32"
+    )
+    model(input_data)
+    ```
+    """
+
+    pass
diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm.py b/keras_nlp/src/models/llama3/llama3_causal_lm.py
new file mode 100644
index 0000000000..be61351ab2
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_causal_lm.py
@@ -0,0 +1,44 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
+from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
+    Llama3CausalLMPreprocessor,
+)
+from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM
+
+
+class Llama3CausalLM(LlamaCausalLM):
+    """An end-to-end Llama 3 model for causal language modeling.
+
+    A causal language model (LM) predicts the next token based on previous
+    tokens. This task setup can be used to train the model unsupervised on
+    plain text input, or to autoregressively generate plain text similar to
+    the data used for training. This task can be used for pre-training or
+    fine-tuning a LLaMA 3 model, simply by calling `fit()`.
+
+    This model has a `generate()` method, which generates text based on a
+    prompt. The generation strategy used is controlled by an additional
+    `sampler` argument on `compile()`. You can recompile the model with
+    different `keras_nlp.samplers` objects to control the generation. By
+    default, `"top_k"` sampling will be used.
+
+    Args:
+        backbone: A `keras_nlp.models.Llama3Backbone` instance.
+        preprocessor: A `keras_nlp.models.Llama3CausalLMPreprocessor` or `None`.
+            If `None`, this model will not apply preprocessing, and inputs
+            should be preprocessed before calling the model.
+    """
+
+    backbone_cls = Llama3Backbone
+    preprocessor_cls = Llama3CausalLMPreprocessor
diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py
new file mode 100644
index 0000000000..b9ed44b527
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py
@@ -0,0 +1,185 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+from absl import logging
+
+from keras_nlp.src.api_export import keras_nlp_export
+from keras_nlp.src.backend import ops
+from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
+from keras_nlp.src.utils.keras_utils import (
+    convert_inputs_to_list_of_tensor_segments,
+)
+from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight
+
+
+@keras_nlp_export("keras_nlp.models.Llama3CausalLMPreprocessor")
+class Llama3CausalLMPreprocessor(Llama3Preprocessor):
+    """Llama 3 Causal LM preprocessor.
+
+    This preprocessing layer is meant for use with
+    `keras_nlp.models.Llama3CausalLM`. By default, it will take in batches of
+    strings, and return outputs in a `(x, y, sample_weight)` format, where the
+    `y` label is the next token id in the `x` sequence.
+
+    For use with generation, the layer also exposes two methods
+    `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
+    is attached to a `keras_nlp.models.Llama3CausalLM` instance, these methods
+    will be called implicitly in `generate()`. They can also be called
+    standalone (e.g. to precompute preprocessing inputs for generation in a
+    separate process).
+
+    Args:
+        tokenizer: A `keras_nlp.models.Llama3Tokenizer` instance.
+        sequence_length: The length of the packed inputs.
+        add_start_token: If `True`, the preprocessor will prepend the tokenizer
+            start token to each input sequence. Default is `False`.
+        add_end_token: If `True`, the preprocessor will append the tokenizer
+            end token to each input sequence. Default is `False`.
+
+    Call arguments:
+        x: A string, `tf.Tensor` or list of python strings.
+        y: Label data. Should always be `None` as the layer generates labels.
+        sample_weight: Label weights. Should always be `None` as the layer
+            generates label weights.
+        sequence_length: Pass to override the configured `sequence_length` of
+            the layer.
+
+    Examples:
+    ```python
+    # Load the preprocessor from a preset.
+    preprocessor = keras_nlp.models.Llama3CausalLMPreprocessor.from_preset(
+        "llama_base_en"
+    )
+
+    # Tokenize and pack a single sentence.
+    sentence = tf.constant("League of legends")
+    preprocessor(sentence)
+    # Same output.
+    preprocessor("League of legends")
+
+    # Tokenize a batch of sentences.
+    sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
+    preprocessor(sentences)
+    # Same output.
+    preprocessor(["Taco tuesday", "Fish taco please!"])
+
+    # Map a dataset to preprocess a single sentence.
+    features = tf.constant(
+        [
+            "Avatar 2 is amazing!",
+            "Well, I am not sure.",
+        ]
+    )
+    labels = tf.constant([1, 0])
+    ds = tf.data.Dataset.from_tensor_slices((features, labels))
+    ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+
+    # Map a dataset to preprocess unlabled sentences.
+    ds = tf.data.Dataset.from_tensor_slices(features)
+    ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+    ```
+    """
+
+    def call(
+        self,
+        x,
+        y=None,
+        sample_weight=None,
+        sequence_length=None,
+    ):
+        if y is not None or sample_weight is not None:
+            logging.warning(
+                "`Llama3CausalLMPreprocessor` generates `y` and "
+                "`sample_weight` based on your input data, but your data "
+                "already contains `y` or `sample_weight`. Your `y` and "
+                "`sample_weight` will be ignored."
+            )
+        sequence_length = sequence_length or self.sequence_length
+
+        x = convert_inputs_to_list_of_tensor_segments(x)[0]
+        x = self.tokenizer(x)
+        # Pad with one extra token to account for the truncation below.
+        token_ids, padding_mask = self.packer(
+            x,
+            sequence_length=sequence_length + 1,
+            add_start_value=self.add_start_token,
+            add_end_value=self.add_end_token,
+        )
+        # The last token does not have a next token, so we truncate it out.
+        x = {
+            "token_ids": token_ids[..., :-1],
+            "padding_mask": padding_mask[..., :-1],
+        }
+        # Target `y` will be the next token.
+        y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
+        return pack_x_y_sample_weight(x, y, sample_weight)
+
+    def generate_preprocess(
+        self,
+        x,
+        sequence_length=None,
+    ):
+        """Convert strings to integer token input for generation.
+
+        Similar to calling the layer for training, this method takes in strings
+        or tensor strings, tokenizes and packs the input, and computes a padding
+        mask masking all inputs not filled in with a padded value.
+
+        Unlike calling the layer for training, this method does not compute
+        labels and will never append a `tokenizer.end_token_id` to the end of
+        the sequence (as generation is expected to continue at the end of the
+        inputted prompt).
+        """
+        if not self.built:
+            self.build(None)
+
+        x = convert_inputs_to_list_of_tensor_segments(x)[0]
+        x = self.tokenizer(x)
+        token_ids, padding_mask = self.packer(
+            x, sequence_length=sequence_length, add_end_value=False
+        )
+        return {
+            "token_ids": token_ids,
+            "padding_mask": padding_mask,
+        }
+
+    def generate_postprocess(
+        self,
+        x,
+    ):
+        """Convert integer token output to strings for generation.
+
+        This method reverses `generate_preprocess()`, by first removing all
+        padding and start/end tokens, and then converting the integer sequence
+        back to a string.
+        """
+        token_ids, padding_mask = x["token_ids"], x["padding_mask"]
+        # Convert the inputs to numpy arrays if they aren't a tensor already.
+        if not isinstance(token_ids, tf.Tensor):
+            token_ids = ops.convert_to_numpy(token_ids)
+            # Make sure the numpy array has type `int32` since
+            # `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
+            token_ids = token_ids.astype("int32")
+        if not isinstance(padding_mask, tf.Tensor):
+            padding_mask = ops.convert_to_numpy(padding_mask)
+            padding_mask = padding_mask.astype("bool")
+        # Strip any special tokens during detokenization (e.g. the start and
+        # end markers). In the future we could make this configurable.
+        padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
+        padding_mask = padding_mask & (
+            token_ids != self.tokenizer.start_token_id
+        )
+        token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
+        return self.tokenizer.detokenize(token_ids)
diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py
new file mode 100644
index 0000000000..2b79bd0d4f
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py
@@ -0,0 +1,94 @@
+# Copyright 2024 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
+    Llama3CausalLMPreprocessor,
+)
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
+from keras_nlp.src.tests.test_case import TestCase
+
+
+class Llama3CausalLMPreprocessorTest(TestCase):
+    def setUp(self):
+        self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
+        self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
+        self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+        self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+        self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+        self.merges += ["Ġai r", "Ġa i", "pla ne"]
+        self.tokenizer = Llama3Tokenizer(
+            vocabulary=self.vocab,
+            merges=self.merges,
+        )
+        self.init_kwargs = {
+            "tokenizer": self.tokenizer,
+            "sequence_length": 8,
+        }
+        self.input_data = ["airplane at airport"]
+
+    def test_preprocessor_basics(self):
+        self.run_preprocessor_test(
+            cls=Llama3CausalLMPreprocessor,
+            init_kwargs=self.init_kwargs,
+            input_data=self.input_data,
+            expected_output=(
+                {
+                    "token_ids": [[6, 1, 3, 4, 2, 5, 0, 0]],
+                    "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
+                },
+                [[1, 3, 4, 2, 5, 0, 0, 0]],  # Pass through labels.
+                [[1, 1, 1, 1, 1, 0, 0, 0]],  # Pass through sample_weights.
+            ),
+        )
+
+    def test_with_start_end_token(self):
+        input_data = ["airplane at airport"] * 4
+
+        preprocessor = Llama3CausalLMPreprocessor(
+            **self.init_kwargs,
+            add_start_token=True,
+            add_end_token=True,
+        )
+        x, y, sw = preprocessor(input_data)
+        self.assertAllEqual(x["token_ids"], [[6, 1, 3, 4, 2, 5, 7, 0]] * 4)
+        self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4)
+        self.assertAllEqual(y, [[1, 3, 4, 2, 5, 7, 0, 0]] * 4)
+        self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
+
+    def test_generate_preprocess(self):
+        input_data = "airplane at airport"
+        preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs)
+        x = preprocessor.generate_preprocess(input_data)
+        self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
+        self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
+
+    def test_generate_postprocess(self):
+        input_data = {
+            "token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
+            "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
+        }
+        preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs)
+        x = preprocessor.generate_postprocess(input_data)
+        self.assertAllEqual(x, "airplane at airport")
+
+    @pytest.mark.extra_large
+    def test_all_presets(self):
+        for preset in Llama3CausalLMPreprocessor.presets:
+            self.run_preset_test(
+                cls=Llama3CausalLMPreprocessor,
+                preset=preset,
+                input_data=self.input_data,
+            )
diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_test.py b/keras_nlp/src/models/llama3/llama3_causal_lm_test.py
new file mode 100644
index 0000000000..d513c9390f
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_causal_lm_test.py
@@ -0,0 +1,130 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import patch
+
+import pytest
+
+from keras_nlp.src.backend import ops
+from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
+from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM
+from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
+    Llama3CausalLMPreprocessor,
+)
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
+from keras_nlp.src.tests.test_case import TestCase
+
+
+class Llama3CausalLMTest(TestCase):
+    def setUp(self):
+        self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
+        self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
+        self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+        self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+        self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+        self.merges += ["Ġai r", "Ġa i", "pla ne"]
+        self.preprocessor = Llama3CausalLMPreprocessor(
+            Llama3Tokenizer(vocabulary=self.vocab, merges=self.merges),
+            sequence_length=7,
+        )
+        self.backbone = Llama3Backbone(
+            vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+            num_layers=2,
+            num_query_heads=4,
+            num_key_value_heads=2,
+            hidden_dim=8,
+            intermediate_dim=16,
+        )
+        self.init_kwargs = {
+            "preprocessor": self.preprocessor,
+            "backbone": self.backbone,
+        }
+        self.train_data = ([" airplane at airport", " airplane at airport"],)
+        self.input_data = self.preprocessor(*self.train_data)[0]
+
+    def test_causal_lm_basics(self):
+        self.run_task_test(
+            cls=Llama3CausalLM,
+            init_kwargs=self.init_kwargs,
+            train_data=self.train_data,
+            expected_output_shape=(2, 7, 8),
+        )
+
+    def test_generate(self):
+        causal_lm = Llama3CausalLM(**self.init_kwargs)
+        # String input.
+        prompt = " airplane at airport"
+        output = causal_lm.generate(" airplane at airport")
+        self.assertTrue(prompt in output)
+        # Int tensor input.
+        prompt_ids = self.preprocessor.generate_preprocess([prompt])
+        causal_lm.preprocessor = None
+        outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
+        # Assert prompt is in output in token id space.
+        self.assertAllEqual(
+            outputs["token_ids"][:, :5],
+            prompt_ids["token_ids"][:, :5],
+        )
+        self.assertAllEqual(
+            outputs["padding_mask"][:, :5],
+            prompt_ids["padding_mask"][:, :5],
+        )
+
+    def test_early_stopping(self):
+        causal_lm = Llama3CausalLM(**self.init_kwargs)
+        call_with_cache = causal_lm.call_with_cache
+
+        def wrapper(*args, **kwargs):
+            """Modify output logits to always favor end_token_id"""
+            logits, hidden_states, cache = call_with_cache(*args, **kwargs)
+            index = self.preprocessor.tokenizer.end_token_id
+            update = ops.ones_like(logits)[:, :, index] * 1.0e9
+            update = ops.expand_dims(update, axis=-1)
+            logits = ops.slice_update(logits, (0, 0, index), update)
+            return logits, hidden_states, cache
+
+        with patch.object(causal_lm, "call_with_cache", wraps=wrapper):
+            prompt = [" airplane at airport", " airplane"]
+            output = causal_lm.generate(prompt)
+            # We should immediately abort and output the prompt.
+            self.assertEqual(prompt, output)
+
+    def test_generate_compilation(self):
+        causal_lm = Llama3CausalLM(**self.init_kwargs)
+        # Assert we do not recompile with successive calls.
+        causal_lm.generate(" airplane at airport")
+        first_fn = causal_lm.generate_function
+        causal_lm.generate(" airplane at airport")
+        second_fn = causal_lm.generate_function
+        self.assertEqual(first_fn, second_fn)
+        # Assert we do recompile after compile is called.
+        causal_lm.compile(sampler="greedy")
+        self.assertIsNone(causal_lm.generate_function)
+
+    @pytest.mark.large
+    def test_saved_model(self):
+        self.run_model_saving_test(
+            cls=Llama3CausalLM,
+            init_kwargs=self.init_kwargs,
+            input_data=self.input_data,
+        )
+
+    @pytest.mark.extra_large
+    def test_all_presets(self):
+        for preset in Llama3CausalLM.presets:
+            self.run_preset_test(
+                cls=Llama3CausalLM,
+                preset=preset,
+                input_data=self.input_data,
+            )
diff --git a/keras_nlp/src/models/llama3/llama3_preprocessor.py b/keras_nlp/src/models/llama3/llama3_preprocessor.py
new file mode 100644
index 0000000000..cd005f0388
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_preprocessor.py
@@ -0,0 +1,21 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from keras_nlp.src.api_export import keras_nlp_export
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
+from keras_nlp.src.models.llama.llama_preprocessor import LlamaPreprocessor
+
+
+@keras_nlp_export("keras_nlp.models.Llama3Preprocessor")
+class Llama3Preprocessor(LlamaPreprocessor):
+    tokenizer_cls = Llama3Tokenizer
diff --git a/keras_nlp/src/models/llama3/llama3_preprocessor_test.py b/keras_nlp/src/models/llama3/llama3_preprocessor_test.py
new file mode 100644
index 0000000000..ffbbb9c1da
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_preprocessor_test.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
+from keras_nlp.src.tests.test_case import TestCase
+
+
+class Llama3PreprocessorTest(TestCase):
+    def setUp(self):
+        self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
+        self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
+        self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+        self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+        self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+        self.merges += ["Ġai r", "Ġa i", "pla ne"]
+        self.tokenizer = Llama3Tokenizer(
+            vocabulary=self.vocab,
+            merges=self.merges,
+        )
+        self.init_kwargs = {
+            "tokenizer": self.tokenizer,
+            "sequence_length": 8,
+        }
+        self.input_data = [
+            "airplane at airport",
+        ]
+
+    def test_preprocessor_basics(self):
+        self.run_preprocessor_test(
+            cls=Llama3Preprocessor,
+            init_kwargs=self.init_kwargs,
+            input_data=self.input_data,
+            expected_output=(
+                {
+                    "token_ids": [[7, 1, 3, 4, 2, 5, 0, 0]],
+                    "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
+                }
+            ),
+        )
+
+    def test_with_start_end_token(self):
+        input_data = ["airplane at airport"] * 4
+
+        preprocessor = Llama3Preprocessor(
+            tokenizer=Llama3Tokenizer(
+                vocabulary=self.vocab,
+                merges=self.merges,
+            ),
+            sequence_length=8,
+            add_start_token=True,
+            add_end_token=True,
+        )
+        x = preprocessor(input_data)
+        self.assertAllEqual(x["token_ids"], [[7, 1, 3, 4, 2, 5, 6, 0]] * 4)
+        self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4)
+
+    def test_sequence_length_override(self):
+        input_data = "airplane at airport"
+        preprocessor = Llama3Preprocessor(**self.init_kwargs)
+        x = preprocessor(input_data, sequence_length=4)
+        self.assertAllEqual(x["token_ids"], [7, 1, 3, 4])
+
+    @pytest.mark.extra_large
+    def test_all_presets(self):
+        for preset in Llama3Preprocessor.presets:
+            self.run_preset_test(
+                cls=Llama3Preprocessor,
+                preset=preset,
+                input_data=self.input_data,
+            )
diff --git a/keras_nlp/src/models/llama3/llama3_presets.py b/keras_nlp/src/models/llama3/llama3_presets.py
new file mode 100644
index 0000000000..10f7fed1a4
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_presets.py
@@ -0,0 +1,38 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Llama 3 model preset configurations."""
+
+# Metadata for loading pretrained model weights.
+backbone_presets = {
+    "llama3_8b_en": {
+        "metadata": {
+            "description": "LLaMA 3 8B Base model",
+            "params": 8030261248,
+            "official_name": "LLaMA 3",
+            "path": "llama3",
+            "model_card": "https://github.com/meta-llama/llama3",
+        },
+        "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/2",
+    },
+    "llama3_instruct_8b_en": {
+        "metadata": {
+            "description": "LLaMA 3 8B Instruct model",
+            "params": 8030261248,
+            "official_name": "LLaMA 3",
+            "path": "llama3",
+            "model_card": "https://github.com/meta-llama/llama3",
+        },
+        "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/2",
+    },
+}
diff --git a/keras_nlp/src/models/llama3/llama3_tokenizer.py b/keras_nlp/src/models/llama3/llama3_tokenizer.py
new file mode 100644
index 0000000000..4c841a6d42
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_tokenizer.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from keras_nlp.src.api_export import keras_nlp_export
+from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
+
+
+@keras_nlp_export("keras_nlp.models.Llama3Tokenizer")
+class Llama3Tokenizer(BytePairTokenizer):
+    def __init__(
+        self,
+        vocabulary=None,
+        merges=None,
+        **kwargs,
+    ):
+        self.start_token = "<|begin_of_text|>"
+        self.end_token = "<|end_of_text|>"
+
+        super().__init__(
+            vocabulary=vocabulary,
+            merges=merges,
+            unsplittable_tokens=[self.start_token, self.end_token],
+            **kwargs,
+        )
+
+    def set_vocabulary_and_merges(self, vocabulary, merges):
+        super().set_vocabulary_and_merges(vocabulary, merges)
+
+        if vocabulary is not None:
+            # Check for necessary special tokens.
+            if self.end_token not in self.get_vocabulary():
+                raise ValueError(
+                    f"Cannot find token `'{self.end_token}'` in the provided "
+                    f"`vocabulary`. Please provide `'{self.end_token}'` in "
+                    "your `vocabulary` or use a pretrained `vocabulary` name."
+                )
+
+            self.start_token_id = self.token_to_id(self.start_token)
+            self.end_token_id = self.token_to_id(self.end_token)
+            self.pad_token_id = 0
+        else:
+            self.end_token_id = None
+            self.start_token_id = None
+            self.pad_token_id = None
+
+    def get_config(self):
+        config = super().get_config()
+        # In the constructor, we pass the list of special tokens to the
+        # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
+        # delete it from the config here.
+        del config["unsplittable_tokens"]
+        return config
diff --git a/keras_nlp/src/models/llama3/llama3_tokenizer_test.py b/keras_nlp/src/models/llama3/llama3_tokenizer_test.py
new file mode 100644
index 0000000000..5ea6c193be
--- /dev/null
+++ b/keras_nlp/src/models/llama3/llama3_tokenizer_test.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
+from keras_nlp.src.tests.test_case import TestCase
+
+
+class Llama3TokenizerTest(TestCase):
+    def setUp(self):
+        self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
+        self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
+        self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+        self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+        self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+        self.merges += ["Ġai r", "Ġa i", "pla ne"]
+        self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+        self.input_data = [
+            "<|begin_of_text|>airplane at airport<|end_of_text|>",
+            " airplane airport",
+        ]
+
+    def test_tokenizer_basics(self):
+        self.run_preprocessing_layer_test(
+            cls=Llama3Tokenizer,
+            init_kwargs=self.init_kwargs,
+            input_data=self.input_data,
+            expected_output=[[7, 1, 3, 4, 2, 5, 6], [2, 3, 2, 5]],
+        )
+
+    def test_errors_missing_special_tokens(self):
+        with self.assertRaises(ValueError):
+            Llama3Tokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"])
+
+    @pytest.mark.large
+    def test_smallest_preset(self):
+        self.run_preset_test(
+            cls=Llama3Tokenizer,
+            preset="llama3_8b_en",
+            input_data=["The quick brown fox."],
+            expected_output=[[791, 4062, 14198, 39935, 13]],
+        )
+
+    @pytest.mark.extra_large
+    def test_all_presets(self):
+        for preset in Llama3Tokenizer.presets:
+            self.run_preset_test(
+                cls=Llama3Tokenizer,
+                preset=preset,
+                input_data=self.input_data,
+            )
diff --git a/tools/checkpoint_conversion/convert_llama3_checkpoints.py b/tools/checkpoint_conversion/convert_llama3_checkpoints.py
new file mode 100644
index 0000000000..c4bfde5fbe
--- /dev/null
+++ b/tools/checkpoint_conversion/convert_llama3_checkpoints.py
@@ -0,0 +1,299 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import traceback
+
+import numpy as np
+import torch
+from absl import app
+from absl import flags
+from huggingface_hub import hf_hub_download
+from keras import ops
+from transformers import AutoTokenizer
+from transformers import LlamaForCausalLM
+
+from keras_nlp import upload_preset
+from keras_nlp.models import Llama3Backbone
+from keras_nlp.models import Llama3CausalLMPreprocessor
+from keras_nlp.models import Llama3Tokenizer
+
+PRESET_MAP = {
+    "llama3_8b_en": "meta-llama/Meta-Llama-3-8B",
+    "llama3_instruct_8b_en": "meta-llama/Meta-Llama-3-8B-Instruct",
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(
+    "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
+)
+
+
+def convert_checkpoints(keras_nlp_model, hf_model):
+    config = hf_model.config
+
+    keras_nlp_model.token_embedding.embeddings.assign(
+        hf_model.model.embed_tokens.weight.detach().cpu().float().numpy()
+    )
+
+    for i in range(keras_nlp_model.num_layers):
+        keras_nlp_model.transformer_layers[
+            i
+        ]._self_attention_layer._key_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .self_attn.k_proj.weight.T.reshape(
+                    config.hidden_size,
+                    config.num_key_value_heads,
+                    config.hidden_size // config.num_attention_heads,
+                )
+                .detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._self_attention_layer._query_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .self_attn.q_proj.weight.T.reshape(
+                    config.hidden_size,
+                    config.num_attention_heads,
+                    config.hidden_size // config.num_attention_heads,
+                )
+                .detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._self_attention_layer._value_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .self_attn.v_proj.weight.T.reshape(
+                    config.hidden_size,
+                    config.num_key_value_heads,
+                    config.hidden_size // config.num_attention_heads,
+                )
+                .detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._self_attention_layer._output_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .self_attn.o_proj.weight.T.reshape(
+                    config.num_attention_heads,
+                    config.hidden_size // config.num_attention_heads,
+                    config.hidden_size,
+                )
+                .detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._self_attention_layernorm.set_weights(
+            [
+                hf_model.model.layers[i]
+                .input_layernorm.weight.detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._feedforward_intermediate_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .mlp.up_proj.weight.T.detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._feedforward_output_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .mlp.down_proj.weight.T.detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._feedforward_gate_dense.set_weights(
+            [
+                hf_model.model.layers[i]
+                .mlp.gate_proj.weight.T.detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+        keras_nlp_model.transformer_layers[
+            i
+        ]._feedforward_layernorm.set_weights(
+            [
+                hf_model.model.layers[i]
+                .post_attention_layernorm.weight.detach()
+                .cpu()
+                .float()
+                .numpy()
+            ]
+        )
+
+    keras_nlp_model.layer_norm.set_weights(
+        [hf_model.model.norm.weight.detach().cpu().float().numpy()]
+    )
+    keras_nlp_model.token_embedding.reverse_embeddings.assign(
+        hf_model.lm_head.weight.T.detach().cpu().float().numpy()
+    )
+
+
+def test_model(
+    keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_model_tokenizer
+):
+    # First, test that the number of parameters match
+    keras_nlp_params = keras_nlp_model.count_params()
+    hf_params = hf_model.num_parameters()
+    assert keras_nlp_params == hf_params
+
+    # Test the outputs of both the models
+    hf_outputs = hf_model(
+        **hf_model_tokenizer(["What is Keras?"], return_tensors="pt")
+    )
+    hf_output_logits = hf_outputs.logits.detach().cpu().numpy()
+
+    keras_nlp_preprocessor = Llama3CausalLMPreprocessor(keras_nlp_tokenizer)
+    keras_nlp_output = keras_nlp_model(
+        keras_nlp_preprocessor(["What is Keras?"], sequence_length=5)[0]
+    )
+    keras_nlp_logits = keras_nlp_model.token_embedding(
+        keras_nlp_output, reverse=True
+    )
+    keras_nlp_logits = ops.convert_to_numpy(keras_nlp_logits)
+
+    # High tolerence since bfloat16 is used as the default dtype for Llama
+    try:
+        np.testing.assert_allclose(
+            keras_nlp_logits, hf_output_logits, atol=1e-4
+        )
+    except AssertionError as err:
+        print("\n")
+        print(traceback.format_exc())
+        print(err.args[0])
+        print("\n")
+
+
+def test_tokenizer(keras_nlp_tokenizer, hf_tokenizer):
+    hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt")
+    hf_output = hf_output["input_ids"].detach().cpu().numpy()
+    keras_nlp_preprocessor = Llama3CausalLMPreprocessor(keras_nlp_tokenizer)
+    keras_nlp_output = keras_nlp_preprocessor(
+        ["What is Keras?"], sequence_length=5
+    )
+    keras_nlp_output = ops.convert_to_numpy(keras_nlp_output[0]["token_ids"])
+
+    np.testing.assert_equal(keras_nlp_output, hf_output)
+
+
+def main(_):
+    # === Get the preset name ===
+    if FLAGS.preset not in PRESET_MAP.keys():
+        raise ValueError(
+            f"Invalid preset {FLAGS.preset}. Must be one "
+            f"of {','.join(PRESET_MAP.keys())}"
+        )
+    preset = FLAGS.preset
+    hf_preset = PRESET_MAP[preset]
+
+    # === Load the Huggingface model ===
+    hf_model = LlamaForCausalLM.from_pretrained(
+        hf_preset, torch_dtype=torch.bfloat16
+    )
+    hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
+    hf_model.eval()
+    print("\n-> Huggingface model and tokenizer loaded")
+
+    # === Load the KerasNLP model ===
+    backbone_kwargs = dict(
+        vocabulary_size=hf_model.config.vocab_size,
+        hidden_dim=hf_model.config.hidden_size,
+        num_layers=hf_model.config.num_hidden_layers,
+        num_query_heads=hf_model.config.num_attention_heads,
+        num_key_value_heads=hf_model.config.num_key_value_heads,
+        intermediate_dim=hf_model.config.intermediate_size,
+        layer_norm_epsilon=hf_model.config.rms_norm_eps,
+        rope_max_wavelength=hf_model.config.rope_theta,
+        dtype="bfloat16",
+    )
+    keras_nlp_model = Llama3Backbone(**backbone_kwargs)
+
+    # === Get the tokenizer from the Huggingface model ===
+    tokenizer_path = hf_hub_download(
+        "meta-llama/Meta-Llama-3-8B", "tokenizer.json", token=True
+    )
+    with open(tokenizer_path, "r") as tokenizer_file:
+        tokenizer_content = json.load(tokenizer_file)
+    vocabulary = hf_tokenizer.vocab
+    merges = tokenizer_content["model"]["merges"]
+    keras_nlp_tokenizer = Llama3Tokenizer(vocabulary, merges)
+    print("\n-> Keras 3 model and tokenizer loaded.")
+
+    # === Port the weights ===
+    convert_checkpoints(keras_nlp_model, hf_model)
+    print("\n-> Weight transfer done.")
+
+    # === Check that the models and tokenizers outputs match ===
+    test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
+    test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
+    print("\n-> Tests passed!")
+
+    keras_nlp_model.save_to_preset(preset)
+    print("\n-> Saved the model preset in float16")
+
+    # === Save the tokenizer ===
+    keras_nlp_tokenizer.save_to_preset(preset)
+    print("\n-> Saved the tokenizer")
+
+    # === Upload the preset ===
+    try:
+        uri = f"kaggle://keras/llama3/keras/{preset}"
+        upload_preset(uri, preset)
+        print("-> Uploaded the preset!")
+    except Exception:
+        print(
+            "-> Failed to upload the preset. Make sure you have the "
+            "correct premissions to upload and/or the page "
+            "you are pushing to exists."
+        )
+
+
+if __name__ == "__main__":
+    flags.mark_flag_as_required("preset")
+    app.run(main)
diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py
index 4e127b2c7d..27be78901b 100644
--- a/tools/checkpoint_conversion/convert_llama_checkpoints.py
+++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py
@@ -11,23 +11,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import gc
-import os
-import shutil
-import tempfile
 import traceback
 
 import numpy as np
+import torch
 from absl import app
 from absl import flags
 from keras import ops
 from transformers import AutoTokenizer
 from transformers import LlamaForCausalLM
 
+from keras_nlp import upload_preset
 from keras_nlp.models import LlamaBackbone
 from keras_nlp.models import LlamaCausalLMPreprocessor
 from keras_nlp.models import LlamaTokenizer
-from keras_nlp.utils.preset_utils import save_to_preset
 
 PRESET_MAP = {
     "llama2_7b_en": "meta-llama/Llama-2-7b-hf",
@@ -224,65 +221,53 @@ def main(_):
     preset = FLAGS.preset
     hf_preset = PRESET_MAP[preset]
 
-    # === Create the temporary save directories ===
-    temp_dir = tempfile.mkdtemp()
-
-    try:
-        # === Load the Huggingface model ===
-        hf_model = LlamaForCausalLM.from_pretrained(hf_preset)
-        hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
-        hf_model.eval()
-        print("\n-> Huggingface model and tokenizer loaded")
-
-        # === Load the KerasNLP model ===
-        backbone_kwargs = dict(
-            vocabulary_size=hf_model.config.vocab_size,
-            hidden_dim=hf_model.config.hidden_size,
-            num_layers=hf_model.config.num_hidden_layers,
-            num_query_heads=hf_model.config.num_attention_heads,
-            num_key_value_heads=hf_model.config.num_key_value_heads,
-            intermediate_dim=hf_model.config.intermediate_size,
-            layer_norm_epsilon=hf_model.config.rms_norm_eps,
-            rope_max_wavelength=hf_model.config.rope_theta,
-            dtype="float32",
-        )
-        keras_nlp_model = LlamaBackbone(**backbone_kwargs)
+    # === Load the Huggingface model ===
+    hf_model = LlamaForCausalLM.from_pretrained(
+        hf_preset, torch_dtype=torch.bfloat16
+    )
+    hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
+    hf_model.eval()
+    print("\n-> Huggingface model and tokenizer loaded")
 
-        # === Get the tokenizer from the Huggingface model ===
-        tokenizer_path = hf_tokenizer.vocab_file
-        keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path)
-        print("\n-> Keras 3 model and tokenizer loaded.")
+    # === Load the KerasNLP model ===
+    backbone_kwargs = dict(
+        vocabulary_size=hf_model.config.vocab_size,
+        hidden_dim=hf_model.config.hidden_size,
+        num_layers=hf_model.config.num_hidden_layers,
+        num_query_heads=hf_model.config.num_attention_heads,
+        num_key_value_heads=hf_model.config.num_key_value_heads,
+        intermediate_dim=hf_model.config.intermediate_size,
+        layer_norm_epsilon=hf_model.config.rms_norm_eps,
+        rope_max_wavelength=hf_model.config.rope_theta,
+        dtype="bfloat16",
+    )
+    keras_nlp_model = LlamaBackbone(**backbone_kwargs)
 
-        # === Port the weights ===
-        convert_checkpoints(keras_nlp_model, hf_model)
-        print("\n-> Weight transfer done.")
+    # === Get the tokenizer from the Huggingface model ===
+    tokenizer_path = hf_tokenizer.vocab_file
+    keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path)
+    print("\n-> Keras 3 model and tokenizer loaded.")
 
-        # === Check that the models and tokenizers outputs match ===
-        test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
-        test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
-        print("\n-> Tests passed!")
+    # === Port the weights ===
+    convert_checkpoints(keras_nlp_model, hf_model)
+    print("\n-> Weight transfer done.")
 
-        # === Save the model weights in float32 format ===
-        keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
-        print("\n-> Saved the model weights in float32")
+    # === Check that the models and tokenizers outputs match ===
+    test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
+    test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
+    print("\n-> Tests passed!")
 
-        del keras_nlp_model, hf_model
-        gc.collect()
+    keras_nlp_model.save_to_preset(keras_nlp_model, preset)
+    print("\n-> Saved the model preset in float16")
 
-        # === Save the weights again in float16 ===
-        backbone_kwargs["dtype"] = "float16"
-        keras_nlp_model = LlamaBackbone(**backbone_kwargs)
-        keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))
-        save_to_preset(keras_nlp_model, preset)
-        print("\n-> Saved the model preset in float16")
+    # === Save the tokenizer ===
+    keras_nlp_tokenizer.save_to_preset(preset)
+    print("\n-> Saved the tokenizer")
 
-        # === Save the tokenizer ===
-        save_to_preset(
-            keras_nlp_tokenizer, preset, config_filename="tokenizer.json"
-        )
-        print("\n-> Saved the tokenizer")
-    finally:
-        shutil.rmtree(temp_dir)
+    # === Upload the preset ===
+    uri = f"kaggle://keras/llama2/keras/{preset}"
+    upload_preset(uri, preset)
+    print("-> Uploaded the preset!")
 
 
 if __name__ == "__main__":