Skip to content

Commit

Permalink
Add support for converting Gemma 2 checkpoints (#1700)
Browse files Browse the repository at this point in the history
* Add support for Gemma 2 checkpoints

This is sadly a little hacky as the flax support for Gemma 2 is not yet
complete. So output checking will not match up, but we can still convert
checkpoints.

* Support newer orbax checkponts, allow flax failures
  • Loading branch information
mattdangerw authored Jul 22, 2024
1 parent b6877df commit 9b20c6e
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions tools/checkpoint_conversion/convert_gemma_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,24 @@ def download_flax_model(handle):
return kagglehub.model_download(handle)


def convert_model(flax_config, vocab_size):
def convert_model(flax_config, flax_params, vocab_size):
kwargs = {}
# Hack to infer Gemma 2 config options until Flax actually adds support.
if "post_attention_norm" in flax_params["transformer"]["layer_0"]:
# The 27B parameter model is the only model that does a weird
# query normalization.
is_gemma2_27b = flax_config.num_heads == 32
# We would like to convert these from Flax, but have no way until
# flax supports Gemma 2.
kwargs = {
"query_head_dim_normalize": not is_gemma2_27b,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"final_logit_soft_cap": 30,
"attention_logit_soft_cap": 50,
"use_sliding_window_attention": True,
"sliding_window_size": 4096,
}
return keras_nlp.models.GemmaBackbone(
vocabulary_size=vocab_size,
num_layers=flax_config.num_layers,
Expand All @@ -95,6 +112,7 @@ def convert_model(flax_config, vocab_size):
hidden_dim=flax_config.embed_dim,
intermediate_dim=flax_config.hidden_dim * 2,
head_dim=flax_config.head_dim,
**kwargs,
)


Expand Down Expand Up @@ -123,6 +141,15 @@ def convert_weights(keras_model, flax_config, flax_params):
[flax_block["pre_ffw_norm"]["scale"]]
)

if "post_attention_norm" in flax_block:
keras_block.post_attention_norm.set_weights(
[flax_block["post_attention_norm"]["scale"]]
)
if "post_ffw_norm" in flax_block:
keras_block.post_ffw_norm.set_weights(
[flax_block["post_ffw_norm"]["scale"]]
)

keras_block.gating_ffw.set_weights(
[flax_block["mlp"]["gating_einsum"][0]]
)
Expand Down Expand Up @@ -176,27 +203,28 @@ def validate_output(
)
keras_output = gemma_lm.generate([input_str], max_length=length)
keras_output = keras_output[0]
print("🔶 KerasNLP output:", keras_output)

# Flax
transformer_config = transformer_lib.TransformerConfig.from_params(
flax_params,
cache_size=length,
)
transformer = transformer_lib.Transformer(transformer_config)
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=flax_tokenizer,
params=flax_params["transformer"],
)
flax_output = sampler(
input_strings=[input_str],
total_generation_steps=length - 5, # Length of "<bos>What is Keras?"
)
flax_output = input_str + flax_output.text[0]

# Comparing the outputs.
print("🔶 KerasNLP output:", keras_output)
print("🔶 Flax output:", flax_output)
try:
transformer_config = transformer_lib.TransformerConfig.from_params(
flax_params,
cache_size=length,
)
transformer = transformer_lib.Transformer(transformer_config)
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=flax_tokenizer,
params=flax_params["transformer"],
)
flax_output = sampler(
input_strings=[input_str],
total_generation_steps=length - 5, # Length of "<bos>What is Keras?"
)
flax_output = input_str + flax_output.text[0]
print("🔶 Flax output:", flax_output)
except Exception as e:
print("🔶 Flax could not be run.", e)


def main(_):
Expand All @@ -223,7 +251,7 @@ def main(_):

checkpoint_dir = None
for path in os.listdir(flax_dir):
checkpoint_file = os.path.join(flax_dir, path, "checkpoint")
checkpoint_file = os.path.join(flax_dir, path, "_METADATA")
if os.path.exists(checkpoint_file):
checkpoint_dir = os.path.join(flax_dir, path)
assert checkpoint_dir is not None, "Cannot find orbax checkpoint files"
Expand All @@ -236,7 +264,7 @@ def main(_):

keras_tokenizer = convert_tokenizer(proto_path)
vocab_size = keras_tokenizer.vocabulary_size()
keras_model = convert_model(flax_config, vocab_size)
keras_model = convert_model(flax_config, flax_params, vocab_size)
print("✅ Keras model loaded")

convert_weights(keras_model, flax_config, flax_params)
Expand Down

0 comments on commit 9b20c6e

Please sign in to comment.