Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for the multi-modal Llama 3.2 model #8811

Merged
merged 82 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
566d57f
add llamav tokeninizer and redirect loader to it
heheda12345 Aug 30, 2024
218145a
start to load shape
heheda12345 Sep 2, 2024
1c57f26
copy original model
heheda12345 Sep 2, 2024
5233e2d
add LlamaVLConfig
heheda12345 Sep 2, 2024
72b9a8a
can load weight, attention is ignored
heheda12345 Sep 2, 2024
2dd36f5
skip profile run by hardcode, can start model execution
heheda12345 Sep 2, 2024
ba9507d
Merge branch 'main' of github.com:vllm-project/vllm
heheda12345 Sep 2, 2024
affa9ba
can run text tokenizer now
heheda12345 Sep 3, 2024
f633de5
finish image preprocessor
heheda12345 Sep 3, 2024
de8bbad
can run vision encoder now
heheda12345 Sep 4, 2024
30239ad
run prefill self attention
heheda12345 Sep 5, 2024
6972cbf
run prefill crossattention
heheda12345 Sep 6, 2024
4e1344b
can generate the first token :)
heheda12345 Sep 7, 2024
f3d869d
can perform offline e2e run without decode crossattn, but wrong answer
heheda12345 Sep 7, 2024
6f26a3b
pass mm data in encoder-decoder
heheda12345 Sep 8, 2024
fa0912e
prefill result matches now. Model is speaking human words.
heheda12345 Sep 11, 2024
46634ff
generate correct result for single image
heheda12345 Sep 12, 2024
6b73f4d
can support arbitary number of image, need better mask for image_cnt<>1
heheda12345 Sep 12, 2024
fb10a70
temp save for profile run
heheda12345 Sep 12, 2024
718f879
can run tp, but wrong answer
heheda12345 Sep 13, 2024
2644349
can run tp for small model with correct result
heheda12345 Sep 13, 2024
ec4cb9c
tp for vision encoder
heheda12345 Sep 14, 2024
fc01266
update image preprocessor
heheda12345 Sep 15, 2024
3e1d249
support text-only input
heheda12345 Sep 15, 2024
c5ba3cf
Merge tag 'v0.6.1.post2' into llamavl
heheda12345 Sep 15, 2024
cac19d5
enable profile run
heheda12345 Sep 16, 2024
7e5eadd
copy mllama from transformer
heheda12345 Sep 17, 2024
7e3fb1e
can init model from vllm
heheda12345 Sep 17, 2024
49b05d6
weight loader
heheda12345 Sep 17, 2024
2e66a5d
run image encoder now
heheda12345 Sep 18, 2024
9770d84
Add API Server Support
simon-mo Sep 18, 2024
c9d612b
run single image reqeusts correctly
heheda12345 Sep 19, 2024
2f54ae3
single image match huggingface result
heheda12345 Sep 19, 2024
9e2d4ea
Merge remote-tracking branch 'origin/meta-ckpt-early-api-server' into…
heheda12345 Sep 19, 2024
8f3989e
small fix
heheda12345 Sep 19, 2024
01621a5
remove old code
heheda12345 Sep 19, 2024
65a470b
hardcode some config to read huggingface's config.json without modify…
heheda12345 Sep 19, 2024
2146716
move prompt to encoder prompt
heheda12345 Sep 19, 2024
062534b
hardcode to match tokenizer result
heheda12345 Sep 19, 2024
23f04b4
update test script
heheda12345 Sep 20, 2024
4ed4e6e
update test script
heheda12345 Sep 20, 2024
c140258
support text-only input
heheda12345 Sep 21, 2024
f662fdd
fix bug in text only prompt
heheda12345 Sep 21, 2024
6cf166a
add unit test
heheda12345 Sep 21, 2024
b7124e5
add complex tests, but cannot run single-gpu and multi-gpu at the sam…
heheda12345 Sep 21, 2024
e69f127
seperate encoder/decoder dummy input, support max_image=1
heheda12345 Sep 21, 2024
e0e297c
add mllamaconfig to override some params, simplying the model code (WIP)
heheda12345 Sep 22, 2024
f6732cf
upd
heheda12345 Sep 22, 2024
228b66b
code cleanup
heheda12345 Sep 22, 2024
f30319c
remove image processing from input processor
heheda12345 Sep 22, 2024
471e79f
fix precision issue of RMSNorm
heheda12345 Sep 22, 2024
2a0cb7e
only keep usefull vision encoder layer
heheda12345 Sep 22, 2024
f4a7e1e
Merge remote-tracking branch 'public/main' into llamavl
heheda12345 Sep 22, 2024
efbd9b8
merge main
heheda12345 Sep 22, 2024
a596997
format code
heheda12345 Sep 23, 2024
70b6bb3
try formater again
heheda12345 Sep 23, 2024
31000d0
try formater again
heheda12345 Sep 23, 2024
5be8a65
try formater again again again
heheda12345 Sep 23, 2024
8505a8f
try formater again again again again
heheda12345 Sep 23, 2024
a32c3ab
update example
heheda12345 Sep 23, 2024
10d1736
fix bug in openai api -> chat template
heheda12345 Sep 23, 2024
0aa61b0
change model based on new hf
heheda12345 Sep 23, 2024
b993988
make formater happy
heheda12345 Sep 23, 2024
9065770
update model name in example
heheda12345 Sep 23, 2024
bc34aa4
remove mllama chat template, use HF's instead
heheda12345 Sep 23, 2024
a25e383
[Bugfix] Include encoder_prompt_tokens in num_prompt_tokensin UsageInfo
CatherineSue Sep 23, 2024
9b931bf
Merge pull request #6 from vllm-project/chang/num_prompt_tokens
heheda12345 Sep 24, 2024
1eefdc7
update config based on HF update
heheda12345 Sep 24, 2024
ccebf14
Merge branch 'main' of github.com:vllm-project/vllm
heheda12345 Sep 25, 2024
d7750d3
update doc and hf model id
heheda12345 Sep 25, 2024
1ebd6dc
update hf model id again
heheda12345 Sep 25, 2024
3b6fb2b
Merge branch 'main' of github.com:vllm-project/vllm
heheda12345 Sep 25, 2024
c857735
fix format problem
heheda12345 Sep 25, 2024
e4bf803
Apply suggestions from code review
heheda12345 Sep 25, 2024
4d7fe0a
Update vllm/worker/enc_dec_model_runner.py
heheda12345 Sep 25, 2024
4cdc6b5
Update vllm/worker/worker.py
heheda12345 Sep 25, 2024
a6ad79f
Update vllm/worker/worker.py
heheda12345 Sep 25, 2024
8364093
upgrade huggingface
heheda12345 Sep 25, 2024
a12c8d3
Update vllm/transformers_utils/configs/__init__.py
heheda12345 Sep 25, 2024
4065047
update code based on code review
heheda12345 Sep 25, 2024
293f07f
add note
ywang96 Sep 25, 2024
3db294b
format
ywang96 Sep 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ Multimodal Language Models
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`MllamaForConditionalGeneration`
- Llama 3.2
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
Expand Down
24 changes: 24 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality):
return llm, prompt, stop_token_ids


# LLama
def run_mllama(question, modality):
assert modality == "image"

model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"

# Note: The default setting of max_num_seqs (256) and
# max_model_len (131072) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.

# The configuration below has been confirmed to launch on a
# single H100 GPU.
llm = LLM(
model=model_name,
max_num_seqs=16,
enforce_eager=True,
)

prompt = f"<|image|><|begin_of_text|>{question}"
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -256,6 +279,7 @@ def run_qwen2_vl(question, modality):
"internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"mllama": run_mllama,
}


Expand Down
4 changes: 2 additions & 2 deletions examples/openai_vision_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"content": [
{
"type": "text",
"text": "Whats in this image?"
"text": "What's in this image?"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"text": "What's in this image?"
"text": "What's in this image?"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing unicode from asian ' to english ' to avoid some encoding error.

},
{
"type": "image_url",
Expand Down Expand Up @@ -75,7 +75,7 @@ def encode_image_base64_from_url(image_url: str) -> str:
"content": [
{
"type": "text",
"text": "Whats in this image?"
"text": "What's in this image?"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"text": "What's in this image?"
"text": "What's in this image?"

},
{
"type": "image_url",
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
transformers >= 4.45.0 # Required for Llama 3.2.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi < 0.113.0; python_version < '3.9'
Expand Down
Empty file.
283 changes: 283 additions & 0 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
from typing import List, Optional, Tuple, Type, overload

import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close

_LIMIT_IMAGE_PER_PROMPT = 1

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})

text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
]

models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

return hf_output_ids, hf_output_str, out_logprobs


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]

if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)


def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.

All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_num_seqs=16,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]

def process(hf_inputs: BatchEncoding):
return hf_inputs

from transformers import AutoConfig
from transformers.models.mllama import MllamaConfig as MllamaConfigHf

# use transformer's MllamaConfig for hf_runner
# and vllm's MllamaConfig for vllm_runner
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]

from vllm.transformers_utils.configs.mllama import MllamaConfig
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
)
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,9 @@ def get_multimodal_config(self) -> "MultiModalConfig":
@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)
return getattr(self.hf_config, "is_encoder_decoder", False) or (
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))

@property
def is_embedding_model(self) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,11 @@ def is_embedding_model(self):

def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
if self.is_encoder_decoder_model():
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
else:
prompt_ids = inputs.get("prompt_token_ids")
Expand Down
Loading
Loading