Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Standardize RoPE handling for Qwen2-VL #9250

Merged
merged 14 commits into from
Oct 16, 2024
Merged
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy < 2.0.0
requests >= 2.26.0
tqdm
py-cpuinfo
transformers >= 4.45.0 # Required for Llama 3.2.
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
Expand Down
25 changes: 10 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
get_hf_text_config,
patch_rope_scaling)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)
Expand Down Expand Up @@ -1721,18 +1722,13 @@ def _get_and_verify_max_len(
default_max_len)
derived_max_model_len = default_max_len

# Backwards compatibility
patch_rope_scaling(hf_config)

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if "type" in rope_scaling:
rope_type = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")
rope_type = rope_scaling["rope_type"]

# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
Expand All @@ -1742,11 +1738,10 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can "
"investigate.")

if rope_type == "mrope":
scaling_factor = 1
else:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
# NOTE: rope_type == "default" does not define factor
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
scaling_factor = rope_scaling.get("factor", 1.0)

if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
Expand Down
51 changes: 32 additions & 19 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.transformers_utils.config import patch_rope_scaling_dict


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -901,6 +902,9 @@ def get_rope(
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Backwards compatibility
patch_rope_scaling_dict(rope_scaling)

# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
Expand All @@ -920,13 +924,10 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling.get("factor", 1.0)
Comment on lines -927 to -928
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear which RoPE implementations need this, so I've moved this code into the individual if blocks.

scaling_type = rope_scaling["rope_type"]

if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
Expand All @@ -937,16 +938,39 @@ def get_rope(
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
Expand All @@ -961,6 +985,7 @@ def get_rope(
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
Expand All @@ -973,9 +998,7 @@ def get_rope(
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
Expand All @@ -989,16 +1012,6 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
elif scaling_type == "mrope":
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from transformers.image_utils import (get_image_size,
infer_channel_dimension_format,
to_numpy_array)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)

Expand Down Expand Up @@ -62,8 +64,6 @@
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

Expand Down
35 changes: 32 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Qwen2VLConfig, RWConfig,
SolarConfig, UltravoxConfig)
RWConfig, SolarConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file

Expand Down Expand Up @@ -57,7 +57,6 @@
"NVLM_D": NVLM_D_Config,
"solar": SolarConfig,
"ultravox": UltravoxConfig,
"qwen2_vl": Qwen2VLConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}

Expand Down Expand Up @@ -91,6 +90,34 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
return False


def patch_rope_scaling(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return

patch_rope_scaling_dict(rope_scaling)


def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
# Although HF prefers "rope_type", we have code that accesses "type",
# so we populate both keys
if "type" in rope_scaling:
rope_type = rope_scaling["rope_type"] = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["type"] = rope_scaling["rope_type"]
else:
raise ValueError("rope_scaling must have a 'type' or 'rope_type' key")

if rope_type == "su":
rope_scaling["type"] = rope_scaling["rope_type"] = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
elif rope_type == "mrope":
assert "mrope_section" in rope_scaling
rope_scaling["type"] = rope_scaling["rope_type"] = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")


def get_config(
model: Union[str, Path],
trust_remote_code: bool,
Expand Down Expand Up @@ -178,6 +205,8 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

patch_rope_scaling(config)

for key, value in [
("rope_scaling", rope_scaling),
("rope_theta", rope_theta),
Expand Down
4 changes: 0 additions & 4 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

Expand All @@ -35,6 +33,4 @@
"NVLM_D_Config",
"SolarConfig",
"UltravoxConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
]
131 changes: 0 additions & 131 deletions vllm/transformers_utils/configs/qwen2vl.py

This file was deleted.