Skip to content

Commit

Permalink
Fix paligemma checkpoint conversion script (keras-team#1931)
Browse files Browse the repository at this point in the history
* add back default image resizing

* fix bug in image converter

* fix paligemma checkpoint conversion file

* fix preset name

* remove debug code

* revert unintended changes
  • Loading branch information
divyashreepathihalli authored and ushareng committed Oct 24, 2024
1 parent e26eb9b commit f391856
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-mix-224.npz \
--image_size=224 --checkpoint_name=pali_gemma_3b_mix_224
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-mix-448.npz \
--image_size=448 --checkpoint_name=pali_gemma_3b_mix_448
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-224.npz \
--image_size=224 --checkpoint_name=pali_gemma_3b_224
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-448.npz \
--image_size=448 --checkpoint_name=pali_gemma_3b_448
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-896.npz \
--image_size=896 --checkpoint_name=pali_gemma_3b_896
"""

import argparse
import os

Expand All @@ -15,6 +33,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)

os.environ["KERAS_BACKEND"] = "jax"

Expand Down Expand Up @@ -308,15 +329,27 @@ def main(args):
pali_gemma_backbone_config = {
"vit_num_layers": 27,
"vit_hidden_dim": 1152,
"vocabulary_size": 257152,
"image_size": args.image_size,
"num_layers": 18,
"num_query_heads": 8,
"num_key_value_heads": 1,
"hidden_dim": 2048,
"intermediate_dim": 32768,
"head_dim": 256,
"vit_patch_size": 14,
"vit_num_heads": 16,
}
pg_image_converter = PaliGemmaImageConverter(
image_size=(args.image_size, args.image_size),
scale=1.0 / 127.5,
offset=-1,
)
tokenizer = PaliGemmaTokenizer(
proto="vocabulary.spm",
)
pg_presprocessor = PaliGemmaCausalLMPreprocessor(
image_converter=pg_image_converter
tokenizer=tokenizer, image_converter=pg_image_converter
)
pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config)
keras_model = PaliGemmaCausalLM(
Expand All @@ -325,8 +358,10 @@ def main(args):
# This could be from kaggle or provide local dir path
weights = np.load(args.weights_path)
jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config)
keras_model = convert_pali_gemma_weights(
keras_model, jax_weights["params"], **pali_gemma_backbone_config
keras_model.backbone = convert_pali_gemma_weights(
keras_model.backbone,
jax_weights["params"],
**pali_gemma_backbone_config,
)
# Specify preset name
keras_model.save_to_preset(args.checkpoint_name)
Expand Down

0 comments on commit f391856

Please sign in to comment.