Skip to content

Commit

Permalink
[Model] Support multiple images for qwen-vl (vllm-project#8247)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
  • Loading branch information
3 people authored and siddharth9820 committed Sep 30, 2024
1 parent 7f2e55d commit ee6aef1
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 65 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ Multimodal Language Models
-
* - :code:`QWenLMHeadModel`
- Qwen-VL
- Image\ :sup:`E`
- Image\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
* - :code:`Qwen2VLForConditionalGeneration`
Expand Down
84 changes: 60 additions & 24 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,39 @@
]


def load_phi3v(question, image_urls: List[str]):
def load_qwenvl_chat(question: str, image_urls: List[str]):
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "".join(f"Picture {i}: <img></img>\n"
for i, _ in enumerate(image_urls, start=1))

# This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)

# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501

messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True,
chat_template=chat_template)

stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids, None, chat_template


def load_phi3v(question: str, image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
Expand All @@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids, None
return llm, prompt, stop_token_ids, None, None


def load_internvl(question, image_urls: List[str]):
def load_internvl(question: str, image_urls: List[str]):
model_name = "OpenGVLab/InternVL2-2B"

llm = LLM(
Expand Down Expand Up @@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]):
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

return llm, prompt, stop_token_ids, None
return llm, prompt, stop_token_ids, None, None


def load_qwen2_vl(question, image_urls: List[str]):
Expand Down Expand Up @@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]):
else:
image_data, _ = process_vision_info(messages)

return llm, prompt, stop_token_ids, image_data
return llm, prompt, stop_token_ids, image_data, None


model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
}


def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids, image_data = model_example_map[model](
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]
Expand All @@ -146,29 +179,32 @@ def run_generate(model, question: str, image_urls: List[str]):


def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
question, image_urls)

sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)

outputs = llm.chat([{
"role":
"user",
"content": [
{
"type": "text",
"text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
outputs = llm.chat(
[{
"role":
"user",
"content": [
{
"type": "text",
"text": question,
},
} for image_url in image_urls),
],
}],
sampling_params=sampling_params)
*({
"type": "image_url",
"image_url": {
"url": image_url
},
} for image_url in image_urls),
],
}],
sampling_params=sampling_params,
chat_template=chat_template,
)

for o in outputs:
generated_text = o.outputs[0].text
Expand Down
Loading

0 comments on commit ee6aef1

Please sign in to comment.