diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index cdd50670f3..02f9c1222f 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -75,6 +75,8 @@ ) from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/falcon/falcon_presets.py b/keras_nlp/models/falcon/falcon_presets.py new file mode 100644 index 0000000000..b0bb6aa54e --- /dev/null +++ b/keras_nlp/models/falcon/falcon_presets.py @@ -0,0 +1,30 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon model preset configurations.""" + +backbone_presets = { + "falcon_refinedweb_1b_en": { + "metadata": { + "description": ( + "24-layer Falcon model (Falcon with 1B parameters), trained on " + "350B tokens of RefinedWeb dataset." + ), + "params": 1311625216, + "official_name": "Falcon", + "path": "falcon", + "model_card": "https://huggingface.co/tiiuae/falcon-rw-1b", + }, + "kaggle_handle": "kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1", + }, +} diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py new file mode 100644 index 0000000000..3201d27a63 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -0,0 +1,117 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.falcon.falcon_presets import backbone_presets +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.FalconTokenizer") +class FalconTokenizer(BytePairTokenizer): + """Falcon tokenizer based on BytePairTokenizer. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by Falcon + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a Falcon preset. + + This tokenizer does not provide truncation or padding of inputs. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.FalconTokenizer.from_preset("falcon_refinedweb_1b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6} + merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + merges += ["Ġ f", "o x", "Ġf ox"] + tokenizer = keras_nlp.models.FalconTokenizer(vocabulary=vocab, merges=merges) + tokenizer("a quick fox.") + ``` + """ + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Falcon uses the same start as end token, i.e., "<|endoftext|>". + self.end_token = self.start_token = "<|endoftext|>" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[self.end_token], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + if self.end_token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{self.end_token}'` in the provided " + f"`vocabulary`. Please provide `'{self.end_token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.end_token_id = self.token_to_id(self.end_token) + self.start_token_id = self.end_token_id + self.pad_token_id = 0 + else: + self.end_token_id = None + self.start_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/models/falcon/falcon_tokenizer_test.py b/keras_nlp/models/falcon/falcon_tokenizer_test.py new file mode 100644 index 0000000000..735bcac4b6 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_tokenizer_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + " airplane at airport<|endoftext|>", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=FalconTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + FalconTokenizer(vocabulary=["a", "b", "c"], merges=[]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=FalconTokenizer, + preset="falcon_refinedweb_1b_en", + input_data=["The quick brown fox."], + expected_output=[[464, 2068, 7586, 21831, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconTokenizer.presets: + self.run_preset_test( + cls=FalconTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py index 90a06503dc..fdbdffd670 100644 --- a/tools/checkpoint_conversion/convert_falcon_checkpoints.py +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -11,51 +11,110 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Falcon weight conversion script. + +To run, install the CPU only development environment and huggingface libraries: +``` +pip install -r requirements.txt +pip install transformers huggingface-cli +``` + +Login to Huggingface: +``` +huggingface-cli login +``` + +Finally run this script to convert, validate and upload weights. +``` +python tools/checkpoint_conversion/convert_falcon_checkpoints.py \ + --preset falcon_refinedweb_1b_en +``` +""" + +import json import os -import tempfile -import keras -import numpy as np -import tensorflow as tf -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" -from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +import absl # noqa: E402 +import huggingface_hub # noqa: E402 +import numpy as np # noqa: E402 +import torch # noqa: E402 +import transformers # noqa: E402 -keras.config.disable_traceback_filtering() +import keras_nlp # noqa: E402 +PRESET_MAP = { + "falcon_refinedweb_1b_en": "tiiuae/falcon-rw-1b", +} -def convert_checkpoints(hf_model): +EXTRACT_DIR = "./model" + +FLAGS = absl.flags.FLAGS +absl.flags.DEFINE_string( + "preset", + "falcon_refinedweb_1b_en", + f'Must be one of {",".join(PRESET_MAP.keys())}.', +) + + +def download_hf_model(hf_model_name): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.bin"], + ignore_patterns=["onnx/*"], + local_dir=EXTRACT_DIR, + ) + + return hf_model_dir + + +def convert_model(hf_model): hf_config = hf_model.config.to_dict() - cfg = {} - cfg["vocabulary_size"] = hf_config["vocab_size"] - cfg["num_layers"] = hf_config["num_hidden_layers"] - cfg["num_attention_heads"] = hf_config["num_attention_heads"] - cfg["hidden_dim"] = hf_config["hidden_size"] - cfg["intermediate_dim"] = 4 * cfg["hidden_dim"] - cfg["feedforward_dropout_rate"] = hf_config["hidden_dropout"] - cfg["attention_dropout_rate"] = hf_config["attention_dropout"] + kwargs = {} + kwargs["vocabulary_size"] = hf_config["vocab_size"] + kwargs["num_layers"] = hf_config["num_hidden_layers"] + kwargs["num_attention_heads"] = hf_config["num_attention_heads"] + kwargs["hidden_dim"] = hf_config["hidden_size"] + kwargs["intermediate_dim"] = 4 * kwargs["hidden_dim"] + kwargs["feedforward_dropout_rate"] = hf_config["hidden_dropout"] + kwargs["attention_dropout_rate"] = hf_config["attention_dropout"] - keras_model = FalconBackbone(**cfg) + return keras_nlp.models.FalconBackbone(**kwargs) + +def convert_tokenizer(hf_model_dir): + tokenizer_file_path = os.path.join(hf_model_dir, "tokenizer.json") + with open(tokenizer_file_path) as tokenizer_file: + hf_tokenizer = json.load(tokenizer_file) + + vocab = hf_tokenizer["model"]["vocab"] + merges = hf_tokenizer["model"]["merges"] + return keras_nlp.models.FalconTokenizer(vocabulary=vocab, merges=merges) + + +def convert_weights(keras_model, hf_model): + hf_model.eval() hf_wts = hf_model.state_dict() - # transformer.word_embeddings.weight + # token_embedding. keras_model.get_layer("token_embedding").embeddings.assign( - hf_wts["transformer.word_embeddings.weight"] + hf_wts["word_embeddings.weight"] ) - for i in range(keras_model.num_layers): - # split key query value + for ilayer in range(keras_model.num_layers): + # Split key query value. fused_qkv = ( - hf_wts[f"transformer.h.{i}.self_attention.query_key_value.weight"] + hf_wts[f"h.{ilayer}.self_attention.query_key_value.weight"] .numpy() .T ) seq_length, _ = fused_qkv.shape - head_dim = cfg["hidden_dim"] // cfg["num_attention_heads"] + head_dim = keras_model.hidden_dim // keras_model.num_attention_heads fused_qkv = fused_qkv.reshape( - seq_length, cfg["num_attention_heads"], 3, head_dim + seq_length, keras_model.num_attention_heads, 3, head_dim ) query, key, value = ( fused_qkv[..., 0, :], @@ -64,9 +123,11 @@ def convert_checkpoints(hf_model): ) fused_bias = hf_wts[ - f"transformer.h.{i}.self_attention.query_key_value.bias" + f"h.{ilayer}.self_attention.query_key_value.bias" ].numpy() - fused_bias = fused_bias.reshape(cfg["num_attention_heads"], 3, head_dim) + fused_bias = fused_bias.reshape( + keras_model.num_attention_heads, 3, head_dim + ) query_bias, key_bias, value_bias = ( fused_bias[..., 0, :], fused_bias[..., 1, :], @@ -74,132 +135,118 @@ def convert_checkpoints(hf_model): ) # TODO: check if bias is true before assigning bias. - # transformer.h.0.self_attention.query_key_value.weight - # transformer.h.0.self_attention.query_key_value.bias + # Attention/query. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.query_dense.kernel.assign(query) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.query_dense.bias.assign(query_bias) + # Attention/key. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.key_dense.kernel.assign(key) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.key_dense.bias.assign(key_bias) + # Attention/value. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.value_dense.kernel.assign(value) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.value_dense.bias.assign(value_bias) - # transformer.h.0.self_attention.dense.weight - # transformer.h.0.self_attention.dense.bias + # Attention/dense. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.output_dense.kernel.assign( - hf_wts[f"transformer.h.{i}.self_attention.dense.weight"].T.numpy() + hf_wts[f"h.{ilayer}.self_attention.dense.weight"].T.numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.output_dense.bias.assign( - hf_wts[f"transformer.h.{i}.self_attention.dense.bias"].numpy() + hf_wts[f"h.{ilayer}.self_attention.dense.bias"].numpy() ) - # transformer.h.0.mlp.dense_h_to_4h.weight - # transformer.h.0.mlp.dense_h_to_4h.bias + # MLP/dense_h_to_4h. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_h_to_4h.kernel.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"].T.numpy() + hf_wts[f"h.{ilayer}.mlp.dense_h_to_4h.weight"].T.numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_h_to_4h.bias.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"].numpy() + hf_wts[f"h.{ilayer}.mlp.dense_h_to_4h.bias"].numpy() ) - # transformer.h.0.mlp.dense_4h_to_h.weight - # transformer.h.0.mlp.dense_4h_to_h.bias + # MLP/dense_4h_to_h. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_4h_to_h.kernel.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"].T.numpy() + hf_wts[f"h.{ilayer}.mlp.dense_4h_to_h.weight"].T.numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_4h_to_h.bias.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"].numpy() + hf_wts[f"h.{ilayer}.mlp.dense_4h_to_h.bias"].numpy() ) - # transformer.h.0.input_layernorm.weight - # transformer.h.0.input_layernorm.bias + # input_layernorm. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).input_layernorm.gamma.assign( - hf_wts[f"transformer.h.{i}.input_layernorm.weight"] + hf_wts[f"h.{ilayer}.input_layernorm.weight"] ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).input_layernorm.beta.assign( - hf_wts[f"transformer.h.{i}.input_layernorm.bias"] + hf_wts[f"h.{ilayer}.input_layernorm.bias"] ) - # transformer.h.0.post_attention_layernorm.weight - # transformer.h.0.post_attention_layernorm.bias + # post_attention_layernorm. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).post_attention_layernorm.gamma.assign( - hf_wts[f"transformer.h.{i}.post_attention_layernorm.weight"].numpy() + hf_wts[f"h.{ilayer}.post_attention_layernorm.weight"].numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).post_attention_layernorm.beta.assign( - hf_wts[f"transformer.h.{i}.post_attention_layernorm.bias"].numpy() + hf_wts[f"h.{ilayer}.post_attention_layernorm.bias"].numpy() ) - # transformer.ln_f.weight - # transformer.ln_f.bias + # final_layernorm. keras_model.get_layer("final_layernorm").gamma.assign( - hf_wts["transformer.ln_f.weight"].numpy() + hf_wts["ln_f.weight"].numpy() ) keras_model.get_layer("final_layernorm").beta.assign( - hf_wts["transformer.ln_f.bias"].numpy() + hf_wts["ln_f.bias"].numpy() ) - # TODO: Assign lm_head weights for CausalLM. - # # lm_head.weight - # keras_model.get_layer("lm_head").kernel.assign( - # hf_wts["lm_head.weight"].T.numpy() - # ) - - # Save the model. - print("Save KerasNLP model weights.") - temp_dir = tempfile.mkdtemp() - keras_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) - - return keras_model - -def check_output(keras_model, hf_model, hf_model_name): - sample_text = ["I am so happy today!"] - hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name) - hf_tokenizer.pad_token = hf_tokenizer.eos_token - hf_sample_input = hf_tokenizer( - sample_text, padding="max_length", return_tensors="pt" - ) - sample_input = { - "token_ids": tf.constant(hf_sample_input["input_ids"].numpy()), - "padding_mask": tf.constant(hf_sample_input["attention_mask"].numpy()), +def validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, +): + input_str = ["the quick brown fox ran, galloped and jumped."] + + # KerasNLP model. + token_ids = torch.tensor(keras_tokenizer(input_str)) + padding_mask = token_ids != 3 + keras_model_input = { + "token_ids": token_ids, + "padding_mask": padding_mask, } - print("token_ids: ", sample_input["token_ids"][0, :7]) - print("padding_mask", sample_input["padding_mask"][0, :7]) + keras_model_outputs = keras_model.predict(keras_model_input) - keras_output = keras_model.predict(sample_input) + # HuggingFace model. + hf_model_input = hf_tokenizer(input_str, return_tensors="pt") activation = {} @@ -209,30 +256,52 @@ def hook(hf_model, input, output): return hook - hf_model.transformer.register_forward_hook( - get_activation("transformer.ln_f") - ) - hf_model(**hf_sample_input) - hf_output = activation["transformer.ln_f"] - print("Keras shape: ", keras_output.shape) - print("HF shape: ", hf_output.shape) - - print("KerasNLP output:", keras_output[0, 1, :5]) - print("HF output:", hf_output[0, 1, :5]) - print( - "Difference:", - np.mean( - abs(keras_output[:, :6, :] - hf_output.detach().numpy()[:, :6, :]) - ), - ) + hf_model.register_forward_hook(get_activation("ln_f")) + hf_model(**hf_model_input) + hf_model_outputs = activation["ln_f"].detach().numpy() + + # Comparing the outputs. + print("🔶 KerasNLP tokens ids:", keras_model_input["token_ids"]) + print("🔶 HF tokens ids:", hf_model_input["input_ids"]) + print("🔶 KerasNLP output:", keras_model_outputs[0, 1, :10]) + print("🔶 HF output:", hf_model_outputs[0, 1, :10]) + print("🔶 Difference:", np.mean(keras_model_outputs - hf_model_outputs)) + + +def main(_): + preset = FLAGS.preset + print(f"✅ Coverting {preset}") + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("✅ Huggingface model downloaded from hub") -def main(): - hf_model_name = "tiiuae/falcon-rw-1b" - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name) - keras_model = convert_checkpoints(hf_model) - check_output(keras_model, hf_model, hf_model_name) + hf_model = transformers.FalconModel.from_pretrained(hf_model_dir) + # Falcon uses GPT2 tokenizer. + hf_tokenizer = transformers.GPT2TokenizerFast.from_pretrained(hf_model_dir) + print("✅ Huggingface model loaded") + + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("✅ Keras model loaded") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted") + + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("✅ Numerics validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print("✅ Preset saved") if __name__ == "__main__": - main() + absl.app.run(main)