Skip to content

Commit

Permalink
[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#…
Browse files Browse the repository at this point in the history
…10980)

Signed-off-by: Andrew Feldman <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
  • Loading branch information
afeldman-nm and njhill authored Feb 24, 2025
1 parent 444b0f0 commit befc402
Show file tree
Hide file tree
Showing 5 changed files with 641 additions and 9 deletions.
103 changes: 98 additions & 5 deletions tests/v1/engine/test_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,114 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import Dict, List, Optional, Tuple

import pytest

from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams

MODEL = "facebook/opt-125m"
DTYPE = "half"

def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""

def _vllm_model(apc: bool, vllm_runner, monkeypatch):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
return vllm_runner(
MODEL,
dtype=DTYPE,
max_model_len=128,
enforce_eager=True,
enable_prefix_caching=apc,
gpu_memory_utilization=0.5,
)


@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture parameterized by APC True/False."""
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
yield vllm_model


@pytest.fixture(scope="function")
def vllm_model_apc(vllm_runner, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
yield vllm_model


def _get_test_sampling_params(
prompt_list: List[str],
seed: Optional[int] = 42,
) -> Tuple[List[SamplingParams], List[int]]:
"""Generate random sampling params for a batch."""

def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28)
if x < 10:
return 1
else:
return x - 8

n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions
return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
for n in n_list
], n_list


def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
Args:
vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing.
"""
sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
model: LLM = vllm_model.model
outputs = model.generate(example_prompts, sampling_params_list)

# Validate each request response
for out, n in zip(outputs, n_list):
completion_counts: Dict[str, int] = {}
# Assert correct number of completions
assert len(out.outputs) == n, (
f"{len(out.outputs)} completions; {n} expected.")
for idx in range(n):
comp = out.outputs[idx]
# Assert correct completion indices
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
text = comp.text
completion_counts[text] = completion_counts.get(text, 0) + 1
# Assert unique completions
if len(completion_counts) != n:
repeats = {
txt: num
for (txt, num) in completion_counts.items() if num > 1
}
raise AssertionError(
f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}")


def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model: LLM = vllm_model_apc.model
with pytest.raises(ValueError) as excinfo:
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
model.generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))

Expand Down
102 changes: 102 additions & 0 deletions tests/v1/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
model_name: str):
"""Parallel sampling without streaming.
A single request output contains a list of completions.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

# High temperature to maximize chance of unique completions.
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
stream=False,
seed=42)

# Assert `n` completions
num_completions = len(completion.choices)
assert num_completions == n, (
f"Num completions {num_completions} but expected {n}.")
completion_repeats: Dict[str, int] = {}
for idx, choice in enumerate(completion.choices):
# Assert correct completion index & some finish reason.
assert choice.index == idx, (
f"Index {choice.index} but expected {idx}.")
assert choice.finish_reason is not None, (
"None finish_reason is invalid.")
text = choice.text
completion_repeats[text] = completion_repeats.get(text, 0) + 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
repeats = {
txt: num
for (txt, num) in completion_repeats.items() if num > 1
}
raise AssertionError(
f"Expected {n} unique completions, got {num_unique};"
f" repeats: {repeats}.")


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
stream=True,
seed=42)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# Assert `n` completions with correct finish reasons
assert finish_reason_count == n, (
f"Expected {n} completions with valid indices and finish_reason.")
completion_repeats: Dict[str, int] = {}
for chunk in chunks:
chunk_len = len(chunk)
# Assert correct number of completion tokens
assert chunk_len == max_tokens, (
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
text = "".join(chunk)
completion_repeats[text] = completion_repeats.get(text, 0) + 1
print(text)
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
repeats = {
txt: num
for (txt, num) in completion_repeats.items() if num > 1
}
raise AssertionError(f"{num_unique} unique completions, expected {n};"
f" repeats: {repeats}")


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
27 changes: 26 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
Expand Down Expand Up @@ -170,7 +171,7 @@ async def add_request(
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
async def _generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
Expand Down Expand Up @@ -241,6 +242,30 @@ async def generate(
await self.abort(request_id)
raise

def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
kwargs = dict(prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
if sampling_params.n is None or sampling_params.n == 1:
return self._generate(**kwargs)
else:
# Special handling for parallel sampling requests
return generate_parallel_sampling_async(generate=self._generate,
**kwargs)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""

Expand Down
43 changes: 40 additions & 3 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor

Expand Down Expand Up @@ -48,6 +49,9 @@ def __init__(
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config

# Bookkeeping for parallel sampling requests
self.parallel_manager = SyncParallelSamplingManager()

# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
Expand Down Expand Up @@ -115,7 +119,8 @@ def from_engine_args(
multiprocess_mode=enable_multiprocessing)

def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
return self.parallel_manager.get_num_unfinished_requests(
self.output_processor.get_num_unfinished_requests())

def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
Expand Down Expand Up @@ -151,7 +156,36 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:

"""Add request."""
kwargs = dict(request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
# Handle parallel sampling requests differently.
if params is None or isinstance(params,
PoolingParams) or params.n == 1:
self._add_request(**kwargs)
else:
# Special handling for parallel sampling requests
self.parallel_manager.add_request_parallel_sampling(
add_request=self._add_request, **kwargs)

def _add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request, `n=1`"""
# 1) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
Expand Down Expand Up @@ -182,7 +216,10 @@ def step(self) -> List[RequestOutput]:
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

return processed_outputs.request_outputs
request_outputs = processed_outputs.request_outputs

# 4) Process unfinished parallel sampling requests
return self.parallel_manager.step(request_outputs)

def get_model_config(self):
return self.model_config
Expand Down
Loading

0 comments on commit befc402

Please sign in to comment.