Skip to content

Commit

Permalink
py : fix StableLM conversion after config.json changes (ggml-org#5703)
Browse files Browse the repository at this point in the history
* Fix issues during StableLM models conversion

* Fix hard coded layer_norm_eps

* Support layer_norm_eps for LlavaStableLM

Co-authored-by: Jared Van Bortel <[email protected]>

* Add missing parenthesis

Co-authored-by: Jared Van Bortel <[email protected]>

* Support rotary_factor for LlavaStableLM

Co-authored-by: Jared Van Bortel <[email protected]>

* fix typo

* Add StableLMEpochForCausalLM for safety

Co-authored-by: compilade <[email protected]>

* Add StableLMEpochForCausalLM for safety 2

Co-authored-by: compilade <[email protected]>

---------

Co-authored-by: Jared Van Bortel <[email protected]>
Co-authored-by: Jared Van Bortel <[email protected]>
Co-authored-by: compilade <[email protected]>
  • Loading branch information
4 people authored Feb 25, 2024
1 parent 9e359a4 commit 69917df
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def from_model_architecture(model_architecture):
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
Expand Down Expand Up @@ -253,7 +253,7 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.REFACT
if arch == "PersimmonForCausalLM":
return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
Expand Down Expand Up @@ -1074,10 +1074,11 @@ def set_gguf_parameters(self):
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
self.gguf_writer.add_layer_norm_eps(1e-5)
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))


class MixtralModel(Model):
Expand Down

0 comments on commit 69917df

Please sign in to comment.