Skip to content

Commit

Permalink
[Bugfix] Fix feature size calculation for LLaVA-NeXT (vllm-project#6982)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jul 31, 2024
1 parent 2f4e108 commit daed30c
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 50 deletions.
88 changes: 70 additions & 18 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type, overload

import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand Down Expand Up @@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
Expand All @@ -62,13 +63,55 @@ def run_test(
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[prompt for _ in sizes],
[image.resize(size) for size in sizes],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
Expand Down Expand Up @@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
)


@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
(183, 488, 776)])
def test_image_feature_size(height_and_width_and_result):
# Avoid initializing CUDA too early in distributed tests
from vllm.model_executor.models.llava_next import (
get_llava_next_image_feature_size)

height, width, result = height_and_width_and_result
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
assert get_llava_next_image_feature_size(config,
input_height=height,
input_width=width) == result
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}")

# process prompts
prompt = llm_inputs["prompt"]
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
Expand All @@ -43,7 +43,7 @@

class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Expand Down Expand Up @@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)

prompt = llm_inputs["prompt"]
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
Expand Down
48 changes: 22 additions & 26 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
Expand All @@ -43,7 +43,7 @@

class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Expand All @@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict):
LlavaNextImageInputs = LlavaNextImagePixelInputs


# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_llava_next_num_unpadded_features(
height: int,
width: int,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
current_height = torch.tensor(current_height).to("cuda")
current_width = torch.tensor(current_width).to("cuda")

aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height

if aspect_ratio > current_aspect_ratio:
scale_factor = current_width / width
new_height = int(height * scale_factor)
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
scale_factor = current_height / height
new_width = int(width * scale_factor)
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width -= padding * 2

Expand All @@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features)


# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
Expand All @@ -111,9 +106,7 @@ def get_llava_next_image_feature_size(
)
base_feature_size = num_patches * num_patches

# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
Expand Down Expand Up @@ -349,11 +342,12 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]

# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()

# image_aspect_ratio == "anyres"
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
Expand All @@ -365,7 +359,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
image_size)
(orig_height, orig_width))
other_patch_embeds = torch.cat((
other_patch_embeds,
self.image_newline[:, None, None] \
Expand Down Expand Up @@ -398,7 +392,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
def _process_image_pixels(
self,
inputs: LlavaNextImagePixelInputs,
) -> BatchedTensors:
) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None

pixel_values = inputs["data"]
Expand All @@ -425,7 +419,9 @@ def _process_image_pixels(
]

def _process_image_input(
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
patch_embeddings = self._process_image_pixels(image_input)

image_sizes = image_input.get("image_sizes")
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput

Expand Down Expand Up @@ -261,7 +261,7 @@ def add_image_newline(self, image_features_hd):

class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Expand Down

0 comments on commit daed30c

Please sign in to comment.