Skip to content

Commit

Permalink
[Bugfix] Fix reloading of pixtral/llava configs (#36077)
Browse files Browse the repository at this point in the history
* add is_composition flag to LlavaConfig

Signed-off-by: Kyle Sayers <[email protected]>

* WIP: pixtral text config

Signed-off-by: Kyle Sayers <[email protected]>

* fix style

Signed-off-by: Kyle Sayers <[email protected]>

* add test

Signed-off-by: Kyle Sayers <[email protected]>

* use is_composition for pixtral

Signed-off-by: Kyle Sayers <[email protected]>

* Revert "use is_composition for pixtral"

This reverts commit a53d5f9.

* Revert "Revert "use is_composition for pixtral""

This reverts commit 3ab1c99.

---------

Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs authored Feb 14, 2025
1 parent 0c78ef6 commit bcfc9d7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class LlavaConfig(PretrainedConfig):

model_type = "llava"
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
is_composition = True

def __init__(
self,
Expand Down
70 changes: 70 additions & 0 deletions tests/models/llava/test_configuration_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import tempfile
import unittest

from transformers import LlavaConfig


class LlavaConfigTest(unittest.TestCase):
def test_llava_reload(self):
"""
Simple test for reloading default llava configs
"""
with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig()
config.save_pretrained(tmp_dir)

reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()

def test_pixtral_reload(self):
"""
Simple test for reloading pixtral configs
"""
vision_config = {
"model_type": "pixtral",
"head_dim": 64,
"hidden_act": "silu",
"image_size": 1024,
"is_composition": True,
"patch_size": 16,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
}

text_config = {
"model_type": "mistral",
"hidden_size": 5120,
"head_dim": 128,
"num_attention_heads": 32,
"intermediate_size": 14336,
"is_composition": True,
"max_position_embeddings": 1024000,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"vocab_size": 131072,
}

with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(vision_config=vision_config, text_config=text_config)
config.save_pretrained(tmp_dir)

reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()

def test_arbitrary_reload(self):
"""
Simple test for reloading arbirarily composed subconfigs
"""
default_values = LlavaConfig().to_dict()
default_values["vision_config"]["model_type"] = "qwen2_vl"
default_values["text_config"]["model_type"] = "opt"

with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(**default_values)
config.save_pretrained(tmp_dir)

reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()

0 comments on commit bcfc9d7

Please sign in to comment.