Skip to content

Commit

Permalink
[Bugfix] Fix missing task for speculative decoding (vllm-project#9524)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
DarkLight1337 authored and garg-amit committed Oct 28, 2024
1 parent 198a021 commit 01bc62c
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

Task = Literal["generate", "embedding"]
TaskOption = Literal["auto", Task]
TaskOption = Literal["auto", "generate", "embedding"]

# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]


class ModelConfig:
Expand Down Expand Up @@ -115,7 +117,7 @@ class ModelConfig:

def __init__(self,
model: str,
task: TaskOption,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
Expand Down Expand Up @@ -255,18 +257,21 @@ def _verify_tokenizer_mode(self) -> None:

def _resolve_task(
self,
task_option: TaskOption,
task_option: Union[TaskOption, _Task],
hf_config: PretrainedConfig,
) -> Tuple[Set[Task], Task]:
) -> Tuple[Set[_Task], _Task]:
if task_option == "draft":
return {"draft"}, "draft"

architectures = getattr(hf_config, "architectures", [])

task_support: Dict[Task, bool] = {
task_support: Dict[_Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[Task] = [
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
Expand Down Expand Up @@ -1002,7 +1007,7 @@ class SchedulerConfig:
"""

def __init__(self,
task: Task,
task: _Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
Expand Down Expand Up @@ -1269,7 +1274,7 @@ def maybe_create_spec_config(
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task=target_model_config.task,
task="draft",
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
Expand Down

0 comments on commit 01bc62c

Please sign in to comment.