From 9675fb6851062dccd3f93672c67bc3a16d94dfce Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 6 Sep 2024 13:40:02 -0400 Subject: [PATCH 1/7] Support multiple images for qwen-vl Signed-off-by: Alex-Brooks --- docs/source/models/supported_models.rst | 4 ++-- vllm/model_executor/models/qwen.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ec8acb224fdf3..3afe60be2b87d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -248,8 +248,8 @@ Multimodal Language Models - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - * - :code:`QWenLMHeadModel` - - Qwen-VL - - Image\ :sup:`E` + - Qwen + - Image\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`Qwen2VLForConditionalGeneration` diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a726ec10984c0..18bc6b303f485 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -47,6 +47,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) +from vllm.utils import is_list_of from .utils import flatten_bn, is_pp_missing_parameter, make_layers @@ -684,9 +685,12 @@ def input_processor_for_qwen(ctx: InputContext, raise ValueError( f"Expected img embeds to be have 3 dimensions, got {num_dims}") num_images = 1 if num_dims == 2 else image_data.shape[0] - else: - # TODO - handle multiple image inputs once the API is solidified + elif isinstance(image_data, Image.Image): num_images = 1 + elif is_list_of(image_data, Image.Image): + num_images = len(image_data) + else: + raise TypeError(f"Invalid image type: {type(image_data)}") if prompt is None: prompt = tokenizer.decode(prompt_token_ids) @@ -767,11 +771,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " f"received shape [{data.shape}]") pixel_values = data - else: transform = build_normalization_transform(image_size) - # TODO - handle multiple image inputs once the API is solidified - transformed_images = [transform(data)] + if not isinstance(data, (list, tuple)): + data = [data] + transformed_images = [transform(datum) for datum in data] pixel_values = torch.stack(transformed_images, dim=0) return MultiModalInputs({"pixel_values": pixel_values}) From 68417ac4a0a1b950a0ac4960a38e611913f2f922 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 9 Sep 2024 04:41:46 -0400 Subject: [PATCH 2/7] Add tests for qwen mm input processor/mapper Signed-off-by: Alex-Brooks --- docs/source/models/supported_models.rst | 2 +- tests/models/test_qwen.py | 152 +++++++++++++++++++++++- 2 files changed, 150 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 3afe60be2b87d..62f9359fe1e2a 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -248,7 +248,7 @@ Multimodal Language Models - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - * - :code:`QWenLMHeadModel` - - Qwen + - Qwen-VL - Image\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 05f5cbf8c3435..c4336b9e157e1 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -1,9 +1,16 @@ import pathlib -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type, Union import pytest +import torch +from PIL.Image import Image -from vllm.multimodal.utils import rescale_image_size +from vllm.config import ModelConfig +from vllm.inputs import InputContext, LLMInputs +from vllm.model_executor.models.qwen import (input_mapper_for_qwen, + input_processor_for_qwen) +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close @@ -23,8 +30,147 @@ "Picture 1: \nWhat is the season?: ", }) +### Multimodal preprocessing tests +SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image +# These values are specific to Qwen-VL/Chat; we can get these from the model +# config also, but they are hardcoded here to keep the parameterize/fixtures +# easy to read. +IMG_START_ID = 151857 +IMG_END_ID = 151858 +IMG_PAD_ID = 151859 +TOKS_PER_IMG = 256 +VIS_ENC_DIM = 4096 +IMG_SIZE = 448 -### Tests for multimodal Qwen models + +def build_model_context(model_name: str, + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False): + """Creates an InputContext for a given model. + + Args: + model_name: Name of the model being considered. + tokenizer_name: Name of the tokenizer being considered. + trust_remote_code: Whether or not to allow loading remote code. + + Returns: + InputContext for the model being considered. + """ + if tokenizer_name is None: + tokenizer_name = model_name + model_config = ModelConfig( + model_name, + tokenizer_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + dtype="float32", + seed=0, + ) + return InputContext(model_config) + + +@pytest.fixture() +def qwen_vl_context() -> InputContext: + """Get an InputContext for Qwen-VL.""" + return build_model_context(model_name="Qwen/Qwen-VL", + trust_remote_code=True) + + +# Happy path tests for single/multi-image scenarios for the multimodal +# input processor and mapper, respectively +@pytest.mark.parametrize("num_images", [1, 2]) +def test_input_processor_valid_mm_data(qwen_vl_context: InputContext, + num_images: int): + """Happy cases for image inputs to Qwen's multimodal input processor.""" + prompt = "".join( + [f"Picture {num}: \n" for num in range(1, num_images + 1)]) + inputs = LLMInputs( + prompt=prompt, + # When processing multimodal data for a multimodal model, the qwen + # input processor will overwrite the provided prompt_token_ids with + # the image prompts + prompt_token_ids=None, + multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, + ) + proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) + assert isinstance(proc_inputs, dict) + + # Each image should have one start / stop and a fixed context of 256 + proc_tokens = proc_inputs["prompt_token_ids"] + assert proc_tokens.count(IMG_START_ID) == num_images + assert proc_tokens.count(IMG_END_ID) == num_images + assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG + + +@pytest.mark.parametrize( + "img_data,expected_shape", + [ + # single / multi-image + (SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)), + (2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)), + # single / multi-image embeddings + (torch.rand( + (TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)), + (torch.rand( + (1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)), + (torch.rand( + (2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)), + ]) +def test_input_mapper_valid_mm_data(qwen_vl_context: InputContext, + img_data: Union[torch.Tensor, List[Image], + Image], + expected_shape: List[int]): + """Happy cases for image inputs to Qwen's multimodal input mapper.""" + mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data) + # Ensure that we get the appropriately shaped pixel_values + # for images and image embeddings, respectively. + assert isinstance(mapped_img_data, MultiModalInputs) + assert "pixel_values" in mapped_img_data + assert mapped_img_data["pixel_values"].shape == expected_shape + + +# Sad path tests for the multimodal input processor and mapper, respectively +@pytest.mark.parametrize("mm_data", [ + { + "image": torch.rand((5)) + }, + { + "image": torch.rand((5, 5, 5, 5, 5)) + }, +]) +def test_input_processor_invalid_mm_data(qwen_vl_context: InputContext, + mm_data: Dict[str, torch.Tensor]): + """Test sad cases validated in Qwen's multimodal input processor.""" + tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer, + trust_remote_code=True) + prompt = "Picture 1: \n" + prompt_token_ids = tokenizer.encode(prompt) + inputs = LLMInputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) + # Should fail since we have too many or too few dimensions for embeddings + with pytest.raises(ValueError): + input_processor_for_qwen(qwen_vl_context, inputs) + + +@pytest.mark.parametrize( + "img_data", + [ + # Wrong context length + torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)), + # Wrong visual encoder output size + torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)), + ]) +def test_input_mapper_invalid_mm_data( + qwen_vl_context: InputContext, + img_data: Union[torch.Tensor, List[Image], Image], +): + """Sad cases validated in Qwen VL's multimodal input mapper.""" + with pytest.raises(ValueError): + input_mapper_for_qwen(qwen_vl_context, img_data) + + +### End-to-end generation tests def run_test( tmp_path: pathlib.PosixPath, hf_runner: Type[HfRunner], From c8de1f88540f2766a4b79749ea7b59f7aabad642 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 9 Sep 2024 05:56:41 -0400 Subject: [PATCH 3/7] Add multi-image example for qwenvl chat Signed-off-by: Alex-Brooks --- ...e_inference_vision_language_multi_image.py | 70 ++++++++++++++----- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index ed7e886d57806..a267af3e819f7 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -19,7 +19,39 @@ ] -def load_phi3v(question, image_urls: List[str]): +def load_qwenvl_chat(question: str, image_urls: List[str]): + model_name = "Qwen/Qwen-VL-Chat" + llm = LLM( + model=model_name, + trust_remote_code=True, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + placeholders = "".join(f"Picture {i}: \n" + for i, _ in enumerate(image_urls, start=1)) + + # This model does not have a chat_template attribute on its tokenizer, + # so we need to explicitly pass it. We use ChatML since it's used in the + # generation utils of the model: + # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265 + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + + # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating + chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501 + + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True, + chat_template=chat_template) + + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return llm, prompt, stop_token_ids, None, chat_template + + +def load_phi3v(question: str, image_urls: List[str]): llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, @@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]): for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" stop_token_ids = None - return llm, prompt, stop_token_ids, None + return llm, prompt, stop_token_ids, None, None -def load_internvl(question, image_urls: List[str]): +def load_internvl(question: str, image_urls: List[str]): model_name = "OpenGVLab/InternVL2-2B" llm = LLM( @@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]): stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids, None + return llm, prompt, stop_token_ids, None, None def load_qwen2_vl(question, image_urls: List[str]): @@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]): else: image_data, _ = process_vision_info(messages) - return llm, prompt, stop_token_ids, image_data + return llm, prompt, stop_token_ids, image_data, None model_example_map = { "phi3_v": load_phi3v, "internvl_chat": load_internvl, "qwen2_vl": load_qwen2_vl, + "qwen_vl_chat": load_qwenvl_chat, } def run_generate(model, question: str, image_urls: List[str]): - llm, prompt, stop_token_ids, image_data = model_example_map[model]( + llm, prompt, stop_token_ids, image_data, _ = model_example_map[model]( question, image_urls) if image_data is None: image_data = [fetch_image(url) for url in image_urls] @@ -146,7 +179,8 @@ def run_generate(model, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]): - llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls) + llm, _, stop_token_ids, _, chat_template = model_example_map[model]( + question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, @@ -155,20 +189,18 @@ def run_chat(model: str, question: str, image_urls: List[str]): outputs = llm.chat([{ "role": "user", - "content": [ - { - "type": "text", - "text": question, + "content": [{ + "type": "text", + "text": question, + }, *({ + "type": "image_url", + "image_url": { + "url": image_url }, - *({ - "type": "image_url", - "image_url": { - "url": image_url - }, - } for image_url in image_urls), - ], + } for image_url in image_urls)], }], - sampling_params=sampling_params) + sampling_params=sampling_params, + chat_template=chat_template) for o in outputs: generated_text = o.outputs[0].text From 49dfb55405f829dca53354b85f2231e5e0911c38 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 9 Sep 2024 07:12:46 -0400 Subject: [PATCH 4/7] Fix formatting Signed-off-by: Alex-Brooks --- ...e_inference_vision_language_multi_image.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index a267af3e819f7..454872c628373 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -185,22 +185,26 @@ def run_chat(model: str, question: str, image_urls: List[str]): sampling_params = SamplingParams(temperature=0.0, max_tokens=128, stop_token_ids=stop_token_ids) - - outputs = llm.chat([{ - "role": - "user", - "content": [{ - "type": "text", - "text": question, - }, *({ - "type": "image_url", - "image_url": { - "url": image_url - }, - } for image_url in image_urls)], - }], - sampling_params=sampling_params, - chat_template=chat_template) + outputs = llm.chat( + [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": question, + }, + *({ + "type": "image_url", + "image_url": { + "url": image_url + }, + } for image_url in image_urls), + ], + }], + sampling_params=sampling_params, + chat_template=chat_template, + ) for o in outputs: generated_text = o.outputs[0].text From a3607c68f9eed2a1ddffef48896109bbbe4cb207 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 10 Sep 2024 04:59:46 -0400 Subject: [PATCH 5/7] Refactor qwen e2e generation test into resusable parts Signed-off-by: Alex-Brooks --- tests/models/test_qwen.py | 89 +++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 31 deletions(-) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index c4336b9e157e1..298b1eac16981 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -1,5 +1,5 @@ import pathlib -from typing import Dict, List, Optional, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import pytest import torch @@ -12,7 +12,8 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, + VllmRunner, _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -30,6 +31,8 @@ "Picture 1: \nWhat is the season?: ", }) +HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: \nPicture 2: \nDescribe the two images in detail.\n" # noqa: E501 + ### Multimodal preprocessing tests SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image # These values are specific to Qwen-VL/Chat; we can get these from the model @@ -171,14 +174,40 @@ def test_input_mapper_invalid_mm_data( ### End-to-end generation tests +def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str, + assets: List[ImageAsset]) -> str: + """Given a temporary dir path, export one or more image assets into the + tempdir & replace its contents with the local path to the string so that + the HF version of Qwen-VL can resolve the path and load the image ni its + forward() call. + + Args: + tmp_path: Tempdir for test under consideration. + prompt: Prompt with image placeholders. + assets: List of image assets whose len equals the num placeholders. + """ + # Ensure that the number of placeholders matches the number of assets; + # If this is not true, the test is probably written incorrectly. + assert prompt.count("") == len(assets) + + # Replace the placeholders with local paths to the exported assets + for asset in assets: + image_tmp_path = tmp_path / f"{asset.name}.jpg" + asset.pil_image.save(image_tmp_path) + prompt = prompt.replace( + "", + f"{image_tmp_path}", + 1, + ) + return prompt + + def run_test( - tmp_path: pathlib.PosixPath, hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, + inputs: List[Tuple[List[str], PromptImageInput]], model: str, *, - size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, @@ -194,23 +223,6 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - images = [asset.pil_image for asset in image_assets] - - # Export the images to a tempdir and substitute it into the hf prompt; - # the contents between / will be ignored by VLLM, but the - # transformers implementation for the visual transformer parses this to - # reload it in the forward call; the contents are treated as a URL or a - # local path. - for idx, asset in enumerate(image_assets): - image_tmp_path = tmp_path / f"{asset.name}.jpg" - asset.pil_image.save(image_tmp_path) - HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace( - "", f"{image_tmp_path}") - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -231,7 +243,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype) as hf_model: @@ -240,7 +252,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -271,16 +283,31 @@ def run_test( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [8]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, - model, size_factors, dtype, max_tokens, - num_logprobs) -> None: +def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, model: str, + size_factors: List[float], dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """Tests multimodal models with single image prompts.""" + images = [asset.pil_image for asset in image_assets] + + prompts = [ + get_prompt_with_path(tmp_path, prompt, [asset]) + for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, prompts)] + run_test( - tmp_path, hf_runner, vllm_runner, - image_assets, + inputs_per_image, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, @@ -296,7 +323,7 @@ def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("num_logprobs", [5]) def test_text_only_qwen_model_can_be_loaded_and_run( vllm_runner: Type[VllmRunner], - example_prompts, + example_prompts: List[str], model: str, *, dtype: str, From 7f1f69bb2d679790fbcd06a0e83853b9dd28764e Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 10 Sep 2024 16:50:50 -0400 Subject: [PATCH 6/7] Add qwen multi-image test Signed-off-by: Alex-Brooks --- tests/models/test_qwen.py | 61 +++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 298b1eac16981..dfcf1b0c340e5 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -31,8 +31,8 @@ "Picture 1: \nWhat is the season?: ", }) +HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: \nPicture 2: \nCan you compare these images?\n" # noqa: E501 HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: \nPicture 2: \nDescribe the two images in detail.\n" # noqa: E501 - ### Multimodal preprocessing tests SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image # These values are specific to Qwen-VL/Chat; we can get these from the model @@ -175,7 +175,7 @@ def test_input_mapper_invalid_mm_data( ### End-to-end generation tests def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str, - assets: List[ImageAsset]) -> str: + assets: Union[_ImageAssets, List[ImageAsset]]) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that the HF version of Qwen-VL can resolve the path and load the image ni its @@ -211,6 +211,7 @@ def run_test( dtype: str, max_tokens: int, num_logprobs: int, + mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -230,11 +231,12 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size - # Qwen encodes images into a fixed content size of 256 + # Qwen encodes each image into a fixed content size of 256 with vllm_runner(model, - max_model_len=300, + max_model_len=1024, max_num_seqs=1, dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: @@ -298,7 +300,7 @@ def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath, for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] - inputs_per_image = [( + inputs = [( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], ) for image, prompt in zip(images, prompts)] @@ -306,11 +308,58 @@ def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath, run_test( hf_runner, vllm_runner, - inputs_per_image, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", multimodal_models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multimodal_models_multi_image(tmp_path: pathlib.PosixPath, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, model: str, + size_factors: List[float], dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """Tests multimodal models with multi-image prompts.""" + images = [asset.pil_image for asset in image_assets] + # Put all of the images into one prompt. + prompt = get_prompt_with_path(tmp_path, HF_MULTIIMAGE_IMAGE_PROMPT, + image_assets) + inputs = [([prompt for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors])] + + run_test( + hf_runner, + vllm_runner, + inputs, model, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=2, tensor_parallel_size=1, ) From bbb9761bfde541d911b2c0cfb79d6f069f18f1df Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 12 Sep 2024 10:19:55 +0800 Subject: [PATCH 7/7] Avoid CUDA re-initialization error --- tests/models/test_qwen.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index dfcf1b0c340e5..5e7f1de99d6c3 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -7,8 +7,6 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs -from vllm.model_executor.models.qwen import (input_mapper_for_qwen, - input_processor_for_qwen) from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size @@ -72,6 +70,20 @@ def build_model_context(model_name: str, return InputContext(model_config) +@pytest.fixture() +def input_mapper_for_qwen(): + # Lazy import to avoid initializing CUDA during test collection + from vllm.model_executor.models.qwen import input_mapper_for_qwen + return input_mapper_for_qwen + + +@pytest.fixture() +def input_processor_for_qwen(): + # Lazy import to avoid initializing CUDA during test collection + from vllm.model_executor.models.qwen import input_processor_for_qwen + return input_processor_for_qwen + + @pytest.fixture() def qwen_vl_context() -> InputContext: """Get an InputContext for Qwen-VL.""" @@ -82,7 +94,8 @@ def qwen_vl_context() -> InputContext: # Happy path tests for single/multi-image scenarios for the multimodal # input processor and mapper, respectively @pytest.mark.parametrize("num_images", [1, 2]) -def test_input_processor_valid_mm_data(qwen_vl_context: InputContext, +def test_input_processor_valid_mm_data(input_processor_for_qwen, + qwen_vl_context: InputContext, num_images: int): """Happy cases for image inputs to Qwen's multimodal input processor.""" prompt = "".join( @@ -119,7 +132,8 @@ def test_input_processor_valid_mm_data(qwen_vl_context: InputContext, (torch.rand( (2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)), ]) -def test_input_mapper_valid_mm_data(qwen_vl_context: InputContext, +def test_input_mapper_valid_mm_data(input_mapper_for_qwen, + qwen_vl_context: InputContext, img_data: Union[torch.Tensor, List[Image], Image], expected_shape: List[int]): @@ -141,7 +155,8 @@ def test_input_mapper_valid_mm_data(qwen_vl_context: InputContext, "image": torch.rand((5, 5, 5, 5, 5)) }, ]) -def test_input_processor_invalid_mm_data(qwen_vl_context: InputContext, +def test_input_processor_invalid_mm_data(input_processor_for_qwen, + qwen_vl_context: InputContext, mm_data: Dict[str, torch.Tensor]): """Test sad cases validated in Qwen's multimodal input processor.""" tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer, @@ -165,6 +180,7 @@ def test_input_processor_invalid_mm_data(qwen_vl_context: InputContext, torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)), ]) def test_input_mapper_invalid_mm_data( + input_mapper_for_qwen, qwen_vl_context: InputContext, img_data: Union[torch.Tensor, List[Image], Image], ):