Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix paligemma checkpoint conversion script #1931

Merged
41 changes: 38 additions & 3 deletions tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-mix-224.npz \
--image_size=224 --checkpoint_name=pali_gemma_3b_mix_224
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-mix-448.npz \
--image_size=448 --checkpoint_name=pali_gemma_3b_mix_448
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-224.npz \
--image_size=224 --checkpoint_name=pali_gemma_3b_224
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-448.npz \
--image_size=448 --checkpoint_name=pali_gemma_3b_448
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-896.npz \
--image_size=896 --checkpoint_name=pali_gemma_3b_896
"""

import argparse
import os

Expand All @@ -15,6 +33,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)

os.environ["KERAS_BACKEND"] = "jax"

Expand Down Expand Up @@ -308,15 +329,27 @@ def main(args):
pali_gemma_backbone_config = {
"vit_num_layers": 27,
"vit_hidden_dim": 1152,
"vocabulary_size": 257152,
"image_size": args.image_size,
"num_layers": 18,
"num_query_heads": 8,
"num_key_value_heads": 1,
"hidden_dim": 2048,
"intermediate_dim": 32768,
"head_dim": 256,
"vit_patch_size": 14,
"vit_num_heads": 16,
}
pg_image_converter = PaliGemmaImageConverter(
image_size=(args.image_size, args.image_size),
scale=1.0 / 127.5,
offset=-1,
)
tokenizer = PaliGemmaTokenizer(
proto="vocabulary.spm",
)
pg_presprocessor = PaliGemmaCausalLMPreprocessor(
image_converter=pg_image_converter
tokenizer=tokenizer, image_converter=pg_image_converter
)
pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config)
keras_model = PaliGemmaCausalLM(
Expand All @@ -325,8 +358,10 @@ def main(args):
# This could be from kaggle or provide local dir path
weights = np.load(args.weights_path)
jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config)
keras_model = convert_pali_gemma_weights(
keras_model, jax_weights["params"], **pali_gemma_backbone_config
keras_model.backbone = convert_pali_gemma_weights(
keras_model.backbone,
jax_weights["params"],
**pali_gemma_backbone_config,
)
# Specify preset name
keras_model.save_to_preset(args.checkpoint_name)
Expand Down
Loading