Skip to content

Commit

Permalink
[serve] add get request metadata func (ray-project#49571)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Create request metadata for a new request through new function
`get_request_metadata`.

---------

Signed-off-by: Cindy Zhang <[email protected]>
Signed-off-by: lielin.hyl <[email protected]>
  • Loading branch information
zcin authored and HYLcool committed Jan 13, 2025
1 parent 912e11b commit 7546b34
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
2 changes: 2 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,8 @@ class RequestMetadata:
# Serve's gRPC context associated with this request for getting and setting metadata
grpc_context: Optional[RayServegRPCContext] = None

_by_reference: bool = True

@property
def is_http_request(self) -> bool:
return self._request_protocol == RequestProtocol.HTTP
Expand Down
31 changes: 31 additions & 0 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
DeploymentHandleSource,
DeploymentID,
EndpointInfo,
RequestMetadata,
RequestProtocol,
)
from ray.serve._private.constants import (
RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE,
Expand All @@ -30,6 +32,7 @@
)
from ray.serve._private.router import Router, SingletonThreadRouter
from ray.serve._private.utils import (
generate_request_id,
get_current_actor_id,
get_head_node_id,
inside_ray_client_context,
Expand Down Expand Up @@ -91,6 +94,34 @@ def create_init_handle_options(**kwargs):
return InitHandleOptions.create(**kwargs)


def get_request_metadata(init_options, handle_options):
_request_context = ray.serve.context._serve_request_context.get()

request_protocol = RequestProtocol.UNDEFINED
if init_options and init_options._source == DeploymentHandleSource.PROXY:
if _request_context.is_http_request:
request_protocol = RequestProtocol.HTTP
elif _request_context.grpc_context:
request_protocol = RequestProtocol.GRPC

return RequestMetadata(
request_id=_request_context.request_id
if _request_context.request_id
else generate_request_id(),
internal_request_id=_request_context._internal_request_id
if _request_context._internal_request_id
else generate_request_id(),
call_method=handle_options.method_name,
route=_request_context.route,
app_name=_request_context.app_name,
multiplexed_model_id=handle_options.multiplexed_model_id,
is_streaming=handle_options.stream,
_request_protocol=request_protocol,
grpc_context=_request_context.grpc_context,
_by_reference=True,
)


def _get_node_id_and_az() -> Tuple[str, Optional[str]]:
node_id = ray.get_runtime_context().get_node_id()
try:
Expand Down
31 changes: 3 additions & 28 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Tuple, Union

import ray
from ray import serve
from ray._raylet import ObjectRefGenerator
from ray.serve._private.common import (
DeploymentHandleSource,
DeploymentID,
RequestMetadata,
RequestProtocol,
)
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.default_impl import (
Expand All @@ -30,7 +30,6 @@
from ray.serve._private.utils import (
DEFAULT,
calculate_remaining_timeout,
generate_request_id,
get_random_string,
inside_ray_client_context,
is_running_in_asyncio_loop,
Expand Down Expand Up @@ -726,32 +725,8 @@ def remote(
**kwargs: Keyword arguments to be serialized and passed to the
remote method call.
"""
_request_context = ray.serve.context._serve_request_context.get()

request_protocol = RequestProtocol.UNDEFINED
if (
self.init_options
and self.init_options._source == DeploymentHandleSource.PROXY
):
if _request_context.is_http_request:
request_protocol = RequestProtocol.HTTP
elif _request_context.grpc_context:
request_protocol = RequestProtocol.GRPC

request_metadata = RequestMetadata(
request_id=_request_context.request_id
if _request_context.request_id
else generate_request_id(),
internal_request_id=_request_context._internal_request_id
if _request_context._internal_request_id
else generate_request_id(),
call_method=self.handle_options.method_name,
route=_request_context.route,
app_name=self.app_name,
multiplexed_model_id=self.handle_options.multiplexed_model_id,
is_streaming=self.handle_options.stream,
_request_protocol=request_protocol,
grpc_context=_request_context.grpc_context,
request_metadata = serve._private.default_impl.get_request_metadata(
self.init_options, self.handle_options
)

future = self._remote(request_metadata, args, kwargs)
Expand Down

0 comments on commit 7546b34

Please sign in to comment.