forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1] Logprobs and prompt logprobs support (vllm-project#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1. New behavior: - During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order. - In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized. - During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.) - Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer. Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: Nick Hill <[email protected]> Signed-off-by: [email protected] <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Nick Hill <[email protected]>
- Loading branch information
1 parent
6ecc410
commit 8479418
Showing
30 changed files
with
2,869 additions
and
287 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import List, Tuple | ||
|
||
import pytest | ||
import torch | ||
from transformers import AutoTokenizer | ||
|
||
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, | ||
NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, | ||
TOKENIZER_NAME, | ||
DummyOutputProcessorTestVectors, | ||
generate_dummy_prompt_logprobs_tensors, | ||
generate_dummy_sample_logprobs) | ||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs | ||
|
||
from tests.v1.engine.utils import FULL_STRINGS # isort: skip | ||
|
||
EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]] | ||
EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: | ||
"""Generate output processor dummy test vectors, without logprobs | ||
Returns: | ||
DummyOutputProcessorTestVectors instance with no logprobs | ||
""" | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | ||
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config() | ||
# Tokenize prompts under test & create dummy generated tokens | ||
prompt_tokens = [ | ||
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS | ||
] | ||
generation_tokens = [ | ||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS | ||
] | ||
# Generate prompt strings | ||
prompt_strings = [ | ||
tokenizer.decode(prompt_tokens, skip_special_tokens=True) | ||
for prompt_tokens in prompt_tokens | ||
] | ||
prompt_strings_len = [ | ||
len(prompt_string) for prompt_string in prompt_strings | ||
] | ||
return DummyOutputProcessorTestVectors( | ||
tokenizer=tokenizer, | ||
tokenizer_group=init_tokenizer_from_configs( | ||
vllm_config.model_config, vllm_config.scheduler_config, | ||
vllm_config.parallel_config, vllm_config.lora_config), | ||
vllm_config=vllm_config, | ||
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], | ||
prompt_tokens=prompt_tokens, | ||
generation_tokens=generation_tokens, | ||
prompt_strings=prompt_strings, | ||
prompt_strings_len=prompt_strings_len, | ||
generation_strings=[ | ||
text[prompt_len:] | ||
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len) | ||
], | ||
prompt_logprobs=[], | ||
generation_logprobs=[]) | ||
|
||
|
||
@pytest.fixture | ||
def dummy_test_vectors() -> DummyOutputProcessorTestVectors: | ||
"""Generate output processor dummy test vectors, with logprobs | ||
Returns: | ||
DummyOutputProcessorTestVectors instance with logprobs | ||
""" | ||
# Build dummy test vectors without logprobs | ||
dtv = _build_test_vectors_no_logprobs() | ||
# Inject logprobs into dummy test vectors | ||
# data structure | ||
dtv.generation_logprobs = [ | ||
generate_dummy_sample_logprobs( | ||
sampled_tokens_list=tokens_list, | ||
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST, | ||
tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens | ||
] | ||
dtv.prompt_logprobs = [ | ||
generate_dummy_prompt_logprobs_tensors( | ||
prompt_tokens_list=tokens_list, | ||
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST, | ||
tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens | ||
] | ||
return dtv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG | ||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): | ||
"""Test passes if LLMEngine raises an exception when it is configured | ||
for automatic prefix caching and it receives a request with | ||
prompt_logprobs enabled, which is incompatible.""" | ||
|
||
monkeypatch.setenv("VLLM_USE_V1", "1") | ||
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. | ||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") | ||
with pytest.raises(ValueError) as excinfo: | ||
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( | ||
"Hello, my name is", | ||
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) | ||
|
||
# Validate exception string is correct | ||
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG |
Oops, something went wrong.