Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StableLM support #3586

Merged
merged 33 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
27467a5
Initial support - model loads, generates random stuff
Galunid Oct 11, 2023
80b2e72
Remove gpt neox references
Galunid Oct 11, 2023
605e701
Fixes suggested by @mmnga
Galunid Oct 12, 2023
1ee5cc3
Make stablelm conversion script use .safetensors
Galunid Oct 18, 2023
f1dd430
Remove random junk print
Galunid Oct 19, 2023
4fbce39
Fix model conversion script
Galunid Oct 21, 2023
a71041a
Use ggml_norm not ggml_rms_norm
Galunid Oct 22, 2023
76b4495
Fix rope parameters
Galunid Oct 22, 2023
e167ebc
Fix formatting in gguf.py
Galunid Oct 22, 2023
839a183
Fix formatting in llama.cpp
Galunid Oct 22, 2023
db09c02
Merge branch 'master' into stablelm-support
Galunid Oct 22, 2023
0153376
batch.seq_id[j] -> batch.seq_id[j][0]
Galunid Oct 22, 2023
e399050
Fix added_tokens crashes
Galunid Oct 22, 2023
cf5eff3
Merge branch 'master' into stablelm-support
Galunid Oct 22, 2023
a92fd2d
Add tests for stablelm tokenizer
Galunid Oct 22, 2023
d9c0332
Update readme with stablelm support
Galunid Oct 22, 2023
fa2cd7e
Add special token handling to conver script
Galunid Oct 24, 2023
27d0c11
Merge branch 'master' into stablelm-support
Galunid Oct 24, 2023
51b3b56
Prevent offloading of more than 33 layers
Galunid Oct 24, 2023
a00bb06
Make convert script with pytorch files
Galunid Oct 26, 2023
8917767
Merge branch 'master' into stablelm-support
Galunid Nov 5, 2023
c959376
Update after #3382
Galunid Nov 5, 2023
698c945
Merge branch 'master' into stablelm-support
Galunid Nov 7, 2023
4713a40
LLAMA_BACKEND_OFFLOAD* -> llama_backend_offload*
Galunid Nov 7, 2023
2f41552
Merge branch 'master' into stablelm-support
Galunid Nov 9, 2023
6be3356
Update conversion script to convert-hf-to-gguf.py
Galunid Nov 9, 2023
a371a8b
Use ggml_view_3d
Galunid Nov 10, 2023
e87d709
Cleanup for review
Galunid Nov 10, 2023
9e035cd
Add vision model support
Galunid Nov 11, 2023
047032d
Duh - add llava in another place
Galunid Nov 12, 2023
be2ac38
Make qrot, krot contiguous
Galunid Nov 12, 2023
beb17a7
Merge branch 'master' into stablelm-support
Galunid Nov 13, 2023
853fe04
Fix gguf post merge
Galunid Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ as the main playground for developing new features for the [ggml](https://github
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
- [X] [StableLM-3b-4e1t](https://github.com/ggerganov/llama.cpp/pull/3586)


**Bindings:**
Expand Down
30 changes: 19 additions & 11 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def load_hparams(dir_model):

@staticmethod
def from_model_architecture(model_architecture):
if model_architecture == "StableLMEpochForCausalLM":
return StableLMModel
if model_architecture == "GPTNeoXForCausalLM":
return GPTNeoXModel
if model_architecture == "BloomForCausalLM":
Expand All @@ -168,6 +166,8 @@ def from_model_architecture(model_architecture):
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -201,6 +201,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.REFACT
Galunid marked this conversation as resolved.
Show resolved Hide resolved
if arch == "PersimmonForCausalLM":
return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -294,15 +296,6 @@ def _set_vocab_sentencepiece(self):
special_vocab.add_to_gguf(self.gguf_writer)


class StableLMModel(Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_dimension_count(
int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
)
self.gguf_writer.add_layer_norm_eps(1e-5)


class GPTNeoXModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
Expand Down Expand Up @@ -824,6 +817,21 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


class StableLMModel(Model):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

self.gguf_writer.add_name(dir_model.name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
self.gguf_writer.add_layer_norm_eps(1e-5)

###### CONVERSION LOGIC ######

def parse_args() -> argparse.Namespace:
Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class MODEL_ARCH(IntEnum):
REFACT = auto()
BERT = auto()
BLOOM = auto()
STABLELM = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -129,6 +130,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -299,6 +301,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.STABLELM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [
# TODO
],
Expand Down
Loading
Loading