Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Refactor replica wrapper #49806

Merged
merged 10 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
)
from ray.serve._private.grpc_util import gRPCServer
from ray.serve._private.handle_options import DynamicHandleOptions, InitHandleOptions
from ray.serve._private.replica_scheduler import (
ActorReplicaWrapper,
PowerOfTwoChoicesReplicaScheduler,
)
from ray.serve._private.replica_scheduler import PowerOfTwoChoicesReplicaScheduler
from ray.serve._private.replica_scheduler.replica_wrapper import RunningReplica
from ray.serve._private.router import Router, SingletonThreadRouter
from ray.serve._private.utils import (
generate_request_id,
Expand Down Expand Up @@ -168,7 +166,7 @@ def create_router(
use_replica_queue_len_cache=(
not is_inside_ray_client_context and RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE
),
create_replica_wrapper_func=lambda r: ActorReplicaWrapper(r),
create_replica_wrapper_func=lambda r: RunningReplica(r),
)

return SingletonThreadRouter(
Expand Down
12 changes: 7 additions & 5 deletions python/ray/serve/_private/handle_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ class DynamicHandleOptionsBase(ABC):
multiplexed_model_id: str = ""
stream: bool = False

@abstractmethod
def copy_and_update(self, **kwargs) -> "DynamicHandleOptionsBase":
pass


@dataclass(frozen=True)
class DynamicHandleOptions(DynamicHandleOptionsBase):
def copy_and_update(self, **kwargs) -> "DynamicHandleOptions":
new_kwargs = {}

for f in fields(self):
Expand All @@ -63,8 +70,3 @@ def copy_and_update(self, **kwargs) -> "DynamicHandleOptionsBase":
new_kwargs[f.name] = kwargs[f.name]

return DynamicHandleOptions(**new_kwargs)


@dataclass(frozen=True)
class DynamicHandleOptions(DynamicHandleOptionsBase):
pass
8 changes: 4 additions & 4 deletions python/ray/serve/_private/replica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, Coroutine, Optional, Union

import ray
from ray.serve._private.common import RequestMetadata
from ray.serve._private.utils import calculate_remaining_timeout
from ray.serve.exceptions import RequestCancelledError

Expand Down Expand Up @@ -54,13 +55,12 @@ class ActorReplicaResult(ReplicaResult):
def __init__(
self,
obj_ref_or_gen: Union[ray.ObjectRef, ray.ObjectRefGenerator],
is_streaming: bool,
request_id: str,
metadata: RequestMetadata,
):
self._obj_ref: Optional[ray.ObjectRef] = None
self._obj_ref_gen: Optional[ray.ObjectRefGenerator] = None
self._is_streaming: bool = is_streaming
self._request_id: str = request_id
self._is_streaming: bool = metadata.is_streaming
self._request_id: str = metadata.request_id
self._object_ref_or_gen_sync_lock = threading.Lock()

if isinstance(obj_ref_or_gen, ray.ObjectRefGenerator):
Expand Down
28 changes: 14 additions & 14 deletions python/ray/serve/_private/replica_scheduler/pow_2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ReplicaQueueLengthCache,
)
from ray.serve._private.replica_scheduler.replica_scheduler import ReplicaScheduler
from ray.serve._private.replica_scheduler.replica_wrapper import ReplicaWrapper
from ray.serve._private.replica_scheduler.replica_wrapper import RunningReplica
from ray.util import metrics

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
use_replica_queue_len_cache: bool = False,
get_curr_time_s: Optional[Callable[[], float]] = None,
create_replica_wrapper_func: Optional[
Callable[[RunningReplicaInfo], ReplicaWrapper]
Callable[[RunningReplicaInfo], RunningReplica]
] = None,
):
self._deployment_id = deployment_id
Expand All @@ -117,7 +117,7 @@ def __init__(
# Current replicas available to be scheduled.
# Updated via `update_replicas`.
self._replica_id_set: Set[ReplicaID] = set()
self._replicas: Dict[ReplicaID, ReplicaWrapper] = {}
self._replicas: Dict[ReplicaID, RunningReplica] = {}
self._replica_queue_len_cache = ReplicaQueueLengthCache(
get_curr_time_s=get_curr_time_s,
)
Expand Down Expand Up @@ -236,7 +236,7 @@ def target_num_scheduling_tasks(self) -> int:
return min(self.num_pending_requests, self.max_num_scheduling_tasks)

@property
def curr_replicas(self) -> Dict[ReplicaID, ReplicaWrapper]:
def curr_replicas(self) -> Dict[ReplicaID, RunningReplica]:
return self._replicas

@property
Expand All @@ -249,7 +249,7 @@ def replica_queue_len_cache(self) -> ReplicaQueueLengthCache:

def create_replica_wrapper(
self, replica_info: RunningReplicaInfo
) -> ReplicaWrapper:
) -> RunningReplica:
return self._create_replica_wrapper_func(replica_info)

def on_replica_actor_died(self, replica_id: ReplicaID):
Expand All @@ -272,7 +272,7 @@ def on_new_queue_len_info(
replica_id, queue_len_info.num_ongoing_requests
)

def update_replicas(self, replicas: List[ReplicaWrapper]):
def update_replicas(self, replicas: List[RunningReplica]):
"""Update the set of available replicas to be considered for scheduling.

When the set of replicas changes, we may spawn additional scheduling tasks
Expand Down Expand Up @@ -515,9 +515,9 @@ async def choose_two_replicas_with_backoff(

async def _probe_queue_lens(
self,
replicas: List[ReplicaWrapper],
replicas: List[RunningReplica],
backoff_index: int,
) -> List[Tuple[ReplicaWrapper, Optional[int]]]:
) -> List[Tuple[RunningReplica, Optional[int]]]:
"""Actively probe the queue length from each of the replicas.

Sends an RPC to each replica to fetch its queue length, with a response deadline
Expand All @@ -531,7 +531,7 @@ async def _probe_queue_lens(
This method also updates the local cache of replica queue lengths according to
the responses.
"""
result: List[Tuple[ReplicaWrapper, int]] = []
result: List[Tuple[RunningReplica, int]] = []
if len(replicas) == 0:
return result

Expand Down Expand Up @@ -618,9 +618,9 @@ async def _probe_queue_lens(

async def select_from_candidate_replicas(
self,
candidates: List[ReplicaWrapper],
candidates: List[RunningReplica],
backoff_index: int,
) -> Optional[ReplicaWrapper]:
) -> Optional[RunningReplica]:
"""Chooses the best replica from the list of candidates.

If none of the replicas can be scheduled, returns `None`.
Expand All @@ -633,7 +633,7 @@ async def select_from_candidate_replicas(
"""
lowest_queue_len = math.inf
chosen_replica_id: Optional[str] = None
not_in_cache: List[ReplicaWrapper] = []
not_in_cache: List[RunningReplica] = []
if self._use_replica_queue_len_cache:
# Populate available queue lens from the cache.
for r in candidates:
Expand Down Expand Up @@ -693,7 +693,7 @@ def _get_pending_request_matching_metadata(

def fulfill_next_pending_request(
self,
replica: ReplicaWrapper,
replica: RunningReplica,
request_metadata: Optional[RequestMetadata] = None,
):
"""Assign the replica to the next pending request in FIFO order.
Expand Down Expand Up @@ -812,7 +812,7 @@ def maybe_start_scheduling_tasks(self):

async def choose_replica_for_request(
self, pending_request: PendingRequest, *, is_retry: bool = False
) -> ReplicaWrapper:
) -> RunningReplica:
"""Chooses a replica to send the provided request to.

By default, requests are scheduled in FIFO order, so this places a future on the
Expand Down
10 changes: 5 additions & 5 deletions python/ray/serve/_private/replica_scheduler/replica_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
PendingRequest,
ReplicaQueueLengthCache,
)
from ray.serve._private.replica_scheduler.replica_wrapper import ReplicaWrapper
from ray.serve._private.replica_scheduler.replica_wrapper import RunningReplica


class ReplicaScheduler(ABC):
Expand All @@ -15,17 +15,17 @@ class ReplicaScheduler(ABC):
@abstractmethod
async def choose_replica_for_request(
self, pending_request: PendingRequest, *, is_retry: bool = False
) -> ReplicaWrapper:
) -> RunningReplica:
pass

@abstractmethod
def create_replica_wrapper(
self, replica_info: RunningReplicaInfo
) -> ReplicaWrapper:
) -> RunningReplica:
pass

@abstractmethod
def update_replicas(self, replicas: List[ReplicaWrapper]):
def update_replicas(self, replicas: List[RunningReplica]):
pass

def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
Expand All @@ -49,5 +49,5 @@ def replica_queue_len_cache(self) -> ReplicaQueueLengthCache:

@property
@abstractmethod
def curr_replicas(self) -> Dict[ReplicaID, ReplicaWrapper]:
def curr_replicas(self) -> Dict[ReplicaID, RunningReplica]:
pass
Loading
Loading