From 73892efae39f30903ff8b810a351355c1d45726e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jan 2025 17:39:15 +0800 Subject: [PATCH 1/3] fix tests Signed-off-by: youkaichao --- tests/conftest.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 917151ddcb8d4..f9be3f27aea44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,8 @@ init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) + TokensPrompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams @@ -886,6 +887,10 @@ def generate_beam_search( beam_width: int, max_tokens: int, ) -> List[Tuple[List[List[int]], List[str]]]: + if isinstance(prompts[0], str): + prompts = [TextPrompt(prompt) for prompt in prompts] + else: + prompts = [TokensPrompt(tokens) for tokens in prompts] outputs = self.model.beam_search( prompts, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) From 2dd912b115bd5444233f2db508479e658ba8d1d2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jan 2025 17:44:37 +0800 Subject: [PATCH 2/3] fix tests Signed-off-by: youkaichao --- tests/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f9be3f27aea44..35fd894654f68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -888,9 +888,11 @@ def generate_beam_search( max_tokens: int, ) -> List[Tuple[List[List[int]], List[str]]]: if isinstance(prompts[0], str): - prompts = [TextPrompt(prompt) for prompt in prompts] + prompts = [TextPrompt(prompt=prompt) for prompt in prompts] else: - prompts = [TokensPrompt(tokens) for tokens in prompts] + prompts = [ + TokensPrompt(prompt_token_ids=tokens) for tokens in prompts + ] outputs = self.model.beam_search( prompts, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) From 5435cc1d003536454e1bbd642eb3a7a02f8b1e63 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jan 2025 18:43:57 +0800 Subject: [PATCH 3/3] use is_list_of Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 1 + tests/conftest.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e288f8f30159a..7d13269540864 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -214,6 +214,7 @@ steps: - vllm/model_executor/layers - vllm/sampling_metadata.py - tests/samplers + - tests/conftest.py commands: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers diff --git a/tests/conftest.py b/tests/conftest.py index 35fd894654f68..95af4ac1eb17b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - identity) + identity, is_list_of) logger = init_logger(__name__) @@ -887,7 +887,7 @@ def generate_beam_search( beam_width: int, max_tokens: int, ) -> List[Tuple[List[List[int]], List[str]]]: - if isinstance(prompts[0], str): + if is_list_of(prompts, str, check="all"): prompts = [TextPrompt(prompt=prompt) for prompt in prompts] else: prompts = [