-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LLaMA 3 tokenizer and preset #1584
Merged
mattdangerw
merged 10 commits into
keras-team:master
from
tirthasheshpatel:llama3-preset
May 17, 2024
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a546dd9
Add LLaMA 3 tokenizer and preset
tirthasheshpatel efe00e5
Add a LLaMA 3 backbone and correct presets
tirthasheshpatel a7a85e7
Add docs for LLaMA 3 backbone
tirthasheshpatel 51177cd
Merge branch 'master' of github.com:keras-team/keras-nlp into llama3-…
tirthasheshpatel b78b885
Fix lint failures
tirthasheshpatel 7817cad
Fix the checkpointing scripts
tirthasheshpatel d407d9c
Add tests for all the components
tirthasheshpatel 74d7f28
Merge branch 'master' of github.com:keras-team/keras-nlp into llama3-…
tirthasheshpatel de427bb
Run shell/api_gen.sh
tirthasheshpatel 3fcf0bf
Address review comments; run api_gen.sh
tirthasheshpatel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone | ||
from keras_nlp.src.models.llama3.llama3_presets import backbone_presets | ||
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer | ||
from keras_nlp.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, (Llama3Backbone, Llama3Tokenizer)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_nlp.src.api_export import keras_nlp_export | ||
from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone | ||
|
||
|
||
# LLaMA 3 shares the same architecture as its predecessors | ||
# So, we simply create an alias for API consistency | ||
@keras_nlp_export("keras_nlp.models.Llama3Backbone") | ||
class Llama3Backbone(LlamaBackbone): | ||
""" | ||
The Llama Transformer core architecture with hyperparameters. | ||
|
||
This network implements a Transformer-based decoder network, | ||
Llama, as described in | ||
["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf). | ||
It includes the embedding lookups and transformer layers. | ||
|
||
The default constructor gives a fully customizable, randomly initialized | ||
Llama model with any number of layers, heads, and embedding | ||
dimensions. To load preset architectures and weights, use the `from_preset` | ||
constructor. | ||
|
||
Args: | ||
vocabulary_size (int): The size of the token vocabulary. | ||
num_layers (int): The number of transformer layers. | ||
num_query_heads (int): The number of query attention heads for | ||
each transformer. | ||
hidden_dim (int): The size of the transformer encoding and pooling layers. | ||
intermediate_dim (int): The output dimension of the first Dense layer in a | ||
three-layer feedforward network for each transformer. | ||
num_key_value_heads (int): The number of key and value attention heads for | ||
each transformer. | ||
rope_max_wavelength (int, optional): The maximum angular wavelength of the | ||
sine/cosine curves, for rotary embeddings. Defaults to `10000`. | ||
rope_scaling_factor (float, optional): The scaling factor for calculation | ||
of roatary embedding. Defaults to `1.0`. | ||
layer_norm_epsilon (float, optional): Epsilon for the layer normalization | ||
layers in the transformer decoder. Defaults to `1e-6`. | ||
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||
for model computations and weights. Note that some computations, | ||
such as softmax and layer normalization, will always be done at | ||
float32 precision regardless of dtype. | ||
|
||
Examples: | ||
|
||
```python | ||
input_data = { | ||
"token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), | ||
} | ||
|
||
# Pretrained Llama decoder. | ||
model = keras_nlp.models.Llama3Backbone.from_preset("llama3_8b_en") | ||
model(input_data) | ||
|
||
# Randomly initialized Llama decoder with custom config. | ||
model = keras_nlp.models.Llama3Backbone( | ||
vocabulary_size=10, | ||
hidden_dim=512, | ||
num_layers=2, | ||
num_query_heads=32, | ||
num_key_value_heads=8, | ||
intermediate_dim=1024, | ||
layer_norm_epsilon=1e-6, | ||
dtype="float32" | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone | ||
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( | ||
Llama3CausalLMPreprocessor, | ||
) | ||
from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM | ||
|
||
|
||
class Llama3CausalLM(LlamaCausalLM): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems inconsistent with the rest of the classes. Should we copy a docstring over here as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"""An end-to-end Llama 3 model for causal language modeling. | ||
|
||
A causal language model (LM) predicts the next token based on previous | ||
tokens. This task setup can be used to train the model unsupervised on | ||
plain text input, or to autoregressively generate plain text similar to | ||
the data used for training. This task can be used for pre-training or | ||
fine-tuning a LLaMA 3 model, simply by calling `fit()`. | ||
|
||
This model has a `generate()` method, which generates text based on a | ||
prompt. The generation strategy used is controlled by an additional | ||
`sampler` argument on `compile()`. You can recompile the model with | ||
different `keras_nlp.samplers` objects to control the generation. By | ||
default, `"top_k"` sampling will be used. | ||
|
||
Args: | ||
backbone: A `keras_nlp.models.Llama3Backbone` instance. | ||
preprocessor: A `keras_nlp.models.Llama3CausalLMPreprocessor` or `None`. | ||
If `None`, this model will not apply preprocessing, and inputs | ||
should be preprocessed before calling the model. | ||
""" | ||
|
||
backbone_cls = Llama3Backbone | ||
preprocessor_cls = Llama3CausalLMPreprocessor |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also for consistency, can we copy some tests for llama3 backbone too? a preset numerics test we should probably annotate with extra_large, given how big our presets are. or just leave it off if it will be too much hassle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should. I don't have compute to locally run the preset tests so feel free to add them in a follow-up if you have time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do as a follow up!