Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Consolidate prompt arguments to LLM engines #4328

Merged
merged 57 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
5d42800
Combine prompt inputs
DarkLight1337 Apr 24, 2024
5db2c5e
Fix a bunch of tests
DarkLight1337 Apr 25, 2024
74c5905
Fix LLaVA test
DarkLight1337 Apr 25, 2024
cd8917b
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 25, 2024
b49aba7
Fix `benchmark_latency` test
DarkLight1337 Apr 25, 2024
bfd7295
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 25, 2024
45c7f23
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 27, 2024
493e6ed
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 28, 2024
20aeceb
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 1, 2024
0f46653
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 3, 2024
c4f3540
Clarify tokenizer usage
DarkLight1337 May 3, 2024
ab8182c
Rename `encode_request -> process_model_inputs`
DarkLight1337 May 3, 2024
eac33e1
Support old API in `LLM.generate`
DarkLight1337 May 3, 2024
703d318
Add tests to ensure old API still works
DarkLight1337 May 3, 2024
19d85f9
Let all entrypoints tests be run at the same time
DarkLight1337 May 3, 2024
baebd99
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 7, 2024
dc9816f
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 8, 2024
1c50600
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 14, 2024
5759dfa
Add tests for LLM.encode and fix corresponding bugs
DarkLight1337 May 14, 2024
cc4bfb5
Apply formatter
DarkLight1337 May 14, 2024
6085b08
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 14, 2024
d5c9731
Rename `_add_requests` to `_validate_and_add_requests` to be more sim…
DarkLight1337 May 14, 2024
4f218a5
Separate `entrypoints` tests into two groups
DarkLight1337 May 14, 2024
a9201d0
Fix memory profiling error
DarkLight1337 May 14, 2024
ceebfa6
Fix memory usage for embedding server
DarkLight1337 May 15, 2024
7d991cd
Update embeddings API to use new imputs
DarkLight1337 May 15, 2024
0e79dfb
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 15, 2024
2c0d58f
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 15, 2024
48e7a4a
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 16, 2024
b6c0e29
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 20, 2024
3097582
Merge `llm` groups back into one by enabling gc
DarkLight1337 May 20, 2024
7bbd123
Improve documentation for LLM/engine
DarkLight1337 May 20, 2024
056eb61
Direct readers to the `PromptInputs` class
DarkLight1337 May 22, 2024
b3b990a
Separate `_run_engine` from `_validate_and_add_requests`
DarkLight1337 May 22, 2024
2169def
Add flag for deprecating legacy API
DarkLight1337 May 22, 2024
3dbded1
Add tests for `deprecate_kwargs`
DarkLight1337 May 22, 2024
8e20317
Apply formatter
DarkLight1337 May 22, 2024
fdccaa2
Rename attribute to be less misleading
DarkLight1337 May 22, 2024
77ee1c8
Renable using `'fork'` start method and improve speed by using `torch…
DarkLight1337 May 23, 2024
b1bcdd1
Simplify logic of casting request output
DarkLight1337 May 23, 2024
44b4681
Improve code readability
DarkLight1337 May 23, 2024
50343cb
Fix `multi_modal_data` being a required key
DarkLight1337 May 23, 2024
45aa420
Fix index out of range error
DarkLight1337 May 23, 2024
d4e2589
Use a flag to control whether to check output types
DarkLight1337 May 23, 2024
c07b579
Simplify flags
DarkLight1337 May 23, 2024
9d56eb0
Move output validation to a more appropriate location
DarkLight1337 May 23, 2024
bc05031
Add message to deprecation notice
DarkLight1337 May 23, 2024
95d4130
Apply formatter
DarkLight1337 May 23, 2024
cc84f65
Remove unused parameter in `_validate_and_add_requests` and fix test
DarkLight1337 May 24, 2024
6c5d4a6
Simplify code
DarkLight1337 May 25, 2024
fd2da12
Move attribute assignment outside `_init_tokenizer`
DarkLight1337 May 25, 2024
d78de94
Only emit warning once
DarkLight1337 May 25, 2024
8a86829
Simplify assignment expression
DarkLight1337 May 25, 2024
731ac0e
Place special case at the start
DarkLight1337 May 25, 2024
2d1a0bc
move API reference to under developer doc
ywang96 May 25, 2024
7b8ce2c
Fix links in docs
DarkLight1337 May 26, 2024
fff21a1
Remove unnecessary code to avoid repeated warning
DarkLight1337 May 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ steps:
- label: Entrypoints Test
#mirror_hardwares: [amd]
commands:
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py
- pytest -v -s test_inputs.py
- pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints -m openai

- label: Examples Test
working_dir: "/vllm-workspace/examples"
Expand Down Expand Up @@ -109,6 +109,9 @@ steps:
mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py

- label: Utils Test
command: pytest -v -s test_utils.py

- label: Worker Test
mirror_hardwares: [amd]
command: pytest -v -s worker
Expand Down
11 changes: 7 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import json
import time
from pathlib import Path
from typing import Optional
from typing import List, Optional

import numpy as np
import torch
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS


Expand Down Expand Up @@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
Expand All @@ -59,13 +62,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
LLM Class
==========
=========

.. autoclass:: vllm.LLM
:members:
Expand Down
14 changes: 14 additions & 0 deletions docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptStrictInputs

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:members:
:member-order: bysource

.. autoclass:: vllm.inputs.TokensPrompt
:show-inheritance:
:members:
:member-order: bysource
8 changes: 8 additions & 0 deletions docs/source/dev/offline_inference/offline_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Offline Inference
=================================

.. toctree::
:maxdepth: 1

llm
llm_inputs
11 changes: 3 additions & 8 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ Documentation
getting_started/quickstart
getting_started/examples/examples_index

.. toctree::
:maxdepth: 1
:caption: Offline Inference

offline_inference/llm
offline_inference/sampling_params

.. toctree::
:maxdepth: 1
:caption: Serving
Expand Down Expand Up @@ -108,7 +101,9 @@ Documentation
.. toctree::
:maxdepth: 2
:caption: Developer Documentation


dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/dockerfile/dockerfile
Expand Down
4 changes: 2 additions & 2 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ completion = client.chat.completions.create(
```

### Extra Parameters for Chat API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
Expand All @@ -65,7 +65,7 @@ The following extra parameters are supported:
```

### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
Expand Down
25 changes: 16 additions & 9 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ def run_llava_pixel_values():
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt")
image = torch.load("images/stop_sign_pixel_values.pt")

outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand All @@ -46,11 +50,14 @@ def run_llava_image_features():
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
image = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"
[tool.isort]
use_parentheses = true
skip_gitignore = true

[tool.pytest.ini_options]
markers = [
"skip_global_cleanup",
"llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only",
]
2 changes: 1 addition & 1 deletion tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def step_async(self):
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []

async def encode_request_async(self, *args, **kwargs):
async def process_model_inputs_async(self, *args, **kwargs):
pass

def generate(self, request_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def server():
ray.shutdown()


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
Expand Down
23 changes: 17 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.sequence import MultiModalData

Expand Down Expand Up @@ -402,12 +403,22 @@ def generate(
) -> List[Tuple[List[int], str]]:
if images is not None:
assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts,
sampling_params=sampling_params,
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
data=images)
if images is not None else None)

prompt_inputs: List[PromptInputs] = []
for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)

prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})

req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
Expand Down
15 changes: 12 additions & 3 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def test_append_slot_cow():

# Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1,
prompt="one two three",
prompt_token_ids=[1, 2, 3],
inputs={
"prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size)

# Fork the sequence, such that a COW will be required when we append a new
Expand Down Expand Up @@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():

assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks

parent = Sequence(1, "one two three", [0, 1, 2], block_size)
parent = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1",
seqs=[parent],
arrival_time=time.time(),
Expand Down
15 changes: 12 additions & 3 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ def create_dummy_prompt(
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
arrival_time=time.time(),
Expand Down Expand Up @@ -51,8 +57,11 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
prompt="",
prompt_token_ids=prompt_token_ids,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/engine/test_skip_tokenizer_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
Expand Down
4 changes: 4 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
from dataclasses import dataclass

import pytest

from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"

pytestmark = pytest.mark.openai


@dataclass
class MockModelConfig:
Expand Down
2 changes: 2 additions & 0 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")

pytestmark = pytest.mark.openai


def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
Expand Down
Loading
Loading