Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill committed Feb 5, 2025
1 parent 054562d commit 552a875
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 93 deletions.
52 changes: 17 additions & 35 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import pytest

from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.utils import STR_ASYNC_LLM_PROMPT_LP_APC_UNSUPPORTED

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand All @@ -22,37 +22,19 @@
disable_log_requests=True)


async def generate(
engine: AsyncLLM,
request_id: str,
output_kind: RequestOutputKind,
max_tokens: int,
sampling_params: Optional[SamplingParams] = None,
) -> Tuple[int, str]:
"""Wrapper for `AsyncLLM` generation.
At least one of `max_tokens` and `sampling_params` must
not be `None`. If `sampling_params` is `None`, `max_tokens`
is used to create a `SamplingParams` instance. If
`sampling_params` is provided, `max_tokens` is not used.
Args:
engine: AsyncLLM instance
request_id: AsyncLLM request ID
output_kind: request output strategy (i.e. delta vs final-only)
max_tokens: (optional) max number of tokens to generate
sampling_params: (optional) request sampling params
Returns:
count: number of returns from engine.generate()
request_id
"""
assert not (max_tokens is None and sampling_params is None), (
"At least one of max_tokens and sampling_params"
" must not be None.")
if sampling_params is None:
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0)
async def generate(engine: AsyncLLM,
request_id: str,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> Tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)

count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
output_kind=output_kind,
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=sampling_params):
Expand Down Expand Up @@ -94,16 +76,16 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc(
"request-0",
output_kind,
10,
sampling_params=SamplingParams(max_tokens=10,
temperature=0,
prompt_logprobs=5)))
prompt_logprobs=5))
# Validate exception string is correct
assert str(excinfo.value) == STR_ASYNC_LLM_PROMPT_LP_APC_UNSUPPORTED
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
finally:
# Shut down engine
engine.shutdown()


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(monkeypatch, output_kind: RequestOutputKind):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
Expand Down
17 changes: 11 additions & 6 deletions tests/v1/engine/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@

import pytest

from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams
from vllm.v1.engine.utils import STR_LLM_ENGINE_PROMPT_LP_APC_UNSUPPORTED


def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""

monkeypatch.setenv("VLLM_USE_V1", "1")
with pytest.raises(ValueError) as excinfo:
(LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)))
llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True)
try:
with pytest.raises(ValueError) as excinfo:
llm.generate("Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
finally:
del llm

# Validate exception string is correct
assert str(excinfo.value) == STR_LLM_ENGINE_PROMPT_LP_APC_UNSUPPORTED
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
10 changes: 5 additions & 5 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _validate_logprobs(
# logprob token id tensors associated with this
# position in the completion. Also break out the
# sampled token ranks
(ref_pos_logprob_vals, ref_pos_logprob_toks,
(ref_pos_logprob_toks, ref_pos_logprob_vals,
ref_sampled_token_rank) = ref_logprobs[idx]
# For each position in the completion sequence,
# ensure the actual sampled token is among the
Expand Down Expand Up @@ -267,17 +267,17 @@ def _validate_logprobs(
# Break out the reference prompt log prob value &
# logprob token id matrices for the whole prompt.
# Also break out the prompt token rank vector
(ref_prompt_logprob_vals, ref_prompt_logprob_toks,
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
ref_prompt_token_ranks) = ref_prompt_logprobs
for idx, (prompt_token, pos_logprob_dict) in enumerate(
zip(prompt_token_ids[1:], prompt_logprobs[1:])):

# Break out the reference prompt log prob value
# vector, prompt logprob token id vector, and
# prompt token rank at the current position.
(ref_pos_prompt_logprob_vals, ref_pos_prompt_logprob_toks,
ref_pos_prompt_token_rank) = (ref_prompt_logprob_vals[idx, :],
ref_prompt_logprob_toks[idx, :],
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
ref_prompt_logprob_vals[idx, :],
ref_prompt_token_ranks[idx])

# For each position in the prompt sequence,
Expand Down
85 changes: 38 additions & 47 deletions tests/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors

GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

Expand All @@ -29,6 +30,9 @@
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
PROMPT_LEN = 5

PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")

random.seed(42)


Expand Down Expand Up @@ -205,7 +209,7 @@ def generate_dummy_sample_logprobs(
sampled_tokens_list: List,
num_logprobs: int,
tokenizer: PreTrainedTokenizer,
) -> List[Tuple[List[float], List[int], int]]:
) -> List[Tuple[List[int], List[float], int]]:
"""Generate dummy sample logprobs
Generate a test data structure which imitates the list of sample logprobs
Expand All @@ -217,7 +221,7 @@ def generate_dummy_sample_logprobs(
tokenizer: model tokenizer to use for detokenization
Returns
List of (logprobs vector, top token ids vector, sampled token rank)
List of (top token ids vector, logprobs vector, sampled token rank)
Python lists tuples; in each tuple the logprobs and top token ids
vectors have the same length which is either `num_logprobs` or
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
Expand All @@ -234,8 +238,9 @@ def generate_dummy_sample_logprobs(
sampled_token_id)

res.append(
(_create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0),
token_vector, sampled_token_rank))
(token_vector,
_create_random_top_logprob_test_vector(num_logprobs + 1, -100,
0), sampled_token_rank))

# Convert tensors in the list tuples to Python lists
res_list_format = [
Expand All @@ -251,7 +256,7 @@ def generate_dummy_prompt_logprobs_tensors(
prompt_tokens_list: List,
num_logprobs: int,
tokenizer: PreTrainedTokenizer,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> LogprobsTensors:
"""Generate dummy prompt logprobs tensors
Generate a test data structure which imitates the torch Tensors of prompt
Expand Down Expand Up @@ -283,9 +288,11 @@ def generate_dummy_prompt_logprobs_tensors(
) = _create_random_top_token_test_matrix(
(num_prompt_logprobs, num_logprobs), 0,
len(tokenizer.vocab) - 1, prompt_tokens_list[1:])
return (_create_random_top_logprob_test_matrix(
(num_prompt_logprobs, num_logprobs + 1), -100,
0), token_vector, prompt_token_ranks)
return LogprobsTensors(
token_vector,
_create_random_top_logprob_test_matrix(
(num_prompt_logprobs, num_logprobs + 1), -100, 0),
prompt_token_ranks)


@dataclass
Expand All @@ -297,13 +304,13 @@ class DummyOutputProcessorTestVectors:
full_tokens: List[List[int]] # Prompt + generated tokens
prompt_tokens: List[List[int]]
generation_tokens: List[List[int]]
# Each request is associated with a tuple of (top logprobs,top tokens)
# prompt logprobs tensors
prompt_logprobs: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
# Each request is associated with a tuple of
# (top tokens, top logprobs, ranks) prompt logprobs tensors
prompt_logprobs: List[LogprobsTensors]
# Each request is associated with a sample logprobs; a request's
# sample logprobs are a list of (top logprobs,top tokens)
# sample logprobs are a list of (top tokens, top logprobs, ranks)
# sample logprobs tensors at each sequence position
generation_logprobs: List[List[Tuple[List[float], List[int], int]]]
generation_logprobs: List[List[Tuple[List[int], List[float], int]]]
prompt_strings: List[str]
prompt_strings_len: List[int]
generation_strings: List[str]
Expand All @@ -317,16 +324,15 @@ def __init__(
tokens_list: List[List[int]],
# For each request, for each sampled token offset,
# a tuple of
# (list of sample logprob vals,list of topk token ids)
generated_logprobs_raw: Optional[List[List[Tuple[List[float],
List[int],
# (list of topk token ids, list of sample logprob vals, rank)
generated_logprobs_raw: Optional[List[List[Tuple[List[int],
List[float],
int]]]] = None,
# For each request, a tuple of
# ( prompt logprob val matrix, prompt logprob tok id matrix );
# (prompt logprob val matrix, prompt logprob tok id matrix);
# each matrix has dimensions
# (num prompt toks) x (num prompt logprobs+1)
prompt_logprobs_raw: Optional[List[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]]] = None,
prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None,
) -> None:
self.tokens_list = tokens_list
self.current_idx = 0
Expand All @@ -345,46 +351,31 @@ def get_outputs(self) -> List[EngineCoreOutput]:
if len(token_ids) > token_idx:
if do_logprobs:
assert self.generated_logprobs_raw is not None
(logprobs_, logprobs_token_ids_, sampled_token_ranks_) = (
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
self.generated_logprobs_raw[req_idx][token_idx])
logprobs = [logprobs_]
logprobs_token_ids = [logprobs_token_ids_]
sampled_token_ranks = [sampled_token_ranks_]
logprobs = LogprobsLists(
[logprobs_token_ids_],
[logprobs_],
[sampled_token_ranks_],
)
else:
logprobs = None
logprobs_token_ids = None
sampled_token_ranks = None
if do_prompt_logprobs:
if self.current_idx == 0:
assert self.prompt_logprobs_raw is not None
pos_prompt_logprobs = self.prompt_logprobs_raw[req_idx]
prompt_logprobs = pos_prompt_logprobs[0]
prompt_logprobs_token_ids = pos_prompt_logprobs[1]
prompt_token_ranks = pos_prompt_logprobs[2]
prompt_logprobs = self.prompt_logprobs_raw[req_idx]
else:
(
prompt_logprobs,
prompt_logprobs_token_ids,
prompt_token_ranks,
) = (torch.empty(0, 0), torch.empty(0,
0), torch.empty(0))
prompt_logprobs = None
else:
(
prompt_logprobs,
prompt_logprobs_token_ids,
prompt_token_ranks,
) = (None, None, None)
prompt_logprobs = None
output = EngineCoreOutput(
request_id=f"request-{req_idx}",
new_token_ids=[token_ids[token_idx]],
new_logprobs=logprobs,
new_logprobs_token_ids=logprobs_token_ids,
new_sampled_token_ranks=sampled_token_ranks,
new_prompt_logprobs=prompt_logprobs,
new_prompt_logprobs_token_ids=prompt_logprobs_token_ids,
new_prompt_token_ranks=prompt_token_ranks)
new_prompt_logprobs_tensors=prompt_logprobs,
)
if token_idx == len(token_ids) - 1:
output.finish_reason = "stopped"
output.finish_reason = FinishReason.STOP
outputs.append(output)

self.current_idx += 1
Expand Down

0 comments on commit 552a875

Please sign in to comment.