Skip to content

Commit

Permalink
get rid of req_id_to_index Dict
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Feb 24, 2025
1 parent 7021595 commit d70d0bb
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
29 changes: 12 additions & 17 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID

import cProfile
import pyinstrument
Expand Down Expand Up @@ -101,7 +102,7 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

self.profiler = cProfile.Profile()
#self.profiler = cProfile.Profile()

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down Expand Up @@ -477,7 +478,7 @@ def _try_schedule_encoder_inputs(
encoder_inputs_to_schedule.append(i)
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget

def _update_from_output(
def Xupdate_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
Expand Down Expand Up @@ -517,9 +518,6 @@ def update_from_output(
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]

# MASK

#print(f"TY = {type(generated_token_ids)} {generated_token_ids} {sampled_token_ids}")
if not isinstance(generated_token_ids, np.ndarray):
generated_token_ids = [generated_token_ids]

Expand All @@ -541,11 +539,9 @@ def update_from_output(
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens[req_id])

num_generated_token_ids = len(generated_token_ids)

num_computed_tokens_step = num_scheduled_tokens[req_id] - (
len(scheduled_spec_token_ids) + 1 -
num_generated_token_ids)
len(generated_token_ids))
request.num_computed_tokens += num_computed_tokens_step

cached_encoder_input_ids = (
Expand All @@ -567,15 +563,15 @@ def update_from_output(

stopped = False
new_logprobs = None
new_token_ids = np.empty(0, dtype=int)
num_new_tokens = 0

if request.num_computed_tokens >= request.num_tokens:
# This loop seems inefficient.
#print(f"G = {generated_token_ids}")
for output_token_id in generated_token_ids:
output_token_id = int(output_token_id)
#print(f"{output_token_id}, {type(output_token_id)}")
if output_token_id == INVALID_TOKEN_ID:
continue
request.append_output_token_ids(output_token_id)
num_new_tokens = num_new_tokens + 1

Expand All @@ -594,11 +590,14 @@ def update_from_output(
new_logprobs = logprobs.slice(req_index, req_index + 1)

# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or req_id in prompt_logprobs_dict:
if num_new_tokens > 0 or req_id in prompt_logprobs_dict:
# Update EngineCoreOutputs for this Request.
output.request_ids.append(req_id)

if (num_new_tokens > 1 or
# TODO: try to eliminate this if all the offsets are adjacent?
output.new_token_id_offsets.append(model_runner_output.req_id_to_index[req_id])

if (num_new_tokens != 1 or
output.new_token_id_counts is not None):
if output.new_token_id_counts is None:
output.new_token_id_counts = [1] * i
Expand All @@ -621,12 +620,8 @@ def update_from_output(
if not stopped:
new_running.append(request)

output.req_id_to_index = model_runner_output.req_id_to_index
output.new_token_ids = sampled_token_ids

#print(f"NEW TOK IDS {output.new_token_ids}")

self.running = new_running
output.new_token_ids = sampled_token_ids
output.new_prompt_logprobs_tensors = prompt_logprobs_dict
output.scheduler_stats = self.make_stats()
return output
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ class EngineCoreOutputs(

# [num_reqs]
request_ids: List[str] = []
req_id_to_index : Dict[str, int] = {}
new_token_id_counts: Optional[np.ndarray] = None
new_token_ids: np.ndarray = np.empty(0, dtype=int) # List[int]
new_token_id_offsets : List[int] = []
new_token_id_counts: Optional[List[int]] = None # ndarray?
new_token_ids: np.ndarray = np.empty(0, dtype=int) # Optional?

# req_id -> LogprobsLists
new_logprobs: Dict[str, LogprobsLists] = {}
Expand Down
15 changes: 6 additions & 9 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,22 @@ def process_outputs(

request_outputs: List[RequestOutput] = []
reqs_to_abort: List[str] = []
req_id_to_index = engine_core_outputs.req_id_to_index
new_token_id_offsets = engine_core_outputs.new_token_id_offsets
new_token_id_counts = engine_core_outputs.new_token_id_counts

# TODO: for i = first:last
for i, req_id in enumerate(
engine_core_outputs.request_ids[first:last]):
req_idx = i + first
req_state = self.request_states.get(req_id)
if req_state is None:
# Ignore output for already-aborted request.
continue

start = req_id_to_index[req_id]

if new_token_id_counts is not None:
end = start + new_token_id_counts[req_idx]
else:
end = start + 1
req_idx = i + first
start = new_token_id_offsets[req_idx]
num_tokens = new_token_id_counts[req_idx] if new_token_id_counts is not None else 1
end = start + num_tokens

num_tokens = end - start
events = engine_core_outputs.events[
req_idx] if engine_core_outputs.events is not None else None

Expand Down
23 changes: 23 additions & 0 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import torch
from msgspec import msgpack
import msgspec
import numpy as np

CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_NDARRAY = 3


class MsgpackEncoder:
Expand All @@ -34,13 +37,29 @@ def decode(self, obj: Any):
return self.decoder.decode(obj)


class NumpySerializedRepresentation(msgspec.Struct, gc=False, array_like=True):
dtype:str
shape:tuple
data:bytes

numpy_array_encoder = msgspec.msgpack.Encoder()
numpy_array_decoder = msgspec.msgpack.Decoder(type=NumpySerializedRepresentation)


def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))

if isinstance(obj, np.ndarray):
return msgspec.msgpack.Ext(CUSTOM_TYPE_NDARRAY,
numpy_array_encoder.encode(NumpySerializedRepresentation(
dtype=obj.dtype.str,
shape=obj.shape,
data=obj.data)))

return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))


Expand All @@ -49,5 +68,9 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_NDARRAY:
serialized_array_rep = numpy_array_decoder.decode(data)
return np.frombuffer(serialized_array_rep.data, dtype=serialized_array_rep.dtype).reshape(
serialized_array_rep.shape)

raise NotImplementedError(f"Extension type code {code} is not supported")

0 comments on commit d70d0bb

Please sign in to comment.