Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel]Generalize Speculative decode from Cuda #10094

Closed
wants to merge 10 commits into from
8 changes: 6 additions & 2 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ def _create_worker(
local_rank: int = 0,
rank: int = 0,
):
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"

wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None:
dtype=torch.long,
device=device)

def init_tensors(self,
device: Union[int, str],
device_type: str = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)

@property
def probs_dtype(self):
return torch.float32
Expand Down Expand Up @@ -77,7 +90,7 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
48 changes: 48 additions & 0 deletions vllm/spec_decode/cpu_draft_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import List, Optional

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerBaseCls
from vllm.worker.cpu_model_runner import ModelInputForCPUWithSamplingMetadata

logger = init_logger(__name__)


class CPUTP1DraftModelRunner(ModelRunnerBaseCls):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(*args, **kwargs)
self.indices_of_seq_with_bonus_tokens = None

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
return super().execute_model(
model_input=model_input,
kv_caches=kv_caches,
previous_hidden_states=previous_hidden_states,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
)
4 changes: 2 additions & 2 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.selector import WorkerCls
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls):
"""Worker for Medusa.
"""

Expand Down
9 changes: 9 additions & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -81,8 +82,16 @@ def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()

def init_tensors(self, rank: int, device_type: str = 'cuda') -> None:
self._rank = rank
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()

def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None

# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
Expand Down
12 changes: 8 additions & 4 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.selector import WorkerCls
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MultiStepWorker(Worker, ProposerWorkerBase):
class MultiStepWorker(WorkerCls, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand Down Expand Up @@ -75,7 +79,7 @@ def sampler_output(

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
Expand Down
3 changes: 2 additions & 1 deletion vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.selector import DEVICE_TYPE
from vllm.spec_decode.top1_proposer import Top1Proposer


Expand Down Expand Up @@ -34,7 +35,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(f"{DEVICE_TYPE}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None

# Current NGramWorker only supports Top1Proposer
Expand Down
53 changes: 53 additions & 0 deletions vllm/spec_decode/selector.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_cls_selector.py may be a better name for this.

Can we wrap the logic to an API? For example

def get_worker_cls_by_platform():
    ...

In general this is still not the best practice, but I don't have a better solution atm.
cc @youkaichao

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from vllm.platforms import current_platform

if current_platform.is_neuron():
from vllm.worker.neuron_model_runner import ( # noqa: F401
ModelInputForNeuron as ModelInputCls)
from vllm.worker.neuron_model_runner import ( # noqa: F401
NeuronModelRunner as ModelRunnerCls)
from vllm.worker.neuron_worker import ( # noqa: F401
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I actually plan to add some arguments like --worker-cls auto and let every platform select there own worker class. we should do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao, is something I can refer to? Or is this file works, currently, I put it under spec_decode folder, it also makes sense to put under worker folder.

NeuronWorker as WorkerCls)
DEVICE_TYPE = "neuron"
elif current_platform.is_hpu():
from vllm.worker.hpu_model_runner import ( # noqa: F401
HPUModelRunner as ModelRunnerCls)
from vllm.worker.hpu_model_runner import ( # noqa: F401
ModelInputForHPUWithSamplingMetadata as ModelInputCls)
from vllm.worker.hpu_worker import HPUWorker as WorkerCls # noqa: F401
DEVICE_TYPE = "hpu"
elif current_platform.is_openvino():
from vllm.worker.openvino_model_runner import ( # noqa: F401
ModelInput as ModelInputCls)
from vllm.worker.openvino_model_runner import ( # noqa: F401
OpenVINOModelRunner as ModelRunnerCls)
from vllm.worker.openvino_worker import ( # noqa: F401
OpenVINOWorker as WorkerCls)
DEVICE_TYPE = "openvino"
elif current_platform.is_cpu():
from vllm.worker.cpu_model_runner import ( # noqa: F401
CPUModelRunner as ModelRunnerCls)
from vllm.worker.cpu_model_runner import ( # noqa: F401
ModelInputForCPUWithSamplingMetadata as ModelInputCls)
from vllm.worker.cpu_worker import CPUWorker as WorkerCls # noqa: F401
DEVICE_TYPE = "cpu"
elif current_platform.is_tpu():
from vllm.worker.tpu_model_runner import ( # noqa: F401
ModelInputForTPU as ModelInputCls)
from vllm.worker.tpu_model_runner import ( # noqa: F401
TPUModelRunner as ModelRunnerCls)
from vllm.worker.tpu_worker import TPUWorker as WorkerCls # noqa: F401
DEVICE_TYPE = "tpu"
elif current_platform.is_xpu():
from vllm.worker.xpu_model_runner import ( # noqa: F401
ModelInputForXPUWithSamplingMetadata as ModelInputCls)
from vllm.worker.xpu_model_runner import ( # noqa: F401
XPUModelRunner as ModelRunnerCls)
from vllm.worker.xpu_worker import XPUWorker as WorkerCls # noqa: F401
DEVICE_TYPE = "xpu"
else:
from vllm.worker.model_runner import ( # noqa: F401
ModelInputForGPUWithSamplingMetadata as ModelInputCls)
from vllm.worker.model_runner import ( # noqa: F401
ModelRunner as ModelRunnerCls)
from vllm.worker.worker import Worker as WorkerCls # noqa: F401
DEVICE_TYPE = "cuda"
50 changes: 39 additions & 11 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
elif current_platform.is_cpu():
from vllm.spec_decode.cpu_draft_model_runner import CPUTP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
Expand All @@ -36,9 +42,23 @@
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerCls
else:
from vllm.worker.worker import Worker as WorkerCls

logger = init_logger(__name__)


Expand All @@ -53,7 +73,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs = kwargs.copy()

kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
target_worker = WorkerCls(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
Expand Down Expand Up @@ -125,7 +145,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@classmethod
def create_worker(
cls,
scorer_worker: Worker,
scorer_worker: WorkerCls,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
Expand Down Expand Up @@ -158,8 +178,15 @@ def create_worker(
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
elif current_platform.is_cpu():
draft_worker_kwargs[
"model_runner_cls"] = CPUTP1DraftModelRunner
else:
raise NotImplementedError(
"current platform does not support EAGLE.")
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
Expand Down Expand Up @@ -306,8 +333,9 @@ def init_device(self) -> None:
self.scorer_worker.load_model()
self.proposer_worker.load_model()

self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self._metrics.init_tensors(self.rank, device_type=self.device.type)
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device.type)

scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
Expand All @@ -320,7 +348,7 @@ def init_device(self) -> None:
"[Speculative Decoding] Use MQA scorer for scoring proposals.")

self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
device=self.device.type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument is device so you shouldn't pass "device type". You could take the device type in scorer_cls and don't need to change this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @comaniac , the reason I changed that is because the device type is str in Scorer_cls init, but for some reason, it passed device=> so it failed mypy test

https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/interfaces.py#L78-L79

vocab_size=self._vocab_size)

self._configure_model_sampler_for_spec_decode()
Expand Down Expand Up @@ -1090,11 +1118,11 @@ def get_cache_block_size_bytes(self):
raise NotImplementedError

def start_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerCls):
self.scorer_worker.start_profile()

def stop_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerCls):
self.scorer_worker.stop_profile()


Expand Down
14 changes: 6 additions & 8 deletions vllm/spec_decode/target_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from vllm.config import VllmConfig
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
from vllm.spec_decode.selector import ModelInputCls, ModelRunnerCls


class TargetModelRunner(ModelRunner):
class TargetModelRunner(ModelRunnerCls):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
Expand Down Expand Up @@ -39,11 +38,10 @@ def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputCls:
model_input: ModelInputCls = super().prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
Expand Down
Loading