diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md
index 6c6f3b701cd28..66a7554da8463 100644
--- a/docs/source/contributing/model/multimodal.md
+++ b/docs/source/contributing/model/multimodal.md
@@ -250,7 +250,11 @@ def get_max_image_tokens(self) -> int:
And thus, we can override the method as:
```python
-def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
```
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 4a099646964f2..fbdca189af620 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -726,7 +726,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc.
*
* ✅︎
- *
+ * \*
- * `Idefics3ForConditionalGeneration`
* Idefics3
* T + I
@@ -799,7 +799,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `NVLM_D_Model`
* NVLM-D 1.0
- * T + IE+
+ * T + I+
* `nvidia/NVLM-D-72B`, etc.
*
* ✅︎
@@ -859,7 +859,11 @@ See [this page](#generative-models) for more information on how to use generativ
+ Multiple items can be inputted per text prompt for this modality.
:::{note}
-To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
+To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
+:::
+
+:::{note}
+H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
:::
:::{note}
diff --git a/tests/models/decoder_only/vision_language/test_h2ovl.py b/tests/models/decoder_only/vision_language/test_h2ovl.py
deleted file mode 100644
index 9590adf6f73c8..0000000000000
--- a/tests/models/decoder_only/vision_language/test_h2ovl.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Optional, Tuple
-
-import pytest
-import torch
-from PIL.Image import Image
-from transformers import AutoConfig
-
-# Import the functions to test
-from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
- image_to_pixel_values_wrapper)
-from vllm.multimodal.image import rescale_image_size
-
-models = [
- "h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
- "h2oai/h2ovl-mississippi-2b",
-]
-
-
-def run_preprocessing_test(
- image: Image,
- config,
- max_dynamic_patch: Optional[int] = None,
-) -> Tuple[torch.Tensor, int]:
- """Test the image preprocessing and calculate expected blocks."""
-
- if max_dynamic_patch is None:
- max_dynamic_patch = config.max_dynamic_patch
-
- width, height = image.size
- use_MSAC = config.use_msac
-
- # Create the mapper function with the provided configuration
- mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
- pixel_values = mapper(image)
-
- # Calculate the expected number of blocks
- if use_MSAC:
- # First pass
- blocks1, _, _, aspect_ratio = calculate_num_blocks(
- width,
- height,
- config.min_dynamic_patch,
- max_dynamic_patch,
- config.vision_config.image_size,
- use_thumbnail=False, # Thumbnail is handled separately
- prior_aspect_ratio=None,
- )
-
- # Second pass
- blocks2, _, _, _ = calculate_num_blocks(
- width,
- height,
- config.min_dynamic_patch,
- max_dynamic_patch,
- config.vision_config.image_size,
- use_thumbnail=False,
- prior_aspect_ratio=aspect_ratio,
- )
-
- # Add thumbnail if use_thumbnail is True and total_blocks > 1
- if config.use_thumbnail:
- blocks1 += 1 if blocks1 > 1 else 0
- blocks2 += 1 if blocks2 > 1 else 0
-
- # Total blocks is the sum of blocks from both passes minus overlapping
- total_blocks = blocks1 + blocks2 - 1
-
- expected_blocks = total_blocks
-
- else:
- blocks, _, _, _ = calculate_num_blocks(
- width,
- height,
- config.min_dynamic_patch,
- max_dynamic_patch,
- config.vision_config.image_size,
- use_thumbnail=False,
- prior_aspect_ratio=None,
- )
- expected_blocks = blocks
-
- if config.use_thumbnail and expected_blocks > 1:
- expected_blocks += 1
-
- return pixel_values, expected_blocks
-
-
-@pytest.mark.parametrize("model_name", models)
-@pytest.mark.parametrize(
- "size_factors",
- [
- # Single-scale
- [1.0],
- # Single-scale, batched
- [1.0, 1.0, 1.0],
- # Multi-scale
- [0.25, 0.5, 1.0],
- ],
-)
-@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
-def test_image_preprocessing(image_assets, model_name, size_factors,
- max_dynamic_patch):
- """Test image preprocessing pipeline with different configurations."""
- # Load the configuration from the model
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
-
- for asset in image_assets:
- image = asset.pil_image
- for factor in size_factors:
- scaled_image = rescale_image_size(image, factor)
-
- # Test preprocessing and get expected number of blocks
- pixel_values, expected_blocks = run_preprocessing_test(
- scaled_image, config, max_dynamic_patch)
-
- # Verify output shapes and properties
- actual_blocks = pixel_values.shape[0]
- assert actual_blocks == expected_blocks, (
- f"Expected {expected_blocks} blocks, got {actual_blocks}")
-
- # Check image dimensions
- expected_size = (
- 3, # Number of channels (C, H, W)
- config.vision_config.image_size,
- config.vision_config.image_size,
- )
- for img in pixel_values:
- assert img.shape == expected_size, (
- f"Expected image size {expected_size}, got {img.shape}")
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index e3cda8971b785..7a14ba2f3b60a 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -250,6 +250,7 @@
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
+ num_logprobs=10,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"idefics3": VLMTestInfo(
@@ -282,7 +283,6 @@
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
- marks=[large_gpu_mark(min_gb=32)],
),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index b0a88161c4c98..d2401b222558e 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -334,12 +334,12 @@ class H2OVLProcessor:
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
- self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
+ self.use_msac = self.config.use_msac
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
@@ -348,18 +348,19 @@ def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
# yapf: disable
from vllm.model_executor.models.h2ovl import (
- IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
+ IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl)
# yapf: enable
images = [images] if isinstance(images, Image) else images
pixel_values = [
- image_to_pixel_values(image,
- self.image_size,
- self.min_num,
- self.max_num,
- self.use_thumbnail,
- use_MSAC=self.config.use_msac).to(
- self.dtype) for image in images
+ image_to_pixel_values_h2ovl(
+ image,
+ input_size=self.image_size,
+ min_num=self.min_num,
+ max_num=self.max_num,
+ use_thumbnail=self.use_thumbnail,
+ use_msac=self.use_msac,
+ ) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
@@ -394,7 +395,6 @@ class InternVLProcessor:
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
- self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
@@ -407,13 +407,17 @@ def __init__(self, hf_runner: HfRunner):
def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
from vllm.model_executor.models.internvl import (
- IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
+ IMG_CONTEXT, IMG_END, IMG_START,
+ image_to_pixel_values_internvl)
images = [images] if isinstance(images, Image) else images
pixel_values = [
- image_to_pixel_values(image, self.image_size, self.min_num,
- self.max_num,
- self.use_thumbnail).to(self.dtype)
- for image in images
+ image_to_pixel_values_internvl(
+ image,
+ input_size=self.image_size,
+ min_num=self.min_num,
+ max_num=self.max_num,
+ use_thumbnail=self.use_thumbnail,
+ ) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
@@ -448,7 +452,8 @@ def _internvl_generate(
) -> torch.LongTensor:
"""Generate method for InternVL2 model without fixed use_cache."""
assert self.img_context_token_id is not None
- vit_embeds = self.extract_feature(pixel_values)
+ target_dtype = next(self.parameters()).dtype
+ vit_embeds = self.extract_feature(pixel_values.to(target_dtype))
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 3921d4e19dd2b..07906a71d06e4 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -141,13 +141,14 @@ def _test_processing_correctness(
# yapf: disable
-# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"adept/fuyu-8b",
+ "h2oai/h2ovl-mississippi-800m",
+ "OpenGVLab/InternVL2-1B",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
@@ -156,6 +157,7 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
+ "nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py
new file mode 100644
index 0000000000000..767ac5eb9ef9a
--- /dev/null
+++ b/tests/models/multimodal/processing/test_h2ovl.py
@@ -0,0 +1,142 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for H2OVL's multimodal preprocessing kwargs."""
+from typing import Optional
+
+import pytest
+
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.image import rescale_image_size
+from vllm.multimodal.utils import cached_get_tokenizer
+
+from ....conftest import _ImageAssets
+from ...utils import build_model_context
+
+
+@pytest.mark.parametrize("model_id", [
+ "h2oai/h2ovl-mississippi-800m",
+ "h2oai/h2ovl-mississippi-2b",
+])
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ ],
+)
+@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8])
+@pytest.mark.parametrize("dynamic_image_size", [True, False])
+@pytest.mark.parametrize("num_imgs", [1, 2])
+def test_processor_override(
+ model_id: str,
+ image_assets: _ImageAssets,
+ size_factors: list[int],
+ max_dynamic_patch: int,
+ dynamic_image_size: Optional[bool],
+ num_imgs: int,
+):
+ from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
+ get_h2ovl_target_ratios)
+
+ ctx = build_model_context(
+ model_name=model_id,
+ tokenizer_name=model_id,
+ trust_remote_code=True,
+ mm_processor_kwargs=None,
+ limit_mm_per_prompt={"image": num_imgs},
+ )
+ tokenizer = cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ )
+ processor = MULTIMODAL_REGISTRY.create_processor(
+ ctx.model_config,
+ tokenizer=tokenizer,
+ )
+
+ config = processor.info.get_hf_config()
+ use_msac = config.use_msac
+
+ mm_processor_kwargs = {
+ "max_dynamic_patch": max_dynamic_patch,
+ }
+ if dynamic_image_size is not None:
+ mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
+
+ min_num = config.min_dynamic_patch
+ max_num = max_dynamic_patch if dynamic_image_size else 1
+
+ # Build the image str / prompt based on the number of images we pass
+ prompt = "" * num_imgs
+
+ for asset in image_assets:
+ for factor in size_factors:
+ image = rescale_image_size(asset.pil_image, factor)
+ mm_data = {"image": [image] * num_imgs}
+
+ width, height = image.size
+
+ # Calculate the expected number of blocks
+ if num_imgs == 1 and use_msac:
+ # First pass
+ blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
+ orig_width=width,
+ orig_height=height,
+ target_ratios=get_h2ovl_target_ratios(
+ min_num,
+ max_num,
+ prior_aspect_ratio=None,
+ ),
+ image_size=config.vision_config.image_size,
+ use_thumbnail=False, # Thumbnail is handled separately
+ )
+
+ # Second pass
+ blocks2, _, _, _ = calculate_h2ovl_targets(
+ orig_width=width,
+ orig_height=height,
+ target_ratios=get_h2ovl_target_ratios(
+ min_num,
+ max_num,
+ prior_aspect_ratio=aspect_ratio,
+ ),
+ image_size=config.vision_config.image_size,
+ use_thumbnail=False,
+ )
+
+ # Add thumbnail if use_thumbnail is True and total_blocks > 1
+ if config.use_thumbnail:
+ blocks1 += 1 if blocks1 > 1 else 0
+ blocks2 += 1 if blocks2 > 1 else 0
+
+ # Total blocks is the sum of blocks from both passes minus
+ # overlapping
+ total_blocks = blocks1 + blocks2 - 1
+
+ expected_num_patches = total_blocks
+ else:
+ blocks, _, _, _ = calculate_h2ovl_targets(
+ orig_width=width,
+ orig_height=height,
+ target_ratios=get_h2ovl_target_ratios(
+ min_num,
+ max_num,
+ prior_aspect_ratio=None,
+ ),
+ image_size=config.vision_config.image_size,
+ use_thumbnail=False,
+ )
+ expected_num_patches = blocks
+
+ if config.use_thumbnail and expected_num_patches != 1:
+ expected_num_patches += 1
+
+ processed_inputs = processor.apply(prompt, mm_data,
+ mm_processor_kwargs)
+ pixel_shape = (
+ processed_inputs["mm_kwargs"]["pixel_values_flat"].shape)
+
+ assert pixel_shape[0] == expected_num_patches * num_imgs
diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py
index 0d921e9d32961..ede961225be7b 100644
--- a/tests/models/multimodal/processing/test_internvl.py
+++ b/tests/models/multimodal/processing/test_internvl.py
@@ -1,207 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for InternVL's multimodal preprocessing kwargs."""
-from typing import Callable, Optional
+from typing import Optional
import pytest
-from transformers import AutoTokenizer
-from vllm.inputs import InputContext, token_inputs
-from vllm.multimodal import MultiModalRegistry
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
-models = ["OpenGVLab/InternVL2-2B"]
-
-# Wrap lazy imports to avoid initializing CUDA during test collection
-@pytest.fixture()
-def input_processor_for_internvl():
- from vllm.model_executor.models.internvl import InternVLInputPipeline
-
- pipeline = InternVLInputPipeline('
', '', '')
- return pipeline.input_processor
-
-
-@pytest.fixture()
-def dummy_data_for_internvl():
- from vllm.model_executor.models.internvl import InternVLInputPipeline
-
- pipeline = InternVLInputPipeline('
', '', '')
- return pipeline.dummy_data
-
-
-@pytest.fixture()
-def get_max_internvl_image_tokens():
- from vllm.model_executor.models.internvl import (
- get_max_internvl_image_tokens)
- return get_max_internvl_image_tokens
-
-
-@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize("model_id", ["OpenGVLab/InternVL2-2B"])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
-def test_input_mapper_override(
- model: str,
+@pytest.mark.parametrize("num_imgs", [1, 2])
+def test_processor_override(
+ model_id: str,
image_assets: _ImageAssets,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
-):
- mm_processor_kwargs = {
- "max_dynamic_patch": max_dynamic_patch,
- }
- if dynamic_image_size is not None:
- mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
-
- expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
- if dynamic_image_size is False:
- expected_num_patches = 1
-
- ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
- trust_remote_code=True,
- mm_processor_kwargs=mm_processor_kwargs,
- )
-
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
-
- image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
- vllm_result = mm_registry.map_input(
- ctx.model_config,
- {"image": image},
- )
- assert vllm_result["pixel_values"].size(1) == expected_num_patches
-
-
-@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
-@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
-def test_max_tokens_override(
- get_max_internvl_image_tokens: Callable,
- model: str,
- max_dynamic_patch: Optional[int],
- dynamic_image_size: Optional[bool],
-):
- """Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
- ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
- trust_remote_code=True,
- mm_processor_kwargs=None,
- )
-
- if max_dynamic_patch is None:
- max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
- expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
- if dynamic_image_size is False:
- expected_num_patches = 1
- expected_max_tokens = 256 * expected_num_patches
-
- actual_max_tokens = get_max_internvl_image_tokens(
- ctx=InputContext(ctx.model_config),
- max_dynamic_patch=max_dynamic_patch,
- dynamic_image_size=dynamic_image_size,
- )
- assert expected_max_tokens == actual_max_tokens
-
-
-@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("num_imgs", [1, 2])
-@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
-@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
-def test_dummy_data_override(
- dummy_data_for_internvl: Callable,
- model: str,
num_imgs: int,
- max_dynamic_patch: Optional[int],
- dynamic_image_size: Optional[bool],
):
- """Ensure dummy_data_for_internvl handles kwargs properly."""
- # Same as the previous test - don't initialize mm_processor_kwargs
- # in this test and assume that the kwargs will be correctly expanded by
- # the partial when calling the dummy data func.
ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
+ model_name=model_id,
+ tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
+ limit_mm_per_prompt={"image": num_imgs},
)
-
- if max_dynamic_patch is None:
- max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
- expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
- if dynamic_image_size is False:
- expected_num_patches = 1
- expected_max_tokens = 256 * expected_num_patches
-
- dummy_data = dummy_data_for_internvl(
- ctx=ctx,
- seq_len=8192, # Should be bigger than num_imgs * toks_per_img
- mm_counts={"image": num_imgs},
- max_dynamic_patch=max_dynamic_patch,
- dynamic_image_size=dynamic_image_size,
+ tokenizer = cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ )
+ processor = MULTIMODAL_REGISTRY.create_processor(
+ ctx.model_config,
+ tokenizer=tokenizer,
)
- sequence_data = dummy_data.seq_data
-
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
- image_token_id = tokenizer.encode('',
- add_special_tokens=False)[0]
- # Ensure we have the right number of placeholders per size
- img_tok_count = sequence_data.get_token_ids().count(image_token_id)
- assert img_tok_count == expected_max_tokens * num_imgs
+ mm_processor_kwargs = {
+ "max_dynamic_patch": max_dynamic_patch,
+ }
+ if dynamic_image_size is not None:
+ mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
+ # Build the image str / prompt based on the number of images we pass
+ prompt = "" * num_imgs
+ image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
+ mm_data = {"image": [image] * num_imgs}
-@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
-@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
-@pytest.mark.parametrize("num_imgs", [1, 2])
-def test_input_processor_override(
- input_processor_for_internvl: Callable,
- image_assets: _ImageAssets,
- model: str,
- num_imgs: int,
- max_dynamic_patch: int,
- dynamic_image_size: Optional[bool],
-):
- """Ensure input_processor_for_internvl handles kwargs properly."""
- # Same as the previous test - don't initialize mm_processor_kwargs
- # in this test and assume that the kwargs will be correctly expanded by
- # the partial when calling the custom input processor.
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
- ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
- trust_remote_code=True,
- mm_processor_kwargs=None,
- )
- expected_toks_per_img = 256 * expected_num_patches
-
- # Build the image str / prompt based on the number of images we pass
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
- placeholders = "" if num_imgs == 1 else "\n".join(
- f"Image-{i}: \n" for i in range(1, num_imgs + 1))
- prompt = placeholders
- images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs
-
- inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
- prompt=prompt,
- multi_modal_data={"image": images})
-
- processed_inputs = input_processor_for_internvl(
- ctx,
- inputs,
- max_dynamic_patch=max_dynamic_patch,
- dynamic_image_size=dynamic_image_size,
- )
+ processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size
- image_token_id = tokenizer.encode('',
- add_special_tokens=False)[0]
+ image_token_id = tokenizer.convert_tokens_to_ids("")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
- assert img_tok_count == expected_toks_per_img * num_imgs
+ pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
+
+ assert img_tok_count == 256 * expected_num_patches * num_imgs
+ assert pixel_shape[0] == expected_num_patches * num_imgs
diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py
index d2497e62d91b2..fe4754c2ef6f6 100644
--- a/tests/models/multimodal/processing/test_llava_next.py
+++ b/tests/models/multimodal/processing/test_llava_next.py
@@ -43,7 +43,10 @@ def test_processor_max_tokens(model_id):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
- tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
+ tokenizer=cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ ),
)
info = processor.info
@@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
- tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
+ tokenizer=cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ ),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
@@ -173,7 +179,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
- tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
+ tokenizer=cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ ),
)
seen_aspect_ratios = set[float]()
diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py
index bd4dbd46da4c2..fb650d9e0995f 100644
--- a/tests/models/multimodal/processing/test_llava_onevision.py
+++ b/tests/models/multimodal/processing/test_llava_onevision.py
@@ -44,7 +44,10 @@ def test_processor_max_tokens(model_id):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
- tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
+ tokenizer=cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ ),
)
info = processor.info
@@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
- tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
+ tokenizer=cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ ),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
@@ -174,7 +180,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
- tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
+ tokenizer=cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ ),
)
seen_aspect_ratios = set[float]()
diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py
index 44edec457a662..dde8904f2ef65 100644
--- a/tests/models/multimodal/processing/test_phi3v.py
+++ b/tests/models/multimodal/processing/test_phi3v.py
@@ -38,7 +38,10 @@ def test_processor_override(
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
- tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
+ tokenizer = cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ )
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py
index 47c9b0add55ab..ef8e97f82d0bc 100644
--- a/tests/models/multimodal/processing/test_qwen2_vl.py
+++ b/tests/models/multimodal/processing/test_qwen2_vl.py
@@ -33,7 +33,10 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
- tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
+ tokenizer = cached_get_tokenizer(
+ ctx.model_config.tokenizer,
+ trust_remote_code=ctx.model_config.trust_remote_code,
+ )
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py
index 97502c38b9f00..98df532aa0a83 100644
--- a/vllm/model_executor/models/aria.py
+++ b/vllm/model_executor/models/aria.py
@@ -399,7 +399,11 @@ def get_hf_processor(self):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index 2b04522223d0e..0463a0b97d40a 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -407,7 +407,11 @@ def get_hf_config(self):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 9061a31280e64..b29dd65a8e357 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -64,7 +64,11 @@ def get_hf_processor(self):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py
index 1343b9762874b..0eaf3a6201f6b 100644
--- a/vllm/model_executor/models/deepseek_vl2.py
+++ b/vllm/model_executor/models/deepseek_vl2.py
@@ -165,7 +165,11 @@ def get_image_size_with_most_features(self) -> ImageSize:
image_width=x[1], image_height=x[0]))
return ImageSize(width=width, height=height)
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height,
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 6d8c829687ee2..50b5ef35d2cd1 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -80,7 +80,11 @@ def get_image_processor(self) -> FuyuImageProcessor:
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py
index 91c89b159ca92..cf3e777a2027f 100644
--- a/vllm/model_executor/models/h2ovl.py
+++ b/vllm/model_executor/models/h2ovl.py
@@ -7,43 +7,55 @@
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
-from functools import partial
-from typing import List, Optional, Tuple
+from typing import Mapping, Optional
import torch
from PIL import Image
from transformers import PretrainedConfig
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
- token_inputs)
+from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
-from vllm.multimodal.utils import cached_get_tokenizer
-from vllm.utils import is_list_of
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalKwargs
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ MultiModalDataItems)
+from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
+ PromptReplacementDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
-from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel,
- InternVLInputPipeline, build_transform,
- find_closest_aspect_ratio, get_internvl_num_patches)
+from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
+ BaseInternVLProcessingInfo, BaseInternVLProcessor,
+ InternVLChatModel, InternVLDummyInputsBuilder,
+ InternVLMultiModalProcessor, build_transform,
+ find_closest_aspect_ratio, get_internvl_target_ratios)
+logger = init_logger(__name__)
-# modified to include blocks generated in second pass
-def calculate_num_blocks(
- orig_width: int,
- orig_height: int,
- min_num: int,
- max_num: int,
- image_size: int,
+
+def resolve_h2ovl_min_max_num(
+ *,
+ min_dynamic_patch: int,
+ max_dynamic_patch: int,
+ dynamic_image_size: bool,
use_thumbnail: bool,
- prior_aspect_ratio=None,
-) -> Tuple[int, int, int, Tuple[int, int]]:
- aspect_ratio = orig_width / orig_height
+) -> tuple[int, int]:
+ max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
+
+ if use_thumbnail and max_dynamic_patch != 1:
+ max_dynamic_patch += 1
- # calculate the existing image aspect ratio
- target_ratios = set((i, j) for n in range(min_num, max_num + 1)
- for i in range(1, n + 1) for j in range(1, n + 1)
- if i * j <= max_num and i * j >= min_num)
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+ return min_dynamic_patch, max_dynamic_patch
+
+
+def get_h2ovl_target_ratios(
+ min_num: int,
+ max_num: int,
+ *,
+ prior_aspect_ratio: Optional[tuple[int, int]],
+) -> list[tuple[int, int]]:
+ target_ratios = get_internvl_target_ratios(min_num, max_num)
# if prior_aspect_ratio is provided, filter the target ratios
if prior_aspect_ratio is not None:
@@ -52,44 +64,66 @@ def calculate_num_blocks(
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
]
+ return target_ratios
+
+
+# modified to include blocks generated in second pass
+def calculate_h2ovl_targets(
+ *,
+ orig_width: int,
+ orig_height: int,
+ target_ratios: list[tuple[int, int]],
+ image_size: int,
+ use_thumbnail: bool,
+) -> tuple[int, int, int, tuple[int, int]]:
+ aspect_ratio = orig_width / orig_height
+
# find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
- target_ratios, orig_width,
- orig_height, image_size)
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio,
+ target_ratios,
+ width=orig_width,
+ height=orig_height,
+ image_size=image_size,
+ )
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
- # add thumbnail image if num_blocks > 1
- if use_thumbnail and blocks > 1:
+
+ # add thumbnail image if num_blocks != 1
+ if use_thumbnail and blocks != 1:
blocks += 1
+
return blocks, target_width, target_height, target_aspect_ratio
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
-# refactored to handle prior_aspect_ratio as optional
-def dynamic_preprocess(
+# refactored to handle prior_aspect_ratio
+def dynamic_preprocess_h2ovl(
image: Image.Image,
- min_num: int,
- max_num: int,
+ *,
+ target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
- prior_aspect_ratio: Optional[Tuple[int, int]] = None,
-) -> Tuple[List[Image.Image], Tuple[int, int]]:
+) -> tuple[list[Image.Image], tuple[int, int]]:
orig_width, orig_height = image.size
- # calculate the number of blocks based on prior aspect ratio if available
- blocks, target_width, target_height, target_aspect_ratio = (
- calculate_num_blocks(
- orig_width,
- orig_height,
- min_num,
- max_num,
- image_size,
- use_thumbnail=False,
- prior_aspect_ratio=prior_aspect_ratio,
- ))
+ # calculate the number of blocks without thumbnail
+ (
+ blocks,
+ target_width,
+ target_height,
+ target_aspect_ratio,
+ ) = calculate_h2ovl_targets(
+ orig_width=orig_width,
+ orig_height=orig_height,
+ target_ratios=target_ratios,
+ image_size=image_size,
+ use_thumbnail=False,
+ )
+
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
@@ -103,276 +137,393 @@ def dynamic_preprocess(
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
+
assert len(processed_images) == blocks
+
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
+
return processed_images, target_aspect_ratio
-def load_image(
+def _preprocess_image(
image: Image.Image,
- input_size=448,
- min_num=1,
- max_num=6,
- use_thumbnail=True,
- prior_aspect_ratio: Optional[Tuple[int, int]] = None,
-) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ *,
+ input_size: int,
+ min_num: int,
+ max_num: int,
+ use_thumbnail: bool,
+ prior_aspect_ratio: Optional[tuple[int, int]],
+) -> tuple[torch.Tensor, tuple[int, int]]:
+ target_ratios = get_h2ovl_target_ratios(
+ min_num,
+ max_num,
+ prior_aspect_ratio=prior_aspect_ratio,
+ )
+
transform = build_transform(input_size=input_size)
- images, target_aspect_ratio = dynamic_preprocess(
+ images, target_aspect_ratio = dynamic_preprocess_h2ovl(
image,
image_size=input_size,
use_thumbnail=use_thumbnail,
- min_num=min_num,
- max_num=max_num,
- prior_aspect_ratio=prior_aspect_ratio,
+ target_ratios=target_ratios,
)
- pixel_values = [transform(image) for image in images]
- pixel_values = torch.stack(pixel_values)
+
+ pixel_values = torch.stack([transform(image) for image in images])
return pixel_values, target_aspect_ratio
-# refactored to use the combined load_image function
-def image_to_pixel_values(
+# refactored to use the _preprocess_image function
+def image_to_pixel_values_h2ovl(
image: Image.Image,
+ *,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
- use_MSAC: bool,
+ use_msac: bool,
) -> torch.Tensor:
# when MSAC is turned on, we need to process the image twice
- if use_MSAC:
+ if use_msac:
# first pass
- pixel_values, target_aspect_ratio = load_image(
+ pixel_values1, aspect_ratio1 = _preprocess_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=True,
+ prior_aspect_ratio=None,
)
# second pass
- pixel_values2, _ = load_image(
+ pixel_values2, _ = _preprocess_image(
image,
input_size=input_size,
- min_num=min_num,
+ min_num=3, # Hardcoded value
max_num=max_num,
- prior_aspect_ratio=target_aspect_ratio,
+ use_thumbnail=True,
+ prior_aspect_ratio=aspect_ratio1,
)
# combine pixel values
pixel_values = torch.cat(
- [pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
+ [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
else:
- pixel_values, _ = load_image(
+ pixel_values, _ = _preprocess_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=use_thumbnail,
+ prior_aspect_ratio=None,
)
return pixel_values
-def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
- max_dynamic_patch: Optional[int] = None,
- use_MSAC: Optional[bool] = None):
- image_size = hf_config.vision_config.image_size
- min_num = hf_config.min_dynamic_patch
- if max_dynamic_patch is None:
- max_dynamic_patch = hf_config.max_dynamic_patch
- if use_MSAC is None:
- use_MSAC = hf_config.use_msac
- use_thumbnail = hf_config.use_thumbnail
- return partial(
- image_to_pixel_values,
- input_size=image_size,
- min_num=min_num,
- max_num=max_dynamic_patch,
- use_thumbnail=use_thumbnail,
- use_MSAC=use_MSAC,
- )
-
-
-def get_max_internvl_image_tokens(ctx: InputContext,
- *,
- max_dynamic_patch: Optional[int] = None):
- """
- Calculate the maximum number of tokens with/without MSAC and thumbnail
- """
- hf_config = ctx.get_hf_config()
- use_thumbnail = hf_config.use_thumbnail
- use_MSAC = hf_config.use_msac
+class H2OVLProcessor(BaseInternVLProcessor):
- if max_dynamic_patch is None:
- max_dynamic_patch = hf_config.max_dynamic_patch
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ use_msac: Optional[bool] = None,
+ ) -> None:
+ super().__init__(
+ config,
+ tokenizer,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
- num_patches = get_internvl_num_patches(hf_config)
+ if use_msac is None:
+ use_msac = config.use_msac
+ assert isinstance(use_msac, bool)
- coefficient = 2 if use_MSAC else 1
- num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)
+ self.use_msac = use_msac
- return num_blocks * num_patches
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[IMG_CONTEXT]
+ def get_image_repl_features(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ return IMG_CONTEXT * feature_size
-class H2OVLInputPipeline(InternVLInputPipeline):
- """
- Input pipeline for processing image and text data for the H2OVL model.
- """
+ def get_image_repl_full(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ features = self.get_image_repl_features(feature_size, num_patches)
+ return IMG_START + features + IMG_END
- def input_processor(
+ def resolve_min_max_num(
self,
- ctx: InputContext,
- inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
- ) -> DecoderOnlyInputs:
- # get multi_modal_data
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- model_config = ctx.model_config
- hf_config = ctx.get_hf_config()
- use_MSAC = hf_config.use_msac
-
- image_data = multi_modal_data["image"]
- num_patches = get_internvl_num_patches(hf_config)
-
- image_pixel_values_mapper = image_to_pixel_values_wrapper(
- hf_config, max_dynamic_patch=max_dynamic_patch)
-
- # single image
- if isinstance(image_data, Image.Image):
- pixel_values = image_pixel_values_mapper(image_data,
- use_MSAC=use_MSAC)
- num_blocks = pixel_values.shape[0]
- image_feature_sizes = [num_blocks * num_patches]
- pixel_values = pixel_values.unsqueeze(0)
-
- # multi images
- elif is_list_of(image_data, Image.Image):
- # Do not use MSAC for multi images
- image_feature_sizes = []
- pixel_values = [
- image_pixel_values_mapper(image, use_MSAC=False)
- for image in image_data
- ]
- for pixel_value in pixel_values:
- num_blocks = pixel_value.shape[0]
- image_feature_sizes.append(num_blocks * num_patches)
-
- # image embeddings as input
- elif isinstance(image_data, torch.Tensor):
- _, image_feature_size, _ = image_data.shape
- image_feature_sizes = [image_feature_size]
- pixel_values = None
-
- # multi-image image embeddings
- elif is_list_of(image_data, torch.Tensor):
-
- image_feature_sizes = []
- for image_embed in image_data:
- _, image_feature_size, _ = image_embed.shape
- image_feature_sizes.append(image_feature_size)
- pixel_values = None
+ dynamic_image_size: Optional[bool] = None,
+ use_thumbnail: Optional[bool] = None,
+ ) -> tuple[int, int]:
+ min_dynamic_patch = self.min_dynamic_patch
+ max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
+ is None else max_dynamic_patch)
+ dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
+ is None else dynamic_image_size)
+ use_thumbnail = (self.use_thumbnail
+ if use_thumbnail is None else use_thumbnail)
+
+ return resolve_h2ovl_min_max_num(
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=use_thumbnail,
+ )
- else:
- raise TypeError(f"Invalid image type: {type(image_data)}")
+ def resolve_target_ratios(
+ self,
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ use_thumbnail: Optional[bool] = None,
+ prior_aspect_ratio: Optional[tuple[int, int]] = None,
+ ) -> list[tuple[int, int]]:
+ min_num, max_num = self.resolve_min_max_num(
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=use_thumbnail,
+ )
+ if prior_aspect_ratio: # hardcoded value for second pass of use_msac
+ min_num = 3
- tokenizer = cached_get_tokenizer(
- model_config.tokenizer,
- trust_remote_code=model_config.trust_remote_code,
+ return get_h2ovl_target_ratios(
+ min_num,
+ max_num,
+ prior_aspect_ratio=prior_aspect_ratio,
)
- prompt = inputs.get("prompt")
- prompt_token_ids = inputs["prompt_token_ids"]
- if prompt is None:
- prompt = tokenizer.decode(prompt_token_ids)
-
- new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
- num_patches)
- new_prompt_token_ids = tokenizer.encode(new_prompt)
-
- # Wrap image processing in input_processor to avoid duplication
- image_token_id = tokenizer.encode(
- self.img_context_token,
- add_special_tokens=False,
- return_tensors="pt",
- )[0]
-
- # Update multi_modal_data to return
- if pixel_values is not None:
- multi_modal_data = {
- "image": {
- "pixel_values": pixel_values,
- "image_token_id": image_token_id,
- }
- }
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ use_msac: Optional[bool] = None,
+ ) -> int:
+ use_msac = (self.use_msac if use_msac is None else use_msac)
+
+ use_thumbnail = self.use_thumbnail
+
+ if use_msac:
+ target_ratios_1 = self.resolve_target_ratios(
+ use_thumbnail=False, # Applied in calculate_targets
+ )
+ num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
+ orig_width=image_width,
+ orig_height=image_height,
+ image_size=self.image_size,
+ target_ratios=target_ratios_1,
+ use_thumbnail=True,
+ )
+
+ target_ratios_2 = self.resolve_target_ratios(
+ use_thumbnail=False, # Applied in calculate_targets
+ prior_aspect_ratio=aspect_ratio_1,
+ )
+ num_patches_2, _, _, _ = calculate_h2ovl_targets(
+ orig_width=image_width,
+ orig_height=image_height,
+ image_size=self.image_size,
+ target_ratios=target_ratios_2,
+ use_thumbnail=True,
+ )
+
+ num_patches = num_patches_1 + num_patches_2 - 1
else:
- multi_modal_data = {"image": {"image_embeds": image_data}}
+ target_ratios = self.resolve_target_ratios(
+ use_thumbnail=False, # Applied in calculate_targets
+ )
+ num_patches, _, _, _ = calculate_h2ovl_targets(
+ orig_width=image_width,
+ orig_height=image_height,
+ image_size=self.image_size,
+ target_ratios=target_ratios,
+ use_thumbnail=use_thumbnail,
+ )
+
+ return num_patches * self.num_image_token
- return token_inputs(
- prompt=prompt,
- prompt_token_ids=new_prompt_token_ids,
- multi_modal_data=multi_modal_data,
+ def _images_to_pixel_values_lst(
+ self,
+ images: list[Image.Image],
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> list[torch.Tensor]:
+ use_msac = self.use_msac if len(images) == 1 else False
+
+ min_num, max_num = self.resolve_min_max_num(
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=False, # Applied in image_to_pixel_values
)
- def input_mapper(
+ return [
+ image_to_pixel_values_h2ovl(
+ image,
+ input_size=self.image_size,
+ min_num=min_num,
+ max_num=max_num,
+ use_thumbnail=self.use_thumbnail,
+ use_msac=use_msac,
+ ) for image in images
+ ]
+
+
+class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
+
+ def get_hf_processor(
self,
- ctx: InputContext,
- data: object,
*,
max_dynamic_patch: Optional[int] = None,
- ) -> MultiModalKwargs:
+ dynamic_image_size: Optional[bool] = None,
+ ) -> H2OVLProcessor:
+ return H2OVLProcessor(
+ self.get_hf_config(),
+ self.get_tokenizer(),
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
+ if mm_counts.get("image", 0) <= 1:
+ max_tokens_per_image = max_tokens_one_image
+ else:
+ max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
+
+ return {"image": max_tokens_per_image}
- # NOTE: Preprocessing for the image data is done in the
- # 'input_processor' function during actual inference.
- if isinstance(data, dict):
- return MultiModalKwargs(data)
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[H2OVLProcessor],
+ use_msac: Optional[bool] = None,
+ ) -> int:
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ return processor.get_num_image_tokens(
+ image_width=image_width,
+ image_height=image_height,
+ use_msac=use_msac,
+ )
- # The section below is only used with dummy data during
- # memory profiling.
- hf_config = ctx.get_hf_config()
+ def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
- image_pixel_values_mapper = image_to_pixel_values_wrapper(
- hf_config, max_dynamic_patch)
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ processor=None,
+ use_msac=use_msac,
+ )
- if isinstance(data, Image.Image):
- pixel_values = image_pixel_values_mapper(data)
- pixel_values = pixel_values.unsqueeze(0)
- elif is_list_of(data, Image.Image):
- hf_config.use_msac = False
- pixel_values = [image_pixel_values_mapper(img) for img in data]
+class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
+ ):
+
+ def __init__(self,
+ info: H2OVLProcessingInfo,
+ dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
+ *,
+ cache: Optional[ProcessingCache] = None,
+ enable_sanity_checks: bool = True) -> None:
+ super().__init__(
+ info,
+ dummy_inputs,
+ cache=cache,
+ enable_sanity_checks=enable_sanity_checks,
+ )
+
+ if self.cache is not None:
+ # The processor output depends on the number of images passed,
+ # making it incompatible with processing cache which is supposed
+ # to be invariant of how many images are passed per prompt
+ self.cache = None
+ logger.warning_once(
+ f"{type(self).__name__} does not support processing cache.")
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ if "image_num_patches" in out_mm_kwargs:
+ image_num_patches = out_mm_kwargs["image_num_patches"]
+ assert isinstance(image_num_patches, torch.Tensor)
+ image_num_patches = image_num_patches.tolist()
+ elif "image_embeds" in out_mm_kwargs:
+ # TODO: Use image size information in dictionary embedding inputs
+ # to compute num_patches (similar to Qwen2-VL)
+ image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
- return MultiModalKwargs({"image_embeds": data})
- model_config = ctx.model_config
- tokenizer = cached_get_tokenizer(
- model_config.tokenizer,
- trust_remote_code=model_config.trust_remote_code,
- )
- image_token_id = tokenizer.encode(
- self.img_context_token,
- add_special_tokens=False,
- return_tensors="pt",
- )[0]
+ image_num_patches = []
+
+ num_images = len(image_num_patches)
- return MultiModalKwargs({
- "pixel_values": pixel_values,
- "image_token_id": image_token_id
- })
+ def get_replacement_internvl(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+ if isinstance(images, ImageEmbeddingItems):
+ feature_size = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ feature_size = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ use_msac=None if num_images == 1 else False,
+ )
+
+ num_patches = image_num_patches[item_idx]
+ if num_patches is not None:
+ assert isinstance(num_patches, int)
+
+ return PromptReplacementDetails(
+ full=hf_processor.get_image_repl_full(feature_size,
+ num_patches),
+ features=hf_processor.get_image_repl_features(
+ feature_size, num_patches),
+ )
-input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
+ return [
+ PromptReplacement(
+ modality="image",
+ target="",
+ replacement=get_replacement_internvl,
+ )
+ ]
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
-@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
+@MULTIMODAL_REGISTRY.register_processor(
+ H2OVLMultiModalProcessor,
+ info=H2OVLProcessingInfo,
+ dummy_inputs=InternVLDummyInputsBuilder)
class H2OVLChatModel(InternVLChatModel):
def _init_vision_model(
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index c46a867a76832..08fc659ab610f 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -6,35 +6,37 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
-import re
-from functools import cached_property, partial
+from abc import ABC, abstractmethod
+from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
- TypedDict, Union)
+ TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
-from transformers import PretrainedConfig
+from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext, token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
-from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
-from vllm.multimodal.utils import cached_get_tokenizer
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ ImageSize, MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptReplacementDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
-from vllm.utils import is_list_of
+from vllm.transformers_utils.tokenizer import AnyTokenizer
-from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
- get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@@ -75,22 +77,27 @@ class InternVLImageEmbeddingInputs(TypedDict):
InternVLImageEmbeddingInputs]
-# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
-def build_transform(input_size):
+# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
+def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
- transform = T.Compose([
+ return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
- return transform
-# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
-def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
- image_size):
+# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
+def find_closest_aspect_ratio(
+ aspect_ratio: float,
+ target_ratios: list[tuple[int, int]],
+ *,
+ width: int,
+ height: int,
+ image_size: int,
+) -> tuple[int, int]:
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
@@ -106,67 +113,82 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio
-def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
- max_num: int, image_size: int,
- use_thumbnail: bool) -> Tuple[int, int, int]:
- aspect_ratio = orig_width / orig_height
+def resolve_internvl_min_max_num(
+ *,
+ min_dynamic_patch: int,
+ max_dynamic_patch: int,
+ dynamic_image_size: bool,
+ use_thumbnail: bool,
+) -> tuple[int, int]:
+ max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
+
+ if use_thumbnail and max_dynamic_patch != 1:
+ max_dynamic_patch += 1
+
+ return min_dynamic_patch, max_dynamic_patch
+
- # calculate the existing image aspect ratio
- target_ratios = set((i, j) for n in range(min_num, max_num + 1)
- for i in range(1, n + 1) for j in range(1, n + 1)
- if i * j <= max_num and i * j >= min_num)
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+def get_internvl_target_ratios(
+ min_num: int,
+ max_num: int,
+) -> list[tuple[int, int]]:
+ target_ratios = {(i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1) if min_num <= i * j <= max_num}
+ return sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+
+def calculate_internvl_targets(
+ *,
+ orig_width: int,
+ orig_height: int,
+ target_ratios: list[tuple[int, int]],
+ image_size: int,
+ use_thumbnail: bool,
+) -> tuple[int, int, int]:
+ aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
- target_ratios, orig_width,
- orig_height, image_size)
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio,
+ target_ratios,
+ width=orig_width,
+ height=orig_height,
+ image_size=image_size,
+ )
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
- # add thumbnail image if num_blocks > 1
- if use_thumbnail and blocks > 1:
- blocks += 1
- return blocks, target_width, target_height
-
-def calculate_num_blocks_wrapper(
- hf_config: PretrainedConfig,
- max_dynamic_patch: Optional[int] = None,
- dynamic_image_size: Optional[bool] = None,
-):
- if dynamic_image_size is None:
- dynamic_image_size = hf_config.dynamic_image_size
+ # add thumbnail image if num_blocks != 1
+ if use_thumbnail and blocks != 1:
+ blocks += 1
- max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
- if max_dynamic_patch is None:
- max_dynamic_patch = hf_config.max_dynamic_patch
- min_num = hf_config.min_dynamic_patch
- image_size = hf_config.vision_config.image_size
- use_thumbnail = hf_config.use_thumbnail
- return partial(calculate_num_blocks,
- min_num=min_num,
- max_num=max_dynamic_patch,
- image_size=image_size,
- use_thumbnail=use_thumbnail)
+ return blocks, target_width, target_height
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
-def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
- image_size: int,
- use_thumbnail: bool) -> List[Image.Image]:
+def dynamic_preprocess_internvl(
+ image: Image.Image,
+ *,
+ target_ratios: list[tuple[int, int]],
+ image_size: int,
+ use_thumbnail: bool,
+) -> list[Image.Image]:
orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
- blocks, target_width, target_height = calculate_num_blocks(
- orig_width,
- orig_height,
- min_num,
- max_num,
- image_size,
- use_thumbnail=False)
+ blocks, target_width, target_height = calculate_internvl_targets(
+ orig_width=orig_width,
+ orig_height=orig_height,
+ target_ratios=target_ratios,
+ image_size=image_size,
+ use_thumbnail=False,
+ )
+
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
@@ -178,301 +200,463 @@ def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
+
assert len(processed_images) == blocks
+
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
+
return processed_images
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
-def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
- max_num: int, use_thumbnail: bool) -> torch.Tensor:
+def image_to_pixel_values_internvl(
+ image: Image.Image,
+ *,
+ input_size: int,
+ min_num: int,
+ max_num: int,
+ use_thumbnail: bool,
+) -> torch.Tensor:
+ target_ratios = get_internvl_target_ratios(min_num, max_num)
+
transform = build_transform(input_size=input_size)
- images = dynamic_preprocess(image,
- min_num=min_num,
- max_num=max_num,
- image_size=input_size,
- use_thumbnail=use_thumbnail)
- pixel_values = [transform(image) for image in images]
- pixel_values = torch.stack(pixel_values)
+ images = dynamic_preprocess_internvl(
+ image,
+ target_ratios=target_ratios,
+ image_size=input_size,
+ use_thumbnail=use_thumbnail,
+ )
+
+ pixel_values = torch.stack([transform(image) for image in images])
return pixel_values
-def image_to_pixel_values_wrapper(
- hf_config: PretrainedConfig,
- max_dynamic_patch: Optional[int] = None,
- dynamic_image_size: Optional[bool] = None,
-):
- image_size = hf_config.vision_config.image_size
- min_num = hf_config.min_dynamic_patch
- if dynamic_image_size is None:
- dynamic_image_size = hf_config.dynamic_image_size
+class BaseInternVLProcessor(ABC):
+ """
+ This model doesn't define its own HF processor,
+ so we implement our own one here.
- max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
- if max_dynamic_patch is None:
- max_dynamic_patch = hf_config.max_dynamic_patch
- use_thumbnail = hf_config.use_thumbnail
- return partial(image_to_pixel_values,
- input_size=image_size,
- min_num=min_num,
- max_num=max_dynamic_patch,
- use_thumbnail=use_thumbnail)
-
-
-def get_internvl_num_patches(hf_config: PretrainedConfig):
- vision_config = hf_config.vision_config
- downsample_ratio = hf_config.downsample_ratio
- image_size = vision_config.image_size
- patch_size = vision_config.patch_size
- return int(
- get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
- (downsample_ratio**2))
-
-
-def get_max_internvl_image_tokens(
- ctx: InputContext,
- *,
- max_dynamic_patch: Optional[int] = None,
- dynamic_image_size: Optional[bool] = None,
-):
- hf_config = ctx.get_hf_config()
- if dynamic_image_size is None:
- dynamic_image_size = hf_config.dynamic_image_size
+ The code to insert image tokens is based on:
+ https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
+ """
- max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
- if max_dynamic_patch is None:
- max_dynamic_patch = hf_config.max_dynamic_patch
- use_thumbnail = hf_config.use_thumbnail
- if use_thumbnail and max_dynamic_patch > 1:
- max_dynamic_patch += 1
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> None:
+ super().__init__()
- num_patches = get_internvl_num_patches(hf_config)
- return num_patches * max_dynamic_patch
+ self.config = config
+ self.tokenizer = tokenizer
+ image_size: int = config.vision_config.image_size
+ patch_size: int = config.vision_config.patch_size
-def get_max_internvl_image_size(
- ctx: InputContext,
- *,
- max_dynamic_patch: Optional[int] = None,
- dynamic_image_size: Optional[bool] = None,
-):
- hf_config = ctx.get_hf_config()
- image_size = hf_config.vision_config.image_size
- if dynamic_image_size is None:
- dynamic_image_size = hf_config.dynamic_image_size
+ if dynamic_image_size is None:
+ dynamic_image_size = config.dynamic_image_size
+ assert isinstance(dynamic_image_size, bool)
- max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
- if max_dynamic_patch is None:
- max_dynamic_patch = hf_config.max_dynamic_patch
- use_thumbnail = hf_config.use_thumbnail
- if use_thumbnail and max_dynamic_patch > 1:
- max_dynamic_patch += 1
- width = image_size * max_dynamic_patch
- height = image_size
- return width, height
+ if max_dynamic_patch is None:
+ max_dynamic_patch = config.max_dynamic_patch
+ assert isinstance(max_dynamic_patch, int)
+ self.num_image_token = int(
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
+ self.image_size = image_size
+ self.min_dynamic_patch: int = config.min_dynamic_patch
+ self.max_dynamic_patch = max_dynamic_patch
+ self.dynamic_image_size = dynamic_image_size
+ self.use_thumbnail: bool = config.use_thumbnail
+
+ @property
+ @abstractmethod
+ def image_token_id(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_image_repl_features(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ raise NotImplementedError
-class InternVLInputPipeline:
+ @abstractmethod
+ def get_image_repl_full(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ raise NotImplementedError
- def __init__(
+ def resolve_min_max_num(
self,
- img_start_token: str,
- img_end_token: str,
- img_context_token: str,
- ) -> None:
- super().__init__()
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ use_thumbnail: Optional[bool] = None,
+ ) -> tuple[int, int]:
+ min_dynamic_patch = self.min_dynamic_patch
+ max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
+ is None else max_dynamic_patch)
+ dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
+ is None else dynamic_image_size)
+ use_thumbnail = (self.use_thumbnail
+ if use_thumbnail is None else use_thumbnail)
+
+ return resolve_internvl_min_max_num(
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=use_thumbnail,
+ )
- self.img_start_token = img_start_token
- self.img_end_token = img_end_token
- self.img_context_token = img_context_token
+ def resolve_target_ratios(
+ self,
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ use_thumbnail: Optional[bool] = None,
+ ) -> list[tuple[int, int]]:
+ min_num, max_num = self.resolve_min_max_num(
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=use_thumbnail,
+ )
- def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
- return (self.img_start_token + self.img_context_token * feature_size +
- self.img_end_token)
+ return get_internvl_target_ratios(min_num, max_num)
- def _expand_image_prompt(
+ def get_num_image_tokens(
self,
- prompt: str,
- feature_sizes: List[int],
- num_patches: int,
- ) -> str:
- image_idx = sorted(
- map(int, re.findall(r"Image-(\d+): \n", prompt)))
+ *,
+ image_width: int,
+ image_height: int,
+ ) -> int:
+ target_ratios = self.resolve_target_ratios(
+ use_thumbnail=False, # Applied in calculate_targets
+ )
+
+ num_patches, _, _ = calculate_internvl_targets(
+ orig_width=image_width,
+ orig_height=image_height,
+ image_size=self.image_size,
+ target_ratios=target_ratios,
+ use_thumbnail=self.use_thumbnail,
+ )
- new_prompt = prompt
- for idx, feature_size in enumerate(feature_sizes, start=1):
- image_prompt = self._create_image_prompt(feature_size, num_patches)
- if not image_idx:
- image_prompt = f"Image-{idx}: {image_prompt}"
+ return num_patches * self.num_image_token
- new_prompt = new_prompt.replace('', image_prompt, 1)
+ def _images_to_pixel_values_lst(
+ self,
+ images: list[Image.Image],
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> list[torch.Tensor]:
+ min_num, max_num = self.resolve_min_max_num(
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=False, # Applied in image_to_pixel_values
+ )
- return new_prompt
+ return [
+ image_to_pixel_values_internvl(
+ image,
+ input_size=self.image_size,
+ min_num=min_num,
+ max_num=max_num,
+ use_thumbnail=self.use_thumbnail,
+ ) for image in images
+ ]
- def input_processor(
+ def __call__(
self,
- ctx: InputContext,
- inputs: DecoderOnlyInputs,
- *,
+ text: Optional[Union[str, list[str]]] = None,
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
- ) -> DecoderOnlyInputs:
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- model_config = ctx.model_config
- hf_config = ctx.get_hf_config()
-
- image_data = multi_modal_data["image"]
- num_patches = get_internvl_num_patches(hf_config)
- num_blocks_calculator = calculate_num_blocks_wrapper(
- hf_config, max_dynamic_patch, dynamic_image_size)
- if isinstance(image_data, Image.Image):
- width, height = image_data.size
- num_blocks, _, _ = num_blocks_calculator(width, height)
- image_feature_sizes = [num_blocks * num_patches]
- elif is_list_of(image_data, Image.Image):
- image_feature_sizes = []
- for image in image_data:
- width, height = image.size
- num_blocks, _, _ = num_blocks_calculator(width, height)
- image_feature_sizes.append(num_blocks * num_patches)
- elif isinstance(image_data, torch.Tensor):
- num_images, image_feature_size, hidden_size = image_data.shape
- image_feature_sizes = [image_feature_size]
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ ) -> BatchFeature:
+ if text is None:
+ text = []
+ if not isinstance(text, list):
+ text = [text]
+ if images is None:
+ images = []
+ if not isinstance(images, list):
+ images = [images]
+
+ if len(images) == 0:
+ image_inputs = {}
else:
- raise TypeError(f"Invalid image type: {type(image_data)}")
-
- tokenizer = cached_get_tokenizer(
- model_config.tokenizer,
- trust_remote_code=model_config.trust_remote_code)
-
- prompt = inputs.get("prompt")
- prompt_token_ids = inputs["prompt_token_ids"]
- if prompt is None:
- prompt = tokenizer.decode(prompt_token_ids)
-
- new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
- num_patches)
- new_prompt_token_ids = tokenizer.encode(new_prompt)
- img_context_token_id = tokenizer.encode(self.img_context_token,
- add_special_tokens=False)
- assert len(img_context_token_id) == 1, \
- (f"Invalid image token '{self.img_context_token}': A valid image "
- f"token encodes to a single token ID, got {img_context_token_id}.")
- img_context_token_id = img_context_token_id[0]
-
- # Get precise tracking of placeholder positions
- token_idx = image_idx = 0
- placeholder_ranges = []
- while token_idx < len(new_prompt_token_ids):
- if new_prompt_token_ids[token_idx] == img_context_token_id:
- curr_image_featue_size = image_feature_sizes[image_idx]
- placeholder_ranges.append(
- PlaceholderRange(offset=token_idx,
- length=curr_image_featue_size))
- image_idx += 1
- token_idx += curr_image_featue_size
- else:
- token_idx += 1
+ pixel_values_lst = self._images_to_pixel_values_lst(
+ images,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+ image_inputs = {
+ "pixel_values_flat": torch.cat(pixel_values_lst),
+ "image_num_patches": list(map(len, pixel_values_lst)),
+ }
+
+ for pixel_values in pixel_values_lst:
+ num_patches = pixel_values.shape[0]
+ feature_size = num_patches * self.num_image_token
+
+ image_repl = self.get_image_repl_full(feature_size,
+ num_patches)
+ text = [t.replace('', image_repl, 1) for t in text]
+
+ text_inputs = self.tokenizer(text)
+
+ return BatchFeature(
+ {
+ **text_inputs,
+ **image_inputs,
+ },
+ tensor_type=return_tensors,
+ )
- return token_inputs(
- prompt=prompt,
- prompt_token_ids=new_prompt_token_ids,
- multi_modal_data=multi_modal_data,
- multi_modal_placeholders={"image": placeholder_ranges})
- def input_mapper(
+class InternVLProcessor(BaseInternVLProcessor):
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[IMG_CONTEXT]
+
+ def get_image_repl_features(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ return IMG_CONTEXT * feature_size
+
+ def get_image_repl_full(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ features = self.get_image_repl_features(feature_size, num_patches)
+ return IMG_START + features + IMG_END
+
+
+class BaseInternVLProcessingInfo(BaseProcessingInfo):
+
+ @abstractmethod
+ def get_hf_processor(
self,
- ctx: InputContext,
- data: object,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
- ):
- hf_config = ctx.get_hf_config()
-
- image_pixel_values_mapper = image_to_pixel_values_wrapper(
- hf_config, max_dynamic_patch, dynamic_image_size)
- if isinstance(data, Image.Image):
- data = image_pixel_values_mapper(data)
- # Add an N dimension for number of images per prompt (currently 1).
- data = data.unsqueeze(0)
- elif is_list_of(data, Image.Image):
- # we can't stack here because images may have different num_patches
- data = [image_pixel_values_mapper(img) for img in data]
- else:
- return MultiModalKwargs({"image_embeds": data})
- model_config = ctx.model_config
- tokenizer = cached_get_tokenizer(
- model_config.tokenizer,
- trust_remote_code=model_config.trust_remote_code)
- image_token_id = tokenizer.encode(self.img_context_token,
- add_special_tokens=False,
- return_tensors="pt")[0]
-
- return MultiModalKwargs({
- "pixel_values": data,
- "image_token_id": image_token_id
- })
-
- def dummy_data(
+ ) -> BaseInternVLProcessor:
+ raise NotImplementedError
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
self,
- ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_max_image_tokens()}
+
+ def get_num_image_tokens(
+ self,
*,
- max_dynamic_patch: Optional[int] = None,
- dynamic_image_size: Optional[bool] = None,
- ):
- num_images = mm_counts["image"]
+ image_width: int,
+ image_height: int,
+ processor: Optional[BaseInternVLProcessor],
+ ) -> int:
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ return processor.get_num_image_tokens(
+ image_width=image_width,
+ image_height=image_height,
+ )
- hf_config = ctx.get_hf_config()
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
- image_feature_size = get_max_internvl_image_tokens(
- ctx,
- max_dynamic_patch=max_dynamic_patch,
- dynamic_image_size=dynamic_image_size,
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ processor=None,
)
- model_config = ctx.model_config
- tokenizer = cached_get_tokenizer(
- model_config.tokenizer,
- trust_remote_code=model_config.trust_remote_code)
-
- seq_data, ranges = dummy_seq_data_for_clip(
- hf_config.vision_config,
- seq_len,
- num_images,
- image_token_id=tokenizer.encode(self.img_context_token,
- add_special_tokens=False)[0],
- image_feature_size_override=image_feature_size,
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ processor = self.get_hf_processor()
+
+ base_size = processor.image_size
+ target_ratios = processor.resolve_target_ratios()
+
+ largest_feature_size, largest_feature_pinpoint = 0, None
+ for wr, hr in target_ratios:
+ width, height = base_size * wr, base_size * hr
+
+ feat_size = self.get_num_image_tokens(
+ image_width=width,
+ image_height=height,
+ processor=processor,
+ )
+ if feat_size > largest_feature_size:
+ largest_feature_size = feat_size
+ largest_feature_pinpoint = ImageSize(width=width,
+ height=height)
+
+ if largest_feature_size == 0 or largest_feature_pinpoint is None:
+ raise ValueError("Cannot have a largest feature size of 0!")
+
+ return largest_feature_pinpoint
+
+
+_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
+
+
+class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="" * num_images,
+ mm_data=mm_data,
)
- max_image_width, max_image_height = get_max_internvl_image_size(
- ctx,
- max_dynamic_patch=max_dynamic_patch,
- dynamic_image_size=dynamic_image_size,
+
+class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
)
- mm_data = dummy_image_for_clip(
- hf_config.vision_config,
- num_images,
- image_width_override=max_image_width,
- image_height_override=max_image_height,
+ image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
+ image_data = mm_data.get("images", [])
+ assert isinstance(image_data, list)
+
+ # Since there may be extra tokens in the feature placeholders,
+ # we need to pass the image token ID to the model to select the
+ # tokens to merge from the vision encoder outputs
+ processed_outputs["image_token_id"] = [image_token_id
+ ] * len(image_data)
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
+
+ return dict(
+ pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_num_patches),
+ image_num_patches=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ image_token_id=MultiModalFieldConfig.batched("image"),
)
- return DummyData(seq_data, mm_data, ranges)
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ if "image_num_patches" in out_mm_kwargs:
+ image_num_patches = out_mm_kwargs["image_num_patches"]
+ assert isinstance(image_num_patches, torch.Tensor)
+ image_num_patches = image_num_patches.tolist()
+ elif "image_embeds" in out_mm_kwargs:
+ # TODO: Use image size information in dictionary embedding inputs
+ # to compute num_patches (similar to Qwen2-VL)
+ image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
+ else:
+ image_num_patches = []
+
+ def get_replacement_internvl(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+
+ if isinstance(images, ImageEmbeddingItems):
+ feature_size = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ feature_size = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ )
+
+ num_patches = image_num_patches[item_idx]
+ if num_patches is not None:
+ assert isinstance(num_patches, int)
+
+ return PromptReplacementDetails(
+ full=hf_processor.get_image_repl_full(feature_size,
+ num_patches),
+ features=hf_processor.get_image_repl_features(
+ feature_size, num_patches),
+ )
+ return [
+ PromptReplacement(
+ modality="image",
+ target="",
+ replacement=get_replacement_internvl,
+ )
+ ]
-input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
+class InternVLProcessingInfo(BaseInternVLProcessingInfo):
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
-@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
+ def get_hf_processor(
+ self,
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> InternVLProcessor:
+ return InternVLProcessor(
+ self.get_hf_config(),
+ self.get_tokenizer(),
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ InternVLMultiModalProcessor,
+ info=InternVLProcessingInfo,
+ dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
@@ -621,11 +805,11 @@ def _validate_shape(d: torch.Tensor):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImageInputs]:
- pixel_values = kwargs.pop("pixel_values", None)
- image_token_id = kwargs.pop("image_token_id", None)
+ pixel_values_flat = kwargs.pop("pixel_values_flat", None)
+ image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
- if pixel_values is None and image_embeds is None:
+ if pixel_values_flat is None and image_embeds is None:
return None
if image_embeds is not None:
@@ -638,31 +822,30 @@ def _parse_and_validate_image_input(
data=flatten_bn(image_embeds),
)
- self.img_context_token_id = image_token_id[0]
+ image_token_id = kwargs["image_token_id"]
+ assert isinstance(image_token_id, torch.Tensor)
+ self.img_context_token_id = image_token_id.flatten().unique().item()
- if pixel_values is not None:
- if not isinstance(pixel_values, (torch.Tensor, list)):
+ if pixel_values_flat is not None:
+ if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values)}")
-
- patches_per_image = []
- for request_pixel_values in pixel_values:
- for image_pixel_values in request_pixel_values:
- patches_per_image.append(image_pixel_values.shape[0])
- # We need to flatten (B, N, P) to (B*N*P),
- # so we call flatten_bn twice.
+ f"Got type: {type(pixel_values_flat)}")
+
+ assert isinstance(image_num_patches, (torch.Tensor, list))
+
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
- flatten_bn(flatten_bn(pixel_values), concat=True)),
- patches_per_image=patches_per_image)
+ flatten_bn(pixel_values_flat, concat=True)),
+ patches_per_image=flatten_bn(image_num_patches,
+ concat=True).tolist())
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
- ) -> Tuple[torch.Tensor]:
+ ) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@@ -689,7 +872,7 @@ def _process_image_input(
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds
- def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
+ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index 19effcbfc5512..63d308ef6d191 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -125,7 +125,11 @@ def get_hf_processor(self) -> LlavaLikeProcessor:
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(
diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py
index d70ae2f148ff9..817edcef4ba14 100644
--- a/vllm/model_executor/models/llava_next_video.py
+++ b/vllm/model_executor/models/llava_next_video.py
@@ -62,7 +62,11 @@ def get_hf_processor(self):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index f1c06cd85967c..2889426283f84 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -103,7 +103,11 @@ def get_hf_processor(self):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py
index f1c1680768b8d..ab697fb8cc645 100644
--- a/vllm/model_executor/models/minicpmo.py
+++ b/vllm/model_executor/models/minicpmo.py
@@ -23,7 +23,6 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial
-from itertools import accumulate
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
@@ -138,11 +137,15 @@ def get_supported_mm_modalities(self) -> List[str]:
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"audio": self.get_max_audio_tokens(),
- "video": self.get_max_video_tokens(seq_len)
+ "video": self.get_max_video_tokens(seq_len),
}
def get_default_audio_pool_step(self) -> int:
@@ -369,23 +372,18 @@ def _get_mm_fields_config(
hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
+ audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
- def get_slices(num_slices: List[int]) -> List[int]:
- slice_indices = [0] + list(accumulate(num_slices))
- slices = [(slice_indices[i], slice_indices[i + 1])
- for i in range(len(num_slices))]
- return [slice(*slice_item) for slice_item in slices]
-
- audio_slices = get_slices(
- hf_inputs.get("audio_num_slices", torch.empty(0)))
return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
- audio_features=MultiModalFieldConfig.flat("audio", audio_slices),
- audio_feature_lens=MultiModalFieldConfig.flat(
- "audio", audio_slices),
+ audio_features=MultiModalFieldConfig.flat_from_sizes(
+ "audio", audio_num_slices),
+ audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
+ "audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
- audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices))
+ audio_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "audio", audio_num_slices))
class MultiModalProjector(nn.Module):
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 6964d6bdce9f7..3d16d635b578a 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -26,7 +26,6 @@
import re
from collections import Counter
from functools import cached_property, partial
-from itertools import accumulate
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
@@ -365,7 +364,11 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
else:
return {"image": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
@@ -761,30 +764,25 @@ def _get_mm_fields_config(
hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
-
- def get_slices(num_slices: List[int]) -> List[int]:
- slice_indices = [0] + list(accumulate(num_slices))
- slices = [(slice_indices[i], slice_indices[i + 1])
- for i in range(len(num_slices))]
- return [slice(*slice_item) for slice_item in slices]
-
- image_slices = get_slices(
- hf_inputs.get("image_num_slices", torch.empty(0)))
- video_slices = get_slices(
- hf_inputs.get("video_num_slices", torch.empty(0)))
-
- return dict(
- pixel_values=MultiModalFieldConfig.flat("image", image_slices),
- image_sizes=MultiModalFieldConfig.batched("image"),
- tgt_sizes=MultiModalFieldConfig.flat("image", image_slices),
- image_num_slices=MultiModalFieldConfig.batched("image"),
- image_embeds=MultiModalFieldConfig.flat("image", image_slices),
- video_pixel_values=MultiModalFieldConfig.flat(
- "video", video_slices),
- video_image_sizes=MultiModalFieldConfig.batched("video"),
- video_tgt_sizes=MultiModalFieldConfig.flat("video", video_slices),
- video_embeds=MultiModalFieldConfig.flat("video", video_slices),
- video_num_slices=MultiModalFieldConfig.batched("video"))
+ image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
+ video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
+
+ return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_num_slices),
+ image_sizes=MultiModalFieldConfig.batched("image"),
+ tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_num_slices),
+ image_num_slices=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_num_slices),
+ video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_num_slices),
+ video_image_sizes=MultiModalFieldConfig.batched("video"),
+ video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_num_slices),
+ video_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_num_slices),
+ video_num_slices=MultiModalFieldConfig.batched("video"))
def apply(
self,
diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py
index 2aa04bd717260..9c674ab464463 100644
--- a/vllm/model_executor/models/nvlm_d.py
+++ b/vllm/model_executor/models/nvlm_d.py
@@ -6,44 +6,190 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
-from typing import Optional
+from typing import Mapping, Optional
+import torch
import torch.nn as nn
from transformers import PretrainedConfig
-from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalKwargs
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ MultiModalDataItems)
+from vllm.multimodal.processing import (PromptReplacement,
+ PromptReplacementDetails)
+from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel
-from .internvl import (InternVLChatModel, InternVLInputPipeline,
- get_max_internvl_image_tokens)
+from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
+ InternVLChatModel, InternVLDummyInputsBuilder,
+ InternVLMultiModalProcessor)
-IMG_START = '<|vision_start|>'
-IMG_END = '<|vision_end|>'
-IMG_CONTEXT = '<|vision_pad|>'
+IMG_PAD = "<|vision_pad|>"
-class NVLMInputPipeline(InternVLInputPipeline):
+class NVLMProcessor(BaseInternVLProcessor):
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[IMG_PAD]
+
+ def get_image_repl_features(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ if num_patches is None:
+ raise NotImplementedError("Embedding inputs are not supported")
+
+ tile_pos_identifiers = [f"" for i in range(1, num_patches)]
+ if self.use_thumbnail and num_patches != 1:
+ tile_pos_identifiers += [""]
- def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
- tile_pos_identifiers = ([f""
- for i in range(1, num_patches)] +
- [""])
context_size = feature_size // num_patches
+ features = "".join(identifier + IMG_PAD * context_size
+ for identifier in tile_pos_identifiers)
+
+ # We include the start and end as well because "<", "tile"], resulting in assertion error
+ # when trying to find "" + features + ""
+
+ def get_image_repl_full(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> str:
+ return self.get_image_repl_features(feature_size, num_patches)
+
+
+class NVLMProcessingInfo(BaseInternVLProcessingInfo):
+
+ def get_hf_processor(
+ self,
+ *,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> NVLMProcessor:
+ return NVLMProcessor(
+ self.get_hf_config(),
+ self.get_tokenizer(),
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+
+ def get_max_image_tokens(self) -> int:
+ hf_processor = self.get_hf_processor()
+ tokenizer = hf_processor.tokenizer
+
+ max_num_patches = hf_processor.max_dynamic_patch
+ # we need +1 here because max_dynamic_patch in config doesn't
+ # include the thumbnail patch
+ tile_pos_identifiers = [
+ f"" for i in range(max_num_patches)
+ ]
+ if hf_processor.use_thumbnail and max_num_patches != 1:
+ tile_pos_identifiers += [""]
+
+ # "<", "tile"]
+ # so we include in the start_str
+ start_str = "" + tile_pos_identifiers.pop(0)
+ end_str = ""
+ start_token_len = len(tokenizer.encode(start_str))
+ end_token_len = len(tokenizer.encode(end_str))
+ tile_token_len = sum(
+ len(tokenizer.encode(identifier))
+ for identifier in tile_pos_identifiers)
+ non_image_tokens_num = start_token_len + end_token_len + tile_token_len
+ return super().get_max_image_tokens() + non_image_tokens_num
+
+
+class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ # The newline is necessary to separate ">" of the current item
+ # and "<" of the next item
+ prompt_text="\n" * num_images,
+ mm_data=mm_data,
+ )
+
+
+class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
- return '' + ''.join(
- tile_pos_identifier + self.img_context_token * context_size
- for tile_pos_identifier in tile_pos_identifiers) + ''
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ if "image_num_patches" in out_mm_kwargs:
+ image_num_patches = out_mm_kwargs["image_num_patches"]
+ assert isinstance(image_num_patches, torch.Tensor)
+ image_num_patches = image_num_patches.tolist()
+ elif "image_embeds" in out_mm_kwargs:
+ # TODO: Use image size information in dictionary embedding inputs
+ # to compute num_patches (similar to Qwen2-VL)
+ image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
+ else:
+ image_num_patches = []
+
+ def get_replacement_nvlm(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+ if isinstance(images, ImageEmbeddingItems):
+ feature_size = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ feature_size = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ )
+
+ num_patches = image_num_patches[item_idx]
+ if num_patches is not None:
+ assert isinstance(num_patches, int)
+
+ return PromptReplacementDetails(
+ full=hf_processor.get_image_repl_full(feature_size,
+ num_patches) + "\n",
+ features=hf_processor.get_image_repl_features(
+ feature_size, num_patches) + "\n",
+ )
-input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
+ # See note in dummy data regarding why we have the extra newline
+ return [
+ PromptReplacement(
+ modality="image",
+ target="\n",
+ replacement=get_replacement_nvlm,
+ )
+ ]
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
-@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
+@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor,
+ info=NVLMProcessingInfo,
+ dummy_inputs=NVLMDummyInputsBuilder)
class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index f089fa5d295eb..053390c521fc2 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -322,7 +322,11 @@ def get_hf_processor(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index d7f6662bc9a97..327fad0f5702d 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -779,7 +779,11 @@ def get_hf_processor(self) -> QWenVLProcessor:
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
@@ -799,13 +803,13 @@ def get_dummy_processor_inputs(
vision_config = hf_config.visual
- max_image_size = vision_config["image_size"]
+ target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
- self._get_dummy_images(width=max_image_size,
- height=max_image_size,
+ self._get_dummy_images(width=target_width,
+ height=target_height,
num_images=num_images)
}
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
index cf104ab008722..f09529ca4bd1f 100644
--- a/vllm/model_executor/models/qwen2_audio.py
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -110,7 +110,11 @@ def get_feature_extractor(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 189ac41e8a6c1..2b2638cf68fc7 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -758,7 +758,11 @@ def get_image_processor(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
@@ -989,26 +993,21 @@ def _get_mm_fields_config(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
- image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
- image_slices = [
- slice(image_slice_idxs[i], image_slice_idxs[i + 1])
- for i in range(len(image_grid_thw))
- ]
+ image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
- video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
- video_slices = [
- slice(video_slice_idxs[i], video_slice_idxs[i + 1])
- for i in range(len(video_grid_thw))
- ]
+ video_grid_sizes = video_grid_thw.prod(-1)
return dict(
- pixel_values=MultiModalFieldConfig.flat("image", image_slices),
- image_embeds=MultiModalFieldConfig.flat("image", image_slices),
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_grid_sizes),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
- pixel_values_videos=MultiModalFieldConfig.flat(
- "video", video_slices),
- video_embeds=MultiModalFieldConfig.flat("video", video_slices),
+ pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_grid_sizes),
+ video_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index 5e86b15db7a8f..52a4d798f4bff 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -92,7 +92,11 @@ def get_feature_extractor(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index eb52551bbdb7b..fe24c7282f3cf 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -4,6 +4,7 @@
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
+from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
@@ -258,6 +259,16 @@ def flat(modality: str, slices: Sequence[slice]):
slices=slices,
)
+ @staticmethod
+ def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
+ slice_idxs = [0, *accumulate(size_per_item)]
+ slices = [
+ slice(slice_idxs[i], slice_idxs[i + 1])
+ for i in range(len(size_per_item))
+ ]
+
+ return MultiModalFieldConfig.flat(modality, slices)
+
def __init__(
self,
field_cls: type[BaseMultiModalField],
diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py
index 2ad42d1c1c057..d704fa59b96af 100644
--- a/vllm/multimodal/processing.py
+++ b/vllm/multimodal/processing.py
@@ -680,7 +680,11 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
raise NotImplementedError
@abstractmethod
- def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py
index 953c010003250..5dd7548540448 100644
--- a/vllm/multimodal/profiling.py
+++ b/vllm/multimodal/profiling.py
@@ -151,7 +151,8 @@ def get_dummy_data(self, seq_len: int) -> DummyData:
mm_counts = self.get_mm_limits()
info = self.processing_info
- mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
+ mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
+ seq_len, mm_counts)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py
index 29036691bfa49..04141114288c9 100644
--- a/vllm/multimodal/registry.py
+++ b/vllm/multimodal/registry.py
@@ -264,7 +264,9 @@ def get_max_tokens_per_item_by_modality(
)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
- return processor.info.get_mm_max_tokens_per_item(seq_len)
+ mm_limits = self.get_mm_limits_per_prompt(model_config)
+ return processor.info.get_mm_max_tokens_per_item(
+ seq_len, mm_limits)
return {
key: plugin.get_max_multimodal_tokens(model_config)