From e78d57087753ed6257701eeaa7a8e83a1f4902c8 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 6 Nov 2024 22:22:11 +0200 Subject: [PATCH 1/9] Spec Decode - Remove hard-dependency on GPU Signed-off-by: Chendi Xue --- .../layers/spec_decode_base_sampler.py | 15 ++++++- vllm/spec_decode/medusa_worker.py | 21 +++++++-- vllm/spec_decode/metrics.py | 9 ++++ vllm/spec_decode/multi_step_worker.py | 28 +++++++++--- vllm/spec_decode/ngram_worker.py | 20 ++++++++- vllm/spec_decode/spec_decode_worker.py | 42 ++++++++++++----- vllm/spec_decode/target_model_runner.py | 45 +++++++++++++++---- vllm/spec_decode/util.py | 12 +++-- 8 files changed, 159 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 7e750a744e25f..bae5c4aee9784 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -43,6 +43,19 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None: dtype=torch.long, device=device) + def init_tensors(self, + device: Union[int, str], + device_type: str = 'cuda') -> None: + assert self.num_accepted_tokens is None + if isinstance(device, int): + device = f"{device_type}:{device}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + @property def probs_dtype(self): return torch.float32 @@ -77,7 +90,7 @@ def _create_output( tensor is [batch_size, k + num_bonus_tokens] """ batch_size, k = substitute_token_ids.shape - bonus_token_ids = bonus_token_ids.squeeze() + bonus_token_ids = bonus_token_ids.squeeze(-1) # Determine the index of the first False value for each row. limits = (accepted == 0).max(1).indices limits[~(accepted == 0).any(1)] = k diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 0d233f393cb8c..550d152c15e36 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -5,14 +5,29 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker - -class MedusaWorker(NonLLMProposerWorkerBase, Worker): +if current_platform.is_neuron(): + from vllm.worker.neuron_worker import NeuronWorker as WorkerCls +elif current_platform.is_hpu(): + from vllm.worker.hpu_worker import HPUWorker as WorkerCls +elif current_platform.is_openvino(): + from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls +elif current_platform.is_cpu(): + from vllm.worker.cpu_worker import CPUWorker as WorkerCls +elif current_platform.is_tpu(): + from vllm.worker.tpu_worker import TPUWorker as WorkerCls +elif current_platform.is_xpu(): + from vllm.worker.xpu_worker import XPUWorker as WorkerCls +else: + from vllm.worker.worker import Worker as WorkerCls + + +class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls): """Worker for Medusa. """ diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 89ccaba70e93c..b85af36ddc8f2 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -6,6 +6,7 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -81,8 +82,16 @@ def init_gpu_tensors(self, rank: int) -> None: self._rank = rank self._copy_stream = torch.cuda.Stream() + def init_tensors(self, rank: int, device_type: str = 'cuda') -> None: + self._rank = rank + if device_type == 'cuda': + self._copy_stream = torch.cuda.Stream() + def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: + # currently using cuda.Event, skip for any non_cuda_alike platform + if not current_platform.is_cuda_alike(): + return None # If a copy was initiated in the previous call, collect and return. if self._in_flight_copy is not None: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index f49b98f5c9528..306f773d24d1e 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -5,17 +5,35 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker - -class MultiStepWorker(Worker, ProposerWorkerBase): +if current_platform.is_neuron(): + from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls +elif current_platform.is_hpu(): + from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls +elif current_platform.is_openvino(): + from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls +elif current_platform.is_cpu(): + from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls +elif current_platform.is_tpu(): + from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls +elif current_platform.is_xpu(): + from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls +else: + from vllm.worker.worker import Worker as WorkerBaseCls + + +class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead @@ -75,7 +93,7 @@ def sampler_output( # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - if isinstance( + if current_platform.is_cuda_alike() and isinstance( self.model_runner, TP1DraftModelRunner ) and self.model_runner.supports_gpu_multi_step(expanded_request): # Here we run the draft_model_runner with multi-step prepare diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index debb3b2d5ec30..c759551ad1246 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -4,11 +4,29 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer +if current_platform.is_cuda_alike(): + DEVICE_TYPE = "cuda" +elif current_platform.is_neuron(): + DEVICE_TYPE = "neuron" +elif current_platform.is_hpu(): + DEVICE_TYPE = "hpu" +elif current_platform.is_openvino(): + DEVICE_TYPE = "openvino" +elif current_platform.is_cpu(): + DEVICE_TYPE = "cpu" +elif current_platform.is_tpu(): + DEVICE_TYPE = "tpu" +elif current_platform.is_xpu(): + DEVICE_TYPE = "xpu" +else: + raise ValueError(f"Unsupported platform: {current_platform}") + class NGramWorker(NonLLMProposerWorkerBase): """NGramWorker provides a light drafter without need for model. @@ -34,7 +52,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int, self.ngram_prompt_lookup_min = ngram_prompt_lookup_min def init_device(self): - self.device = torch.device(f"cuda:{self.local_rank}") + self.device = torch.device(f"{DEVICE_TYPE}:{self.local_rank}") self.load_model = lambda *args, **kwargs: None # Current NGramWorker only supports Top1Proposer diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b57742c2ebfdd..caaa32c5323dd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -14,12 +14,16 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) +from vllm.platforms import current_platform from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.medusa_worker import MedusaWorker @@ -36,9 +40,23 @@ get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +if current_platform.is_neuron(): + from vllm.worker.neuron_worker import NeuronWorker as WorkerCls +elif current_platform.is_hpu(): + from vllm.worker.hpu_worker import HPUWorker as WorkerCls +elif current_platform.is_openvino(): + from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls +elif current_platform.is_cpu(): + from vllm.worker.cpu_worker import CPUWorker as WorkerCls +elif current_platform.is_tpu(): + from vllm.worker.tpu_worker import TPUWorker as WorkerCls +elif current_platform.is_xpu(): + from vllm.worker.xpu_worker import XPUWorker as WorkerCls +else: + from vllm.worker.worker import Worker as WorkerCls + logger = init_logger(__name__) @@ -53,7 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_kwargs = kwargs.copy() kwargs["model_runner_cls"] = TargetModelRunner - target_worker = Worker(*args, **kwargs) + target_worker = WorkerCls(*args, **kwargs) # Set the disable_logprobs variable in the TargetModelRunner instance # as per its value specified in the SpeculativeConfig. target_worker.model_runner.disable_logprobs =\ @@ -125,7 +143,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): @classmethod def create_worker( cls, - scorer_worker: Worker, + scorer_worker: WorkerCls, draft_worker_kwargs: Dict[str, Any], disable_mqa_scorer: bool, disable_by_batch_size: Optional[int], @@ -158,8 +176,9 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner + if current_platform.is_cuda_alike(): + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( @@ -306,8 +325,9 @@ def init_device(self) -> None: self.scorer_worker.load_model() self.proposer_worker.load_model() - self._metrics.init_gpu_tensors(self.rank) - self.spec_decode_sampler.init_gpu_tensors(self.rank) + self._metrics.init_tensors(self.rank, device_type=self.device.type) + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device.type) scorer_cls: Type[SpeculativeScorer] if self.disable_mqa_scorer: @@ -320,7 +340,7 @@ def init_device(self) -> None: "[Speculative Decoding] Use MQA scorer for scoring proposals.") self.scorer = scorer_cls(scorer_worker=self.scorer_worker, - device=self.device, + device=self.device.type, vocab_size=self._vocab_size) self._configure_model_sampler_for_spec_decode() @@ -1090,11 +1110,11 @@ def get_cache_block_size_bytes(self): raise NotImplementedError def start_profile(self): - if isinstance(self.scorer_worker, Worker): + if isinstance(self.scorer_worker, WorkerCls): self.scorer_worker.start_profile() def stop_profile(self): - if isinstance(self.scorer_worker, Worker): + if isinstance(self.scorer_worker, WorkerCls): self.scorer_worker.stop_profile() diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index e61cde5b17f20..f1c87c7bfda3c 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,12 +1,42 @@ from typing import List, Optional from vllm.config import VllmConfig +from vllm.platforms import current_platform from vllm.sequence import SequenceGroupMetadata -from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, - ModelRunner) +if current_platform.is_cuda_alike(): + from vllm.worker.model_runner import ( + ModelInputForGPUWithSamplingMetadata as ModelInputCls) # yapf: disable + from vllm.worker.model_runner import ModelRunner as ModelRunnerCls +elif current_platform.is_neuron(): + from vllm.worker.neuron_model_runner import ( + ModelInputForNeuron as ModelInputCls) # yapf: disable + from vllm.worker.neuron_model_runner import ( + NeuronModelRunner as ModelRunnerCls) # yapf: disable +elif current_platform.is_hpu(): + from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerCls + from vllm.worker.hpu_model_runner import ( + ModelInputForHPUWithSamplingMetadata as ModelInputCls) # yapf: disable +elif current_platform.is_openvino(): + from vllm.worker.openvino_model_runner import ModelInput as ModelInputCls + from vllm.worker.openvino_model_runner import ( + OpenVINOModelRunner as ModelRunnerCls) # yapf: disable +elif current_platform.is_cpu(): + from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerCls + from vllm.worker.cpu_model_runner import ( + ModelInputForCPUWithSamplingMetadata as ModelInputCls) # yapf: disable +elif current_platform.is_tpu(): + from vllm.worker.tpu_model_runner import ModelInputForTPU as ModelInputCls + from vllm.worker.tpu_model_runner import TPUModelRunner as ModelRunnerCls +elif current_platform.is_xpu(): + from vllm.worker.xpu_model_runner import ( + ModelInputForXPUWithSamplingMetadata as ModelInputCls) # yapf: disable + from vllm.worker.xpu_model_runner import XPUModelRunner as ModelRunnerCls +else: + raise ValueError(f"Unsupported platform: {current_platform}") -class TargetModelRunner(ModelRunner): + +class TargetModelRunner(ModelRunnerCls): """Specialized model runner for speculative decoding target model. In speculative decoding, the log probabilities selected finally may not be the same ones as selected by the target model sampling. This means @@ -39,11 +69,10 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - model_input: ModelInputForGPUWithSamplingMetadata = super( - ).prepare_model_input(seq_group_metadata_list, virtual_engine, - finished_requests_ids) + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputCls: + model_input: ModelInputCls = super().prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) # If token log probabilities is disabled then skip generating sampler # CPU output. We directly serialize the GPU sampled_token_id tensors # as needed. If log probabilities is enabled then synchronize all the diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 193ef870dfceb..da8706658d09a 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SequenceGroupMetadata, SequenceOutput) @@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs): Arguments: msg (string): message to associate with the range """ - torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) - try: + if current_platform.is_cuda_alike(): + torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + else: yield - finally: - torch.cuda.nvtx.range_pop() class Timer: From ce2665c76208b3ca70bf6d3c32cf08e3cebad7f4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 8 Nov 2024 04:53:50 -0500 Subject: [PATCH 2/9] Enable CPU for speculative decoding Signed-off-by: Chendi Xue --- vllm/executor/cpu_executor.py | 8 +++++-- vllm/spec_decode/spec_decode_worker.py | 9 ++++--- vllm/worker/cpu_model_runner.py | 33 ++++++++++++++++++++++++-- vllm/worker/cpu_worker.py | 27 +++++++++++++++++++-- 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 4ceb5a837dd7f..fb96dd435a8a8 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -125,8 +125,12 @@ def _create_worker( local_rank: int = 0, rank: int = 0, ): - worker_module_name = "vllm.worker.cpu_worker" - worker_class_name = "CPUWorker" + if self.speculative_config is not None: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.cpu_worker" + worker_class_name = "CPUWorker" wrapper = WorkerWrapperBase( worker_module_name=worker_module_name, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index caaa32c5323dd..ea88182e78e0a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -23,6 +23,10 @@ if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner +elif current_platform.is_cpu(): + from vllm.spec_decode.cpu_draft_model_runner import (CPUTP1DraftModelRunner + as + TP1DraftModelRunner) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -176,9 +180,8 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - if current_platform.is_cuda_alike(): - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d3e1202c15e61..1bd1a3636fc12 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -78,6 +78,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): Used by the ModelRunner. """ sampling_metadata: Optional["SamplingMetadata"] = None + is_prompt: Optional[bool] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -146,6 +147,7 @@ def build(self) -> ModelInputForCPU: # just use seq_lens instead. seq_lens=seq_lens, query_lens=seq_lens, + is_prompt=is_prompt, ) def _compute_multi_modal_input( @@ -432,6 +434,7 @@ def __init__( vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + return_hidden_states: bool = False, *args, **kwargs, ): @@ -442,19 +445,25 @@ def __init__( cache_config = self.cache_config self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states self.device = self.device_config.device + self.pin_memory = False self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - ) + ) if needs_attn_backend else None # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY @@ -531,6 +540,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + previous_hidden_states: Optional[torch.Tensor] = None, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( @@ -551,7 +561,9 @@ def execute_model( "intermediate_tensors": intermediate_tensors, } - + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. @@ -567,4 +579,21 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states return [output] + + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) + + # sampler property will be used by spec_decode_worker + @property + def sampler(self): + return self.model.sampler + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index bc9164bd9d5df..ccb59a37996b0 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -128,6 +128,7 @@ def __init__( distributed_init_method: str, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + model_runner_cls: Optional[Type[CPUModelRunnerBase]] = None, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) @@ -151,15 +152,29 @@ def __init__( else: self.local_omp_cpuid = omp_cpuids.split("|")[rank] + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner if self.model_config.task == "embedding": ModelRunnerClass = CPUEmbeddingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner + elif model_runner_cls is not None: + ModelRunnerClass = model_runner_cls self.model_runner: CPUModelRunnerBase = ModelRunnerClass( vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + is_driver_worker=is_driver_worker, + **speculative_args, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CPUCacheEngine] @@ -197,7 +212,7 @@ def init_device(self) -> None: ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) if ret: logger.info(ret) - + self.device = torch.device("cpu") self.init_distributed_environment() # Set random seed. set_random_seed(self.model_config.seed) @@ -297,6 +312,14 @@ def do_metadata_broadcast(self) -> bool: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.cpu_cache + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + def execute_worker( self, worker_input: WorkerInput, From 2f393a1acbb48563810993c219034c394ac15fe4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 8 Nov 2024 05:36:22 -0500 Subject: [PATCH 3/9] Fix mypy formatting issue Signed-off-by: Chendi Xue --- vllm/spec_decode/spec_decode_worker.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ea88182e78e0a..36ea1310dda3b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -24,9 +24,7 @@ if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner elif current_platform.is_cpu(): - from vllm.spec_decode.cpu_draft_model_runner import (CPUTP1DraftModelRunner - as - TP1DraftModelRunner) + from vllm.spec_decode.cpu_draft_model_runner import CPUTP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -180,8 +178,15 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner + if current_platform.is_cuda_alike(): + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner + elif current_platform.is_cpu(): + draft_worker_kwargs[ + "model_runner_cls"] = CPUTP1DraftModelRunner + else: + raise NotImplementedError( + "current platform does not support EAGLE.") else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( From 707c1496b55e5935154810fbbb4048aafa369a8f Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 8 Nov 2024 20:38:54 -0500 Subject: [PATCH 4/9] forget to submit cpu_draft_model_runner. add it here Signed-off-by: Chendi Xue --- vllm/spec_decode/cpu_draft_model_runner.py | 48 ++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 vllm/spec_decode/cpu_draft_model_runner.py diff --git a/vllm/spec_decode/cpu_draft_model_runner.py b/vllm/spec_decode/cpu_draft_model_runner.py new file mode 100644 index 0000000000000..9699af2e26224 --- /dev/null +++ b/vllm/spec_decode/cpu_draft_model_runner.py @@ -0,0 +1,48 @@ +from typing import List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import IntermediateTensors +from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerBaseCls +from vllm.worker.cpu_model_runner import ModelInputForCPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class CPUTP1DraftModelRunner(ModelRunnerBaseCls): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + TODOs: + 1. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__(self, *args, **kwargs): + if kwargs.get("return_hidden_states"): + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + super().__init__(*args, **kwargs) + self.indices_of_seq_with_bonus_tokens = None + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForCPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + return super().execute_model( + model_input=model_input, + kv_caches=kv_caches, + previous_hidden_states=previous_hidden_states, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + ) From 9a3bd1675e18d202110b656f4d7afb0b13b8b1a7 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Sat, 16 Nov 2024 01:55:06 -0500 Subject: [PATCH 5/9] Fix format Signed-off-by: Chendi Xue --- vllm/worker/cpu_model_runner.py | 18 +++++++++--------- vllm/worker/cpu_worker.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 1bd1a3636fc12..16f675eff2851 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -493,6 +493,15 @@ def _prepare_model_input_tensors( return builder.build() # type: ignore + # sampler property will be used by spec_decode_worker + @property + def sampler(self): + return self.model.sampler + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( @@ -588,12 +597,3 @@ def execute_model( def generate_proposals(self, *args, **kwargs): return self.model.generate_proposals(*args, **kwargs) - - # sampler property will be used by spec_decode_worker - @property - def sampler(self): - return self.model.sampler - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ccb59a37996b0..ebe64d49161cf 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -128,7 +128,7 @@ def __init__( distributed_init_method: str, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - model_runner_cls: Optional[Type[CPUModelRunnerBase]] = None, + model_runner_cls: Optional[Type[CPUModelRunner]] = None, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) From 17e1aa5e3ccabc8a090202284d7e9717bd050506 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 21 Nov 2024 01:58:05 -0500 Subject: [PATCH 6/9] rebase to main Signed-off-by: Chendi Xue --- vllm/worker/cpu_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7b7f819319446..adaaec361d198 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -504,11 +504,12 @@ def execute_model( if model_input.multi_modal_kwargs is not None: multimodal_kwargs = MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs, device=self.device) - + previous_hidden_states_kwargs = {} if previous_hidden_states is not None: - previous_hidden_states_kwargs = - {"previous_hidden_states": previous_hidden_states} + previous_hidden_states_kwargs = { + "previous_hidden_states": previous_hidden_states + } hidden_states = model_executable( input_ids=model_input.input_tokens, From 3a4e912b1e903bee16c02590103c255e36d239b4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 21 Nov 2024 02:10:30 -0500 Subject: [PATCH 7/9] extract platform selector into single file Signed-off-by: Chendi Xue --- vllm/spec_decode/medusa_worker.py | 17 +------- vllm/spec_decode/multi_step_worker.py | 20 ++-------- vllm/spec_decode/ngram_worker.py | 19 +-------- vllm/spec_decode/selector.py | 53 +++++++++++++++++++++++++ vllm/spec_decode/target_model_runner.py | 33 +-------------- 5 files changed, 59 insertions(+), 83 deletions(-) create mode 100644 vllm/spec_decode/selector.py diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 550d152c15e36..584dd85c237fe 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -5,27 +5,12 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.selector import WorkerCls from vllm.spec_decode.top1_proposer import Top1Proposer -if current_platform.is_neuron(): - from vllm.worker.neuron_worker import NeuronWorker as WorkerCls -elif current_platform.is_hpu(): - from vllm.worker.hpu_worker import HPUWorker as WorkerCls -elif current_platform.is_openvino(): - from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls -elif current_platform.is_cpu(): - from vllm.worker.cpu_worker import CPUWorker as WorkerCls -elif current_platform.is_tpu(): - from vllm.worker.tpu_worker import TPUWorker as WorkerCls -elif current_platform.is_xpu(): - from vllm.worker.xpu_worker import XPUWorker as WorkerCls -else: - from vllm.worker.worker import Worker as WorkerCls - class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls): """Worker for Medusa. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 306f773d24d1e..d3e2b2da69862 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -15,25 +15,11 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.selector import WorkerCls from vllm.spec_decode.top1_proposer import Top1Proposer -if current_platform.is_neuron(): - from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls -elif current_platform.is_hpu(): - from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls -elif current_platform.is_openvino(): - from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls -elif current_platform.is_cpu(): - from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls -elif current_platform.is_tpu(): - from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls -elif current_platform.is_xpu(): - from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls -else: - from vllm.worker.worker import Worker as WorkerBaseCls - - -class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase): + +class MultiStepWorker(WorkerCls, ProposerWorkerBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index c759551ad1246..441894c0ca070 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -4,29 +4,12 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.selector import DEVICE_TYPE from vllm.spec_decode.top1_proposer import Top1Proposer -if current_platform.is_cuda_alike(): - DEVICE_TYPE = "cuda" -elif current_platform.is_neuron(): - DEVICE_TYPE = "neuron" -elif current_platform.is_hpu(): - DEVICE_TYPE = "hpu" -elif current_platform.is_openvino(): - DEVICE_TYPE = "openvino" -elif current_platform.is_cpu(): - DEVICE_TYPE = "cpu" -elif current_platform.is_tpu(): - DEVICE_TYPE = "tpu" -elif current_platform.is_xpu(): - DEVICE_TYPE = "xpu" -else: - raise ValueError(f"Unsupported platform: {current_platform}") - class NGramWorker(NonLLMProposerWorkerBase): """NGramWorker provides a light drafter without need for model. diff --git a/vllm/spec_decode/selector.py b/vllm/spec_decode/selector.py new file mode 100644 index 0000000000000..5a24557c1daa2 --- /dev/null +++ b/vllm/spec_decode/selector.py @@ -0,0 +1,53 @@ +from vllm.platforms import current_platform + +if current_platform.is_neuron(): + from vllm.worker.neuron_model_runner import ( # noqa: F401 + ModelInputForNeuron as ModelInputCls) + from vllm.worker.neuron_model_runner import ( # noqa: F401 + NeuronModelRunner as ModelRunnerCls) + from vllm.worker.neuron_worker import ( # noqa: F401 + NeuronWorker as WorkerCls) + DEVICE_TYPE = "neuron" +elif current_platform.is_hpu(): + from vllm.worker.hpu_model_runner import ( # noqa: F401 + HPUModelRunner as ModelRunnerCls) + from vllm.worker.hpu_model_runner import ( # noqa: F401 + ModelInputForHPUWithSamplingMetadata as ModelInputCls) + from vllm.worker.hpu_worker import HPUWorker as WorkerCls # noqa: F401 + DEVICE_TYPE = "hpu" +elif current_platform.is_openvino(): + from vllm.worker.openvino_model_runner import ( # noqa: F401 + ModelInput as ModelInputCls) + from vllm.worker.openvino_model_runner import ( # noqa: F401 + OpenVINOModelRunner as ModelRunnerCls) + from vllm.worker.openvino_worker import ( # noqa: F401 + OpenVINOWorker as WorkerCls) + DEVICE_TYPE = "openvino" +elif current_platform.is_cpu(): + from vllm.worker.cpu_model_runner import ( # noqa: F401 + CPUModelRunner as ModelRunnerCls) + from vllm.worker.cpu_model_runner import ( # noqa: F401 + ModelInputForCPUWithSamplingMetadata as ModelInputCls) + from vllm.worker.cpu_worker import CPUWorker as WorkerCls # noqa: F401 + DEVICE_TYPE = "cpu" +elif current_platform.is_tpu(): + from vllm.worker.tpu_model_runner import ( # noqa: F401 + ModelInputForTPU as ModelInputCls) + from vllm.worker.tpu_model_runner import ( # noqa: F401 + TPUModelRunner as ModelRunnerCls) + from vllm.worker.tpu_worker import TPUWorker as WorkerCls # noqa: F401 + DEVICE_TYPE = "tpu" +elif current_platform.is_xpu(): + from vllm.worker.xpu_model_runner import ( # noqa: F401 + ModelInputForXPUWithSamplingMetadata as ModelInputCls) + from vllm.worker.xpu_model_runner import ( # noqa: F401 + XPUModelRunner as ModelRunnerCls) + from vllm.worker.xpu_worker import XPUWorker as WorkerCls # noqa: F401 + DEVICE_TYPE = "xpu" +else: + from vllm.worker.model_runner import ( # noqa: F401 + ModelInputForGPUWithSamplingMetadata as ModelInputCls) + from vllm.worker.model_runner import ( # noqa: F401 + ModelRunner as ModelRunnerCls) + from vllm.worker.worker import Worker as WorkerCls # noqa: F401 + DEVICE_TYPE = "cuda" diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index f1c87c7bfda3c..4b16f5eda4402 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,39 +1,8 @@ from typing import List, Optional from vllm.config import VllmConfig -from vllm.platforms import current_platform from vllm.sequence import SequenceGroupMetadata - -if current_platform.is_cuda_alike(): - from vllm.worker.model_runner import ( - ModelInputForGPUWithSamplingMetadata as ModelInputCls) # yapf: disable - from vllm.worker.model_runner import ModelRunner as ModelRunnerCls -elif current_platform.is_neuron(): - from vllm.worker.neuron_model_runner import ( - ModelInputForNeuron as ModelInputCls) # yapf: disable - from vllm.worker.neuron_model_runner import ( - NeuronModelRunner as ModelRunnerCls) # yapf: disable -elif current_platform.is_hpu(): - from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerCls - from vllm.worker.hpu_model_runner import ( - ModelInputForHPUWithSamplingMetadata as ModelInputCls) # yapf: disable -elif current_platform.is_openvino(): - from vllm.worker.openvino_model_runner import ModelInput as ModelInputCls - from vllm.worker.openvino_model_runner import ( - OpenVINOModelRunner as ModelRunnerCls) # yapf: disable -elif current_platform.is_cpu(): - from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerCls - from vllm.worker.cpu_model_runner import ( - ModelInputForCPUWithSamplingMetadata as ModelInputCls) # yapf: disable -elif current_platform.is_tpu(): - from vllm.worker.tpu_model_runner import ModelInputForTPU as ModelInputCls - from vllm.worker.tpu_model_runner import TPUModelRunner as ModelRunnerCls -elif current_platform.is_xpu(): - from vllm.worker.xpu_model_runner import ( - ModelInputForXPUWithSamplingMetadata as ModelInputCls) # yapf: disable - from vllm.worker.xpu_model_runner import XPUModelRunner as ModelRunnerCls -else: - raise ValueError(f"Unsupported platform: {current_platform}") +from vllm.spec_decode.selector import ModelInputCls, ModelRunnerCls class TargetModelRunner(ModelRunnerCls): From d6e2b0500779723298025f37d429cac3fee0e7af Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 21 Nov 2024 23:52:36 -0500 Subject: [PATCH 8/9] use selector to define workerCls Signed-off-by: Chendi Xue --- vllm/spec_decode/spec_decode_worker.py | 31 +++++--------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 36ea1310dda3b..47054d5fa7937 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -24,7 +24,9 @@ if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner elif current_platform.is_cpu(): - from vllm.spec_decode.cpu_draft_model_runner import CPUTP1DraftModelRunner + from vllm.spec_decode.cpu_draft_model_runner import (CPUTP1DraftModelRunner + as + TP1DraftModelRunner) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -35,6 +37,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.selector import WorkerCls from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner from vllm.spec_decode.util import (Timer, create_logprobs_output, @@ -44,21 +47,6 @@ split_batch_by_proposal_len) from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase -if current_platform.is_neuron(): - from vllm.worker.neuron_worker import NeuronWorker as WorkerCls -elif current_platform.is_hpu(): - from vllm.worker.hpu_worker import HPUWorker as WorkerCls -elif current_platform.is_openvino(): - from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls -elif current_platform.is_cpu(): - from vllm.worker.cpu_worker import CPUWorker as WorkerCls -elif current_platform.is_tpu(): - from vllm.worker.tpu_worker import TPUWorker as WorkerCls -elif current_platform.is_xpu(): - from vllm.worker.xpu_worker import XPUWorker as WorkerCls -else: - from vllm.worker.worker import Worker as WorkerCls - logger = init_logger(__name__) @@ -178,15 +166,8 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - if current_platform.is_cuda_alike(): - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner - elif current_platform.is_cpu(): - draft_worker_kwargs[ - "model_runner_cls"] = CPUTP1DraftModelRunner - else: - raise NotImplementedError( - "current platform does not support EAGLE.") + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( From 73916b310b778fddd8b58bdbbb90a8a552fe2c2f Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 22 Nov 2024 00:19:17 -0500 Subject: [PATCH 9/9] Remove inflight DEVICE_TYPE in selector Signed-off-by: Chendi Xue --- vllm/spec_decode/ngram_worker.py | 4 ++-- vllm/spec_decode/selector.py | 7 ------- vllm/spec_decode/spec_decode_worker.py | 2 ++ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 441894c0ca070..f6133f7bf7b20 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -7,7 +7,6 @@ from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase -from vllm.spec_decode.selector import DEVICE_TYPE from vllm.spec_decode.top1_proposer import Top1Proposer @@ -23,6 +22,7 @@ def __init__(self, *args, **kwargs): # Get local_rank/vocab_size from kwargs attribute self.local_rank = kwargs["local_rank"] self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() + self.device_type = kwargs["device_type"] # Lazy initialization list. self._proposer: Top1Proposer @@ -35,7 +35,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int, self.ngram_prompt_lookup_min = ngram_prompt_lookup_min def init_device(self): - self.device = torch.device(f"{DEVICE_TYPE}:{self.local_rank}") + self.device = torch.device(f"{self.device_type}:{self.local_rank}") self.load_model = lambda *args, **kwargs: None # Current NGramWorker only supports Top1Proposer diff --git a/vllm/spec_decode/selector.py b/vllm/spec_decode/selector.py index 5a24557c1daa2..ae829ea3c00da 100644 --- a/vllm/spec_decode/selector.py +++ b/vllm/spec_decode/selector.py @@ -7,14 +7,12 @@ NeuronModelRunner as ModelRunnerCls) from vllm.worker.neuron_worker import ( # noqa: F401 NeuronWorker as WorkerCls) - DEVICE_TYPE = "neuron" elif current_platform.is_hpu(): from vllm.worker.hpu_model_runner import ( # noqa: F401 HPUModelRunner as ModelRunnerCls) from vllm.worker.hpu_model_runner import ( # noqa: F401 ModelInputForHPUWithSamplingMetadata as ModelInputCls) from vllm.worker.hpu_worker import HPUWorker as WorkerCls # noqa: F401 - DEVICE_TYPE = "hpu" elif current_platform.is_openvino(): from vllm.worker.openvino_model_runner import ( # noqa: F401 ModelInput as ModelInputCls) @@ -22,32 +20,27 @@ OpenVINOModelRunner as ModelRunnerCls) from vllm.worker.openvino_worker import ( # noqa: F401 OpenVINOWorker as WorkerCls) - DEVICE_TYPE = "openvino" elif current_platform.is_cpu(): from vllm.worker.cpu_model_runner import ( # noqa: F401 CPUModelRunner as ModelRunnerCls) from vllm.worker.cpu_model_runner import ( # noqa: F401 ModelInputForCPUWithSamplingMetadata as ModelInputCls) from vllm.worker.cpu_worker import CPUWorker as WorkerCls # noqa: F401 - DEVICE_TYPE = "cpu" elif current_platform.is_tpu(): from vllm.worker.tpu_model_runner import ( # noqa: F401 ModelInputForTPU as ModelInputCls) from vllm.worker.tpu_model_runner import ( # noqa: F401 TPUModelRunner as ModelRunnerCls) from vllm.worker.tpu_worker import TPUWorker as WorkerCls # noqa: F401 - DEVICE_TYPE = "tpu" elif current_platform.is_xpu(): from vllm.worker.xpu_model_runner import ( # noqa: F401 ModelInputForXPUWithSamplingMetadata as ModelInputCls) from vllm.worker.xpu_model_runner import ( # noqa: F401 XPUModelRunner as ModelRunnerCls) from vllm.worker.xpu_worker import XPUWorker as WorkerCls # noqa: F401 - DEVICE_TYPE = "xpu" else: from vllm.worker.model_runner import ( # noqa: F401 ModelInputForGPUWithSamplingMetadata as ModelInputCls) from vllm.worker.model_runner import ( # noqa: F401 ModelRunner as ModelRunnerCls) from vllm.worker.worker import Worker as WorkerCls # noqa: F401 - DEVICE_TYPE = "cuda" diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 47054d5fa7937..678a843fc263c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -153,6 +153,8 @@ def create_worker( draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'vllm_config'].parallel_config if ngram_prompt_lookup_max > 0: + draft_worker_kwargs[ + "device_type"] = scorer_worker.device_config.device.type proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max)