Skip to content

Commit

Permalink
llama : add support for GLM-Edge and GLM-Edge-V series models (ggerga…
Browse files Browse the repository at this point in the history
…nov#10573)

* add glm edge chat model

* use config partial_rotary_factor as rope ratio

* support for glm edge model

* vision model support

* remove debug info

* fix format

* llava.cpp trailing whitespace

* remove unused AutoTokenizer

* Update src/llama.cpp for not contain <|end|> or </s>

Co-authored-by: Xuan Son Nguyen <[email protected]>

* add edge template

* fix chat template

* fix confict

* fix confict

* fix ci err

* fix format err

* fix template err

* 9b hf chat support

* format

* format clip.cpp

* fix format

* Apply suggestions from code review

* Apply suggestions from code review

* Update examples/llava/clip.cpp

* fix format

* minor : style

---------

Co-authored-by: liyuhang <[email protected]>
Co-authored-by: piDack <[email protected]>
Co-authored-by: Xuan Son Nguyen <[email protected]>
Co-authored-by: liyuhang <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
6 people authored Feb 2, 2025
1 parent 53debe6 commit 0cec062
Show file tree
Hide file tree
Showing 15 changed files with 568 additions and 67 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Bitnet b1.58 models](https://huggingface.co/1bitLLM)
- [x] [Flan T5](https://huggingface.co/models?search=flan-t5)
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
Expand All @@ -117,6 +117,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge)
- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)

</details>
Expand Down
58 changes: 15 additions & 43 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
Expand Down Expand Up @@ -4513,7 +4513,7 @@ def prepare_tensors(self):
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)


@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM

Expand Down Expand Up @@ -4619,47 +4619,15 @@ def set_vocab(self):

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["padded_vocab_size"]
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
assert max(tokenizer.get_vocab().values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)

merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) >= 2 and len(merged) <= 7
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))

# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.merges = merges
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
# only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
Expand All @@ -4670,16 +4638,20 @@ def set_vocab(self):
def set_gguf_parameters(self):
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head))
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_dimension_count(64)
if "attention_dim" in self.hparams:
rope_dim = self.hparams["attention_dim"]
else:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_add_bos_token(False)
rope_freq = 10000
if "rope_ratio" in self.hparams:
Expand All @@ -4689,7 +4661,7 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.endswith(".rotary_pos_emb.inv_freq"):
if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."):
return []

name = name.removeprefix("transformer.")
Expand Down
43 changes: 43 additions & 0 deletions examples/llava/README-glmedge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# GLMV-EDGE

Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b).

## Usage
Build with cmake or run `make llama-llava-cli` to build it.

After building, run: `./llama-llava-cli` to see the usage. For example:

```sh
./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <image><|user|>\n prompt <|assistant|>\n"
```

**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
**note**: For GPU offloading ensure to use the `-ngl` flag just like usual

## GGUF conversion

1. Clone a GLMV-EDGE model ([2B](https://huggingface.co/THUDM/glm-edge-v-2b) or [5B](https://huggingface.co/THUDM/glm-edge-v-5b)). For example:

```sh
git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/THUDM/glm-edge-v-2b
```

2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents:

```sh
python ./examples/llava/glmedge-surgery.py -m ../model_path
```

4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF:

```sh
python ./examples/llava/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path
```

5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF:

```sh
python convert_hf_to_gguf.py ../model_path
```

Now both the LLM part and the image encoder are in the `model_path` directory.
Loading

0 comments on commit 0cec062

Please sign in to comment.