From d70d0bbeecf103d6ad28eb7d3234be10761fb463 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 24 Feb 2025 15:57:56 +0000 Subject: [PATCH] get rid of req_id_to_index Dict Signed-off-by: Bill Nell --- vllm/v1/core/scheduler.py | 29 ++++++++++++----------------- vllm/v1/engine/__init__.py | 6 +++--- vllm/v1/engine/output_processor.py | 15 ++++++--------- vllm/v1/serial_utils.py | 23 +++++++++++++++++++++++ 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e3b1f7ff7acf8..0b1f51c9178ae 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -17,6 +17,7 @@ from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID import cProfile import pyinstrument @@ -101,7 +102,7 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) - self.profiler = cProfile.Profile() + #self.profiler = cProfile.Profile() def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: @@ -477,7 +478,7 @@ def _try_schedule_encoder_inputs( encoder_inputs_to_schedule.append(i) return encoder_inputs_to_schedule, num_new_tokens, encoder_budget - def _update_from_output( + def Xupdate_from_output( self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", @@ -517,9 +518,6 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - # MASK - - #print(f"TY = {type(generated_token_ids)} {generated_token_ids} {sampled_token_ids}") if not isinstance(generated_token_ids, np.ndarray): generated_token_ids = [generated_token_ids] @@ -541,11 +539,9 @@ def update_from_output( scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens[req_id]) - num_generated_token_ids = len(generated_token_ids) - num_computed_tokens_step = num_scheduled_tokens[req_id] - ( len(scheduled_spec_token_ids) + 1 - - num_generated_token_ids) + len(generated_token_ids)) request.num_computed_tokens += num_computed_tokens_step cached_encoder_input_ids = ( @@ -567,7 +563,6 @@ def update_from_output( stopped = False new_logprobs = None - new_token_ids = np.empty(0, dtype=int) num_new_tokens = 0 if request.num_computed_tokens >= request.num_tokens: @@ -575,7 +570,8 @@ def update_from_output( #print(f"G = {generated_token_ids}") for output_token_id in generated_token_ids: output_token_id = int(output_token_id) - #print(f"{output_token_id}, {type(output_token_id)}") + if output_token_id == INVALID_TOKEN_ID: + continue request.append_output_token_ids(output_token_id) num_new_tokens = num_new_tokens + 1 @@ -594,11 +590,14 @@ def update_from_output( new_logprobs = logprobs.slice(req_index, req_index + 1) # Transmit partial if chunked prefill & prompt logprobs is enabled - if new_token_ids or req_id in prompt_logprobs_dict: + if num_new_tokens > 0 or req_id in prompt_logprobs_dict: # Update EngineCoreOutputs for this Request. output.request_ids.append(req_id) - if (num_new_tokens > 1 or + # TODO: try to eliminate this if all the offsets are adjacent? + output.new_token_id_offsets.append(model_runner_output.req_id_to_index[req_id]) + + if (num_new_tokens != 1 or output.new_token_id_counts is not None): if output.new_token_id_counts is None: output.new_token_id_counts = [1] * i @@ -621,12 +620,8 @@ def update_from_output( if not stopped: new_running.append(request) - output.req_id_to_index = model_runner_output.req_id_to_index - output.new_token_ids = sampled_token_ids - - #print(f"NEW TOK IDS {output.new_token_ids}") - self.running = new_running + output.new_token_ids = sampled_token_ids output.new_prompt_logprobs_tensors = prompt_logprobs_dict output.scheduler_stats = self.make_stats() return output diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 5961868b1dad2..b73bc5785af44 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -128,9 +128,9 @@ class EngineCoreOutputs( # [num_reqs] request_ids: List[str] = [] - req_id_to_index : Dict[str, int] = {} - new_token_id_counts: Optional[np.ndarray] = None - new_token_ids: np.ndarray = np.empty(0, dtype=int) # List[int] + new_token_id_offsets : List[int] = [] + new_token_id_counts: Optional[List[int]] = None # ndarray? + new_token_ids: np.ndarray = np.empty(0, dtype=int) # Optional? # req_id -> LogprobsLists new_logprobs: Dict[str, LogprobsLists] = {} diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index c79384a7a5fcf..fdb73bcdbe8d4 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -153,25 +153,22 @@ def process_outputs( request_outputs: List[RequestOutput] = [] reqs_to_abort: List[str] = [] - req_id_to_index = engine_core_outputs.req_id_to_index + new_token_id_offsets = engine_core_outputs.new_token_id_offsets new_token_id_counts = engine_core_outputs.new_token_id_counts + # TODO: for i = first:last for i, req_id in enumerate( engine_core_outputs.request_ids[first:last]): - req_idx = i + first req_state = self.request_states.get(req_id) if req_state is None: # Ignore output for already-aborted request. continue - start = req_id_to_index[req_id] - - if new_token_id_counts is not None: - end = start + new_token_id_counts[req_idx] - else: - end = start + 1 + req_idx = i + first + start = new_token_id_offsets[req_idx] + num_tokens = new_token_id_counts[req_idx] if new_token_id_counts is not None else 1 + end = start + num_tokens - num_tokens = end - start events = engine_core_outputs.events[ req_idx] if engine_core_outputs.events is not None else None diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3f000abcde0d1..437f6111c0d99 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -5,9 +5,12 @@ import torch from msgspec import msgpack +import msgspec +import numpy as np CUSTOM_TYPE_TENSOR = 1 CUSTOM_TYPE_PICKLE = 2 +CUSTOM_TYPE_NDARRAY = 3 class MsgpackEncoder: @@ -34,6 +37,15 @@ def decode(self, obj: Any): return self.decoder.decode(obj) +class NumpySerializedRepresentation(msgspec.Struct, gc=False, array_like=True): + dtype:str + shape:tuple + data:bytes + +numpy_array_encoder = msgspec.msgpack.Encoder() +numpy_array_decoder = msgspec.msgpack.Decoder(type=NumpySerializedRepresentation) + + def custom_enc_hook(obj: Any) -> Any: if isinstance(obj, torch.Tensor): # NOTE(rob): it is fastest to use numpy + pickle @@ -41,6 +53,13 @@ def custom_enc_hook(obj: Any) -> Any: # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) + if isinstance(obj, np.ndarray): + return msgspec.msgpack.Ext(CUSTOM_TYPE_NDARRAY, + numpy_array_encoder.encode(NumpySerializedRepresentation( + dtype=obj.dtype.str, + shape=obj.shape, + data=obj.data))) + return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) @@ -49,5 +68,9 @@ def custom_ext_hook(code: int, data: memoryview) -> Any: return torch.from_numpy(pickle.loads(data)) if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) + if code == CUSTOM_TYPE_NDARRAY: + serialized_array_rep = numpy_array_decoder.decode(data) + return np.frombuffer(serialized_array_rep.data, dtype=serialized_array_rep.dtype).reshape( + serialized_array_rep.shape) raise NotImplementedError(f"Extension type code {code} is not supported")