Skip to content

Commit

Permalink
Add a subtle fix for gemma 2 conversions
Browse files Browse the repository at this point in the history
Gemma 2 will use different normalization constants for the query
depending of the model size.

9b = head_dim
27b = hidden_dim / num_query_heads

We need to slightly tweak our config conversion to account for this.
  • Loading branch information
mattdangerw committed Jul 19, 2024
1 parent b0c21b3 commit 8df5959
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion keras_nlp/src/utils/transformers/convert_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def load_gemma_backbone(cls, preset, load_weights):
"hidden_dim": transformers_config["hidden_size"],
"intermediate_dim": transformers_config["intermediate_size"] * 2,
"head_dim": transformers_config["head_dim"],
"query_head_dim_normalize": False,
"query_head_dim_normalize": (
transformers_config["head_dim"]
== transformers_config["query_pre_attn_scalar"]
),
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"final_logit_soft_cap": transformers_config[
Expand Down

0 comments on commit 8df5959

Please sign in to comment.