Skip to content

Commit

Permalink
Add LLaMA 3 tokenizer and preset (#1584)
Browse files Browse the repository at this point in the history
* Add LLaMA 3 tokenizer and preset

* Add a LLaMA 3 backbone and correct presets

* Add docs for LLaMA 3 backbone

[skip ci]

* Fix lint failures

* Fix the checkpointing scripts

* Add tests for all the components

* Run shell/api_gen.sh

* Address review comments; run api_gen.sh
  • Loading branch information
tirthasheshpatel authored May 17, 2024
1 parent a675aeb commit 294304b
Show file tree
Hide file tree
Showing 16 changed files with 1,176 additions and 73 deletions.
6 changes: 6 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@
GPTNeoXPreprocessor,
)
from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Llama3CausalLMPreprocessor,
)
from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM
from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import (
Expand Down
12 changes: 3 additions & 9 deletions keras_nlp/src/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import (
LlamaCausalLMPreprocessor,
)
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.tensor_utils import any_equal


Expand All @@ -46,6 +45,9 @@ class LlamaCausalLM(CausalLM):
should be preprocessed before calling the model.
"""

backbone_cls = LlamaBackbone
preprocessor_cls = LlamaCausalLMPreprocessor

def __init__(self, backbone, preprocessor=None, **kwargs):
# === Layers ===
self.backbone = backbone
Expand All @@ -61,14 +63,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
**kwargs,
)

@classproperty
def backbone_cls(cls):
return LlamaBackbone

@classproperty
def preprocessor_cls(cls):
return LlamaCausalLMPreprocessor

def call_with_cache(
self,
token_ids,
Expand Down
7 changes: 0 additions & 7 deletions keras_nlp/src/models/llama/llama_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.llama.llama_presets import backbone_presets
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
from keras_nlp.src.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.LlamaTokenizer")
Expand Down Expand Up @@ -85,7 +82,3 @@ def set_proto(self, proto):
self.start_token_id = None
self.end_token_id = None
self.pad_token_id = None

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
20 changes: 20 additions & 0 deletions keras_nlp/src/models/llama3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_presets import backbone_presets
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (Llama3Backbone, Llama3Tokenizer))
84 changes: 84 additions & 0 deletions keras_nlp/src/models/llama3/llama3_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone


# LLaMA 3 shares the same architecture as its predecessors
# So, we simply create an alias for API consistency
@keras_nlp_export("keras_nlp.models.Llama3Backbone")
class Llama3Backbone(LlamaBackbone):
"""
The Llama Transformer core architecture with hyperparameters.
This network implements a Transformer-based decoder network,
Llama, as described in
["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf).
It includes the embedding lookups and transformer layers.
The default constructor gives a fully customizable, randomly initialized
Llama model with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the `from_preset`
constructor.
Args:
vocabulary_size (int): The size of the token vocabulary.
num_layers (int): The number of transformer layers.
num_query_heads (int): The number of query attention heads for
each transformer.
hidden_dim (int): The size of the transformer encoding and pooling layers.
intermediate_dim (int): The output dimension of the first Dense layer in a
three-layer feedforward network for each transformer.
num_key_value_heads (int): The number of key and value attention heads for
each transformer.
rope_max_wavelength (int, optional): The maximum angular wavelength of the
sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_factor (float, optional): The scaling factor for calculation
of roatary embedding. Defaults to `1.0`.
layer_norm_epsilon (float, optional): Epsilon for the layer normalization
layers in the transformer decoder. Defaults to `1e-6`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.
Examples:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}
# Pretrained Llama decoder.
model = keras_nlp.models.Llama3Backbone.from_preset("llama3_8b_en")
model(input_data)
# Randomly initialized Llama decoder with custom config.
model = keras_nlp.models.Llama3Backbone(
vocabulary_size=10,
hidden_dim=512,
num_layers=2,
num_query_heads=32,
num_key_value_heads=8,
intermediate_dim=1024,
layer_norm_epsilon=1e-6,
dtype="float32"
)
model(input_data)
```
"""

pass
44 changes: 44 additions & 0 deletions keras_nlp/src/models/llama3/llama3_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Llama3CausalLMPreprocessor,
)
from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM


class Llama3CausalLM(LlamaCausalLM):
"""An end-to-end Llama 3 model for causal language modeling.
A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a LLaMA 3 model, simply by calling `fit()`.
This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.
Args:
backbone: A `keras_nlp.models.Llama3Backbone` instance.
preprocessor: A `keras_nlp.models.Llama3CausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
"""

backbone_cls = Llama3Backbone
preprocessor_cls = Llama3CausalLMPreprocessor
Loading

0 comments on commit 294304b

Please sign in to comment.