From b18e3ee4a55a9ab7f5a455b024e600e77a39c9ff Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:24:59 +0800 Subject: [PATCH] [Model] Refactoring of MiniCPM-V and add MiniCPM-o-2.6 support for vLLM (#12069) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: hzh Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: shaochangxu.scx Signed-off-by: DarkLight1337 Signed-off-by: NickLucche Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Roger Wang Signed-off-by: Rafael Vasquez Signed-off-by: Akshat Tripathi Signed-off-by: Oleg Mosalov Signed-off-by: Jee Jee Li Signed-off-by: rshaw@neuralmagic.com Signed-off-by: Yida Wu Signed-off-by: Chenguang Li <757486878@qq.com> Signed-off-by: youkaichao Signed-off-by: Alex-Brooks Signed-off-by: Chen Zhang Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Shanshan Shen <467638484@qq.com> Signed-off-by: elijah Signed-off-by: Yikun Signed-off-by: mgoin Signed-off-by: Woosuk Kwon Signed-off-by: Konrad Zawora Signed-off-by: tjtanaa Signed-off-by: wangxiyuan Signed-off-by: Rui Qiao Co-authored-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Co-authored-by: shaochangxu <85155497+shaochangxu@users.noreply.github.com> Co-authored-by: shaochangxu.scx Co-authored-by: Cyrus Leung Co-authored-by: Nicolò Lucchesi Co-authored-by: sixgod Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Rafael Vasquez Co-authored-by: Isotr0py Co-authored-by: Cyrus Leung Co-authored-by: Akshat Tripathi Co-authored-by: Oleg Mosalov Co-authored-by: Jee Jee Li Co-authored-by: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Yangcheng Li Co-authored-by: Siyuan Li <94890248+liaoyanqing666@users.noreply.github.com> Co-authored-by: Concurrensee Co-authored-by: Chenguang Li <757486878@qq.com> Co-authored-by: youkaichao Co-authored-by: Alex Brooks Co-authored-by: Chen Zhang Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: elijah <30852919+e1ijah1@users.noreply.github.com> Co-authored-by: Yikun Jiang Co-authored-by: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Co-authored-by: mgoin Co-authored-by: Woosuk Kwon Co-authored-by: Konrad Zawora Co-authored-by: TJian Co-authored-by: tjtanaa Co-authored-by: wangxiyuan Co-authored-by: maang-h <55082429+maang-h@users.noreply.github.com> Co-authored-by: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Roger Wang Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/source/models/supported_models.md | 9 +- examples/offline_inference/audio_language.py | 32 +- examples/offline_inference/vision_language.py | 33 +- requirements-cpu.txt | 1 + requirements-cuda.txt | 1 + requirements-test.in | 3 + requirements-test.txt | 37 +- .../vision_language/test_models.py | 14 + .../vision_language/vlm_utils/model_utils.py | 11 + .../multimodal/processing/test_common.py | 2 + tests/models/registry.py | 4 +- vllm/entrypoints/chat_utils.py | 6 +- vllm/model_executor/models/minicpmo.py | 811 +++++++++++++++++ vllm/model_executor/models/minicpmv.py | 843 ++++++++++++++---- vllm/model_executor/models/registry.py | 1 + 15 files changed, 1622 insertions(+), 186 deletions(-) create mode 100644 vllm/model_executor/models/minicpmo.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 94f4bd6cadabd..afaad8818bdcb 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -693,9 +693,16 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `MiniCPMO` + * MiniCPM-O + * T + IE+ + VE+ + AE+ + * `openbmb/MiniCPM-o-2_6`, etc. + * ✅︎ + * ✅︎ + * - * `MiniCPMV` * MiniCPM-V - * T + IE+ + * T + IE+ + VE+ * `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. * ✅︎ * ✅︎ diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 6fd74782a9aae..5952ec13ec3cb 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -67,7 +67,37 @@ def run_qwen2_audio(question: str, audio_count: int): return llm, prompt, stop_token_ids -model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio} +def run_minicpmo(question: str, audio_count: int): + model_name = "openbmb/MiniCPM-o-2_6" + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + llm = LLM(model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}) + + stop_tokens = ['<|im_end|>', '<|endoftext|>'] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + + audio_placeholder = "()" * audio_count + audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501 + messages = [{ + 'role': 'user', + 'content': f'{audio_placeholder}\n{question}' + }] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True, + chat_template=audio_chat_template) + return llm, prompt, stop_token_ids + + +model_example_map = { + "ultravox": run_ultravox, + "qwen2_audio": run_qwen2_audio, + "minicpmo": run_minicpmo +} def main(args): diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 415439e88ed59..38c2b13d3f2c7 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -265,8 +265,9 @@ def run_mantis(question: str, modality: str): # MiniCPM-V -def run_minicpmv(question: str, modality: str): - assert modality == "image" +def run_minicpmv_base(question: str, modality: str, model_name): + assert modality in ["image", "video"] + # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa # 2.0 # The official repo doesn't work yet, so we need to use a fork for now @@ -277,7 +278,15 @@ def run_minicpmv(question: str, modality: str): # model_name = "openbmb/MiniCPM-Llama3-V-2_5" # 2.6 - model_name = "openbmb/MiniCPM-V-2_6" + # model_name = "openbmb/MiniCPM-V-2_6" + # o2.6 + + # modality supports + # 2.0: image + # 2.5: image + # 2.6: image, video + # o2.6: image, video, audio + # model_name = "openbmb/MiniCPM-o-2_6" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) llm = LLM( @@ -294,13 +303,18 @@ def run_minicpmv(question: str, modality: str): # 2.5 # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] - # 2.6 + # 2.6 / o2.6 stop_tokens = ['<|im_end|>', '<|endoftext|>'] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + modality_placeholder = { + "image": "(./)", + "video": "()", + } + messages = [{ 'role': 'user', - 'content': f'(./)\n{question}' + 'content': f'{modality_placeholder[modality]}\n{question}' }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, @@ -308,6 +322,14 @@ def run_minicpmv(question: str, modality: str): return llm, prompt, stop_token_ids +def run_minicpmo(question: str, modality: str): + return run_minicpmv_base(question, modality, "openbmb/MiniCPM-o-2_6") + + +def run_minicpmv(question: str, modality: str): + return run_minicpmv_base(question, modality, "openbmb/MiniCPM-V-2_6") + + # LLama 3.2 def run_mllama(question: str, modality: str): assert modality == "image" @@ -523,6 +545,7 @@ def run_qwen2_vl(question: str, modality: str): "llava-next-video": run_llava_next_video, "llava-onevision": run_llava_onevision, "mantis": run_mantis, + "minicpmo": run_minicpmo, "minicpmv": run_minicpmv, "mllama": run_mllama, "molmo": run_molmo, diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 056fbf5a7adec..ed0d2c9fae0b6 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -4,5 +4,6 @@ # Dependencies for CPUs torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" torch==2.5.1; platform_machine == "aarch64" or platform_system == "Darwin" +torchaudio; platform_machine != "ppc64le" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch datasets # for benchmark scripts diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 8002fbd8ee5b9..78fa360f2dc96 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -5,6 +5,7 @@ ray[default] >= 2.9 nvidia-ml-py >= 12.560.30 # for pynvml package torch == 2.5.1 +torchaudio==2.5.1 # These must be updated alongside torch torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 diff --git a/requirements-test.in b/requirements-test.in index bc76a91ad5356..13ad17b256734 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -12,6 +12,8 @@ decord # required for video tests einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests +vector_quantize_pytorch # required for minicpmo_26 test +vocos # required for minicpmo_26 test peft pqdm ray[adag]==2.40.0 @@ -19,6 +21,7 @@ sentence-transformers # required for embedding tests soundfile # required for audio tests timm # required for internvl test torch==2.5.1 +torchaudio==2.5.1 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.0 # required for pixtral test diff --git a/requirements-test.txt b/requirements-test.txt index 09e009c2e21f4..df7e904bb0d34 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -106,9 +106,17 @@ dnspython==2.7.0 docutils==0.16 # via awscli einops==0.8.0 - # via -r requirements-test.in + # via + # -r requirements-test.in + # encodec + # vector-quantize-pytorch + # vocos +einx==0.3.0 + # via vector-quantize-pytorch email-validator==2.2.0 # via pydantic +encodec==0.1.1 + # via vocos evaluate==0.4.3 # via lm-eval fastparquet==2024.11.0 @@ -125,6 +133,8 @@ filelock==3.16.1 # triton fonttools==4.54.1 # via matplotlib +frozendict==2.4.6 + # via einx frozenlist==1.5.0 # via # aiohttp @@ -159,6 +169,7 @@ huggingface-hub==0.26.2 # timm # tokenizers # transformers + # vocos idna==3.10 # via # anyio @@ -261,6 +272,8 @@ numpy==1.26.4 # cupy-cuda12x # datasets # decord + # einx + # encodec # evaluate # fastparquet # genai-perf @@ -283,6 +296,7 @@ numpy==1.26.4 # torchvision # transformers # tritonclient + # vocos nvidia-cublas-cu12==12.4.5.8 # via # nvidia-cudnn-cu12 @@ -455,6 +469,7 @@ pyyaml==6.0.2 # responses # timm # transformers + # vocos ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 @@ -517,6 +532,7 @@ scipy==1.13.1 # scikit-learn # sentence-transformers # statsmodels + # vocos sentence-transformers==3.2.1 # via -r requirements-test.in sentencepiece==0.2.0 @@ -540,7 +556,9 @@ sqlitedict==2.1.0 statsmodels==0.14.4 # via genai-perf sympy==1.13.1 - # via torch + # via + # einx + # torch tabledata==1.3.3 # via pytablewriter tabulate==0.9.0 @@ -568,12 +586,21 @@ torch==2.5.1 # -r requirements-test.in # accelerate # bitsandbytes + # encodec # lm-eval # peft # sentence-transformers # tensorizer # timm + # torchaudio # torchvision + # vector-quantize-pytorch + # vocos +torchaudio==2.5.1 + # via + # -r requirements-test.in + # encodec + # vocos torchvision==0.20.1 # via timm tqdm==4.66.6 @@ -584,6 +611,7 @@ tqdm==4.66.6 # lm-eval # nltk # peft + # pqdm # sentence-transformers # tqdm-multiprocess # transformers @@ -615,6 +643,7 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # pqdm # pydantic # pydantic-core # torch @@ -626,6 +655,10 @@ urllib3==2.2.3 # requests # responses # tritonclient +vector-quantize-pytorch==1.21.2 + # via -r requirements-test.in +vocos==0.1.0 + # via -r requirements-test.in word2number==1.1 # via lm-eval xxhash==3.5.0 diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index d5f0d63288cc1..62c644f73d62d 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -350,6 +350,20 @@ postprocess_inputs=model_utils.wrap_inputs_post_processor, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), + "minicpmo_26": VLMTestInfo( + models=["openbmb/MiniCPM-o-2_6"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "(./)\n", + max_model_len=4096, + max_num_seqs=2, + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + postprocess_inputs=model_utils.ignore_inputs_post_processor( + "image_sizes" + ), + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + patch_hf_runner=model_utils.minicpmo_patch_hf_runner + ), "minicpmv_26": VLMTestInfo( models=["openbmb/MiniCPM-V-2_6"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 1ca85c7bb2056..07bdb2cee44d2 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -497,6 +497,17 @@ def _generate(self, *args, **kwargs): return hf_model +def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + orig_generate = hf_model.model.generate + + def _generate(self, *args, **kwargs): + return orig_generate(*args, decode_text=False, **kwargs) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model + + def _generate_greedy_logprobs_limit( self, prompts: List[str], diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index b575ec6acbef3..ca28da268fa05 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -152,6 +152,8 @@ def _test_processing_correctness( "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "TIGER-Lab/Mantis-8B-siglip-llama3", "mistral-community/pixtral-12b", + "openbmb/MiniCPM-o-2_6", + "openbmb/MiniCPM-V-2_6", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", diff --git a/tests/models/registry.py b/tests/models/registry.py index 0bd06dea0ec7f..7952e65aa76a5 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -245,7 +245,9 @@ def check_available_online( "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501 hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 - "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", + "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", + trust_remote_code=True), + "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", trust_remote_code=True), diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 723d6e9085806..97d2561df602a 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -392,7 +392,7 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return f"<|image_{current_count}|>" - if model_type == "minicpmv": + if model_type in ("minicpmo", "minicpmv"): return "(./)" if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", "pixtral"): @@ -424,10 +424,14 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "qwen2_audio": return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") + if model_type == "minicpmo": + return "()" raise TypeError(f"Unknown model type: {model_type}") elif modality == "video": if model_type == "qwen2_vl": return "<|vision_start|><|video_pad|><|vision_end|>" + if model_type in ("minicpmo", "minicpmv"): + return "()" if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.video_token_index) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py new file mode 100644 index 0000000000000..eb4282d62005a --- /dev/null +++ b/vllm/model_executor/models/minicpmo.py @@ -0,0 +1,811 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" +from functools import partial +from itertools import accumulate +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, TypedDict, Union) + +import torch +import torch.types +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.whisper.modeling_whisper import ( + ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.parse import (ModalityData, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser, + VideoItem) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PromptReplacement) +from vllm.multimodal.profiling import ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, + MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser, + MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo) +from .utils import AutoWeightsLoader, maybe_prefix + +CPU_DEVICE = torch.device("cpu") + +MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems + + +class MiniCPMOAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: torch.Tensor + """ + Shape: `(batch_size * num_audios * num_slices, num_channels, length)` + Slice here means chunk. Audio that is too long will be split into slices, + which is the same as image. + Padding is used therefore `data` is `torch.Tensor`. + """ + + audio_feature_lens: torch.Tensor + """ + Shape: `(batch_size * num_audios * num_slices)` + + This should be feature length of each audio slice, + which equals to `data.shape[-1]` + """ + + audio_bounds: torch.Tensor + """ + Shape: `(batch_size * num_audios * num_slices, 2)` + + This should be in `(start, stop)` format. + """ + + +class MiniCPMOAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: List[torch.Tensor] + """ + Shape: `(batch_size * num_images * num_slices, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + Length of each slice may vary, so pass it as a list. + """ + audio_bounds: torch.Tensor + """ + Shape: `(batch_size * num_audios * num_slices, 2)` + + This should be in `(start, stop)` format. + """ + + +MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, + MiniCPMOAudioEmbeddingInputs] + + +class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems): + + def __init__(self, data: Dict) -> None: + super().__init__(data, "audio") + audio_embeds = self.data.get("audio_embeds", None) + if audio_embeds is None: + raise ValueError("Incorrect type of video_embeds", + "Got type: None") + self.data["audio_embeds"] = audio_embeds + + def get(self, index: int) -> object: + return self.data["audio_embeds"][index] + + +class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMOAudioEmbeddingItems(data) + return super()._parse_audio_data(data) + + +class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): + audio_pattern = "()" + + def get_supported_mm_modalities(self) -> List[str]: + return ["image", "video", "audio"] + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None, "audio": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "audio": self.get_max_audio_tokens(), + "video": self.get_max_video_tokens(seq_len) + } + + def get_default_audio_pool_step(self) -> int: + return 2 + + def get_default_audio_sampling_rate(self) -> int: + return 16000 + + def get_chunk_length(self) -> int: + return self.get_hf_config().audio_chunk_length + + def get_max_audio_tokens_per_chunk(self) -> int: + pool_step = self.get_default_audio_pool_step() + fbank_feat_in_chunk = 100 + cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 + num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 + return num_audio_tokens + 2 # + + def get_max_audio_chunks_with_most_features(self) -> int: + return 30 + + def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: + sampling_rate = self.get_default_audio_sampling_rate() + # exclude + num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2 + return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 + + def get_num_frames_with_most_features(self, seq_len: int) -> int: + mm_config = self.ctx.get_mm_config() + max_images = mm_config.limit_per_prompt.get("image", 1) + max_videos = mm_config.limit_per_prompt.get("video", 1) + max_audios = mm_config.limit_per_prompt.get("audio", 1) + + # count tokens + # which are not in get_max_image_tokens + max_image_tokens = self.get_max_image_tokens( + ) * max_images + 4 * max_images + max_audio_tokens = self.get_max_audio_tokens( + ) * max_audios + 2 * max_audios + max_total_frames = self.get_max_video_frames(seq_len - + max_image_tokens - + max_audio_tokens) + + num_frames = max(max_total_frames // max(max_videos, 1), 1) + + return num_frames + + +class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder): + + def get_dummy_processor_inputs( + self, seq_len: int, mm_counts: Mapping[str, + int]) -> ProcessorInputs: + num_audios = mm_counts.get("audio", 0) + audio_len = self.info.get_max_audio_chunks_with_most_features() * \ + self.info.get_default_audio_sampling_rate() + + processor_inputs = super().get_dummy_processor_inputs( + seq_len, mm_counts) + mm_data = { + "image": + processor_inputs.mm_data["image"], + "video": + processor_inputs.mm_data["video"], + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + audio_prompt_texts = self.info.audio_pattern * num_audios + + return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \ + audio_prompt_texts, + mm_data=mm_data) + + +class MiniCPMOMultiModalProcessor( + MiniCPMVMultiModalProcessor, + BaseMultiModalProcessor[MiniCPMOProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + return MiniCPMOMultiModalDataParser( + target_sr=self.info.get_default_audio_sampling_rate()) + + def get_audio_prompt_texts(self, + audio_lens: int, + chunk_input: bool = True, + chunk_length: int = 1) -> str: + return self.info.get_hf_processor().get_audio_placeholder( + audio_lens, chunk_input, chunk_length) + + def get_special_tokens(self) -> Dict[str, torch.Tensor]: + tokenizer = self.info.get_tokenizer() + special_tokens = super().get_special_tokens() + if hasattr(tokenizer, "audio_start_id"): + special_tokens["audio_start_id"] = torch.tensor( + tokenizer.audio_start_id) + special_tokens["audio_end_id"] = torch.tensor( + tokenizer.audio_end_id) + return special_tokens + + def process_audios(self, mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object]) -> Dict[str, object]: + audios = mm_data.pop("audios", []) + audio_embeds = mm_data.pop("audio_embeds", []) + if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0: + audio_outputs = { + "audio_lens": [], + "audio_features": [], + "audio_feature_lens": [], + "audio_num_segments": [] + } + for audio in audios: + single_audio_outputs = super().call_base_hf_processor( + prompt=self.info.audio_pattern, + mm_data={ + "audios": audio, + "chunk_input": True + }, + mm_kwargs=mm_kwargs) + audio_outputs["audio_lens"].append(len(audio)) + audio_outputs["audio_features"].append( + single_audio_outputs["audio_features"]) + audio_outputs["audio_num_segments"].append( + len(single_audio_outputs["audio_feature_lens"][0])) + audio_outputs["audio_feature_lens"] += \ + single_audio_outputs["audio_feature_lens"] + audio_outputs["audio_features"] = [ + audio_feature for single_audio_features in \ + audio_outputs["audio_features"] + for audio_feature in single_audio_features + ] + audio_outputs["audio_feature_lens"] = torch.cat( + audio_outputs["audio_feature_lens"]) + elif len(audio_embeds): + audio_outputs = { + "audio_lens": [ + self.info.get_audio_len_by_num_chunks( + sum(chunk_embeds.shape[0] + for chunk_embeds in single_audio_embeds)) + for single_audio_embeds in audio_embeds + ], + "audio_embeds": [ + chunk_embeds for single_audio_embeds in audio_embeds + for chunk_embeds in single_audio_embeds + ], + "audio_num_segments": [ + len(single_audio_embeds) + for single_audio_embeds in audio_embeds + ] + } + else: + audio_outputs = {} + return audio_outputs + + def get_placeholder_match_pattern(self) -> str: + return r"\(<(image|video|audio)>./\)" + + def get_placeholder_split_pattern(self) -> str: + return r"\(<(?:image|video|audio)>./\)" + + def process_mm_inputs(self, mm_data, mm_kwargs) -> object: + return { + "image": self.process_images(mm_data, mm_kwargs), + "video": self.process_videos(mm_data, mm_kwargs), + "audio": self.process_audios(mm_data, mm_kwargs) + } + + def get_modality_num_counter(self, modality: str) -> str: + if modality == "audio": + return "audio_lens" + return super().get_modality_num_counter(modality) + + def get_num_slices_by_modality(self, inputs: Dict[str, object], + modality: str, index: int) -> int: + if modality == "audio": + return inputs["audio"]["audio_num_segments"][index] + return super().get_num_slices_by_modality(inputs, modality, index) + + def get_prompt_texts_by_modality(self, inputs: Dict[str, object], + modality: str, index: int) -> str: + if modality == "audio": + return self.get_audio_prompt_texts( + inputs["audio"]["audio_lens"][index]) + return super().get_prompt_texts_by_modality(inputs, modality, index) + + def _get_prompt_replacements( + self, mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: + placeholder = { + "image": self.info.image_pattern, + "video": self.info.video_pattern, + "audio": self.info.audio_pattern + } + + def get_replacement_minicpmv(item_idx: int, modality: str): + if modality == "image": + return self.get_image_prompt_texts( + mm_items["image"].get_image_size(item_idx), item_idx) + elif modality == "video": + return self.get_video_prompt_texts( + mm_items["video"].get_frame_size(item_idx), + mm_items["video"].get_num_frames(item_idx)) + else: # audio + if isinstance(mm_items["audio"], MiniCPMOAudioEmbeddingItems): + single_audio_embeds = mm_items["audio"].get(item_idx) + audio_len = self.info.get_audio_len_by_num_chunks( + sum(chunk_embeds.shape[0] + for chunk_embeds in single_audio_embeds)) + return self.get_audio_prompt_texts(audio_len) + return self.get_audio_prompt_texts( + len(mm_items["audio"].get(item_idx))) + + return [ + PromptReplacement(modality=modality, + target=placeholder[modality], + replacement=partial(get_replacement_minicpmv, + modality=modality)) + for modality in ("image", "video", "audio") + ] + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + def get_slices(num_slices: List[int]) -> List[int]: + slice_indices = [0] + list(accumulate(num_slices)) + slices = [(slice_indices[i], slice_indices[i + 1]) + for i in range(len(num_slices))] + return [slice(*slice_item) for slice_item in slices] + + audio_slices = get_slices( + hf_inputs.get("audio_num_slices", torch.empty(0))) + return dict( + **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), + audio_features=MultiModalFieldConfig.flat("audio", audio_slices), + audio_feature_lens=MultiModalFieldConfig.flat( + "audio", audio_slices), + audio_num_slices=MultiModalFieldConfig.batched("audio"), + audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices)) + + +class MultiModalProjector(nn.Module): + + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.linear1 = nn.Linear(in_features=in_dim, + out_features=out_dim, + bias=True) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(in_features=out_dim, + out_features=out_dim, + bias=True) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.relu(self.linear1(audio_features)) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +class MiniCPMWhisperEncoderLayer(nn.Module): + + def __init__(self, config: WhisperConfig, layer_idx: int = None): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WHISPER_ATTENTION_CLASSES[ + config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + past_key_values = None + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, past_key_values = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + ) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + outputs = (hidden_states, ) + + return outputs + + +class MiniCPMWhisperEncoder(WhisperEncoder): + + def __init__(self, config: WhisperConfig): + super().__init__(config) + self.layers = nn.ModuleList([ + MiniCPMWhisperEncoderLayer(config, layer_idx=i) + for i in range(config.encoder_layers) + ]) + + def forward( + self, + input_features: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> BaseModelOutputWithPast: + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, + device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + + embed_pos = self.embed_positions.weight + + embed_pos = embed_pos[:inputs_embeds.shape[1], :] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + encoder_states = () + + for idx, encoder_layer in enumerate(self.layers): + encoder_states = encoder_states + (hidden_states, ) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.layer_norm(hidden_states) + encoder_states = encoder_states + (hidden_states, ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + MiniCPMOMultiModalProcessor, + info=MiniCPMOProcessingInfo, + dummy_inputs=MiniCPMODummyInputsBuilder) +class MiniCPMO(MiniCPMV2_6): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.apm = self.init_audio_module(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "apm")) + + def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Do not use parameters temporarily + audio_config = self.config.audio_config + model = MiniCPMWhisperEncoder(audio_config) + audio_output_dim = int(audio_config.encoder_ffn_dim // 4) + self.audio_avg_pooler = \ + nn.AvgPool1d(self.config.audio_pool_step, + stride=self.config.audio_pool_step) + self.audio_projection_layer = \ + MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim) + self.audio_encoder_layer = -1 + return model + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) + return loader.load_weights(weights) + + def subsequent_chunk_mask( + self, + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = CPU_DEVICE, + num_lookhead: int = 0, + ) -> torch.Tensor: + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, + 0) + ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, + size) + ret[i, start:ending] = True + return ret + + def _get_feat_extract_output_lengths(self, + input_lengths: torch.LongTensor): + input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 + input_lengths_after_pooling = ( + input_lengths_after_cnn - + self.config.audio_pool_step) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = input_lengths_after_pooling.to( + dtype=torch.int32) + + return input_lengths_after_cnn, input_lengths_after_pooling + + # Copied from HF repo of MiniCPM-o-2_6, + # designed for batched inputs and outputs + def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, + chunk_length: int) -> torch.Tensor: + wavforms = data.get( + "data", + []) # (bs, 80, frames) or [], multi audios need filled in advance + audio_feature_lens_raw = [data.get("audio_feature_lens", + [])] # list, [[x1, x2], [y1], [z1]] + + # exist audio + if len(wavforms) > 0: + audio_feature_lens = torch.hstack(audio_feature_lens_raw) + batch_size, _, max_mel_seq_len = wavforms.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = (torch.arange( + 0, + max_seq_len, + dtype=audio_feature_lens.dtype, + device=audio_feature_lens.device).unsqueeze(0).expand( + batch_size, max_seq_len)) + lengths_expand = audio_feature_lens.unsqueeze(1).expand( + batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand # 1 for padded values + + audio_attention_mask_ = padding_mask.view( + batch_size, 1, 1, max_seq_len).expand(batch_size, 1, + max_seq_len, max_seq_len) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.apm.conv1.weight.dtype, + device=self.apm.conv1.weight.device) + + if chunk_length > 0: + chunk_num_frame = int(chunk_length * 50) + chunk_mask = self.subsequent_chunk_mask( + size=max_seq_len, + chunk_size=chunk_num_frame, + num_left_chunks=-1, + device=audio_attention_mask_.device, + ) + audio_attention_mask_ = torch.logical_or( + audio_attention_mask_, torch.logical_not(chunk_mask)) + + audio_attention_mask[audio_attention_mask_] = float("-inf") + audio_states = self.apm( + wavforms, attention_mask=audio_attention_mask).hidden_states[ + self.audio_encoder_layer] + audio_embeds = self.audio_projection_layer(audio_states) + + audio_embeds = audio_embeds.transpose(1, 2) + audio_embeds = self.audio_avg_pooler(audio_embeds) + audio_embeds = audio_embeds.transpose(1, 2) + + _, feature_lens_after_pooling = \ + self._get_feat_extract_output_lengths(audio_feature_lens) + + num_audio_tokens = feature_lens_after_pooling + + final_audio_embeds = [] + idx = 0 + for i in range(len(audio_feature_lens_raw)): + target_audio_embeds = [] + for _ in range(len(audio_feature_lens_raw[i])): + target_audio_embeds.append( + audio_embeds[idx, :num_audio_tokens[idx], :]) + idx += 1 + final_audio_embeds.append(target_audio_embeds) + return final_audio_embeds + else: + return [] + + def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, + audio_inputs: Optional[MiniCPMOAudioInputs], + chunk_length: int) -> torch.Tensor: + device, dtype = vlm_embedding.device, vlm_embedding.dtype + if audio_inputs["type"] == "audio_embeds": + audio_embeddings = audio_inputs["data"] + audio_embeddings = [ + audio_embeddings[i].to(device=device, dtype=dtype) + for i in range(len(audio_embeddings)) + ] + else: + audio_embeddings = self.get_audio_hidden_states( + audio_inputs, chunk_length)[0] + if audio_embeddings is None or len(audio_embeddings) == 0: + return vlm_embedding + audio_bounds = audio_inputs["audio_bounds"] + if self.config.chunk_input: + audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device, + dtype=dtype) + audio_start_pos = 0 + for bound in audio_bounds: + audio_len = bound[1] - bound[0] + vlm_embedding[bound[0]:bound[1]] = audio_embs[ + audio_start_pos:audio_start_pos + audio_len, :] + audio_start_pos += audio_len + else: + for embs, bound in zip(audio_embeddings, audio_bounds): + audio_indices = torch.arange(bound[0], + bound[1], + dtype=torch.long).to(device) + + if embs.shape[0] != len(audio_indices): + raise ValueError( + "Shape mismatch: Trying to assign embeddings " + f"of shape {embs.shape} " + f"to input indices of length {len(audio_indices)}") + vlm_embedding[audio_indices] = embs.to(dtype) + return vlm_embedding + + def _get_audio_bounds(self, input_ids: torch.Tensor, + audio_start_id: torch.Tensor, + audio_end_id: torch.Tensor) -> torch.Tensor: + audio_start_tokens, = torch.where(input_ids == audio_start_id[0]) + audio_start_tokens += 1 + audio_end_tokens, = torch.where(input_ids == audio_end_id[0]) + valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens)) + return torch.hstack([ + audio_start_tokens[:valid_audio_nums].unsqueeze(-1), + audio_end_tokens[:valid_audio_nums].unsqueeze(-1) + ]) + + def _parse_and_validate_audio_inputs( + self, input_ids: torch.Tensor, + **kwargs: object) -> Tuple[MiniCPMOAudioInputs]: + audio_features = kwargs.pop("audio_features", []) + audio_feature_lens = kwargs.pop("audio_feature_lens", []) + audio_embeds = kwargs.pop("audio_embeds", None) + audio_start_id = kwargs.pop("audio_start_id", None) + audio_end_id = kwargs.pop("audio_end_id", None) + if audio_embeds is not None: + audio_embeds = [ + audio_embeds[i][j] for i in range(len(audio_embeds)) + for j in range(len(audio_embeds[i])) + ] + return MiniCPMOAudioEmbeddingInputs( + audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, + audio_end_id), + data=audio_embeds, + type="audio_embeds") + if len(audio_features) > 0: + audio_features_all = [ + i.permute(1, 0) for audio_feature in audio_features + for i in audio_feature + ] + audio_features = torch.nn.utils.rnn.pad_sequence( + audio_features_all, batch_first=True, + padding_value=0.0).permute(0, 2, 1) + audio_feature_lens = torch.cat( + [item for item in audio_feature_lens]) + + return MiniCPMOAudioFeatureInputs( + audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, + audio_end_id), + data=audio_features, + audio_feature_lens=audio_feature_lens, + type="audio_features") + return None + + def _parse_and_validate_inputs(self, input_ids: torch.Tensor, + **kwargs: object): + image_inputs = self._parse_and_validate_image_inputs( + input_ids, **kwargs) + if not any("audio" in key for key in kwargs): + return image_inputs, None + audio_inputs = self._parse_and_validate_audio_inputs( + input_ids, **kwargs) + return image_inputs, audio_inputs + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: Any, + ) -> torch.Tensor: + if intermediate_tensors is not None: + vlm_embeddings = None + else: + image_inputs, audio_inputs = \ + self._parse_and_validate_inputs(input_ids, **kwargs) + vlm_embeddings, _ = self.get_embedding_with_vision( + input_ids, image_inputs) + + if audio_inputs is not None: + vlm_embeddings = self.get_embedding_with_audios( + vlm_embeddings, audio_inputs, + self.config.audio_chunk_length) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None + + output = self.llm.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=vlm_embeddings, + ) + return output diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1aa529056893b..bf967d33a3176 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -22,21 +22,21 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re +from collections import Counter from functools import cached_property, partial -from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Set, Tuple, TypedDict, Union) +from itertools import accumulate +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Set, Tuple, TypedDict, Union) +import numpy as np import torch import torch.types from PIL import Image from torch import nn -from transformers import PretrainedConfig -from typing_extensions import NotRequired +from transformers import BatchFeature, PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) @@ -48,33 +48,30 @@ from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, PlaceholderRange) +from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser, VideoItem) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import AutoWeightsLoader, maybe_prefix -RawImageType = Union[Image.Image, torch.Tensor] - - -class MiniCPMVRawImageInput(TypedDict): - """Input mapper input with auxiliary data for computing image bounds.""" - image: RawImageType +CPU_DEVICE = torch.device("cpu") - # Image bounds token ids in 0-dim scaler tensor. - im_start_id: torch.Tensor - im_end_id: torch.Tensor - slice_start_id: NotRequired[torch.Tensor] - slice_end_id: NotRequired[torch.Tensor] +RawImageType = Union[Image.Image, torch.Tensor] class MiniCPMVImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: List[torch.Tensor] """ - Shape: `(batch_size * num_images, num_channels, height, width)` + Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` Note that the image size may vary, so we pass it as a list instead of a batched tensor. @@ -82,14 +79,14 @@ class MiniCPMVImagePixelInputs(TypedDict): image_bounds: torch.Tensor """ - Shape: `(batch_size * num_images, 2)` + Shape: `(batch_size * num_images * num_slices, 2)` This should be in `(start, stop)` format. """ tgt_sizes: torch.Tensor """ - Shape: `(batch_size * num_images, 2)` + Shape: `(batch_size * num_images * num_slices, 2)` This should be in `(height, width)` format. """ @@ -99,7 +96,8 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor """ - Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + Shape: `(batch_size * num_images * num_slices, + image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. instead of a batched tensor. @@ -107,7 +105,7 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): image_bounds: torch.Tensor """ - Shape: `(batch_size * num_images, 2)` + Shape: `(batch_size * num_images * num_slices, 2)` This should be in `(start, stop)` format. """ @@ -116,6 +114,93 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] + +class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], + dict[str, torch.Tensor]]): + + def __init__(self, data: Dict, modality: str) -> None: + super().__init__(data, modality) + + def get_processor_data(self) -> Mapping[str, object]: + return self.data + + def get_passthrough_data(self) -> Mapping[str, object]: + return {} + + def get_count(self) -> int: + return len(self.data[f"{self.modality}_embeds"]) + + def get(self, index: int) -> Dict[str, torch.Tensor]: + out = {} + for k, v in self.data.items(): + out[k] = v[index] + return out + + +class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems): + + def __init__(self, data: Dict) -> None: + super().__init__(data, "image") + image_embeds = self.data.get("image_embeds", None) + image_sizes = self.data.get("image_sizes", None) + if image_embeds is None: + raise ValueError("In correct type of image_embeds", + "Got type: None") + if not isinstance(image_embeds[0], torch.Tensor): + raise ValueError("In correct type of image_embeds", + f"Got type: {type(image_embeds[0])}") + if image_sizes is None: + raise ValueError( + "In correct type of image_sizes", "Got type: None." + "If you're using `image_size_list`, " + "please rename it to `image_sizes`") + if len(image_embeds[0].shape) == 2: + image_embeds = [image_embeds] + image_sizes = [image_sizes] + self.data["image_embeds"] = image_embeds + self.data["image_sizes"] = image_sizes + + def get_image_size(self, index: int) -> ImageSize: + image_size = self.data["image_sizes"][index] + return ImageSize(width=image_size[0], height=image_size[1]) + + +class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems): + + def __init__(self, data: Dict) -> None: + super().__init__(data, "video") + video_embeds = self.data.get("video_embeds", None) + image_sizes = self.data.get("image_sizes", None) + num_frames = self.data.get("num_frames", None) + if video_embeds is None: + raise ValueError("In correct type of video_embeds", + "Got type: None") + if not isinstance(video_embeds[0], torch.Tensor): + raise ValueError("In correct type of video_embeds", + f"Got type: {type(video_embeds[0])}") + if image_sizes is None: + raise ValueError( + "In correct type of image_sizes", "Got type: None." + "If you're using `image_size_list`, " + "please rename it to `image_sizes`") + if num_frames is None: + raise ValueError("In correct type of numframes", "Got type: None") + if len(video_embeds[0].shape) == 2: + video_embeds = [video_embeds] + image_sizes = [image_sizes] + num_frames = [num_frames] + self.data["video_embeds"] = video_embeds + self.data["image_sizes"] = image_sizes + self.data["num_frames"] = num_frames + + def get_frame_size(self, index: int) -> ImageSize: + frame_size = self.data["image_sizes"][index] + return ImageSize(width=frame_size[0], height=frame_size[1]) + + def get_num_frames(self, index: int) -> int: + return self.data["num_frames"][index] + + DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -212,25 +297,6 @@ def forward(self, x: torch.Tensor, return x -def _build_image_input(ctx: InputContext, - image: RawImageType) -> MiniCPMVRawImageInput: - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code) - if hasattr(tokenizer, "slice_start_id"): - return MiniCPMVRawImageInput( - image=image, - im_start_id=torch.tensor(tokenizer.im_start_id), - im_end_id=torch.tensor(tokenizer.im_end_id), - slice_start_id=torch.tensor(tokenizer.slice_start_id), - slice_end_id=torch.tensor(tokenizer.slice_end_id)) - else: - return MiniCPMVRawImageInput( - image=image, - im_start_id=torch.tensor(tokenizer.im_start_id), - im_end_id=torch.tensor(tokenizer.im_end_id)) - - def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: version_float = getattr(config, "version", None) @@ -240,129 +306,512 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: if config.hidden_size == 2304 and config.query_num == 64: return (2, 0) return (2, 5) - version_str = str(version_float) return tuple(int(x) for x in version_str.split(".")) -def get_max_minicpmv_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config() - return getattr(hf_config, "query_num", 64) +class MiniCPMVMultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMVImageEmbeddingItems(data) + return super()._parse_image_data(data) + + def _parse_video_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMVVideoEmbeddingItems(data) + return super()._parse_video_data(data) + + +class MiniCPMVProcessingInfo(BaseProcessingInfo): + image_pattern = "(./)" + video_pattern = "()" + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor( + self, + **kwargs: object, + ): + hf_processor = self.ctx.get_hf_processor() + return hf_processor + + def get_image_processor(self): + hf_processor = self.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + return image_processor + + def get_model_version(self): + return get_version_by_config(self.get_hf_config()) + + def get_supported_mm_modalities(self) -> List[str]: + if self.get_model_version() == (2, 6): + return ["image", "video"] + else: + return ["image"] + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + if self.get_model_version() == (2, 6): + return {"image": None, "video": None} + else: + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + mm_max_tokens = {"image": self.get_max_image_tokens()} + if self.get_model_version() == (2, 6): + mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) + return mm_max_tokens + + def get_max_video_frame_tokens(self) -> int: + frame_size = self.get_video_frame_size_with_most_features() + return self.get_num_image_tokens(frame_size, + self.get_video_max_slice_num()) + + def get_max_video_tokens(self, seq_len: int) -> int: + return self.get_max_video_frame_tokens( + ) * self.get_num_frames_with_most_features(seq_len) + + def get_max_audio_tokens(self) -> int: + return self.get_max_audio_tokens_per_chunk( + ) * self.get_max_audio_chunks_with_most_features() + + def get_slice_query_num(self) -> int: + hf_config = self.get_hf_config() + query_num = getattr(hf_config, "query_num", 64) + return query_num + + def get_max_slice_num(self) -> int: + hf_config = self.get_hf_config() + max_slice_num = getattr(hf_config, "max_slice_num", 9) + return max_slice_num + + def get_sliced_grid(self, image_size: ImageSize, + max_slice_num: int) -> Tuple[int, int]: + if self.get_model_version() == (2, 6): + slice_grid = self.get_image_processor().get_sliced_grid( + image_size, max_slice_num) + else: + slice_grid = self.get_image_processor().get_sliced_grid(image_size) + return slice_grid + + def get_num_image_tokens(self, image_size: ImageSize, + max_slice_num: int) -> int: + slice_grid = self.get_sliced_grid(image_size, max_slice_num) + num_tokens = self.get_slice_query_num( + ) + 2 # ( * query_num) + if slice_grid is not None: + if self.get_model_version() == (2, 6): + num_additional_tokens = 0 + else: + # ( * query_num) + num_additional_tokens = 2 + num_tokens += ((self.get_slice_query_num() + 2) \ + * slice_grid[0] * slice_grid[1]) \ + + slice_grid[1] - 1 + num_additional_tokens + return num_tokens + def get_image_slice_nums(self, image_size: torch.Tensor, + max_slice_nums: int) -> int: + grid = self.get_sliced_grid(image_size, max_slice_nums) + return 1 if grid is None else grid[0] * grid[1] + 1 -def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - return SequenceData.from_prompt_token_counts((0, seq_len)) + def get_max_image_tokens(self) -> int: + image_size = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_size, self.get_max_slice_num()) + def get_image_size_with_most_features(self) -> ImageSize: + # Result in the max possible feature size (h:w = 9:1) + return self.get_default_image_sizes(self.get_max_slice_num()) -def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, - num_images: int): - width = height = hf_config.image_size - image = _build_image_input(ctx, - image=Image.new("RGB", (width, height), - color=0)) - return {"image": [image] if num_images == 1 else [image] * num_images} + def get_video_max_slice_num(self) -> int: + return 1 + def get_video_frame_size_with_most_features(self) -> ImageSize: + return self.get_default_image_sizes(self.get_video_max_slice_num()) -def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config() - num_images = mm_counts["image"] + def get_max_video_frames(self, max_tokens: int) -> int: + num_frame_tokens = self.get_max_video_frame_tokens() + num_frames = max_tokens // num_frame_tokens + return num_frames - seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) - mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) + def get_num_frames_with_most_features(self, seq_len: int) -> int: + mm_config = self.ctx.get_mm_config() + max_images = mm_config.limit_per_prompt.get("image", 1) + max_videos = mm_config.limit_per_prompt.get("video", 1) - return DummyData(seq_data, mm_data) + # count tokens + # which are not in get_max_image_tokens + max_image_tokens = self.get_max_image_tokens( + ) * max_images + 4 * max_images + max_total_frames = self.get_max_video_frames(seq_len - + max_image_tokens) + num_frames = max(max_total_frames // max(max_videos, 1), 1) -def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - model_config = ctx.model_config - version = get_version_by_config(model_config.hf_config) - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - image_processor = cached_get_image_processor(model_config.tokenizer) + return num_frames - def get_placeholder(image_size: Tuple[int, int], num_image: int): + def get_default_image_sizes(self, num_slices: int) -> ImageSize: + image_size = getattr(self.get_hf_config(), "image_size", 448) + return ImageSize(width=image_size, height=image_size * num_slices) + + +class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] + ): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_width, image_height = \ + self.info.get_image_size_with_most_features() + video_width, video_height = \ + self.info.get_video_frame_size_with_most_features() + num_video_frames = \ + self.info.get_num_frames_with_most_features(seq_len) + + mm_data = { + "image": + self._get_dummy_images(width=image_width, + height=image_height, + num_images=num_images), + "video": [ + self._get_dummy_images(width=video_width, + height=video_height, + num_images=num_video_frames) + ] * num_videos, + } + + image_prompt_texts = self.info.image_pattern * num_images + video_prompt_texts = self.info.video_pattern * num_videos + + return ProcessorInputs(prompt_text=image_prompt_texts + + video_prompt_texts, + mm_data=mm_data) + + +class MiniCPMVMultiModalProcessor( + BaseMultiModalProcessor[MiniCPMVProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + return MiniCPMVMultiModalDataParser() + + def get_slice_image_placeholder(self, image_size: ImageSize, + **kwargs) -> str: + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() if version == (2, 0) or version == (2, 5): return image_processor.get_slice_image_placeholder(image_size) return image_processor.get_slice_image_placeholder( - image_size, num_image) - - prompt = inputs.get("prompt") - token_ids = inputs.get("prompt_token_ids") - if prompt is None: - prompt = tokenizer.decode(token_ids) - - pattern = "(./)" - images = multi_modal_data["image"] - image_tags = re.findall(pattern, prompt) - if len(image_tags) == 0: - new_token_ids = token_ids - new_prompt = prompt - else: - if isinstance(images, dict): - image_size_list = images.get("image_size_list") - images = [images.get("image_embeds")] + image_size, **kwargs) + + def get_image_prompt_texts(self, + image_size: ImageSize, + image_idx: int = 0) -> str: + prompt_texts = self.get_slice_image_placeholder(image_size, + image_idx=image_idx) + return prompt_texts + + def get_video_prompt_texts(self, image_size: ImageSize, + num_frames: int) -> str: + prompt_texts = "".join( + self.get_slice_image_placeholder( + image_size=image_size, + image_idx=0, + max_slice_nums=self.info.get_video_max_slice_num(), + use_image_id=False) for image_idx in range(num_frames)) + return prompt_texts + + def get_special_tokens(self) -> Dict[str, torch.Tensor]: + tokenizer = self.info.get_tokenizer() + special_tokens = { + "im_start_id": torch.tensor(tokenizer.im_start_id), + "im_end_id": torch.tensor(tokenizer.im_end_id) + } + if hasattr(tokenizer, "slice_start_id"): + special_tokens["slice_start_id"] = torch.tensor( + tokenizer.slice_start_id) + special_tokens["slice_end_id"] = torch.tensor( + tokenizer.slice_end_id) + return special_tokens + + @staticmethod + def repack_processor_outputs(outputs: Any) -> BatchFeature: + valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"] + outputs = {key: outputs[key][0] for key in valid_keys} + return outputs + + def process_images(self, mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object]) -> Dict[str, object]: + images = mm_data.pop("images", []) + image_embeds = mm_data.pop("image_embeds", []) + if isinstance(images, Image.Image): + images = [images] + if isinstance(images, (list, torch.Tensor)) and len(images) > 0: + image_outputs = super()._call_hf_processor( + prompt=self.info.image_pattern * len(images), + mm_data={"images": images}, + mm_kwargs=mm_kwargs) + image_outputs = MiniCPMVMultiModalProcessor.\ + repack_processor_outputs(image_outputs) + elif len(image_embeds) > 0: + image_sizes = mm_data.pop("image_sizes", None) + image_outputs = { + "image_embeds": torch.cat(image_embeds), + "image_sizes": image_sizes + } else: - if isinstance(images, Image.Image): - images = [images] - image_size_list = [image.size for image in images] - - text_chunks = prompt.split(pattern) - new_prompt_chunks: List[str] = [] - for i in range(len(image_size_list)): - new_prompt_chunks += [ - text_chunks[i], - get_placeholder(image_size_list[i], i) - ] - new_prompt_chunks.append(text_chunks[-1]) - new_prompt = "".join(new_prompt_chunks) - new_token_ids = tokenizer.encode(new_prompt) - - multi_modal_data["image"] = [ - _build_image_input(ctx, image) for image in images - ] + image_outputs = {} + return image_outputs + + def process_videos(self, mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object]) -> Dict[str, object]: + videos = mm_data.pop("videos", []) + video_embeds = mm_data.pop("video_embeds", []) + if len(videos) > 0 and isinstance(videos[0], Image.Image): + videos = [videos] + if isinstance(videos, list) and len(videos) > 0: + video_outputs = { + "video_pixel_values": [], + "video_image_sizes": [], + "video_tgt_sizes": [], + "num_frames": [] + } + for video in videos: + parsed_video = [] + for frame in video: + if isinstance(frame, np.ndarray): + parsed_video.append(Image.fromarray(frame)) + else: + parsed_video.append(frame) + video = parsed_video + single_video_outputs = super()._call_hf_processor( + prompt=self.info.image_pattern * len(video), + mm_data={"images": video}, + mm_kwargs={ + **mm_kwargs, "max_slice_nums": + self.info.get_video_max_slice_num() + }) + video_outputs["num_frames"].append(len(video)) + for key in single_video_outputs: + if "video_" + key in video_outputs: + if key == "image_sizes": + video_outputs["video_" + key].append( + single_video_outputs[key][0][0]) + else: + video_outputs["video_" + + key] += single_video_outputs[key][0] + elif len(video_embeds): + image_sizes = mm_data.pop("image_sizes", None) + num_frames = mm_data.pop("num_frames", None) + video_outputs = { + "video_embeds": torch.cat(video_embeds), + "video_image_sizes": image_sizes, + "num_frames": num_frames + } + else: + video_outputs = {} + return video_outputs - return token_inputs( - prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - ) + def get_placeholder_match_pattern(self) -> str: + return r"\(<(image|video)>./\)" + def get_placeholder_split_pattern(self) -> str: + return r"\(<(?:image|video)>./\)" -def input_mapper_for_minicpmv(ctx: InputContext, data: object): - model_config = ctx.model_config + def process_mm_inputs(self, mm_data, mm_kwargs) -> object: + return { + "image": self.process_images(mm_data, mm_kwargs), + "video": self.process_videos(mm_data, mm_kwargs) + } - image_processor = cached_get_image_processor( - model_config.model, trust_remote_code=model_config.trust_remote_code) - if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") + def get_input_modalities(self, mm_data) -> List[str]: + supported_mm_modalities = self.info.get_supported_mm_modalities() + input_modalities = [] + for modality in supported_mm_modalities: + if modality in mm_data and mm_data[modality] != {}: + input_modalities.append(modality) + return input_modalities + + def get_modality_num_counter(self, modality: str) -> str: + if modality == "image": + return "image_sizes" + elif modality == "video": + return "video_image_sizes" + + def get_num_slices_by_modality(self, inputs: Dict[str, object], + modality: str, index: int) -> int: + if modality == "image": + return self.info.get_image_slice_nums( + inputs[modality]["image_sizes"][index], + self.info.get_max_slice_num()) + elif modality == "video": + return self.info.get_image_slice_nums( + inputs[modality]["video_image_sizes"][index], + self.info.get_video_max_slice_num() + ) * inputs[modality]["num_frames"][index] + else: + raise ValueError(f"UnExpected modality: {modality}") + + def check_mm_inputs(self, inputs: Dict[str, object], + matches: List[str]) -> None: + counts = Counter(matches) + for modality, count in counts.items(): + if modality not in inputs or not inputs[modality]: + raise ValueError(f"None input data of {modality}." + "But prompt requires.") + counter_key = self.get_modality_num_counter(modality) + if len(inputs[modality][counter_key]) != count: + raise ValueError(f"The prompt requires {count} " + f"{modality} inputs while you pass " + f"{len(inputs[modality][counter_key])}") + + def get_prompt_texts_by_modality(self, inputs: Dict[str, object], + modality: str, index: int) -> str: + if modality == "image": + return self.get_image_prompt_texts( + inputs["image"]["image_sizes"][index], index) + elif modality == "video": + return self.get_video_prompt_texts( + inputs["video"]["video_image_sizes"][index], + inputs["video"]["num_frames"][index]) + else: + raise ValueError(f"UnExpected modality: {modality}") - if not isinstance(data, list): - raise ValueError( - "Image input must be list of MiniCPMVImageInput, got (%s)", data) + def call_base_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + return super()._call_hf_processor(prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Do not support combination inputs of images and videos for now + # Try to handle interleaved multimodal data + tokenizer = self.info.get_tokenizer() + inputs = self.process_mm_inputs(mm_data, mm_kwargs) + mm_input_modalities = self.get_input_modalities(inputs) + num_mm_slices = {modality: [] for modality in mm_input_modalities} + for modality in mm_input_modalities: + num_counter_key = self.get_modality_num_counter(modality) + for index in range(len(inputs[modality][num_counter_key])): + num_mm_slices[modality].append( + self.get_num_slices_by_modality(inputs, modality, index)) + return { + "input_ids": np.array([tokenizer.encode(prompt)]), + **{ + key: value + for modality in inputs + for key, value in inputs[modality].items() + }, + **{ + f"{modality}_num_slices": num_mm_slices[modality] + for modality in mm_input_modalities + } + } - if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor): - batch_data = { - "image_embeds": data[0]['image'], + def _get_prompt_replacements( + self, mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: + placeholder = { + "image": self.info.image_pattern, + "video": self.info.video_pattern, } - else: - batch_data = image_processor \ - .preprocess([img["image"] for img in data], return_tensors="pt") \ - .data - if len(data) > 0: - batch_data["im_start_id"] = data[0]["im_start_id"] - batch_data["im_end_id"] = data[0]["im_end_id"] - if "slice_start_id" in data[0]: - batch_data["slice_start_id"] = data[0]["slice_start_id"] - batch_data["slice_end_id"] = data[0]["slice_end_id"] + def get_replacement_minicpmv(item_idx: int, modality: str): + if modality == "image": + return self.get_image_prompt_texts( + mm_items["image"].get_image_size(item_idx), item_idx) + else: # video + return self.get_video_prompt_texts( + mm_items["video"].get_frame_size(item_idx), + mm_items["video"].get_num_frames(item_idx)) + + return [ + PromptReplacement(modality=modality, + target=placeholder[modality], + replacement=partial(get_replacement_minicpmv, + modality=modality)) + for modality in ("image", "video") + ] - return MultiModalKwargs(batch_data) + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + def get_slices(num_slices: List[int]) -> List[int]: + slice_indices = [0] + list(accumulate(num_slices)) + slices = [(slice_indices[i], slice_indices[i + 1]) + for i in range(len(num_slices))] + return [slice(*slice_item) for slice_item in slices] + + image_slices = get_slices( + hf_inputs.get("image_num_slices", torch.empty(0))) + video_slices = get_slices( + hf_inputs.get("video_num_slices", torch.empty(0))) + + return dict( + pixel_values=MultiModalFieldConfig.flat("image", image_slices), + image_sizes=MultiModalFieldConfig.batched("image"), + tgt_sizes=MultiModalFieldConfig.flat("image", image_slices), + image_num_slices=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.flat("image", image_slices), + video_pixel_values=MultiModalFieldConfig.flat( + "video", video_slices), + video_image_sizes=MultiModalFieldConfig.batched("video"), + video_tgt_sizes=MultiModalFieldConfig.flat("video", video_slices), + video_embeds=MultiModalFieldConfig.flat("video", video_slices), + video_num_slices=MultiModalFieldConfig.batched("video")) + + def apply( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputs: + supported_mm_modalities = self.info.get_supported_mm_modalities() + if isinstance(prompt, list): + prompt = self.info.get_tokenizer().decode(prompt) + matches = re.findall(self.get_placeholder_match_pattern(), prompt) + mm_orders = { + f"{modality}_orders": + torch.tensor( + [index for index, m in enumerate(matches) if m == modality]) + for modality in supported_mm_modalities + } + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) + # Exclude x from placeholders + if "image" in result["mm_placeholders"] and \ + self.info.get_model_version() == (2, 6): + result["mm_placeholders"]["image"] = [ + PlaceholderRange(offset=p["offset"] + 3 + idx // 10, + length=p["length"] - 3 - idx // 10) + for idx, p in enumerate(result["mm_placeholders"]["image"]) + ] + result["mm_kwargs"].update(**mm_orders) + result["mm_kwargs"].update(**self.get_special_tokens()) + return result class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): @@ -409,7 +858,7 @@ def sampler(self): return get_sampler() - def get_embedding( + def get_embedding_with_vision( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], @@ -471,25 +920,46 @@ def _get_image_bounds( image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) - def _parse_and_validate_inputs( + def _parse_and_validate_image_inputs( self, input_ids: torch.Tensor, **kwargs: object, ) -> Optional[MiniCPMVImageInputs]: - pixel_values = kwargs.pop("pixel_values", []) - tgt_sizes = kwargs.pop("tgt_sizes", []) + mm_data = { + "image": { + key: kwargs.pop(key, []) + for key in ["pixel_values", "tgt_sizes", "image_num_slices"] + }, + "video": { + "pixel_values": kwargs.pop("video_pixel_values", []), + "tgt_sizes": kwargs.pop("video_tgt_sizes", []), + "video_num_slices": kwargs.pop("video_num_slices", []) + } + } im_start_id = kwargs.pop("im_start_id", None) im_end_id = kwargs.pop("im_end_id", None) slice_start_id = kwargs.pop("slice_start_id", None) slice_end_id = kwargs.pop("slice_end_id", None) + mm_orders = { + f"{modality}": kwargs.pop(f"{modality}_orders", None) + for modality in ["image", "video", "audio"] + } + batch_size = max(len(mm_data["image"]["pixel_values"]), + len(mm_data["video"]["pixel_values"])) image_embeds = kwargs.pop("image_embeds", None) - + video_embeds = kwargs.pop("video_embeds", None) + if image_embeds is not None and video_embeds is not None: + raise ValueError( + "Incorrect inputs for vision embeddings. " + "Image embeds and video embeds can not exist simultaneously.") + if video_embeds is not None: + image_embeds = video_embeds if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError(f"Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") - if isinstance(image_embeds, list): - image_embeds = torch.concat(image_embeds) + image_embeds = torch.concat( + [image_embeds[i] for i in range(len(image_embeds))]) return MiniCPMVImageEmbeddingInputs( image_bounds=self._get_image_bounds(input_ids, im_start_id, @@ -498,29 +968,47 @@ def _parse_and_validate_inputs( data=image_embeds, type="image_embeds", ) - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(tgt_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of target sizes. " - f"Got type: {type(tgt_sizes)}") - - if len(pixel_values) != len(tgt_sizes): - raise ValueError("Inconsistent batch lengths, found: " - f"{len(pixel_values)} vs. {len(tgt_sizes)}") + for modality, modality_mm_data in mm_data.items(): + if not isinstance(modality_mm_data["pixel_values"], + (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " + f"Got type: {type(modality_mm_data['pixel_values'])}") + + if not isinstance(modality_mm_data["tgt_sizes"], + (torch.Tensor, list)): + raise ValueError( + "Incorrect type of target sizes. " + f"Got type: {type(modality_mm_data['tgt_sizes'])}") + + if len(modality_mm_data["pixel_values"]) != len( + modality_mm_data["tgt_sizes"]): + raise ValueError( + "Inconsistent batch lengths, found: " + f"{len(modality_mm_data['pixel_values'])} vs. " + f"{len(modality_mm_data['tgt_sizes'])}") pixel_values_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = [] - for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): - if len(pixel_b) != len(tgt_b): - raise ValueError("Inconsistent N lengths, found: " - f"{len(pixel_b)} vs {len(tgt_b)}") - - for pixel_n, tgt_n in zip(pixel_b, tgt_b): - pixel_values_flat += pixel_n - tgt_sizes_flat += tgt_n + for b in range(batch_size): + mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \ + else {"image": 0} + mm_slice_counts = {"image": 0, "video": 0} \ + if self.version == (2, 6) else {"image": 0} + mm_orders_b = [(index, modality) for modality in mm_counts + for index in mm_orders[modality][b]] + for _, modality in sorted(mm_orders_b, key=lambda x: x[0]): + pos = mm_counts[modality] + num_slices = mm_data[modality][f"{modality}_num_slices"][b][ + pos] + slice_start_idx = mm_slice_counts[modality] + slice_end_idx = slice_start_idx + num_slices + pixel_values_flat += mm_data[modality]["pixel_values"][b][ + slice_start_idx:slice_end_idx] + tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][ + slice_start_idx:slice_end_idx] + mm_counts[modality] += 1 + mm_slice_counts[modality] += num_slices # NOTE: Input IDs does not contain image tokens during memory profiling, # so we allow it to be empty @@ -544,6 +1032,10 @@ def _parse_and_validate_inputs( type="pixel_values", ) + def _parse_and_validate_inputs(self, input_ids: torch.Tensor, + **kwargs: object): + return self._parse_and_validate_image_inputs(input_ids, **kwargs) + def forward( self, input_ids: torch.Tensor, @@ -556,9 +1048,10 @@ def forward( if intermediate_tensors is not None: vlm_embeddings = None else: - image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) - - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + image_inputs = \ + self._parse_and_validate_inputs(input_ids, **kwargs) + vlm_embeddings, _ = self.get_embedding_with_vision( + input_ids, image_inputs) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent @@ -964,15 +1457,15 @@ def get_vision_hidden_states(self, _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, - (2, 6): MiniCPMV2_6 + (2, 6): MiniCPMV2_6, } -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) -class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): +@MULTIMODAL_REGISTRY.register_processor( + MiniCPMVMultiModalProcessor, + info=MiniCPMVProcessingInfo, + dummy_inputs=MiniCPMVDummyInputsBuilder) +class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ Different versions of MiniCPMV use different visual encoders and LLMs, which is not conducive to the current integration logic of LoRA and diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8d71b19060bf4..de05bf2b772f5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -162,6 +162,7 @@ "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 + "MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMV": ("minicpmv", "MiniCPMV"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"),