Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix missing task for speculative decoding #9524

Merged
merged 1 commit into from
Oct 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

Task = Literal["generate", "embedding"]
TaskOption = Literal["auto", Task]
TaskOption = Literal["auto", "generate", "embedding"]

# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]


class ModelConfig:
Expand Down Expand Up @@ -115,7 +117,7 @@ class ModelConfig:

def __init__(self,
model: str,
task: TaskOption,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
Expand Down Expand Up @@ -255,18 +257,21 @@ def _verify_tokenizer_mode(self) -> None:

def _resolve_task(
self,
task_option: TaskOption,
task_option: Union[TaskOption, _Task],
hf_config: PretrainedConfig,
) -> Tuple[Set[Task], Task]:
) -> Tuple[Set[_Task], _Task]:
if task_option == "draft":
return {"draft"}, "draft"

architectures = getattr(hf_config, "architectures", [])

task_support: Dict[Task, bool] = {
task_support: Dict[_Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[Task] = [
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
Expand Down Expand Up @@ -1002,7 +1007,7 @@ class SchedulerConfig:
"""

def __init__(self,
task: Task,
task: _Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
Expand Down Expand Up @@ -1269,7 +1274,7 @@ def maybe_create_spec_config(
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task=target_model_config.task,
task="draft",
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
Expand Down