Skip to content

Commit

Permalink
Fix Mistral memory consumption with JAX and default dtype bug (#1460)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel authored and mattdangerw committed Feb 27, 2024
1 parent 712f172 commit 47c1ab5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 123 deletions.
1 change: 1 addition & 0 deletions keras_nlp/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def next(prompt, cache, index):
mask=padding_mask,
end_token_id=end_token_id,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/mistral/mistral_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"path": "mistral",
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3",
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6",
},
"mistral_instruct_7b_en": {
"metadata": {
Expand All @@ -33,6 +33,6 @@
"path": "mistral",
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3",
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
},
}
190 changes: 69 additions & 121 deletions tools/checkpoint_conversion/convert_mistral_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import gc
import json
import os
import pathlib
import shutil
import tempfile
import traceback

import keras
import numpy as np
import requests
from absl import app
Expand All @@ -27,10 +25,10 @@
from transformers import AutoTokenizer
from transformers import MistralForCausalLM

import keras_nlp
from keras_nlp.models import MistralBackbone
from keras_nlp.models import MistralCausalLMPreprocessor
from keras_nlp.models import MistralTokenizer
from keras_nlp.utils.preset_utils import save_to_preset

PRESET_MAP = {
"mistral_7b_en": "mistralai/Mistral-7B-v0.1",
Expand Down Expand Up @@ -227,124 +225,74 @@ def main(_):
preset = FLAGS.preset
hf_preset = PRESET_MAP[preset]

# === Create the save directories ===
model_dir = pathlib.Path(__file__).parent / f"{preset}"
tokenizer_dir = model_dir / "assets" / "tokenizer"
if not model_dir.exists():
os.makedirs(model_dir)
if not tokenizer_dir.exists():
os.makedirs(tokenizer_dir)
# === Create the temporary save directories ===
temp_dir = tempfile.mkdtemp()

# === Load the Huggingface model ===
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print("\n-> Huggingface model and tokenizer loaded")

# === Load the KerasNLP model ===
keras_nlp_config = dict(
vocabulary_size=hf_model.config.vocab_size,
hidden_dim=hf_model.config.hidden_size,
num_layers=hf_model.config.num_hidden_layers,
num_query_heads=hf_model.config.num_attention_heads,
num_key_value_heads=hf_model.config.num_key_value_heads,
intermediate_dim=hf_model.config.intermediate_size,
sliding_window=hf_model.config.sliding_window,
layer_norm_epsilon=hf_model.config.rms_norm_eps,
rope_max_wavelength=hf_model.config.rope_theta,
dtype="float32",
)
keras_nlp_model = MistralBackbone(**keras_nlp_config)

# === Download the tokenizer from Huggingface model card ===
spm_path = (
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
)
response = requests.get(spm_path)
if not response.ok:
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
tokenizer_path = tokenizer_dir / "vocabulary.spm"
with open(tokenizer_path, "wb") as tokenizer_file:
tokenizer_file.write(response.content)
keras_nlp_tokenizer = MistralTokenizer(str(tokenizer_path.absolute()))
print("\n-> Keras 3 model and tokenizer loaded.")

# === Port the weights ===
convert_checkpoints(keras_nlp_model, hf_model)
print("\n-> Weight transfer done.")

# === Check that the models and tokenizers outputs match ===
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
print("\n-> Tests passed!")

# === Save the model weights in float32 format ===
keras_nlp_model.save_weights(
str((model_dir / "model.weights.h5").absolute())
)
print("\n-> Saved the model weights in float16")

del keras_nlp_model, hf_model
gc.collect()

keras_nlp_config["dtype"] = "float16"

# === Save the weights again in float16 ===
keras_nlp_model = MistralBackbone(**keras_nlp_config)
keras_nlp_model.load_weights(
str((model_dir / "model.weights.h5").absolute())
)
keras_nlp_model.save_weights(
str((model_dir / "model.weights.h5").absolute())
)
print("-> Saved the model weights in float16")

# === Save the model config ===
keras_nlp_config["dtype"] = "bfloat16"
model_config = {
"module": "keras_nlp.src.models.mistral.mistral_backbone",
"class_name": "MistralBackbone",
"config": {**keras_nlp_config},
"registered_name": "keras_nlp>MistralBackbone",
"assets": [],
"weights": "model.weights.h5",
}
model_config_json = json.dumps(model_config)
with open(model_dir / "config.json", "w") as model_config_file:
model_config_file.write(model_config_json)
print("\n-> Saved model config")

# === Save the tokenizer config ===
tokenizer_config = {
"module": "keras_nlp.src.models.mistral.Mistral_tokenizer",
"class_name": "MistralTokenizer",
"config": {
"name": "mistral_tokenizer",
"trainable": True,
"dtype": "int32",
"proto": None,
"sequence_length": None,
},
"registered_name": "keras_nlp>MistralTokenizer",
"assets": ["assets/tokenizer/vocabulary.spm"],
"weights": None,
}
tokenizer_config_json = json.dumps(tokenizer_config)
with open(model_dir / "tokenizer.json", "w") as tokenizer_config_file:
tokenizer_config_file.write(tokenizer_config_json)
print("\n-> Saved tokenizer config")
try:
# === Load the Huggingface model ===
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print("\n-> Huggingface model and tokenizer loaded")

# === Load the KerasNLP model ===
backbone_kwargs = dict(
vocabulary_size=hf_model.config.vocab_size,
hidden_dim=hf_model.config.hidden_size,
num_layers=hf_model.config.num_hidden_layers,
num_query_heads=hf_model.config.num_attention_heads,
num_key_value_heads=hf_model.config.num_key_value_heads,
intermediate_dim=hf_model.config.intermediate_size,
sliding_window=hf_model.config.sliding_window,
layer_norm_epsilon=hf_model.config.rms_norm_eps,
rope_max_wavelength=hf_model.config.rope_theta,
dtype="float32",
)
keras_nlp_model = MistralBackbone(**backbone_kwargs)

# === Save metadata ===
metadata_config = {
"keras_version": keras.__version__,
"keras_nlp_version": keras_nlp.__version__,
"parameter_count": keras_nlp_model.count_params(),
"date_saved": datetime.datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_config_json = json.dumps(metadata_config)
with open(model_dir / "metadata.json", "w") as metadata_config_file:
metadata_config_file.write(metadata_config_json)
print("\n-> Saved metadata")
# === Download the tokenizer from Huggingface model card ===
spm_path = (
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
)
response = requests.get(spm_path)
if not response.ok:
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
tokenizer_path = os.path.join(temp_dir, "vocabulary.spm")
with open(tokenizer_path, "wb") as tokenizer_file:
tokenizer_file.write(response.content)
keras_nlp_tokenizer = MistralTokenizer(tokenizer_path)
print("\n-> Keras 3 model and tokenizer loaded.")

# === Port the weights ===
convert_checkpoints(keras_nlp_model, hf_model)
print("\n-> Weight transfer done.")

# === Check that the models and tokenizers outputs match ===
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
print("\n-> Tests passed!")

# === Save the model weights in float32 format ===
keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
print("\n-> Saved the model weights in float32")

del keras_nlp_model, hf_model
gc.collect()

# === Save the weights again in float16 ===
backbone_kwargs["dtype"] = "float16"
keras_nlp_model = MistralBackbone(**backbone_kwargs)
keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))
save_to_preset(keras_nlp_model, preset)
print("\n-> Saved the model preset in float16")

# === Save the tokenizer ===
save_to_preset(
keras_nlp_tokenizer, preset, config_filename="tokenizer.json"
)
print("\n-> Saved the tokenizer")
finally:
shutil.rmtree(temp_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit 47c1ab5

Please sign in to comment.