-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] clean up executor class hierarchy between v1 and v0 #12171
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,75 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Type | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.executor.executor_base import ExecutorBase | ||
from vllm.executor.ray_distributed_executor import RayDistributedExecutor | ||
from vllm.executor.uniproc_executor import (ExecutorWithExternalLauncher, | ||
UniProcExecutor) | ||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec | ||
from vllm.v1.outputs import ModelRunnerOutput | ||
|
||
|
||
class Executor(ABC): | ||
"""Abstract class for executors.""" | ||
class Executor(ExecutorBase): | ||
"""Abstract class for v1 executors, mainly define some new methods.""" | ||
|
||
@staticmethod | ||
def get_class(vllm_config: VllmConfig) -> Type["Executor"]: | ||
executor_class: Type[Executor] | ||
distributed_executor_backend = ( | ||
vllm_config.parallel_config.distributed_executor_backend) | ||
if distributed_executor_backend == "ray": | ||
from vllm.executor.ray_distributed_executor import ( # noqa | ||
RayDistributedExecutor) | ||
executor_class = RayDistributedExecutor | ||
executor_class = RayDistributedExecutorV1 | ||
elif distributed_executor_backend == "mp": | ||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor | ||
executor_class = MultiprocExecutor | ||
else: | ||
assert (distributed_executor_backend is None) | ||
from vllm.v1.executor.uniproc_executor import UniprocExecutor | ||
executor_class = UniprocExecutor | ||
elif distributed_executor_backend == "uni": | ||
executor_class = UniprocExecutorV1 | ||
elif distributed_executor_backend == "external_launcher": | ||
# TODO: make v1 scheduling deterministic | ||
# to support external launcher | ||
executor_class = ExecutorWithExternalLauncherV1 | ||
return executor_class | ||
|
||
@abstractmethod | ||
def __init__(self, vllm_config: VllmConfig) -> None: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def initialize(self, kv_cache_config: KVCacheConfig) -> None: | ||
raise NotImplementedError | ||
""" | ||
Initialize the KV caches and begin the model execution loop of the | ||
underlying workers. | ||
""" | ||
self.collective_rpc("initialize_cache", args=(kv_cache_config, )) | ||
self.collective_rpc("compile_or_warm_up_model") | ||
|
||
@abstractmethod | ||
def determine_available_memory(self) -> int: # in bytes | ||
raise NotImplementedError | ||
output = self.collective_rpc("determine_available_memory") | ||
# Since we use a shared centralized controller, we take the minimum | ||
# memory size across all workers to make sure all the memory | ||
# operators can be applied to all workers. | ||
return min(output) | ||
|
||
@abstractmethod | ||
def get_kv_cache_spec(self) -> KVCacheSpec: | ||
raise NotImplementedError | ||
output = self.collective_rpc("get_kv_cache_spec") | ||
for x in output: | ||
assert x == output[0] | ||
return output[0] | ||
|
||
@abstractmethod | ||
def execute_model( | ||
self, | ||
scheduler_output, | ||
) -> ModelRunnerOutput: | ||
raise NotImplementedError | ||
output = self.collective_rpc("execute_model", | ||
args=(scheduler_output, )) | ||
return output[0] | ||
|
||
@abstractmethod | ||
def profile(self, is_start: bool = True): | ||
raise NotImplementedError | ||
self.collective_rpc("profile", args=(is_start, )) | ||
|
||
|
||
class UniprocExecutorV1(UniProcExecutor, Executor): | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that we don't have xxxV1 class in v1 code now. Will this be better? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good, added in 21a4b0e |
||
|
||
|
||
class ExecutorWithExternalLauncherV1(ExecutorWithExternalLauncher, Executor): | ||
pass | ||
|
||
@abstractmethod | ||
def shutdown(self): | ||
pass | ||
|
||
@abstractmethod | ||
def check_health(self) -> None: | ||
raise NotImplementedError | ||
class RayDistributedExecutorV1(RayDistributedExecutor, Executor): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more clear comment: "Abstract class for v1 executors, mainly define some methods for v1. Define methods shared by v0 and v1 in ExecutorBase" And how to add an interface to only v0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in dc2a886
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't introduce new features that only work for v0.