Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny code cleanup in tokenizer_manager.py #2586

Merged
merged 9 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 47 additions & 48 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import time
import uuid
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union, Generic, TypeVar

import fastapi
import uvloop
Expand Down Expand Up @@ -173,6 +173,9 @@ def __init__(

# Others
self.gracefully_exit = False
self.init_weights_update_group_communicator = _Communicator(self.send_to_scheduler, server_args.dp_size)
self.update_weights_from_distributed_communicator = _Communicator(self.send_to_scheduler, server_args.dp_size)
self.get_weights_by_name_communicator = _Communicator(self.send_to_scheduler, server_args.dp_size)

# Metrics
if self.enable_metrics:
Expand All @@ -190,8 +193,7 @@ async def generate_request(
):
created_time = time.time()

if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()

if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
Expand Down Expand Up @@ -440,8 +442,7 @@ async def update_weights_from_disk(
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()

# default the load format to the server_args
if obj.load_format is None:
Expand All @@ -456,7 +457,7 @@ async def update_weights_from_disk(

async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str, int]:
) -> Tuple[bool, str]:
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
Expand Down Expand Up @@ -485,60 +486,40 @@ async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)

self.init_weights_update_group_result = asyncio.Future()
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = await self.init_weights_update_group_result
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message

async def update_weights_from_distributed(
self,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"

# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
self.send_to_scheduler.send_pyobj(obj)
self.parameter_update_result: Awaitable[
UpdateWeightsFromDistributedReqOutput
] = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message

async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()

self.send_to_scheduler.send_pyobj(obj)
self.get_weights_by_name_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.get_weights_by_name_result
return result.parameter
else:
self.get_weights_by_name_tmp = []
result = await self.get_weights_by_name_result
all_parameters = [r.parameter for r in result]
return all_parameters
self.auto_create_handle_loop()
results = await self.get_weights_by_name_communicator(obj)
return [r.parameter for r in results]

async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()

session_id = uuid.uuid4().hex
obj.session_id = session_id
Expand Down Expand Up @@ -568,7 +549,7 @@ async def abort_request():
background_tasks.add_task(abort_request)
return background_tasks

def create_handle_loop(self):
def auto_create_handle_loop(self):
if not self.to_create_loop:
return

Expand Down Expand Up @@ -711,21 +692,14 @@ async def handle_loop(self):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
self.init_weights_update_group_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
else:
self.get_weights_by_name_tmp.append(recv_obj)
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
self.get_weights_by_name_communicator.handle_recv(recv_obj)
else:
raise ValueError(f"Invalid object: {recv_obj=}")

Expand Down Expand Up @@ -809,3 +783,28 @@ def signal_handler(self, signum=None, frame=None):
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True


T = TypeVar('T')


class _Communicator(Generic[T]):
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_future: Optional[asyncio.Future] = None
self._result_values: Optional[List[T]] = None

async def __call__(self, obj):
self._sender.send_pyobj(obj)
self._result_future = asyncio.Future()
self._result_values = []
await self._result_future
result_values = self._result_values
self._result_future = self._result_values = None
return result_values

def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_future.set_result(None)
53 changes: 16 additions & 37 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None:
return ORJSONResponse(
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
return _create_error_response("Get parameter by name failed")
else:
return ORJSONResponse(ret, status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)


@app.api_route("/open_session", methods=["GET", "POST"])
Expand All @@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
session_id = await tokenizer_manager.open_session(obj, request)
return session_id
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)


@app.api_route("/close_session", methods=["GET", "POST"])
Expand All @@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
await tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)


# fastapi implicitly converts json in the request to obj (dataclass)
Expand Down Expand Up @@ -312,9 +303,7 @@ async def stream_results() -> AsyncIterator[bytes]:
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)


@app.api_route("/encode", methods=["POST", "PUT"])
Expand All @@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)


@app.api_route("/classify", methods=["POST", "PUT"])
Expand All @@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)


##### OpenAI-compatible API endpoints #####
Expand Down Expand Up @@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id)


def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)


def launch_engine(
server_args: ServerArgs,
):
Expand Down Expand Up @@ -849,12 +840,8 @@ def init_weights_update_group(
group_name=group_name,
backend=backend,
)

async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)

loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group())
return loop.run_until_complete(tokenizer_manager.init_weights_update_group(obj, None))

def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
Expand All @@ -863,22 +850,14 @@ def update_weights_from_distributed(self, name, dtype, shape):
dtype=dtype,
shape=shape,
)

async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)

loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
return loop.run_until_complete(tokenizer_manager.update_weights_from_distributed(obj, None))

def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)

async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)

loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))


class Runtime:
Expand Down
Loading