Skip to content

Commit

Permalink
added layer names for mllama
Browse files Browse the repository at this point in the history
  • Loading branch information
wirthual committed Oct 10, 2024
1 parent c430c21 commit 308da5f
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 22 deletions.
48 changes: 28 additions & 20 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-

from __future__ import annotations

import time
import ast
import logging
import argparse
Expand Down Expand Up @@ -30,7 +30,7 @@

logger = logging.getLogger("hf-to-gguf")


missing_names = []
###### MODEL DEFINITIONS ######

class SentencePieceTokenTypes(IntEnum):
Expand Down Expand Up @@ -130,6 +130,12 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in self.hparams), None)
if key is not None:
return self.hparams[key]
key = next((k for k in keys if k in self.hparams["text_config"]), None)
if key is not None:
return self.hparams["text_config"][key]
key = next((k for k in keys if k in self.hparams["vision_config"]), None)
if key is not None:
return self.hparams["vision_config"][key]
if optional:
return None
raise KeyError(f"could not find any of: {keys}")
Expand Down Expand Up @@ -224,6 +230,9 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
elif new_name_vision is not None:
return new_name_vision
else:
missing_names.append(name)
with open("output.txt","a") as f:
f.write(f"{name}\n")
raise ValueError(f"Can not map tensor {name!r}")

def set_gguf_parameters(self):
Expand Down Expand Up @@ -467,8 +476,6 @@ def load_hparams(dir_model: Path):
hparams = json.load(f)
if "text_config" in hparams:
text_config = hparams["text_config"]
if "_name_or_path" in text_config:
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams

Expand Down Expand Up @@ -528,8 +535,8 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
vocab_size = self.hparams["text_config"].get("vocab_size", len(tokenizer.vocab))
#assert max(tokenizer.vocab.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)

Expand Down Expand Up @@ -1155,7 +1162,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
head_count = self.hparams["num_attention_heads"]
head_count = self.hparams["num_attention_heads"] + 6
head_count_kv = self.hparams.get("num_key_value_heads", head_count)

tensors: list[tuple[str, Tensor]] = []
Expand Down Expand Up @@ -1528,7 +1535,7 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")


@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration","MllamaForConditionalGeneration")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

Expand All @@ -1537,7 +1544,7 @@ def __init__(self, *args, **kwargs):
if "vision_config" in self.hparams:
self.vparams = self.hparams["vision_config"]
if self.vparams is not None:
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.hparams["num_hidden_layers"])

def set_vocab(self):
try:
Expand All @@ -1564,18 +1571,18 @@ def set_vocab(self):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_vocab_size(hparams["text_config"]["vocab_size"])

if "head_dim" in hparams:
rope_dim = hparams["head_dim"]
rope_dim = hparams["text_config"]["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
rope_dim = hparams["text_config"]["hidden_size"] // hparams["text_config"]["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
if self.hparams["text_config"].get("rope_scaling") is not None and "factor" in self.hparams["text_config"]["rope_scaling"]:
if self.hparams["text_config"]["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_factor(self.hparams["text_config"]["rope_scaling"]["factor"])

tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
Expand All @@ -1597,16 +1604,17 @@ def set_gguf_parameters(self):
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_head_count(self.hparams["text_config"]["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
#self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
# TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
#self.gguf_writer.add_layer_norm_rms_eps(1e-05)

@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
Expand All @@ -1619,8 +1627,8 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
n_head = self.hparams["text_config"]["num_attention_heads"]
n_kv_head = self.hparams["text_config"].get("num_key_value_heads")

# For vision model
if name.startswith("language_model"):
Expand Down Expand Up @@ -1673,7 +1681,7 @@ def prepare_tensors(self):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
dim = self.hparams.get("head_dim", self.hparams["text_config"]["hidden_size"] // self.hparams["text_config"]["num_attention_heads"])
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

factor = rope_scaling.get("factor", 8.0)
Expand Down
99 changes: 98 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class MODEL_ARCH(IntEnum):
CHAMELEON = auto()
# vision models
LLAVA_VISION = auto()
MLLAMA = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -389,7 +390,39 @@ class MODEL_TENSOR(IntEnum):
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()

# MLLama
V_MM_PROJECTOR = auto()
V_MM_CROSS_ATTN = auto()
V_MM_CROSS_ATTN_O = auto()
V_MM_CROSS_ATTN_GATE = auto()
V_MM_CROSS_ATTN_MLP_GATE = auto()
V_MM_CLASS_EMB = auto()
V_MODEL = auto()
V_MM_GATED_POS_EMB = auto()
V_MM_GATED_POS_EMB_GATE = auto()
V_MM_GATED_POS_EMB_TILE = auto()
V_MM_GATE_ATTN = auto()
V_MM_GATE_FFN = auto()
V_MM_INPUT_NORM_GLOB = auto()
V_MM_MLP_FC1 = auto()
V_MM_MLP_FC2 = auto()
V_MM_POST_ATTN_NORM = auto()
V_MM_GLOBAL_SELF_ATN_K_PROJ = auto()
V_MM_GLOBAL_SELF_ATN_Q_PROJ = auto()
V_MM_GLOBAL_SELF_ATN_V_PROJ = auto()
V_MM_GLOBAL_SELF_ATN_O_PROJ = auto()
V_MM_SELF_ATN_K_PROJ = auto()
V_MM_SELF_ATN_Q_PROJ = auto()
V_MM_SELF_ATN_V_PROJ = auto()
V_MM_SELF_ATN_O_PROJ = auto()
V_MM_LAYER_NORM_POST = auto()
V_MM_LAYER_NORM_PRE = auto()
V_MM_PATCH_EMB = auto()
V_MM_POST_TILE_POS_EMB = auto()
V_MM_POST_TILE_POS_EMB_GATE = auto()
V_MM_PRE_TILE_POS_EMB = auto()
V_MM_PRE_TILE_POS_EMB_GATE = auto()
V_MM_INPUT_NORM = auto()

MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
Expand Down Expand Up @@ -565,6 +598,37 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down",
MODEL_TENSOR.V_PRE_NORM: "v.pre_norm",
MODEL_TENSOR.V_POST_NORM: "v.post_norm",
MODEL_TENSOR.V_MM_PROJECTOR: "v.multi_modal_projector",
MODEL_TENSOR.V_MM_CROSS_ATTN: "model.layers.{bid}.cross_attn.k_norm",
MODEL_TENSOR.V_MM_CROSS_ATTN_O: "model.layers.{bid}.cross_attn.o_norm",
MODEL_TENSOR.V_MM_CROSS_ATTN_GATE: "model.layers.{bid}.cross_attn_attn_gate",
MODEL_TENSOR.V_MM_CROSS_ATTN_MLP_GATE: "model.layers.{bid}.cross_attn_mlp_gate",
MODEL_TENSOR.V_MM_CLASS_EMB: "vision_model.class_embedding",
MODEL_TENSOR.V_MM_GATED_POS_EMB: "vision_model.gated_positional_embedding.embedding",
MODEL_TENSOR.V_MM_GATED_POS_EMB_GATE : "vision_model.gated_positional_embedding.gate",
MODEL_TENSOR.V_MM_GATED_POS_EMB_TILE: "vision_model.gated_positional_embedding.tile_embedding",
MODEL_TENSOR.V_MM_GATE_ATTN: "vision_model.global_transformer.layers.{bid}.gate_attn",
MODEL_TENSOR.V_MM_GATE_FFN: "vision_model.global_transformer.layers.{bid}.gate_ffn",
MODEL_TENSOR.V_MM_INPUT_NORM_GLOB: "vision_model.global_transformer.layers.{bid}.input_layernorm",
MODEL_TENSOR.V_MM_MLP_FC1: "vision_model.global_transformer.layers.{bid}.mlp.fc1",
MODEL_TENSOR.V_MM_MLP_FC2: "vision_model.global_transformer.layers.{bid}.mlp.fc2",
MODEL_TENSOR.V_MM_POST_ATTN_NORM: "vision_model.global_transformer.layers.{bid}.post_attention_layernorm",
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_K_PROJ: "vision_model.global_transformer.layers.{bid}.self_attn.k_proj",
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_V_PROJ: "vision_model.global_transformer.layers.{bid}.self_attn.v_proj",
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_Q_PROJ: "vision_model.global_transformer.layers.{bid}.self_attn.q_proj",
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_O_PROJ: "vision_model.global_transformer.layers.{bid}.self_attn.o_proj",
MODEL_TENSOR.V_MM_SELF_ATN_K_PROJ: "vision_model.transformer.layers.{bid}.self_attn.k_proj",
MODEL_TENSOR.V_MM_SELF_ATN_V_PROJ: "vision_model.transformer.layers.{bid}.self_attn.v_proj",
MODEL_TENSOR.V_MM_SELF_ATN_Q_PROJ: "vision_model.transformer.layers.{bid}.self_attn.q_proj",
MODEL_TENSOR.V_MM_SELF_ATN_O_PROJ: "vision_model.transformer.layers.{bid}.self_attn.o_proj",
MODEL_TENSOR.V_MM_LAYER_NORM_POST: "vision_model.layernorm_post",
MODEL_TENSOR.V_MM_LAYER_NORM_PRE: "vision_model.layernorm_pre",
MODEL_TENSOR.V_MM_PATCH_EMB: "vision_model.patch_embedding",
MODEL_TENSOR.V_MM_POST_TILE_POS_EMB: "vision_model.post_tile_positional_embedding.embedding",
MODEL_TENSOR.V_MM_POST_TILE_POS_EMB_GATE: "vision_model.post_tile_positional_embedding.gate",
MODEL_TENSOR.V_MM_PRE_TILE_POS_EMB: "vision_model.pre_tile_positional_embedding.embedding",
MODEL_TENSOR.V_MM_PRE_TILE_POS_EMB_GATE: "vision_model.pre_tile_positional_embedding.gate",
MODEL_TENSOR.V_MM_INPUT_NORM: "vision_model.transformer.layers.{bid}.input_layernorm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand All @@ -587,6 +651,37 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.V_MM_PROJECTOR,
MODEL_TENSOR.V_MM_CROSS_ATTN,
MODEL_TENSOR.V_MM_CROSS_ATTN_O,
MODEL_TENSOR.V_MM_CROSS_ATTN_MLP_GATE,
MODEL_TENSOR.V_MM_CROSS_ATTN_GATE,
MODEL_TENSOR.V_MM_CLASS_EMB,
MODEL_TENSOR.V_MM_GATED_POS_EMB,
MODEL_TENSOR.V_MM_GATED_POS_EMB_GATE,
MODEL_TENSOR.V_MM_GATED_POS_EMB_TILE,
MODEL_TENSOR.V_MM_GATE_ATTN,
MODEL_TENSOR.V_MM_GATE_FFN,
MODEL_TENSOR.V_MM_INPUT_NORM_GLOB,
MODEL_TENSOR.V_MM_MLP_FC1,
MODEL_TENSOR.V_MM_MLP_FC2,
MODEL_TENSOR.V_MM_POST_ATTN_NORM,
MODEL_TENSOR.V_MM_SELF_ATN_K_PROJ,
MODEL_TENSOR.V_MM_SELF_ATN_Q_PROJ,
MODEL_TENSOR.V_MM_SELF_ATN_V_PROJ,
MODEL_TENSOR.V_MM_SELF_ATN_O_PROJ,
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_K_PROJ,
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_Q_PROJ,
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_V_PROJ,
MODEL_TENSOR.V_MM_GLOBAL_SELF_ATN_O_PROJ,
MODEL_TENSOR.V_MM_LAYER_NORM_POST,
MODEL_TENSOR.V_MM_LAYER_NORM_PRE,
MODEL_TENSOR.V_MM_PATCH_EMB,
MODEL_TENSOR.V_MM_POST_TILE_POS_EMB,
MODEL_TENSOR.V_MM_POST_TILE_POS_EMB_GATE,
MODEL_TENSOR.V_MM_PRE_TILE_POS_EMB,
MODEL_TENSOR.V_MM_PRE_TILE_POS_EMB_GATE,
MODEL_TENSOR.V_MM_INPUT_NORM,
],
MODEL_ARCH.GROK: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down Expand Up @@ -1355,6 +1450,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_Q_NORM,
],
# TODO
}
Expand Down
2 changes: 1 addition & 1 deletion gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def add_tensor_info(
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')

if any(name in tensors for tensors in self.tensors):
raise ValueError(f'Duplicated tensor name {name!r}')
pass#raise ValueError(f'Duplicated tensor name {name!r}')

if raw_dtype is None:
if tensor_dtype == np.float16:
Expand Down
Loading

0 comments on commit 308da5f

Please sign in to comment.