Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add vision language model support. #3042

Merged
merged 32 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
682e1a0
[Feature] Support Llava.
xwjiang2010 Feb 26, 2024
72f967a
Merge branch 'main' of https://github.com/vllm-project/vllm into xwji…
xwjiang2010 Mar 18, 2024
d2ddd4e
pillow
xwjiang2010 Mar 19, 2024
07ee4d2
formatting
xwjiang2010 Mar 19, 2024
e657634
address comments and fix some tests
xwjiang2010 Mar 19, 2024
afc2078
rest of Phillip's comments
xwjiang2010 Mar 19, 2024
c90faa9
ImageRequest --> MultiModalRequest
xwjiang2010 Mar 19, 2024
4933c98
Merge remote-tracking branch 'origin/main' into xwjiang/llava
xwjiang2010 Mar 21, 2024
7e12364
address comments
xwjiang2010 Mar 21, 2024
096b758
fix test
xwjiang2010 Mar 21, 2024
17453ec
awscli, download image
xwjiang2010 Mar 22, 2024
a2b2f78
Merge remote-tracking branch 'origin' into xwjiang/llava
xwjiang2010 Mar 22, 2024
6fc90cf
lint
xwjiang2010 Mar 22, 2024
c61efd8
lint
xwjiang2010 Mar 22, 2024
e98306d
bash
xwjiang2010 Mar 22, 2024
32a6e3f
working dir
xwjiang2010 Mar 22, 2024
ed78229
Merge remote-tracking branch 'origin/main' into xwjiang/llava
xwjiang2010 Mar 22, 2024
bba6cb2
wget
xwjiang2010 Mar 22, 2024
239c0a9
fix
xwjiang2010 Mar 22, 2024
908798a
fix
xwjiang2010 Mar 22, 2024
408402b
fix
xwjiang2010 Mar 22, 2024
c3ca810
install wget
xwjiang2010 Mar 22, 2024
381559c
Merge branch 'main' of https://github.com/vllm-project/vllm into xwji…
xwjiang2010 Mar 22, 2024
cd3fdb3
fix
xwjiang2010 Mar 23, 2024
1fb1eac
up
xwjiang2010 Mar 23, 2024
4436b68
Merge remote-tracking branch 'origin/main' into xwjiang/llava
xwjiang2010 Mar 25, 2024
4e21f3a
lint
xwjiang2010 Mar 25, 2024
c32905f
isort
xwjiang2010 Mar 25, 2024
4185777
formatting!!
xwjiang2010 Mar 25, 2024
d722b5b
remove type_checking
xwjiang2010 Mar 25, 2024
1d79460
separate out llava test
xwjiang2010 Mar 25, 2024
76d0a3b
Merge branch 'vllm-project:main' into xwjiang/llava
xwjiang2010 Mar 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .buildkite/download-images.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for doing this!


set -ex
set -o pipefail

(which wget && which curl) || (apt-get update && apt-get install -y wget curl)

# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
mkdir -p images
cd images
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg

cd -
8 changes: 7 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ steps:

- label: Models Test
commands:
- pytest -v -s models --forked
- bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py --forked
soft_fail: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof this is a bit tough because it is soft failed. do you think the test can run on a single L4 (with fp16)? If so maybe we can create another job for the test that are currently passing right now.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo is Llava support coming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you say more about this point? I think the current CI should pass. It's still failing on the same huggingface error. I am confused about that.


- label: Llava Test
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py

- label: Prefix Caching Test
commands:
- pytest -v -s prefix_caching
Expand Down
84 changes: 84 additions & 0 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import os
import subprocess

import torch

from vllm import LLM
from vllm.sequence import MultiModalData

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.


def run_llava_pixel_values():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt")

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def run_llava_image_features():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"

# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)

# Use AWS CLI to sync the directory
subprocess.check_call(
["aws", "s3", "sync", s3_bucket_path, local_directory])
main(args)
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ openai
requests
ray
peft
awscli

# Benchmarking
aiohttp

# Multimodal
pillow
126 changes: 112 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,79 @@

import pytest
import torch
from transformers import AutoModelForCausalLM
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer

_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]

# Multi modal related
_PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
_IMAGE_FILES = [
os.path.join(_TEST_DIR, "images", filename)
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
_IMAGE_PROMPTS = [
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
_IMAGE_FILES) == len(_IMAGE_PROMPTS)


def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
prompts = f.readlines()
return prompts


@pytest.fixture(scope="session")
def hf_image_prompts() -> List[str]:
return _IMAGE_PROMPTS


@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
return [Image.open(filename) for filename in _IMAGE_FILES]


@pytest.fixture()
def vllm_images(request) -> "torch.Tensor":
vision_language_config = request.getfixturevalue("model_and_config")[1]
all_images = []
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
filenames = _IMAGE_FEATURES_FILES
else:
filenames = _PIXEL_VALUES_FILES
for filename in filenames:
all_images.append(torch.load(filename))
return torch.concat(all_images, dim=0)


@pytest.fixture()
def vllm_image_prompts(request) -> List[str]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
return [
"<image>" * (vision_language_config.image_feature_size - 1) + p
for p in _IMAGE_PROMPTS
]


@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
Expand All @@ -42,6 +98,10 @@ def example_long_prompts() -> List[str]:
"float": torch.float,
}

_VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
}


class HfRunner:

Expand All @@ -53,25 +113,53 @@ def __init__(
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.model_name = model_name
if model_name not in _VISION_LANGUAGE_MODELS:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
else:
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
)
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)

def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if images:
assert len(prompts) == len(images)
for i, prompt in enumerate(prompts):
if self.model_name not in _VISION_LANGUAGE_MODELS:
input_ids = self.tokenizer(prompt,
return_tensors="pt").input_ids
inputs = {"input_ids": input_ids.cuda()}
else:
image = images[i] if images else None
inputs = self.processor(text=prompt,
images=image,
return_tensors="pt")
inputs = {
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}
output_ids = self.model.generate(
input_ids.cuda(),
**inputs,
use_cache=True,
**kwargs,
)
Expand All @@ -88,10 +176,12 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens)
max_new_tokens=max_tokens,
images=images)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
Expand Down Expand Up @@ -183,9 +273,16 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
if images is not None:
assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts,
sampling_params=sampling_params,
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
data=images)
if images is not None else None)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
Expand Down Expand Up @@ -222,9 +319,10 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
outputs = self.generate(prompts, greedy_params, images=images)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand Down
Loading
Loading