Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 18, 2025
1 parent 901e4da commit 1ced593
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]:


def _get_arch_mapping():
from .register import MODEL_MAPPING
res = {}
for model_type, model_meta in MODEL_MAPPING.items():
architectures = model_meta.architectures
Expand All @@ -372,7 +371,7 @@ def get_matched_model_types(architectures: Optional[List[str]]) -> List[str]:
"""Get possible model_type."""
architectures = architectures or ['nulll']
if architectures:
arch = architectures[0]
architectures = architectures[0]
arch_mapping = _get_arch_mapping()
return arch_mapping.get(architectures) or []

Expand All @@ -385,10 +384,10 @@ def _get_model_info(model_dir: str, model_type: Optional[str], quantization_conf
torch_dtype = HfConfigFactory.get_torch_dtype(config_dict, quant_info)
max_model_len = HfConfigFactory.get_max_model_len(config_dict)
rope_scaling = HfConfigFactory.get_config_attr(config_dict, 'rope_scaling')
architectures = HfConfigFactory.get_config_attr(config, 'architectures')
architectures = HfConfigFactory.get_config_attr(config_dict, 'architectures')

if model_type is None:
model_types = get_matched_model_types(architectures) # config.json
model_types = get_matched_model_types(architectures)
if len(model_types) > 1:
raise ValueError('Please explicitly pass the model_type. For reference, '
f'the available model_types: {model_types}.')
Expand Down

0 comments on commit 1ced593

Please sign in to comment.