Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Remove hard-dependencies of Speculative decode to CUDA workers #10587

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):

target_worker.init_device.assert_called_once()

metrics_collector.init_gpu_tensors.assert_called_once()
spec_decode_sampler.init_gpu_tensors.assert_called_once()
metrics_collector.init_tensors.assert_called_once()
spec_decode_sampler.init_tensors.assert_called_once()


@pytest.mark.parametrize("acceptance_sampler_method",
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ class ParallelConfig:
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"

world_size: int = field(init=False)

Expand Down
17 changes: 16 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None:
dtype=torch.long,
device=device)

def init_tensors(self,
device: Union[int, str],
device_type: Union[torch.device, str] = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device_type, torch.device):
device_type = device_type.type
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)

@property
def probs_dtype(self):
return torch.float32
Expand Down Expand Up @@ -77,7 +92,7 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
8 changes: 7 additions & 1 deletion vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "mp"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
4 changes: 3 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

Expand Down Expand Up @@ -236,4 +238,4 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()
CudaPlatform.log_warnings()
24 changes: 12 additions & 12 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@
allow_gpu_advance_step = True


class TP1DraftModelRunner(ModelRunner):
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
Expand All @@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)

super().__init__(*args, **kwargs)
super().__init__(model_runner)

self.indices_of_seq_with_bonus_tokens = None

Expand All @@ -73,10 +75,8 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple

def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt

Expand Down Expand Up @@ -168,7 +168,7 @@ def set_indices_of_seq_with_bonus_tokens(self,
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
Expand Down
8 changes: 5 additions & 3 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Set
from typing import Optional, Set, Union

import torch

Expand Down Expand Up @@ -75,9 +75,11 @@ def get_spec_proposals(

class SpeculativeScorer(ABC):

def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
def __init__(self, scorer_worker: WorkerBase,
device: Union[torch.device, str], vocab_size: int):
self._scorer_worker = scorer_worker
if isinstance(device, torch.device):
device = device.type
self._device = device
self._vocab_size = vocab_size

Expand Down
9 changes: 5 additions & 4 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
"""Worker for Medusa.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)

# Lazy initialization list.
self._proposer: Top1Proposer

def init_device(self):
super().init_device()
self.worker.init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
Expand Down
15 changes: 14 additions & 1 deletion vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from typing import Callable, Optional
from typing import Callable, Optional, Union

import msgspec
import torch

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -81,8 +82,20 @@ def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()

def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()

def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None

# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
Expand Down
31 changes: 24 additions & 7 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
Comment on lines +12 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to put this check in the module import? It would be better if this was only lazy imported within sampler_output


from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase


class MultiStepWorker(Worker, ProposerWorkerBase):
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand All @@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)

# Lazy initialization list.
self._proposer: SpeculativeProposer

def init_device(self) -> None:
super().init_device()
self.worker.init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
Expand All @@ -51,6 +56,18 @@ def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)

def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()

def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()

def initialize_cache(self, *args, **kwargs) -> None:
self.worker.initialize_cache(*args, **kwargs)

def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
return self.worker.execute_model(*args, **kwargs)

@torch.inference_mode()
def sampler_output(
self,
Expand All @@ -75,7 +92,7 @@ def sampler_output(

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
Expand All @@ -92,7 +109,7 @@ def sampler_output(
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
Expand Down
3 changes: 2 additions & 1 deletion vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
self.device_type = kwargs.get("device_type", "cuda")

# Lazy initialization list.
self._proposer: Top1Proposer
Expand All @@ -34,7 +35,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None

# Current NGramWorker only supports Top1Proposer
Expand Down
Loading