Skip to content

Commit

Permalink
Tests may use outdated rope scalings, so we patch them as well
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Oct 11, 2024
1 parent 530d8a0 commit 1e9605b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
6 changes: 5 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
get_hf_text_config,
patch_rope_scaling)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)
Expand Down Expand Up @@ -1721,6 +1722,9 @@ def _get_and_verify_max_len(
default_max_len)
derived_max_model_len = default_max_len

# Backwards compatibility
patch_rope_scaling(hf_config)

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
rope_type = rope_scaling["rope_type"]
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.transformers_utils.config import patch_rope_scaling_dict


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -901,6 +902,9 @@ def get_rope(
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Backwards compatibility
patch_rope_scaling_dict(rope_scaling)

# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
Expand All @@ -920,8 +924,7 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
scaling_type = rope_scaling["rope_type"]

if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
Expand Down
49 changes: 29 additions & 20 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,34 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
return False


def patch_rope_scaling(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return

patch_rope_scaling_dict(rope_scaling)


def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
# Although HF prefers "rope_type", we have code that accesses "type",
# so we populate both keys
if "type" in rope_scaling:
rope_type = rope_scaling["rope_type"] = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["type"] = rope_scaling["rope_type"]
else:
raise ValueError("rope_scaling must have a 'type' or 'rope_type' key")

if rope_type == "su":
rope_scaling["type"] = rope_scaling["rope_type"] = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
elif rope_type == "mrope":
assert "mrope_section" in rope_scaling
rope_scaling["type"] = rope_scaling["rope_type"] = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")


def get_config(
model: Union[str, Path],
trust_remote_code: bool,
Expand Down Expand Up @@ -177,26 +205,7 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

# Backwards compatibility for RoPE
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None:
# Although HF prefers "rope_type", we have code that accesses "type",
# so we populate both keys
if "type" in rope_scaling:
rope_type = rope_scaling["rope_type"] = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["type"] = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")

if rope_type == "su":
rope_scaling["rope_type"] = rope_type = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
elif rope_type == "mrope":
assert "mrope_section" in rope_scaling
rope_scaling["rope_type"] = rope_type = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
patch_rope_scaling(config)

for key, value in [
("rope_scaling", rope_scaling),
Expand Down

0 comments on commit 1e9605b

Please sign in to comment.