From daed30c4a917c870f8fbddf45e3b027710c0842b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 31 Jul 2024 23:46:17 +0800 Subject: [PATCH] [Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982) --- tests/models/test_llava_next.py | 88 +++++++++++++++++++----- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/internvl.py | 6 +- vllm/model_executor/models/llava_next.py | 48 ++++++------- vllm/model_executor/models/phi3v.py | 4 +- 5 files changed, 98 insertions(+), 50 deletions(-) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 9c64f39eb6d08..b6d72dee5c5b5 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, overload import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs +@overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -62,13 +63,55 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [prompt for _ in sizes], + [image.resize(size) for size in sizes], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + else: + raise ValueError("You must provide either `size_factors` or `sizes`") # max_model_len should be greater than image_feature_size with vllm_runner(model, @@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ) -@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), - (183, 488, 776)]) -def test_image_feature_size(height_and_width_and_result): - # Avoid initializing CUDA too early in distributed tests - from vllm.model_executor.models.llava_next import ( - get_llava_next_image_feature_size) - - height, width, result = height_and_width_and_result - config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") - assert get_llava_next_image_feature_size(config, - input_height=height, - input_width=width) == result +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]], +) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index fdea8ee30ce68..c4738263c3056 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): raise TypeError(f"Invalid image type: {type(image_data)}") # process prompts - prompt = llm_inputs["prompt"] + prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] tokenizer = cached_get_tokenizer(model_config.model) # dim0 is batch_size, dim1 is subseq_size which will always be 1 diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index f64c78c15f8ee..eabc283b1efdb 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -20,7 +20,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput @@ -43,7 +43,7 @@ class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: BatchedTensors + data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` @@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - prompt = llm_inputs["prompt"] + prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 5abb55c2cc415..4a67b9a583ea8 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -21,7 +21,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, @@ -43,7 +43,7 @@ class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: BatchedTensors + data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` @@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict): LlavaNextImageInputs = LlavaNextImagePixelInputs -# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91 -# NOTE: new_height and new_width are further incremented to properly invert the -# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133 +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 def _get_llava_next_num_unpadded_features( - height: int, - width: int, + original_height: int, + original_width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width - current_height = torch.tensor(current_height).to("cuda") - current_width = torch.tensor(current_width).to("cuda") - aspect_ratio: float = width / height - current_aspect_ratio: float = current_width / current_height + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + if aspect_ratio > current_aspect_ratio: - scale_factor = current_width / width - new_height = int(height * scale_factor) + new_height = (original_height * current_width) // original_width padding = (current_height - new_height) // 2 current_height -= padding * 2 else: - scale_factor = current_height / height - new_width = int(width * scale_factor) + new_width = (original_width * current_height) // original_height padding = (current_width - new_width) // 2 current_width -= padding * 2 @@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features( return (unpadded_features, newline_features) -# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111 +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 def get_llava_next_image_feature_size( hf_config: LlavaNextConfig, *, @@ -111,9 +106,7 @@ def get_llava_next_image_feature_size( ) base_feature_size = num_patches * num_patches - # Note: We follow the "wrong" width/height order - # [ref: PR huggingface/transformers#31588] - num_patch_width, num_patch_height = get_anyres_image_grid_shape( + num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_size=(input_height, input_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_config.image_size, @@ -349,11 +342,12 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + # image_aspect_ratio == "anyres" - # Note: We follow the "wrong" width/height order - # [ref: PR huggingface/transformers#31588] - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) @@ -365,7 +359,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, .permute(4, 0, 2, 1, 3).contiguous() \ .flatten(1, 2).flatten(2, 3) other_patch_embeds = unpad_image(other_patch_embeds, - image_size) + (orig_height, orig_width)) other_patch_embeds = torch.cat(( other_patch_embeds, self.image_newline[:, None, None] \ @@ -398,7 +392,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, - ) -> BatchedTensors: + ) -> Union[torch.Tensor, List[torch.Tensor]]: assert self.vision_tower is not None pixel_values = inputs["data"] @@ -425,7 +419,9 @@ def _process_image_pixels( ] def _process_image_input( - self, image_input: LlavaNextImageInputs) -> BatchedTensors: + self, + image_input: LlavaNextImageInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 75e2f5fc95cb7..823c34b101870 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -36,7 +36,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput @@ -261,7 +261,7 @@ def add_image_newline(self, image_features_hd): class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: BatchedTensors + data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`