Skip to content

Commit

Permalink
Remove Dynamic WorkerCls, use WorkerWrapperBase instead
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Nov 23, 2024
1 parent d0c4d49 commit 9b49363
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 142 deletions.
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ class ParallelConfig:
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"
actual_worker_cls: str = "auto"

world_size: int = field(init=False)

Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.actual_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
48 changes: 0 additions & 48 deletions vllm/spec_decode/cpu_draft_model_runner.py

This file was deleted.

24 changes: 12 additions & 12 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@
allow_gpu_advance_step = True


class TP1DraftModelRunner(ModelRunner):
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
Expand All @@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)

super().__init__(*args, **kwargs)
super().__init__(model_runner)

self.indices_of_seq_with_bonus_tokens = None

Expand All @@ -73,10 +75,8 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple

def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt

Expand Down Expand Up @@ -168,7 +168,7 @@ def set_indices_of_seq_with_bonus_tokens(self,
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
Expand Down
8 changes: 5 additions & 3 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,22 @@
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import WorkerWrapperBase


class MedusaWorker(NonLLMProposerWorkerBase):
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
"""Worker for Medusa.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)

# Lazy initialization list.
self._proposer: Top1Proposer

def init_device(self):
super().init_device()
self.worker.init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
Expand Down
22 changes: 18 additions & 4 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import WorkerWrapperBase


class MultiStepWorker(ProposerWorkerBase):
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand All @@ -31,13 +32,14 @@ class MultiStepWorker(ProposerWorkerBase):
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)

# Lazy initialization list.
self._proposer: SpeculativeProposer

def init_device(self) -> None:
super().init_device()
self.worker.init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
Expand All @@ -54,6 +56,18 @@ def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)

def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()

def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()

def initialize_cache(self, *args, **kwargs) -> None:
self.worker.initialize_cache(*args, **kwargs)

def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
return self.worker.execute_model(*args, **kwargs)

@torch.inference_mode()
def sampler_output(
self,
Expand Down Expand Up @@ -95,7 +109,7 @@ def sampler_output(
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
Expand Down
46 changes: 0 additions & 46 deletions vllm/spec_decode/selector.py

This file was deleted.

20 changes: 12 additions & 8 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
elif current_platform.is_cpu():
from vllm.spec_decode.cpu_draft_model_runner import (CPUTP1DraftModelRunner
as
TP1DraftModelRunner)

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
Expand All @@ -44,7 +40,8 @@
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase, WorkerWrapperBase
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
WorkerWrapperBase)

logger = init_logger(__name__)

Expand All @@ -60,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs = kwargs.copy()

kwargs["model_runner_cls"] = TargetModelRunner
target_worker = WorkerWrapperBase(*args, **kwargs)
target_worker_config = copy.deepcopy(vllm_config)
target_worker_config.parallel_config.worker_cls =\
target_worker_config.parallel_config.actual_worker_cls
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
target_worker.init_worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
Expand All @@ -72,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config.model_config,
vllm_config.load_config,
)
speculative_config.draft_parallel_config.worker_cls =\
draft_worker_config.parallel_config.actual_worker_cls
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
# TODO allow draft-model specific load config.

Expand Down Expand Up @@ -167,8 +170,9 @@ def create_worker(
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
Expand Down
27 changes: 9 additions & 18 deletions vllm/spec_decode/target_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import List, Optional

from vllm.config import VllmConfig
from vllm.sequence import SequenceGroupMetadata
from vllm.spec_decode.selector import ModelInputCls, ModelRunnerCls
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)


class TargetModelRunner(ModelRunnerCls):
class TargetModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
Expand All @@ -17,30 +18,20 @@ class TargetModelRunner(ModelRunnerCls):
requested or not.
"""

def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
):
def __init__(self, model_runner: ModelRunnerBase):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
super().__init__(model_runner)
self.disable_logprobs = True
super().__init__(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
return_hidden_states=return_hidden_states,
)

def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputCls:
model_input: ModelInputCls = super().prepare_model_input(
) -> ModelRunnerInputBase:
model_input: ModelRunnerInputBase =\
self.model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
Expand Down
15 changes: 15 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,18 @@ def get_generators(self, finished_request_ids: Optional[List[str]] = None):
self.generators.pop(request_id, None)

return self.generators


class ModelRunnerWrapperBase:
"""
The whole point of this class is to lazily initialize the model_runner.
"""

def __init__(
self,
moderl_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = moderl_runner

def __getattr__(self, attr):
return getattr(self.model_runner, attr)
Loading

0 comments on commit 9b49363

Please sign in to comment.