Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] SequenceController in SamplingParams #4775

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ def allocate(self,
device=device)
self._num_full_slots = len(token_ids)

def backtrack(self, num_slots: int) -> None:
"""Remove the specified number of slots from the end of the table.

Args:
num_slots (int): The number of slots to backtrack by.
"""
assert self._is_allocated
assert num_slots <= self._num_full_slots
if num_slots == 0:
return
self._num_full_slots -= num_slots
blocks = self._blocks[self._num_full_slots // self._block_size:]
blocks[0].trim(self._num_full_slots % self._block_size)
for b in blocks[1:]:
b.trim(0)

def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0) -> None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def is_full(self) -> bool:
def prev_block(self) -> Optional["Block"]:
pass

@abstractmethod
def trim(self, num_tokens: int):
pass

@property
@abstractmethod
def computed(self) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids)

def trim(self, num_tokens: int):
del self._token_ids[num_tokens:]

@property
def computed(self) -> bool:
raise NotImplementedError
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ def __init__(
_cow_target=self,
)

def trim(self, num_tokens: int):
return self._block.trim(num_tokens)

@property
def computed(self) -> bool:
return self._computed
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,13 @@ def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int = 0,
backtrack: int = 0,
) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
assert backtrack == 0, \
"Backtrack not supported; consider --use-v2-block-manager"
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
# Currently this code only supports adding one physical block
Expand Down
2 changes: 2 additions & 0 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
backtrack: int = 0,
) -> List[Tuple[int, int]]:

block_table = self.block_tables[seq.seq_id]

block_table.backtrack(backtrack)
block_table.append_token_ids(
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
Expand Down
1 change: 1 addition & 0 deletions vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
backtrack: int = 0,
) -> List[Tuple[int, int]]:
return None # type: ignore

Expand Down
1 change: 1 addition & 0 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
backtrack: int = 0,
) -> List[Tuple[int, int]]:
pass

Expand Down
14 changes: 13 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.sequence_controller import SequenceController

logger = init_logger(__name__)

Expand Down Expand Up @@ -939,8 +940,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}

ctrl = seq_group.sampling_params.controller
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
if ctrl:
ctrl.scheduled(seq)
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
Expand Down Expand Up @@ -987,6 +991,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
)
seq_group_metadata_list.append(seq_group_metadata)

if not scheduler_outputs.is_empty():
SequenceController.forward_started()

# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
Expand All @@ -1002,6 +1009,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:

def free_seq(self, seq: Sequence) -> None:
"""Free a sequence from a block table."""
if seq.controller:
seq.controller.free(seq)
self.block_manager.free(seq)

def free_finished_seq_groups(self) -> None:
Expand Down Expand Up @@ -1032,7 +1041,10 @@ def _append_slots(
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
cows = self.block_manager.append_slots(seq,
num_lookahead_slots,
backtrack=seq.backtrack)
seq.backtrack = 0
blocks_to_copy.extend(cows)

def _preempt(
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,10 @@ def _create_sequence_group_with_sampling(
f"{max_logprobs} logprobs.")

# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
# this doesn't deep-copy LogitsProcessor or SequenceController objects
sampling_params = sampling_params.clone()
# Link controller to sequence.
seq.controller = sampling_params.controller
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
Expand Down
20 changes: 18 additions & 2 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Set, Tuple, Union

from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
Expand Down Expand Up @@ -85,6 +85,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []

to_stop: Set[int] = set()

# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
Expand All @@ -108,9 +110,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
child_seqs.append((parent, parent))
ctrl = seq_group.sampling_params.controller
if ctrl:
sid = parent.seq_id
sampled_token = last_child_sample.output_token
backtrack, ff_tokens, should_stop = ctrl.sampled(
parent, sampled_token, last_child_sample.logprobs)
if should_stop:
to_stop.add(sid)
if backtrack != 0 or ff_tokens != [sampled_token]:
parent.splice_tokens(backtrack, ff_tokens)
continue # don't call append_token_id()
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize and self.detokenizer:
Expand All @@ -120,6 +133,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
seq_group.sampling_params)
if seq.seq_id in to_stop:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = "<SequenceController>"

# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
Expand Down
19 changes: 15 additions & 4 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pydantic import Field
from typing_extensions import Annotated

from .sequence_controller import SequenceController

_SAMPLING_EPS = 1e-5


Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.controller: Optional[SequenceController] = None
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
Expand Down Expand Up @@ -277,6 +280,9 @@ def _verify_greedy_sampling(self) -> None:
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# If present, we want the controller to control stopping.
if self.controller:
return
# Update eos_token_id for generation
if (not self.ignore_eos) and (eos_ids :=
generation_config.get("eos_token_id")):
Expand Down Expand Up @@ -305,10 +311,15 @@ def clone(self) -> "SamplingParams":
See https://github.com/vllm-project/vllm/issues/3087
"""

logit_processor_refs = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
logit_processor_refs: Optional[
dict] = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
if self.controller:
if logit_processor_refs is None:
logit_processor_refs = {}
logit_processor_refs[id(self.controller)] = self.controller
return copy.deepcopy(self, memo=logit_processor_refs)

def __repr__(self) -> str:
Expand Down
61 changes: 59 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence_controller import SequenceController

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -223,6 +224,8 @@ def __init__(
self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
self.backtrack = 0
self.controller: Optional[SequenceController] = None

self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
Expand Down Expand Up @@ -296,6 +299,58 @@ def append_token_id(
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)

def splice_tokens(self, backtrack: int, token_ids: List[int]):
assert self.backtrack == 0

data = self.data

if not token_ids:
# we need at least one token in forward step,
# so we pretend we're backtracking one token more
# and repeat the token that was there
# otherwise, the _num_comptued_tokens gets out of sync
backtrack += 1
if backtrack <= len(data.output_token_ids):
token_ids = [data.output_token_ids[-backtrack]]
else:
off = backtrack - len(data.output_token_ids)
if off <= len(data.prompt_token_ids):
token_ids = [data.prompt_token_ids[-off]]
else:
token_ids = [1]

backtrack = min(backtrack, data.get_len())
self.backtrack = backtrack

if backtrack > 0:
prompt_backtrack = 0
output_len = len(data.output_token_ids)
if backtrack > output_len:
prompt_backtrack = backtrack - output_len
backtrack = output_len
del data.output_token_ids[-backtrack:]
del self.output_logprobs[-backtrack:]
data._num_computed_tokens = min(data._num_computed_tokens,
len(data.output_token_ids))
if prompt_backtrack > 0:
assert not data.output_token_ids
del data.prompt_token_ids[-prompt_backtrack:]
needed_blocks = \
(self.get_len() + self.block_size - 1) // self.block_size
if len(self.logical_token_blocks) > needed_blocks:
del self.logical_token_blocks[needed_blocks:]
if needed_blocks > 0:
last_block = self.logical_token_blocks[-1]
last_num_tokens = self.get_len() % self.block_size
if last_num_tokens == 0:
last_num_tokens = self.block_size
last_block.num_tokens = last_num_tokens

for t in token_ids:
self.append_token_id(t, {t: Logprob(logprob=0.0)})
if data.get_num_uncomputed_tokens() > 1:
data._stage = SequenceStage.PREFILL

def get_len(self) -> int:
return self.data.get_len()

Expand Down Expand Up @@ -454,8 +509,10 @@ def lora_int_id(self) -> int:

def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
# If still in prefill phase, raise Error (unless using controllers,
# where the request may go from decode to prefill).
if self.is_prefill() and not (self.sampling_params
and self.sampling_params.controller):
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")
Expand Down
44 changes: 44 additions & 0 deletions vllm/sequence_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Dict, List, Tuple, TYPE_CHECKING

if TYPE_CHECKING:
from .sequence import Logprob, Sequence


class SequenceController:
"""Callback for generation control for a single sequence group.

This can be part of SamplingParams and gets callbacks for various
steps. It is to be used together with LogitsProcessor.
"""

def scheduled(self, seq: 'Sequence'):
"""
Called whenever the current sequence is scheduled to be run
in the next step.
"""
pass

@staticmethod
def forward_started():
"""
Called when all sequences for the current step have been queued.
"""
pass

def sampled(self, seq: 'Sequence', token_id: int,
logprobs: Dict[int, 'Logprob']) -> Tuple[int, List[int], bool]:
"""
Informs the controller a given token has been sampled.
Returns the number of tokens to backtrack, the tokens to append,
and whether to stop.
"""
if token_id == seq.eos_token_id:
return 0, [], True
return 0, [token_id], False

def free(self, seq: 'Sequence'):
"""
Called when the sequence is stopped, and deallocated.
.scheduled() will not be called again for this sequence.
"""
pass