-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
29acd2c
commit e255082
Showing
9 changed files
with
274 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 126 additions & 21 deletions
147
examples/offline_inference_vision_language_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,127 @@ | ||
""" | ||
This example shows how to use vLLM for running offline inference with | ||
the correct prompt format on vision language models for multimodal embedding. | ||
For most models, the prompt format should follow corresponding examples | ||
on HuggingFace model repository. | ||
""" | ||
from argparse import Namespace | ||
from typing import List, NamedTuple, Optional, Union | ||
|
||
from PIL.Image import Image | ||
|
||
from vllm import LLM | ||
from vllm.assets.image import ImageAsset | ||
|
||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB") | ||
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501 | ||
|
||
# Create an LLM. | ||
llm = LLM( | ||
model="TIGER-Lab/VLM2Vec-Full", | ||
task="embedding", | ||
trust_remote_code=True, | ||
max_model_len=4096, | ||
max_num_seqs=2, | ||
mm_processor_kwargs={"num_crops": 16}, | ||
) | ||
|
||
# Generate embedding. The output is a list of EmbeddingRequestOutputs. | ||
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}}) | ||
|
||
# Print the outputs. | ||
for output in outputs: | ||
print(output.outputs.embedding) # list of 3072 floats | ||
from vllm.multimodal.utils import fetch_image | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
|
||
class ModelRequestData(NamedTuple): | ||
llm: LLM | ||
prompt: str | ||
stop_token_ids: Optional[List[str]] | ||
image: Optional[Image] | ||
|
||
|
||
def run_e5_v(text_or_image: Union[str, Image]): | ||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 | ||
|
||
if isinstance(text_or_image, str): | ||
prompt = llama3_template.format( | ||
"<image>\nSummary above image in one word: ") | ||
image = None | ||
else: | ||
prompt = llama3_template.format( | ||
f"{text_or_image}\nSummary above sentence in one word: ") | ||
image = text_or_image | ||
|
||
llm = LLM( | ||
model="royokong/e5-v-2", | ||
task="embedding", | ||
) | ||
|
||
return ModelRequestData( | ||
llm=llm, | ||
prompt=prompt, | ||
stop_token_ids=None, | ||
image=image, | ||
) | ||
|
||
|
||
def run_vlm2vec(text_or_image: Union[str, Image]): | ||
if isinstance(text_or_image, str): | ||
prompt = f"Find me an everyday image that matches the given caption: {text_or_image}" # noqa: E501 | ||
image = None | ||
else: | ||
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501 | ||
image = text_or_image | ||
|
||
llm = LLM( | ||
model="TIGER-Lab/VLM2Vec-Full", | ||
task="embedding", | ||
trust_remote_code=True, | ||
mm_processor_kwargs={"num_crops": 4}, | ||
) | ||
|
||
return ModelRequestData( | ||
llm=llm, | ||
prompt=prompt, | ||
stop_token_ids=None, | ||
image=image, | ||
) | ||
|
||
|
||
def get_text_or_image(modality: str): | ||
if modality == "text": | ||
return "A dog sitting in the grass" | ||
|
||
if modality == "image": | ||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" | ||
return fetch_image(image_url) | ||
|
||
msg = f"Modality {modality} is not supported." | ||
raise ValueError(msg) | ||
|
||
|
||
def run_encode(model: str, modality: str): | ||
text_or_image = get_text_or_image(modality) | ||
req_data = model_example_map[model](text_or_image) | ||
|
||
# Generate embedding. The output is a list of EmbeddingRequestOutputs. | ||
outputs = req_data.llm.encode( | ||
{ | ||
"prompt": req_data.prompt, | ||
"multi_modal_data": { | ||
"image": req_data.image | ||
}, | ||
}, ) | ||
|
||
for output in outputs: | ||
print(output.outputs.embedding) | ||
|
||
|
||
def main(args: Namespace): | ||
run_encode(args.model, args.modality) | ||
|
||
|
||
model_example_map = { | ||
"e5_v": run_e5_v, | ||
"vlm2vec": run_vlm2vec, | ||
} | ||
|
||
if __name__ == "__main__": | ||
parser = FlexibleArgumentParser( | ||
description='Demo on using vLLM for offline inference with ' | ||
'vision language models') | ||
parser.add_argument('--model-type', | ||
'-m', | ||
type=str, | ||
default="vlm2vec", | ||
choices=model_example_map.keys(), | ||
help='The name of the embedding model.') | ||
parser.add_argument('--modality', | ||
type=str, | ||
default="image", | ||
choices=['text', 'image'], | ||
help='Modality of the input.') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import pytest | ||
import torch.nn.functional as F | ||
|
||
from ....conftest import IMAGE_ASSETS | ||
from ..utils import check_embeddings_close | ||
|
||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 | ||
|
||
HF_TEXT_PROMPTS = [ | ||
llama3_template.format( | ||
"The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501 | ||
), | ||
llama3_template.format( | ||
"cherry blossom\nSummary above sentence in one word: "), | ||
] | ||
|
||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ | ||
"stop_sign": | ||
llama3_template.format("<image>\nSummary above image in one word: "), | ||
"cherry_blossom": | ||
llama3_template.format("<image>\nSummary above image in one word: "), | ||
}) | ||
|
||
MODELS = ["royokong/e5-v-2"] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
def test_models( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model: str, | ||
dtype: str, | ||
) -> None: | ||
input_texts_images = [ | ||
*((text, None) for text in HF_TEXT_PROMPTS), | ||
*((text, image) | ||
for text, image in zip(HF_IMAGE_PROMPTS, image_assets)), | ||
] | ||
input_texts = [text for text, _ in input_texts_images] | ||
input_images = [image for _, image in input_texts_images] | ||
|
||
# NOTE: take care of the order. run vLLM first, and then run HF. | ||
# vLLM needs a fresh new process without cuda initialization. | ||
# if we run HF first, the cuda initialization will be done and it | ||
# will hurt multiprocessing backend with fork method (the default method). | ||
with vllm_runner(model, task="embedding", dtype=dtype, | ||
enforce_eager=True) as vllm_model: | ||
vllm_outputs = vllm_model.encode(input_texts, images=input_images) | ||
|
||
with hf_runner(model, dtype=dtype) as hf_model: | ||
all_inputs = hf_model.get_inputs(input_texts, images=input_images) | ||
|
||
all_outputs = [] | ||
for inputs in all_inputs: | ||
# Based on: https://huggingface.co/royokong/e5-v | ||
outputs = hf_model.model( | ||
**hf_model.wrap_device(inputs, | ||
device=hf_model.model.device.type), | ||
return_dict=True, | ||
output_hidden_states=True, | ||
) | ||
pooled_output = F.normalize(outputs.hidden_states[-1][:, -1, :], | ||
dim=-1) | ||
|
||
all_outputs.append(pooled_output.tolist()) | ||
|
||
hf_outputs = all_outputs | ||
|
||
check_embeddings_close( | ||
embeddings_0_lst=hf_outputs, | ||
embeddings_1_lst=vllm_outputs, | ||
name_0="hf", | ||
name_1="vllm", | ||
) |
Oops, something went wrong.