From 0cadb73586af794259fc7d4ab569c7a9abf9925c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 9 May 2024 18:22:09 -0700 Subject: [PATCH 1/9] aici rebase --- vllm/core/block/block_table.py | 16 +++ vllm/core/block/interfaces.py | 4 + vllm/core/block/naive_block.py | 3 + vllm/core/block/prefix_caching_block.py | 3 + vllm/core/block_manager_v1.py | 3 + vllm/core/block_manager_v2.py | 16 ++- vllm/core/interfaces.py | 1 + vllm/core/scheduler.py | 25 ++++- vllm/engine/output_processor/single_step.py | 24 ++++- vllm/entrypoints/openai/api_server.py | 61 +++++++++++- vllm/entrypoints/openai/protocol.py | 26 +++++ vllm/entrypoints/openai/serving_aici.py | 102 ++++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/model_executor/layers/sampler.py | 38 ++++++++ vllm/outputs.py | 4 +- vllm/sampling_params.py | 4 + vllm/sequence.py | 54 +++++++++++ 17 files changed, 377 insertions(+), 9 deletions(-) create mode 100644 vllm/entrypoints/openai/serving_aici.py diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index f1b65b2514f76..bdbed63a93c7a 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -85,6 +85,22 @@ def allocate(self, device=device) self._num_full_slots = len(token_ids) + def backtrack(self, num_slots: int) -> None: + """Remove the specified number of slots from the end of the table. + + Args: + num_slots (int): The number of slots to backtrack by. + """ + assert self._is_allocated + assert num_slots <= self._num_full_slots + if num_slots == 0: + return + self._num_full_slots -= num_slots + blocks = self._blocks[self._num_full_slots // self._block_size:] + blocks[0].trim(self._num_full_slots % self._block_size) + for b in blocks[1:]: + b.trim(0) + def append_token_ids(self, token_ids: List[int], num_lookahead_slots: int = 0) -> None: diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 50ce922118124..c28b95b367bc1 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -35,6 +35,10 @@ def is_full(self) -> bool: def prev_block(self) -> Optional["Block"]: pass + @abstractmethod + def trim(self, num_tokens: int): + pass + class Factory(Protocol): @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index f8e9265bb2d67..2c9ef1bbf3c7e 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -247,6 +247,9 @@ def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: assert self.num_empty_slots >= len(token_ids) self._token_ids.extend(token_ids) + def trim(self, num_tokens: int): + del self._token_ids[num_tokens:] + @property def block_id(self) -> Optional[int]: return self._block_id diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 6aa75a8abb80a..94e99329836fe 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -361,6 +361,9 @@ def __init__( _cow_target=self, ) + def trim(self, num_tokens: int): + return self._block.trim(num_tokens) + def append_token_ids(self, token_ids: List[int]) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index be093922b84f2..0c5b8d985eecb 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -375,10 +375,13 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int = 0, + backtrack: int = 0, ) -> Dict[int, List[int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] + assert backtrack == 0, \ + "Backtrack not supported; consider --use-v2-block-manager" # If we need to allocate a new physical block if len(block_table) < len(logical_blocks): # Currently this code only supports adding one physical block diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6339a6baf4161..3fa7da96f93cf 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -168,14 +168,22 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, + backtrack: int = 0, ) -> Dict[int, List[int]]: block_table = self.block_tables[seq.seq_id] - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - ) + block_table.backtrack(backtrack) + token_ids = block_table.get_unseen_token_ids(seq.get_token_ids()) + if seq.has_aici and not token_ids: + # AICI may want to "append" empty tokens, either to just backtrack + # or to force a wait for one step. + assert num_lookahead_slots == 0 + else: + block_table.append_token_ids( + token_ids=token_ids, + num_lookahead_slots=num_lookahead_slots, + ) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 56c2c5995c38b..12e33647a0664 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -54,6 +54,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, + backtrack: int = 0, ) -> Dict[int, List[int]]: pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 99f7a34d336a4..9ab470760b41e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,6 +4,8 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from pyaici.comms import AiciRunner + from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.policy import Policy, PolicyFactory @@ -261,6 +263,7 @@ def __init__( version="v2" if self.scheduler_config. use_v2_block_manager else "v1") + self.aici_runner: AiciRunner = None # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( block_size=self.cache_config.block_size, @@ -296,6 +299,11 @@ def num_decoding_tokens_per_seq(self) -> int: return 1 def add_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.sampling_params.has_aici: + seq = seq_group.get_seqs()[0] + seq.has_aici = True + self.aici_runner.assign_seq_id(seq_group.request_id, seq.seq_id) + # Add sequence groups to the waiting queue. self.waiting.append(seq_group) @@ -884,6 +892,8 @@ def _can_swap_in(self, seq_group: SequenceGroup) -> bool: ) def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + runner = self.aici_runner + # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. @@ -905,6 +915,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id + if seq_group.sampling_params.has_aici: + runner.add_mid(seq_id) seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) @@ -935,6 +947,12 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: ) seq_group_metadata_list.append(seq_group_metadata) + if runner: + if scheduler_outputs.is_empty(): + assert not runner.needs_exec_mid() + else: + runner.exec_mid() + # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution @@ -950,6 +968,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None: """Free a sequence from a block table.""" + if seq.has_aici: + self.aici_runner.seq_freed(seq.seq_id) self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: @@ -979,7 +999,10 @@ def _append_slots( num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) + cows = self.block_manager.append_slots(seq, + num_lookahead_slots, + backtrack=seq.backtrack) + seq.backtrack = 0 for src, dests in cows.items(): if src not in blocks_to_copy: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 7e9d652446703..0b30cf838a226 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -79,6 +79,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # List of (child, parent) child_seqs: List[Tuple[Sequence, Sequence]] = [] + aici_runner = self.scheduler.aici_runner + # Process the child samples for each parent sequence for parent in parent_seqs: child_samples: List[SequenceOutput] = parent_child_dict[ @@ -102,9 +104,26 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] + child_seqs.append((parent, parent)) + if seq_group.sampling_params.has_aici: + sid = parent.seq_id + sampled_token = last_child_sample.output_token + r = aici_runner.mid_status(sid) + assert len(r.branches) <= 1 + if r.branches: + splice = r.branches[0].find_splice(sampled_token) + if splice: + parent.splice_tokens(splice.backtrack, + splice.ff_tokens) + aici_runner.tokens_generated( + sid, splice.ff_tokens, backtrack=splice.backtrack) + continue # don't call append_token_id() + else: + aici_runner.tokens_generated(sid, [sampled_token]) parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs) - child_seqs.append((parent, parent)) + + to_stop = aici_runner.get_seqs_to_stop() if aici_runner else set() for seq, _ in child_seqs: if seq_group.sampling_params.detokenize and self.detokenizer: @@ -114,6 +133,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, new_char_count = 0 self.stop_checker.maybe_stop_sequence(seq, new_char_count, seq_group.sampling_params) + if seq.seq_id in to_stop: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = "" # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 37d76b8e74055..d72e358c03676 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -6,6 +6,7 @@ from http import HTTPStatus import fastapi +import pyaici import uvicorn from fastapi import Request from fastapi.exceptions import RequestValidationError @@ -19,7 +20,9 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, ErrorResponse) + CompletionRequest, ErrorResponse, + RunRequest, SetTagsRequest) +from vllm.entrypoints.openai.serving_aici import AiciRunnerCompletion from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.logger import init_logger @@ -29,6 +32,8 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion +pyaici_runner_completion: AiciRunnerCompletion + logger = init_logger(__name__) @@ -51,6 +56,7 @@ async def _force_log(): def parse_args(): parser = make_arg_parser() + parser = pyaici.add_cli_args(parser) return parser.parse_args() @@ -114,6 +120,51 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) +def _no_aici(): + return JSONResponse({"error": "AICI runtime is not enabled"}, + status_code=501) + + +@app.post("/v1/controllers") +async def upload_aici_module(request: Request): + if not pyaici_runner_completion: + return _no_aici() + contents = await request.body() + return JSONResponse( + await + pyaici_runner_completion.aici_runner.upload_module_async(contents)) + + +@app.post("/v1/run") +async def aici_run(request: RunRequest, raw_request: Request): + if not pyaici_runner_completion: + return _no_aici() + request_id, inst_res = \ + await pyaici_runner_completion.prep_completion(request) + generator = pyaici_runner_completion.create_completion( + request_id, inst_res, request, raw_request) + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@app.post("/v1/controllers/tags") +async def aici_set_tags(request: SetTagsRequest): + if not pyaici_runner_completion: + return _no_aici() + # non-admin users can only set tags that start with their username + auto_info = {"user": "vllm", "is_admin": True} + r = await pyaici_runner_completion.aici_runner.set_tags( + request.module_id, request.tags, auth_info=auto_info) + return JSONResponse(r) + + +@app.get("/v1/controllers/tags") +async def aici_get_tags(): + if not pyaici_runner_completion: + return _no_aici() + r = await pyaici_runner_completion.aici_runner.get_tags() + return JSONResponse(r) + + if __name__ == "__main__": args = parse_args() @@ -158,6 +209,14 @@ async def authentication(request: Request, call_next): engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + if args.aici_rt: + config = asyncio.run(engine.get_model_config()) + dtype = str(config.dtype).replace("torch.", "").replace("float", "f") + pyaici_runner = pyaici.runner_from_cli(args, dtype=dtype) + pyaici_runner.fast_api() + assert len(served_model_names) == 1 + pyaici_runner_completion = AiciRunnerCompletion( + pyaici_runner, engine, served_model_names[0]) openai_serving_chat = OpenAIServingChat(engine, served_model_names, args.response_role, args.lora_modules, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d9763d024eb83..7b28c3fd3eb4d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -343,6 +343,32 @@ def check_guided_decoding_count(cls, data): return data +class RunRequest(BaseModel): + prompt: str + controller: str + controller_arg: Union[str, dict] + temperature: Optional[float] = 0.0 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + max_tokens: Optional[int] = None + + def to_sampling_params(self): + r = SamplingParams( + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + max_tokens=self.max_tokens, + ignore_eos=True, + ) + r.has_aici = True + return r + + +class SetTagsRequest(BaseModel): + module_id: str + tags: List[str] + + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_aici.py b/vllm/entrypoints/openai/serving_aici.py new file mode 100644 index 0000000000000..b4476f18fbe03 --- /dev/null +++ b/vllm/entrypoints/openai/serving_aici.py @@ -0,0 +1,102 @@ +from typing import List, Union + +from fastapi import Request +from pyaici.comms import AiciRunner + +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.utils import random_uuid + +from .protocol import RunRequest + +# logger = init_logger(__name__) + + +class AiciRunnerCompletion(OpenAIServing): + + def __init__(self, aici_runner: AiciRunner, engine: AsyncLLMEngine, + served_model_names: List[str]): + super().__init__(engine=engine, + served_model_names=served_model_names, + lora_modules=None) + self.aici_runner = aici_runner + self.empty_prompt: List[int] = self.tokenizer("").input_ids + if not self.empty_prompt: + # if there's no start symbol, add a space, otherwise Engine + # gets stuck on empty prompt + self.empty_prompt = self.tokenizer(" ").input_ids + assert self.empty_prompt + # TODO: this is a hack: + engine.engine.scheduler.aici_runner = aici_runner + + # this is separate from create_completion() so fastapi exceptions + # from .instantiate_async() are properly sent to the user + async def prep_completion(self, request: RunRequest): + request_id = f"run-{random_uuid()}" + prompt = self.tokenizer(request.prompt).input_ids + inst_res = await self.aici_runner.instantiate_async( + request_id, prompt, request.controller, request.controller_arg) + return request_id, inst_res + + async def create_completion(self, request_id: str, inst_res: Union[dict, + list], + request: RunRequest, raw_request: Request): + """Completion API for AICI controllers. + + See https://github.com/microsoft/aici/blob/main/docs/REST.md + """ + runner = self.aici_runner + yield runner.data_line( + runner.initial_json(request_id, self.served_model_names[0])) + + if isinstance(inst_res, dict): + # error case + yield runner.data_line(inst_res) + yield runner.final_data() + return + + # Engine doesn't like prompts with no tokens + # self.empty_prompt is either start symbol or a single space + if len(inst_res) == 0: + inst_res = self.empty_prompt + + sampling_params = request.to_sampling_params() + sampling_params.stop_token_ids = [] + generator = self.engine.generate(prompt=None, + sampling_params=sampling_params, + request_id=request_id, + prompt_token_ids=inst_res) + + previous_texts = [] + ff_tokens = len(inst_res) + sampled_tokens = 0 + + async for res in generator: + # Abort the request if the client disconnects. + if await raw_request.is_disconnected(): + await self.engine.abort(request_id) + raise StopAsyncIteration() + forks = [] + for output in res.outputs: + # TODO: + ff_tokens += 1 + sampled_tokens += 1 + + i = output.index + while len(previous_texts) <= i: + previous_texts.append("") + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + + fork_res = runner.seq_logs( + output.seq_id, + index=i, + text=delta_text, + finish_reason=output.finish_reason, + ) + forks.append(fork_res) + yield runner.data_line( + runner.run_json(forks, + runner.usage_json(ff_tokens, sampled_tokens))) + + yield runner.final_data() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 31da27a447c6c..730ad50e6ceed 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -32,7 +32,7 @@ class OpenAIServing: def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules=Optional[List[LoRA]]): + lora_modules: Optional[List[LoRA]] = None): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c4b11cb33a677..f41856c0d4c62 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from pyaici.comms import AiciRunner from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.sampling_metadata import (SamplingMetadata, @@ -51,6 +52,9 @@ def forward( assert logits is not None _, vocab_size = logits.shape + # Start with constrained decoding + logits = _apply_aici_logit_bias(logits, sampling_metadata) + # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens # have not been generated yet logits = _apply_min_tokens_penalty(logits, sampling_metadata) @@ -144,6 +148,40 @@ def _get_bin_counts_and_mask( return bin_counts, mask +def _apply_aici_logit_bias( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +): + aici_runner = AiciRunner.instance + if not aici_runner: + return logits + mid_results, arr = aici_runner.recv_logit_bias_torch() + # logits.dtype should generally match arr.dtype + bias = arr.to(logits.device).to(logits.dtype) + if bias.shape[0] == 0: + return logits + + logits_row_idx = 0 + for seq_ids, sampling_params in sampling_metadata.seq_groups: + if sampling_params.has_aici: + for id in seq_ids: + r = mid_results.get(id) + if r and len(r.branches) >= 1: + # this is actually also enforced by AICIrt since + # we don't pass --cap-fork + assert len(r.branches) <= 1, "Only one branch is supported" + mask = r.branches[0].mask + if mask is not None: + logits[logits_row_idx] += bias[mask, 0:logits.shape[1]] + temp = r.branches[0].temperature + if temp is not None: + sampling_params.temperature = temp + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + + return logits + def _apply_min_tokens_penalty( logits: torch.Tensor, diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd2..7ab9793bb92d1 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -26,6 +26,7 @@ class CompletionOutput: def __init__( self, + seq_id: int, index: int, text: str, token_ids: List[int], @@ -35,6 +36,7 @@ def __init__( stop_reason: Union[int, str, None] = None, lora_request: Optional[LoRARequest] = None, ) -> None: + self.seq_id = seq_id self.index = index self.text = text self.token_ids = token_ids @@ -114,7 +116,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": include_logprobs = seq_group.sampling_params.logprobs is not None text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), + CompletionOutput(seq.seq_id, seqs.index(seq), seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dc0e60344d858..9744be9b832e9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -166,6 +166,7 @@ def __init__( self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output + self.has_aici = False self.truncate_prompt_tokens = truncate_prompt_tokens # Number of characters to hold back for stop string evaluation # until sequence is finished. @@ -274,6 +275,9 @@ def _verify_greedy_sampling(self) -> None: def update_from_generation_config( self, generation_config: Dict[str, Any]) -> None: """Update if there are non-default values from generation_config""" + # For AICI, we want the controller to control stopping. + if self.has_aici: + return # Update eos_token_id for generation if eos_ids := generation_config.get("eos_token_id"): # it can be either int or list of int diff --git a/vllm/sequence.py b/vllm/sequence.py index b296b37a84f15..82d22e0b8300e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -218,6 +218,8 @@ def __init__( self.data = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" + self.has_aici = False + self.backtrack = 0 self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. @@ -291,6 +293,58 @@ def append_token_id( self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob) + def splice_tokens(self, backtrack: int, token_ids: List[int]): + assert self.backtrack == 0 + + data = self.data + + if not token_ids: + # we need at least one token in forward step, + # so we pretend we're backtracking one token more + # and repeat the token that was there + # otherwise, the _num_comptued_tokens gets out of sync + backtrack += 1 + if backtrack <= len(data.output_token_ids): + token_ids = [data.output_token_ids[-backtrack]] + else: + off = backtrack - len(data.output_token_ids) + if off <= len(data.prompt_token_ids): + token_ids = [data.prompt_token_ids[-off]] + else: + token_ids = [1] + + backtrack = min(backtrack, data.get_len()) + self.backtrack = backtrack + + if backtrack > 0: + prompt_backtrack = 0 + output_len = len(data.output_token_ids) + if backtrack > output_len: + prompt_backtrack = backtrack - output_len + backtrack = output_len + del data.output_token_ids[-backtrack:] + del self.output_logprobs[-backtrack:] + data._num_computed_tokens = min(data._num_computed_tokens, + len(data.output_token_ids)) + if prompt_backtrack > 0: + assert not data.output_token_ids + del data.prompt_token_ids[-prompt_backtrack:] + needed_blocks = \ + (self.get_len() + self.block_size - 1) // self.block_size + if len(self.logical_token_blocks) > needed_blocks: + del self.logical_token_blocks[needed_blocks:] + if needed_blocks > 0: + last_block = self.logical_token_blocks[-1] + last_num_tokens = self.get_len() % self.block_size + if last_num_tokens == 0: + last_num_tokens = self.block_size + last_block.num_tokens = last_num_tokens + + for t in token_ids: + self.append_token_id(t, {t: Logprob(logprob=0.0)}) + if data.get_num_uncomputed_tokens() > 1: + data._stage = SequenceStage.PREFILL + def get_len(self) -> int: return self.data.get_len() From bc8a7222859fc87a74b38cbad04aa2e1e284f4b6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 10 May 2024 01:35:23 +0000 Subject: [PATCH 2/9] adapt to new APIs --- vllm/entrypoints/openai/api_server.py | 18 +++++++++--------- vllm/entrypoints/openai/serving_aici.py | 6 ++++-- vllm/model_executor/layers/sampler.py | 11 ++++++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 84d43f8f4a748..cec4a910f1b08 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -219,15 +219,6 @@ async def authentication(request: Request, call_next): engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - if args.aici_rt: - config = asyncio.run(engine.get_model_config()) - dtype = str(config.dtype).replace("torch.", "").replace("float", "f") - pyaici_runner = pyaici.runner_from_cli(args, dtype=dtype) - pyaici_runner.fast_api() - assert len(served_model_names) == 1 - pyaici_runner_completion = AiciRunnerCompletion( - pyaici_runner, engine, served_model_names[0]) - event_loop: Optional[asyncio.AbstractEventLoop] try: event_loop = asyncio.get_running_loop() @@ -242,6 +233,15 @@ async def authentication(request: Request, call_next): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) + if args.aici_rt: + config = asyncio.run(engine.get_model_config()) + dtype = str(config.dtype).replace("torch.", "").replace("float", "f") + pyaici_runner = pyaici.runner_from_cli(args, dtype=dtype) + pyaici_runner.fast_api() + assert len(served_model_names) == 1 + pyaici_runner_completion = AiciRunnerCompletion( + pyaici_runner, engine, model_config, served_model_names[0]) + openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, args.response_role, diff --git a/vllm/entrypoints/openai/serving_aici.py b/vllm/entrypoints/openai/serving_aici.py index b4476f18fbe03..81be479f5166b 100644 --- a/vllm/entrypoints/openai/serving_aici.py +++ b/vllm/entrypoints/openai/serving_aici.py @@ -3,6 +3,7 @@ from fastapi import Request from pyaici.comms import AiciRunner +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.utils import random_uuid @@ -15,8 +16,9 @@ class AiciRunnerCompletion(OpenAIServing): def __init__(self, aici_runner: AiciRunner, engine: AsyncLLMEngine, - served_model_names: List[str]): + model_config: ModelConfig, served_model_names: List[str]): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, lora_modules=None) self.aici_runner = aici_runner @@ -67,7 +69,7 @@ async def create_completion(self, request_id: str, inst_res: Union[dict, request_id=request_id, prompt_token_ids=inst_res) - previous_texts = [] + previous_texts: List[str] = [] ff_tokens = len(inst_res) sampled_tokens = 0 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d2bfe287eb916..2b3992fa887f4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -152,6 +152,7 @@ def _get_bin_counts_and_mask( return bin_counts, mask + def _apply_aici_logit_bias( logits: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -166,9 +167,9 @@ def _apply_aici_logit_bias( return logits logits_row_idx = 0 - for seq_ids, sampling_params in sampling_metadata.seq_groups: - if sampling_params.has_aici: - for id in seq_ids: + for sg in sampling_metadata.seq_groups: + if sg.sampling_params.has_aici: + for id in sg.seq_ids: r = mid_results.get(id) if r and len(r.branches) >= 1: # this is actually also enforced by AICIrt since @@ -179,10 +180,10 @@ def _apply_aici_logit_bias( logits[logits_row_idx] += bias[mask, 0:logits.shape[1]] temp = r.branches[0].temperature if temp is not None: - sampling_params.temperature = temp + sg.sampling_params.temperature = temp logits_row_idx += 1 else: - logits_row_idx += len(seq_ids) + logits_row_idx += len(sg.seq_ids) return logits From 69e08563035f228945f3152b8bd99d3b26f22774 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 10 May 2024 01:39:03 +0000 Subject: [PATCH 3/9] avoid crash --- vllm/sequence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 1316f7f0089b8..9a07748fd43c6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -499,7 +499,8 @@ def lora_int_id(self) -> int: def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. - if self.is_prefill(): + # With AICI, the request may go from decode to prefill, so ignore. + if self.is_prefill() and not self.sampling_params.has_aici: raise ValueError( "seq_group.get_last_latency() should not be called " "if the seq_group is in prefill phase.") From 371aa8040760bf4105506b6acd13a1fa3f104aee Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 11 May 2024 21:21:22 +0000 Subject: [PATCH 4/9] simplify condition --- aici.patch | 713 ++++++++++++++++++++++++++++++++++ vllm/core/block_manager_v2.py | 14 +- 2 files changed, 717 insertions(+), 10 deletions(-) create mode 100644 aici.patch diff --git a/aici.patch b/aici.patch new file mode 100644 index 0000000000000..884d22bd8720b --- /dev/null +++ b/aici.patch @@ -0,0 +1,713 @@ +diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py +index b0d9511f..eace526e 100644 +--- a/vllm/core/block/block_table.py ++++ b/vllm/core/block/block_table.py +@@ -87,6 +87,22 @@ class BlockTable: + device=device) + self._num_full_slots = len(token_ids) + ++ def backtrack(self, num_slots: int) -> None: ++ """Remove the specified number of slots from the end of the table. ++ ++ Args: ++ num_slots (int): The number of slots to backtrack by. ++ """ ++ assert self._is_allocated ++ assert num_slots <= self._num_full_slots ++ if num_slots == 0: ++ return ++ self._num_full_slots -= num_slots ++ blocks = self._blocks[self._num_full_slots // self._block_size:] ++ blocks[0].trim(self._num_full_slots % self._block_size) ++ for b in blocks[1:]: ++ b.trim(0) ++ + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0) -> None: +diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py +index 140fbbb0..999360ff 100644 +--- a/vllm/core/block/interfaces.py ++++ b/vllm/core/block/interfaces.py +@@ -43,6 +43,10 @@ class Block(ABC): + def prev_block(self) -> Optional["Block"]: + pass + ++ @abstractmethod ++ def trim(self, num_tokens: int): ++ pass ++ + @property + @abstractmethod + def computed(self) -> bool: +diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py +index ae019308..8f8ae375 100644 +--- a/vllm/core/block/naive_block.py ++++ b/vllm/core/block/naive_block.py +@@ -269,6 +269,9 @@ class NaiveBlock(Block): + assert self.num_empty_slots >= len(token_ids) + self._token_ids.extend(token_ids) + ++ def trim(self, num_tokens: int): ++ del self._token_ids[num_tokens:] ++ + @property + def computed(self) -> bool: + raise NotImplementedError +diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py +index 882f301c..4575fcbd 100644 +--- a/vllm/core/block/prefix_caching_block.py ++++ b/vllm/core/block/prefix_caching_block.py +@@ -456,6 +456,9 @@ class PrefixCachingBlock(Block): + _cow_target=self, + ) + ++ def trim(self, num_tokens: int): ++ return self._block.trim(num_tokens) ++ + @property + def computed(self) -> bool: + return self._computed +diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py +index 52a170d7..0359a4ea 100644 +--- a/vllm/core/block_manager_v1.py ++++ b/vllm/core/block_manager_v1.py +@@ -386,10 +386,13 @@ class BlockSpaceManagerV1(BlockSpaceManager): + self, + seq: Sequence, + num_lookahead_slots: int = 0, ++ backtrack: int = 0, + ) -> List[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + logical_blocks = seq.logical_token_blocks + block_table = self.block_tables[seq.seq_id] ++ assert backtrack == 0, \ ++ "Backtrack not supported; consider --use-v2-block-manager" + # If we need to allocate a new physical block + if len(block_table) < len(logical_blocks): + # Currently this code only supports adding one physical block +diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py +index f0bc9656..0a62f83d 100644 +--- a/vllm/core/block_manager_v2.py ++++ b/vllm/core/block_manager_v2.py +@@ -167,10 +167,12 @@ class BlockSpaceManagerV2(BlockSpaceManager): + self, + seq: Sequence, + num_lookahead_slots: int, ++ backtrack: int = 0, + ) -> List[Tuple[int, int]]: + + block_table = self.block_tables[seq.seq_id] + ++ block_table.backtrack(backtrack) + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, +diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py +index b2a5e419..c56fd760 100644 +--- a/vllm/core/interfaces.py ++++ b/vllm/core/interfaces.py +@@ -55,6 +55,7 @@ class BlockSpaceManager(ABC): + self, + seq: Sequence, + num_lookahead_slots: int, ++ backtrack: int = 0, + ) -> List[Tuple[int, int]]: + pass + +diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py +index 35e3db18..a6cce282 100644 +--- a/vllm/core/scheduler.py ++++ b/vllm/core/scheduler.py +@@ -6,6 +6,8 @@ from collections import deque + from dataclasses import dataclass, field + from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union + ++from pyaici.comms import AiciRunner ++ + from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig + from vllm.core.interfaces import AllocStatus, BlockSpaceManager + from vllm.core.policy import Policy, PolicyFactory +@@ -274,6 +276,7 @@ class Scheduler: + version="v2" if self.scheduler_config. + use_v2_block_manager else "v1") + ++ self.aici_runner: AiciRunner = None + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, +@@ -316,6 +319,11 @@ class Scheduler: + return 1 + + def add_seq_group(self, seq_group: SequenceGroup) -> None: ++ if seq_group.sampling_params.has_aici: ++ seq = seq_group.get_seqs()[0] ++ seq.has_aici = True ++ self.aici_runner.assign_seq_id(seq_group.request_id, seq.seq_id) ++ + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + +@@ -915,6 +923,8 @@ class Scheduler: + ) + + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: ++ runner = self.aici_runner ++ + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. +@@ -936,6 +946,8 @@ class Scheduler: + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id ++ if seq_group.sampling_params.has_aici: ++ runner.add_mid(seq_id) + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) +@@ -981,6 +993,12 @@ class Scheduler: + ) + seq_group_metadata_list.append(seq_group_metadata) + ++ if runner: ++ if scheduler_outputs.is_empty(): ++ assert not runner.needs_exec_mid() ++ else: ++ runner.exec_mid() ++ + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution +@@ -996,6 +1014,8 @@ class Scheduler: + + def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" ++ if seq.has_aici: ++ self.aici_runner.seq_freed(seq.seq_id) + self.block_manager.free(seq) + + def free_finished_seq_groups(self) -> None: +@@ -1026,7 +1046,10 @@ class Scheduler: + num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): +- cows = self.block_manager.append_slots(seq, num_lookahead_slots) ++ cows = self.block_manager.append_slots(seq, ++ num_lookahead_slots, ++ backtrack=seq.backtrack) ++ seq.backtrack = 0 + blocks_to_copy.extend(cows) + + def _preempt( +diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py +index 07b14058..6cd6ed0f 100644 +--- a/vllm/engine/output_processor/single_step.py ++++ b/vllm/engine/output_processor/single_step.py +@@ -85,6 +85,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + ++ aici_runner = self.scheduler.aici_runner ++ + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ +@@ -108,9 +110,26 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] ++ child_seqs.append((parent, parent)) ++ if seq_group.sampling_params.has_aici: ++ sid = parent.seq_id ++ sampled_token = last_child_sample.output_token ++ r = aici_runner.mid_status(sid) ++ assert len(r.branches) <= 1 ++ if r.branches: ++ splice = r.branches[0].find_splice(sampled_token) ++ if splice: ++ parent.splice_tokens(splice.backtrack, ++ splice.ff_tokens) ++ aici_runner.tokens_generated( ++ sid, splice.ff_tokens, backtrack=splice.backtrack) ++ continue # don't call append_token_id() ++ else: ++ aici_runner.tokens_generated(sid, [sampled_token]) + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) +- child_seqs.append((parent, parent)) ++ ++ to_stop = aici_runner.get_seqs_to_stop() if aici_runner else set() + + for seq, _ in child_seqs: + if seq_group.sampling_params.detokenize and self.detokenizer: +@@ -120,6 +139,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + new_char_count = 0 + self.stop_checker.maybe_stop_sequence(seq, new_char_count, + seq_group.sampling_params) ++ if seq.seq_id in to_stop: ++ seq.status = SequenceStatus.FINISHED_STOPPED ++ seq.stop_reason = "" + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: +diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py +index 362f28d0..cec4a910 100644 +--- a/vllm/entrypoints/openai/api_server.py ++++ b/vllm/entrypoints/openai/api_server.py +@@ -7,6 +7,7 @@ from http import HTTPStatus + from typing import Optional, Set + + import fastapi ++import pyaici + import uvicorn + from fastapi import Request + from fastapi.exceptions import RequestValidationError +@@ -22,7 +23,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.openai.cli_args import make_arg_parser + from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, +- CompletionRequest, ErrorResponse) ++ CompletionRequest, ErrorResponse, ++ RunRequest, SetTagsRequest) ++from vllm.entrypoints.openai.serving_aici import AiciRunnerCompletion + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + from vllm.logger import init_logger +@@ -32,6 +35,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds + + openai_serving_chat: OpenAIServingChat + openai_serving_completion: OpenAIServingCompletion ++pyaici_runner_completion: AiciRunnerCompletion ++ + logger = init_logger(__name__) + + _running_tasks: Set[asyncio.Task] = set() +@@ -58,6 +63,7 @@ app = fastapi.FastAPI(lifespan=lifespan) + + def parse_args(): + parser = make_arg_parser() ++ parser = pyaici.add_cli_args(parser) + return parser.parse_args() + + +@@ -123,6 +129,51 @@ async def create_completion(request: CompletionRequest, raw_request: Request): + return JSONResponse(content=generator.model_dump()) + + ++def _no_aici(): ++ return JSONResponse({"error": "AICI runtime is not enabled"}, ++ status_code=501) ++ ++ ++@app.post("/v1/controllers") ++async def upload_aici_module(request: Request): ++ if not pyaici_runner_completion: ++ return _no_aici() ++ contents = await request.body() ++ return JSONResponse( ++ await ++ pyaici_runner_completion.aici_runner.upload_module_async(contents)) ++ ++ ++@app.post("/v1/run") ++async def aici_run(request: RunRequest, raw_request: Request): ++ if not pyaici_runner_completion: ++ return _no_aici() ++ request_id, inst_res = \ ++ await pyaici_runner_completion.prep_completion(request) ++ generator = pyaici_runner_completion.create_completion( ++ request_id, inst_res, request, raw_request) ++ return StreamingResponse(content=generator, media_type="text/event-stream") ++ ++ ++@app.post("/v1/controllers/tags") ++async def aici_set_tags(request: SetTagsRequest): ++ if not pyaici_runner_completion: ++ return _no_aici() ++ # non-admin users can only set tags that start with their username ++ auto_info = {"user": "vllm", "is_admin": True} ++ r = await pyaici_runner_completion.aici_runner.set_tags( ++ request.module_id, request.tags, auth_info=auto_info) ++ return JSONResponse(r) ++ ++ ++@app.get("/v1/controllers/tags") ++async def aici_get_tags(): ++ if not pyaici_runner_completion: ++ return _no_aici() ++ r = await pyaici_runner_completion.aici_runner.get_tags() ++ return JSONResponse(r) ++ ++ + if __name__ == "__main__": + args = parse_args() + +@@ -168,7 +219,6 @@ if __name__ == "__main__": + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) +- + event_loop: Optional[asyncio.AbstractEventLoop] + try: + event_loop = asyncio.get_running_loop() +@@ -183,6 +233,15 @@ if __name__ == "__main__": + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + ++ if args.aici_rt: ++ config = asyncio.run(engine.get_model_config()) ++ dtype = str(config.dtype).replace("torch.", "").replace("float", "f") ++ pyaici_runner = pyaici.runner_from_cli(args, dtype=dtype) ++ pyaici_runner.fast_api() ++ assert len(served_model_names) == 1 ++ pyaici_runner_completion = AiciRunnerCompletion( ++ pyaici_runner, engine, model_config, served_model_names[0]) ++ + openai_serving_chat = OpenAIServingChat(engine, model_config, + served_model_names, + args.response_role, +diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py +index 3cd9ddad..02c0c907 100644 +--- a/vllm/entrypoints/openai/protocol.py ++++ b/vllm/entrypoints/openai/protocol.py +@@ -363,6 +363,32 @@ class CompletionRequest(OpenAIBaseModel): + return data + + ++class RunRequest(BaseModel): ++ prompt: str ++ controller: str ++ controller_arg: Union[str, dict] ++ temperature: Optional[float] = 0.0 ++ top_p: Optional[float] = 1.0 ++ top_k: Optional[int] = -1 ++ max_tokens: Optional[int] = None ++ ++ def to_sampling_params(self): ++ r = SamplingParams( ++ temperature=self.temperature, ++ top_p=self.top_p, ++ top_k=self.top_k, ++ max_tokens=self.max_tokens, ++ ignore_eos=True, ++ ) ++ r.has_aici = True ++ return r ++ ++ ++class SetTagsRequest(BaseModel): ++ module_id: str ++ tags: List[str] ++ ++ + class LogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) +diff --git a/vllm/entrypoints/openai/serving_aici.py b/vllm/entrypoints/openai/serving_aici.py +new file mode 100644 +index 00000000..81be479f +--- /dev/null ++++ b/vllm/entrypoints/openai/serving_aici.py +@@ -0,0 +1,104 @@ ++from typing import List, Union ++ ++from fastapi import Request ++from pyaici.comms import AiciRunner ++ ++from vllm.config import ModelConfig ++from vllm.engine.async_llm_engine import AsyncLLMEngine ++from vllm.entrypoints.openai.serving_engine import OpenAIServing ++from vllm.utils import random_uuid ++ ++from .protocol import RunRequest ++ ++# logger = init_logger(__name__) ++ ++ ++class AiciRunnerCompletion(OpenAIServing): ++ ++ def __init__(self, aici_runner: AiciRunner, engine: AsyncLLMEngine, ++ model_config: ModelConfig, served_model_names: List[str]): ++ super().__init__(engine=engine, ++ model_config=model_config, ++ served_model_names=served_model_names, ++ lora_modules=None) ++ self.aici_runner = aici_runner ++ self.empty_prompt: List[int] = self.tokenizer("").input_ids ++ if not self.empty_prompt: ++ # if there's no start symbol, add a space, otherwise Engine ++ # gets stuck on empty prompt ++ self.empty_prompt = self.tokenizer(" ").input_ids ++ assert self.empty_prompt ++ # TODO: this is a hack: ++ engine.engine.scheduler.aici_runner = aici_runner ++ ++ # this is separate from create_completion() so fastapi exceptions ++ # from .instantiate_async() are properly sent to the user ++ async def prep_completion(self, request: RunRequest): ++ request_id = f"run-{random_uuid()}" ++ prompt = self.tokenizer(request.prompt).input_ids ++ inst_res = await self.aici_runner.instantiate_async( ++ request_id, prompt, request.controller, request.controller_arg) ++ return request_id, inst_res ++ ++ async def create_completion(self, request_id: str, inst_res: Union[dict, ++ list], ++ request: RunRequest, raw_request: Request): ++ """Completion API for AICI controllers. ++ ++ See https://github.com/microsoft/aici/blob/main/docs/REST.md ++ """ ++ runner = self.aici_runner ++ yield runner.data_line( ++ runner.initial_json(request_id, self.served_model_names[0])) ++ ++ if isinstance(inst_res, dict): ++ # error case ++ yield runner.data_line(inst_res) ++ yield runner.final_data() ++ return ++ ++ # Engine doesn't like prompts with no tokens ++ # self.empty_prompt is either start symbol or a single space ++ if len(inst_res) == 0: ++ inst_res = self.empty_prompt ++ ++ sampling_params = request.to_sampling_params() ++ sampling_params.stop_token_ids = [] ++ generator = self.engine.generate(prompt=None, ++ sampling_params=sampling_params, ++ request_id=request_id, ++ prompt_token_ids=inst_res) ++ ++ previous_texts: List[str] = [] ++ ff_tokens = len(inst_res) ++ sampled_tokens = 0 ++ ++ async for res in generator: ++ # Abort the request if the client disconnects. ++ if await raw_request.is_disconnected(): ++ await self.engine.abort(request_id) ++ raise StopAsyncIteration() ++ forks = [] ++ for output in res.outputs: ++ # TODO: ++ ff_tokens += 1 ++ sampled_tokens += 1 ++ ++ i = output.index ++ while len(previous_texts) <= i: ++ previous_texts.append("") ++ delta_text = output.text[len(previous_texts[i]):] ++ previous_texts[i] = output.text ++ ++ fork_res = runner.seq_logs( ++ output.seq_id, ++ index=i, ++ text=delta_text, ++ finish_reason=output.finish_reason, ++ ) ++ forks.append(fork_res) ++ yield runner.data_line( ++ runner.run_json(forks, ++ runner.usage_json(ff_tokens, sampled_tokens))) ++ ++ yield runner.final_data() +diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py +index e52e350d..2b3992fa 100644 +--- a/vllm/model_executor/layers/sampler.py ++++ b/vllm/model_executor/layers/sampler.py +@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple + + import torch + import torch.nn as nn ++from pyaici.comms import AiciRunner + + from vllm.model_executor.layers.ops.sample import sample as sample_triton + from vllm.model_executor.sampling_metadata import (SamplingMetadata, +@@ -59,6 +60,9 @@ class Sampler(nn.Module): + assert logits is not None + _, vocab_size = logits.shape + ++ # Start with constrained decoding ++ logits = _apply_aici_logit_bias(logits, sampling_metadata) ++ + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Prepare sampling tensors with pinned memory to avoid blocking. +@@ -149,6 +153,41 @@ def _get_bin_counts_and_mask( + return bin_counts, mask + + ++def _apply_aici_logit_bias( ++ logits: torch.Tensor, ++ sampling_metadata: SamplingMetadata, ++): ++ aici_runner = AiciRunner.instance ++ if not aici_runner: ++ return logits ++ mid_results, arr = aici_runner.recv_logit_bias_torch() ++ # logits.dtype should generally match arr.dtype ++ bias = arr.to(logits.device).to(logits.dtype) ++ if bias.shape[0] == 0: ++ return logits ++ ++ logits_row_idx = 0 ++ for sg in sampling_metadata.seq_groups: ++ if sg.sampling_params.has_aici: ++ for id in sg.seq_ids: ++ r = mid_results.get(id) ++ if r and len(r.branches) >= 1: ++ # this is actually also enforced by AICIrt since ++ # we don't pass --cap-fork ++ assert len(r.branches) <= 1, "Only one branch is supported" ++ mask = r.branches[0].mask ++ if mask is not None: ++ logits[logits_row_idx] += bias[mask, 0:logits.shape[1]] ++ temp = r.branches[0].temperature ++ if temp is not None: ++ sg.sampling_params.temperature = temp ++ logits_row_idx += 1 ++ else: ++ logits_row_idx += len(sg.seq_ids) ++ ++ return logits ++ ++ + def _apply_min_tokens_penalty( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +diff --git a/vllm/outputs.py b/vllm/outputs.py +index d01be0eb..7ab9793b 100644 +--- a/vllm/outputs.py ++++ b/vllm/outputs.py +@@ -26,6 +26,7 @@ class CompletionOutput: + + def __init__( + self, ++ seq_id: int, + index: int, + text: str, + token_ids: List[int], +@@ -35,6 +36,7 @@ class CompletionOutput: + stop_reason: Union[int, str, None] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: ++ self.seq_id = seq_id + self.index = index + self.text = text + self.token_ids = token_ids +@@ -114,7 +116,7 @@ class RequestOutput: + include_logprobs = seq_group.sampling_params.logprobs is not None + text_buffer_length = seq_group.sampling_params.output_text_buffer_length + outputs = [ +- CompletionOutput(seqs.index(seq), ++ CompletionOutput(seq.seq_id, seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), +diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py +index 5fa94eb1..57739a51 100644 +--- a/vllm/sampling_params.py ++++ b/vllm/sampling_params.py +@@ -169,6 +169,7 @@ class SamplingParams: + self.spaces_between_special_tokens = spaces_between_special_tokens + self.logits_processors = logits_processors + self.include_stop_str_in_output = include_stop_str_in_output ++ self.has_aici = False + self.truncate_prompt_tokens = truncate_prompt_tokens + # Number of characters to hold back for stop string evaluation + # until sequence is finished. +@@ -277,6 +278,9 @@ class SamplingParams: + def update_from_generation_config( + self, generation_config: Dict[str, Any]) -> None: + """Update if there are non-default values from generation_config""" ++ # For AICI, we want the controller to control stopping. ++ if self.has_aici: ++ return + # Update eos_token_id for generation + if (not self.ignore_eos) and (eos_ids := + generation_config.get("eos_token_id")): +diff --git a/vllm/sequence.py b/vllm/sequence.py +index 3cebb85b..9a07748f 100644 +--- a/vllm/sequence.py ++++ b/vllm/sequence.py +@@ -221,6 +221,8 @@ class Sequence: + self.data: SequenceData = SequenceData(prompt_token_ids) + self.output_logprobs: SampleLogprobs = [] + self.output_text = "" ++ self.has_aici = False ++ self.backtrack = 0 + + self.logical_token_blocks: List[LogicalTokenBlock] = [] + # Initialize the logical token blocks with the prompt token ids. +@@ -294,6 +296,58 @@ class Sequence: + self.output_logprobs.append(logprobs) + self.data.append_token_id(token_id, logprobs[token_id].logprob) + ++ def splice_tokens(self, backtrack: int, token_ids: List[int]): ++ assert self.backtrack == 0 ++ ++ data = self.data ++ ++ if not token_ids: ++ # we need at least one token in forward step, ++ # so we pretend we're backtracking one token more ++ # and repeat the token that was there ++ # otherwise, the _num_comptued_tokens gets out of sync ++ backtrack += 1 ++ if backtrack <= len(data.output_token_ids): ++ token_ids = [data.output_token_ids[-backtrack]] ++ else: ++ off = backtrack - len(data.output_token_ids) ++ if off <= len(data.prompt_token_ids): ++ token_ids = [data.prompt_token_ids[-off]] ++ else: ++ token_ids = [1] ++ ++ backtrack = min(backtrack, data.get_len()) ++ self.backtrack = backtrack ++ ++ if backtrack > 0: ++ prompt_backtrack = 0 ++ output_len = len(data.output_token_ids) ++ if backtrack > output_len: ++ prompt_backtrack = backtrack - output_len ++ backtrack = output_len ++ del data.output_token_ids[-backtrack:] ++ del self.output_logprobs[-backtrack:] ++ data._num_computed_tokens = min(data._num_computed_tokens, ++ len(data.output_token_ids)) ++ if prompt_backtrack > 0: ++ assert not data.output_token_ids ++ del data.prompt_token_ids[-prompt_backtrack:] ++ needed_blocks = \ ++ (self.get_len() + self.block_size - 1) // self.block_size ++ if len(self.logical_token_blocks) > needed_blocks: ++ del self.logical_token_blocks[needed_blocks:] ++ if needed_blocks > 0: ++ last_block = self.logical_token_blocks[-1] ++ last_num_tokens = self.get_len() % self.block_size ++ if last_num_tokens == 0: ++ last_num_tokens = self.block_size ++ last_block.num_tokens = last_num_tokens ++ ++ for t in token_ids: ++ self.append_token_id(t, {t: Logprob(logprob=0.0)}) ++ if data.get_num_uncomputed_tokens() > 1: ++ data._stage = SequenceStage.PREFILL ++ + def get_len(self) -> int: + return self.data.get_len() + +@@ -445,7 +499,8 @@ class SequenceGroup: + def get_last_latency(self, now: float) -> Optional[float]: + """Sets the last token time for Request level timings.""" + # If still in prefill phase, raise Error. +- if self.is_prefill(): ++ # With AICI, the request may go from decode to prefill, so ignore. ++ if self.is_prefill() and not self.sampling_params.has_aici: + raise ValueError( + "seq_group.get_last_latency() should not be called " + "if the seq_group is in prefill phase.") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 92a7f69051834..0a62f83d06e5c 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -173,16 +173,10 @@ def append_slots( block_table = self.block_tables[seq.seq_id] block_table.backtrack(backtrack) - token_ids = block_table.get_unseen_token_ids(seq.get_token_ids()) - if seq.has_aici and not token_ids: - # AICI may want to "append" empty tokens, either to just backtrack - # or to force a wait for one step. - assert num_lookahead_slots == 0 - else: - block_table.append_token_ids( - token_ids=token_ids, - num_lookahead_slots=num_lookahead_slots, - ) + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + ) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() From 697ff789ec77d39e5841616c37d7ed702d5a1621 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 11 May 2024 21:25:51 +0000 Subject: [PATCH 5/9] remove most of AICI-specific code --- aici.patch | 713 -------------------- vllm/core/scheduler.py | 10 - vllm/engine/output_processor/single_step.py | 2 - vllm/entrypoints/openai/api_server.py | 63 +- vllm/entrypoints/openai/protocol.py | 26 - vllm/entrypoints/openai/serving_aici.py | 104 --- vllm/model_executor/layers/sampler.py | 39 -- vllm/outputs.py | 4 +- 8 files changed, 3 insertions(+), 958 deletions(-) delete mode 100644 aici.patch delete mode 100644 vllm/entrypoints/openai/serving_aici.py diff --git a/aici.patch b/aici.patch deleted file mode 100644 index 884d22bd8720b..0000000000000 --- a/aici.patch +++ /dev/null @@ -1,713 +0,0 @@ -diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py -index b0d9511f..eace526e 100644 ---- a/vllm/core/block/block_table.py -+++ b/vllm/core/block/block_table.py -@@ -87,6 +87,22 @@ class BlockTable: - device=device) - self._num_full_slots = len(token_ids) - -+ def backtrack(self, num_slots: int) -> None: -+ """Remove the specified number of slots from the end of the table. -+ -+ Args: -+ num_slots (int): The number of slots to backtrack by. -+ """ -+ assert self._is_allocated -+ assert num_slots <= self._num_full_slots -+ if num_slots == 0: -+ return -+ self._num_full_slots -= num_slots -+ blocks = self._blocks[self._num_full_slots // self._block_size:] -+ blocks[0].trim(self._num_full_slots % self._block_size) -+ for b in blocks[1:]: -+ b.trim(0) -+ - def append_token_ids(self, - token_ids: List[int], - num_lookahead_slots: int = 0) -> None: -diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py -index 140fbbb0..999360ff 100644 ---- a/vllm/core/block/interfaces.py -+++ b/vllm/core/block/interfaces.py -@@ -43,6 +43,10 @@ class Block(ABC): - def prev_block(self) -> Optional["Block"]: - pass - -+ @abstractmethod -+ def trim(self, num_tokens: int): -+ pass -+ - @property - @abstractmethod - def computed(self) -> bool: -diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py -index ae019308..8f8ae375 100644 ---- a/vllm/core/block/naive_block.py -+++ b/vllm/core/block/naive_block.py -@@ -269,6 +269,9 @@ class NaiveBlock(Block): - assert self.num_empty_slots >= len(token_ids) - self._token_ids.extend(token_ids) - -+ def trim(self, num_tokens: int): -+ del self._token_ids[num_tokens:] -+ - @property - def computed(self) -> bool: - raise NotImplementedError -diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py -index 882f301c..4575fcbd 100644 ---- a/vllm/core/block/prefix_caching_block.py -+++ b/vllm/core/block/prefix_caching_block.py -@@ -456,6 +456,9 @@ class PrefixCachingBlock(Block): - _cow_target=self, - ) - -+ def trim(self, num_tokens: int): -+ return self._block.trim(num_tokens) -+ - @property - def computed(self) -> bool: - return self._computed -diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py -index 52a170d7..0359a4ea 100644 ---- a/vllm/core/block_manager_v1.py -+++ b/vllm/core/block_manager_v1.py -@@ -386,10 +386,13 @@ class BlockSpaceManagerV1(BlockSpaceManager): - self, - seq: Sequence, - num_lookahead_slots: int = 0, -+ backtrack: int = 0, - ) -> List[Tuple[int, int]]: - """Allocate a physical slot for a new token.""" - logical_blocks = seq.logical_token_blocks - block_table = self.block_tables[seq.seq_id] -+ assert backtrack == 0, \ -+ "Backtrack not supported; consider --use-v2-block-manager" - # If we need to allocate a new physical block - if len(block_table) < len(logical_blocks): - # Currently this code only supports adding one physical block -diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py -index f0bc9656..0a62f83d 100644 ---- a/vllm/core/block_manager_v2.py -+++ b/vllm/core/block_manager_v2.py -@@ -167,10 +167,12 @@ class BlockSpaceManagerV2(BlockSpaceManager): - self, - seq: Sequence, - num_lookahead_slots: int, -+ backtrack: int = 0, - ) -> List[Tuple[int, int]]: - - block_table = self.block_tables[seq.seq_id] - -+ block_table.backtrack(backtrack) - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, -diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py -index b2a5e419..c56fd760 100644 ---- a/vllm/core/interfaces.py -+++ b/vllm/core/interfaces.py -@@ -55,6 +55,7 @@ class BlockSpaceManager(ABC): - self, - seq: Sequence, - num_lookahead_slots: int, -+ backtrack: int = 0, - ) -> List[Tuple[int, int]]: - pass - -diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py -index 35e3db18..a6cce282 100644 ---- a/vllm/core/scheduler.py -+++ b/vllm/core/scheduler.py -@@ -6,6 +6,8 @@ from collections import deque - from dataclasses import dataclass, field - from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union - -+from pyaici.comms import AiciRunner -+ - from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig - from vllm.core.interfaces import AllocStatus, BlockSpaceManager - from vllm.core.policy import Policy, PolicyFactory -@@ -274,6 +276,7 @@ class Scheduler: - version="v2" if self.scheduler_config. - use_v2_block_manager else "v1") - -+ self.aici_runner: AiciRunner = None - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, -@@ -316,6 +319,11 @@ class Scheduler: - return 1 - - def add_seq_group(self, seq_group: SequenceGroup) -> None: -+ if seq_group.sampling_params.has_aici: -+ seq = seq_group.get_seqs()[0] -+ seq.has_aici = True -+ self.aici_runner.assign_seq_id(seq_group.request_id, seq.seq_id) -+ - # Add sequence groups to the waiting queue. - self.waiting.append(seq_group) - -@@ -915,6 +923,8 @@ class Scheduler: - ) - - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: -+ runner = self.aici_runner -+ - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. -@@ -936,6 +946,8 @@ class Scheduler: - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq_id = seq.seq_id -+ if seq_group.sampling_params.has_aici: -+ runner.add_mid(seq_id) - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) -@@ -981,6 +993,12 @@ class Scheduler: - ) - seq_group_metadata_list.append(seq_group_metadata) - -+ if runner: -+ if scheduler_outputs.is_empty(): -+ assert not runner.needs_exec_mid() -+ else: -+ runner.exec_mid() -+ - # Now that the batch has been created, we can assume all blocks in the - # batch will have been computed before the next scheduling invocation. - # This is because the engine assumes that a failure in model execution -@@ -996,6 +1014,8 @@ class Scheduler: - - def free_seq(self, seq: Sequence) -> None: - """Free a sequence from a block table.""" -+ if seq.has_aici: -+ self.aici_runner.seq_freed(seq.seq_id) - self.block_manager.free(seq) - - def free_finished_seq_groups(self) -> None: -@@ -1026,7 +1046,10 @@ class Scheduler: - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): -- cows = self.block_manager.append_slots(seq, num_lookahead_slots) -+ cows = self.block_manager.append_slots(seq, -+ num_lookahead_slots, -+ backtrack=seq.backtrack) -+ seq.backtrack = 0 - blocks_to_copy.extend(cows) - - def _preempt( -diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py -index 07b14058..6cd6ed0f 100644 ---- a/vllm/engine/output_processor/single_step.py -+++ b/vllm/engine/output_processor/single_step.py -@@ -85,6 +85,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - -+ aici_runner = self.scheduler.aici_runner -+ - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ -@@ -108,9 +110,26 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] -+ child_seqs.append((parent, parent)) -+ if seq_group.sampling_params.has_aici: -+ sid = parent.seq_id -+ sampled_token = last_child_sample.output_token -+ r = aici_runner.mid_status(sid) -+ assert len(r.branches) <= 1 -+ if r.branches: -+ splice = r.branches[0].find_splice(sampled_token) -+ if splice: -+ parent.splice_tokens(splice.backtrack, -+ splice.ff_tokens) -+ aici_runner.tokens_generated( -+ sid, splice.ff_tokens, backtrack=splice.backtrack) -+ continue # don't call append_token_id() -+ else: -+ aici_runner.tokens_generated(sid, [sampled_token]) - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) -- child_seqs.append((parent, parent)) -+ -+ to_stop = aici_runner.get_seqs_to_stop() if aici_runner else set() - - for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize and self.detokenizer: -@@ -120,6 +139,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): - new_char_count = 0 - self.stop_checker.maybe_stop_sequence(seq, new_char_count, - seq_group.sampling_params) -+ if seq.seq_id in to_stop: -+ seq.status = SequenceStatus.FINISHED_STOPPED -+ seq.stop_reason = "" - - # Non-beam search case - if not seq_group.sampling_params.use_beam_search: -diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py -index 362f28d0..cec4a910 100644 ---- a/vllm/entrypoints/openai/api_server.py -+++ b/vllm/entrypoints/openai/api_server.py -@@ -7,6 +7,7 @@ from http import HTTPStatus - from typing import Optional, Set - - import fastapi -+import pyaici - import uvicorn - from fastapi import Request - from fastapi.exceptions import RequestValidationError -@@ -22,7 +23,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine - from vllm.entrypoints.openai.cli_args import make_arg_parser - from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionResponse, -- CompletionRequest, ErrorResponse) -+ CompletionRequest, ErrorResponse, -+ RunRequest, SetTagsRequest) -+from vllm.entrypoints.openai.serving_aici import AiciRunnerCompletion - from vllm.entrypoints.openai.serving_chat import OpenAIServingChat - from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion - from vllm.logger import init_logger -@@ -32,6 +35,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds - - openai_serving_chat: OpenAIServingChat - openai_serving_completion: OpenAIServingCompletion -+pyaici_runner_completion: AiciRunnerCompletion -+ - logger = init_logger(__name__) - - _running_tasks: Set[asyncio.Task] = set() -@@ -58,6 +63,7 @@ app = fastapi.FastAPI(lifespan=lifespan) - - def parse_args(): - parser = make_arg_parser() -+ parser = pyaici.add_cli_args(parser) - return parser.parse_args() - - -@@ -123,6 +129,51 @@ async def create_completion(request: CompletionRequest, raw_request: Request): - return JSONResponse(content=generator.model_dump()) - - -+def _no_aici(): -+ return JSONResponse({"error": "AICI runtime is not enabled"}, -+ status_code=501) -+ -+ -+@app.post("/v1/controllers") -+async def upload_aici_module(request: Request): -+ if not pyaici_runner_completion: -+ return _no_aici() -+ contents = await request.body() -+ return JSONResponse( -+ await -+ pyaici_runner_completion.aici_runner.upload_module_async(contents)) -+ -+ -+@app.post("/v1/run") -+async def aici_run(request: RunRequest, raw_request: Request): -+ if not pyaici_runner_completion: -+ return _no_aici() -+ request_id, inst_res = \ -+ await pyaici_runner_completion.prep_completion(request) -+ generator = pyaici_runner_completion.create_completion( -+ request_id, inst_res, request, raw_request) -+ return StreamingResponse(content=generator, media_type="text/event-stream") -+ -+ -+@app.post("/v1/controllers/tags") -+async def aici_set_tags(request: SetTagsRequest): -+ if not pyaici_runner_completion: -+ return _no_aici() -+ # non-admin users can only set tags that start with their username -+ auto_info = {"user": "vllm", "is_admin": True} -+ r = await pyaici_runner_completion.aici_runner.set_tags( -+ request.module_id, request.tags, auth_info=auto_info) -+ return JSONResponse(r) -+ -+ -+@app.get("/v1/controllers/tags") -+async def aici_get_tags(): -+ if not pyaici_runner_completion: -+ return _no_aici() -+ r = await pyaici_runner_completion.aici_runner.get_tags() -+ return JSONResponse(r) -+ -+ - if __name__ == "__main__": - args = parse_args() - -@@ -168,7 +219,6 @@ if __name__ == "__main__": - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) -- - event_loop: Optional[asyncio.AbstractEventLoop] - try: - event_loop = asyncio.get_running_loop() -@@ -183,6 +233,15 @@ if __name__ == "__main__": - # When using single vLLM without engine_use_ray - model_config = asyncio.run(engine.get_model_config()) - -+ if args.aici_rt: -+ config = asyncio.run(engine.get_model_config()) -+ dtype = str(config.dtype).replace("torch.", "").replace("float", "f") -+ pyaici_runner = pyaici.runner_from_cli(args, dtype=dtype) -+ pyaici_runner.fast_api() -+ assert len(served_model_names) == 1 -+ pyaici_runner_completion = AiciRunnerCompletion( -+ pyaici_runner, engine, model_config, served_model_names[0]) -+ - openai_serving_chat = OpenAIServingChat(engine, model_config, - served_model_names, - args.response_role, -diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py -index 3cd9ddad..02c0c907 100644 ---- a/vllm/entrypoints/openai/protocol.py -+++ b/vllm/entrypoints/openai/protocol.py -@@ -363,6 +363,32 @@ class CompletionRequest(OpenAIBaseModel): - return data - - -+class RunRequest(BaseModel): -+ prompt: str -+ controller: str -+ controller_arg: Union[str, dict] -+ temperature: Optional[float] = 0.0 -+ top_p: Optional[float] = 1.0 -+ top_k: Optional[int] = -1 -+ max_tokens: Optional[int] = None -+ -+ def to_sampling_params(self): -+ r = SamplingParams( -+ temperature=self.temperature, -+ top_p=self.top_p, -+ top_k=self.top_k, -+ max_tokens=self.max_tokens, -+ ignore_eos=True, -+ ) -+ r.has_aici = True -+ return r -+ -+ -+class SetTagsRequest(BaseModel): -+ module_id: str -+ tags: List[str] -+ -+ - class LogProbs(OpenAIBaseModel): - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) -diff --git a/vllm/entrypoints/openai/serving_aici.py b/vllm/entrypoints/openai/serving_aici.py -new file mode 100644 -index 00000000..81be479f ---- /dev/null -+++ b/vllm/entrypoints/openai/serving_aici.py -@@ -0,0 +1,104 @@ -+from typing import List, Union -+ -+from fastapi import Request -+from pyaici.comms import AiciRunner -+ -+from vllm.config import ModelConfig -+from vllm.engine.async_llm_engine import AsyncLLMEngine -+from vllm.entrypoints.openai.serving_engine import OpenAIServing -+from vllm.utils import random_uuid -+ -+from .protocol import RunRequest -+ -+# logger = init_logger(__name__) -+ -+ -+class AiciRunnerCompletion(OpenAIServing): -+ -+ def __init__(self, aici_runner: AiciRunner, engine: AsyncLLMEngine, -+ model_config: ModelConfig, served_model_names: List[str]): -+ super().__init__(engine=engine, -+ model_config=model_config, -+ served_model_names=served_model_names, -+ lora_modules=None) -+ self.aici_runner = aici_runner -+ self.empty_prompt: List[int] = self.tokenizer("").input_ids -+ if not self.empty_prompt: -+ # if there's no start symbol, add a space, otherwise Engine -+ # gets stuck on empty prompt -+ self.empty_prompt = self.tokenizer(" ").input_ids -+ assert self.empty_prompt -+ # TODO: this is a hack: -+ engine.engine.scheduler.aici_runner = aici_runner -+ -+ # this is separate from create_completion() so fastapi exceptions -+ # from .instantiate_async() are properly sent to the user -+ async def prep_completion(self, request: RunRequest): -+ request_id = f"run-{random_uuid()}" -+ prompt = self.tokenizer(request.prompt).input_ids -+ inst_res = await self.aici_runner.instantiate_async( -+ request_id, prompt, request.controller, request.controller_arg) -+ return request_id, inst_res -+ -+ async def create_completion(self, request_id: str, inst_res: Union[dict, -+ list], -+ request: RunRequest, raw_request: Request): -+ """Completion API for AICI controllers. -+ -+ See https://github.com/microsoft/aici/blob/main/docs/REST.md -+ """ -+ runner = self.aici_runner -+ yield runner.data_line( -+ runner.initial_json(request_id, self.served_model_names[0])) -+ -+ if isinstance(inst_res, dict): -+ # error case -+ yield runner.data_line(inst_res) -+ yield runner.final_data() -+ return -+ -+ # Engine doesn't like prompts with no tokens -+ # self.empty_prompt is either start symbol or a single space -+ if len(inst_res) == 0: -+ inst_res = self.empty_prompt -+ -+ sampling_params = request.to_sampling_params() -+ sampling_params.stop_token_ids = [] -+ generator = self.engine.generate(prompt=None, -+ sampling_params=sampling_params, -+ request_id=request_id, -+ prompt_token_ids=inst_res) -+ -+ previous_texts: List[str] = [] -+ ff_tokens = len(inst_res) -+ sampled_tokens = 0 -+ -+ async for res in generator: -+ # Abort the request if the client disconnects. -+ if await raw_request.is_disconnected(): -+ await self.engine.abort(request_id) -+ raise StopAsyncIteration() -+ forks = [] -+ for output in res.outputs: -+ # TODO: -+ ff_tokens += 1 -+ sampled_tokens += 1 -+ -+ i = output.index -+ while len(previous_texts) <= i: -+ previous_texts.append("") -+ delta_text = output.text[len(previous_texts[i]):] -+ previous_texts[i] = output.text -+ -+ fork_res = runner.seq_logs( -+ output.seq_id, -+ index=i, -+ text=delta_text, -+ finish_reason=output.finish_reason, -+ ) -+ forks.append(fork_res) -+ yield runner.data_line( -+ runner.run_json(forks, -+ runner.usage_json(ff_tokens, sampled_tokens))) -+ -+ yield runner.final_data() -diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py -index e52e350d..2b3992fa 100644 ---- a/vllm/model_executor/layers/sampler.py -+++ b/vllm/model_executor/layers/sampler.py -@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple - - import torch - import torch.nn as nn -+from pyaici.comms import AiciRunner - - from vllm.model_executor.layers.ops.sample import sample as sample_triton - from vllm.model_executor.sampling_metadata import (SamplingMetadata, -@@ -59,6 +60,9 @@ class Sampler(nn.Module): - assert logits is not None - _, vocab_size = logits.shape - -+ # Start with constrained decoding -+ logits = _apply_aici_logit_bias(logits, sampling_metadata) -+ - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - # Prepare sampling tensors with pinned memory to avoid blocking. -@@ -149,6 +153,41 @@ def _get_bin_counts_and_mask( - return bin_counts, mask - - -+def _apply_aici_logit_bias( -+ logits: torch.Tensor, -+ sampling_metadata: SamplingMetadata, -+): -+ aici_runner = AiciRunner.instance -+ if not aici_runner: -+ return logits -+ mid_results, arr = aici_runner.recv_logit_bias_torch() -+ # logits.dtype should generally match arr.dtype -+ bias = arr.to(logits.device).to(logits.dtype) -+ if bias.shape[0] == 0: -+ return logits -+ -+ logits_row_idx = 0 -+ for sg in sampling_metadata.seq_groups: -+ if sg.sampling_params.has_aici: -+ for id in sg.seq_ids: -+ r = mid_results.get(id) -+ if r and len(r.branches) >= 1: -+ # this is actually also enforced by AICIrt since -+ # we don't pass --cap-fork -+ assert len(r.branches) <= 1, "Only one branch is supported" -+ mask = r.branches[0].mask -+ if mask is not None: -+ logits[logits_row_idx] += bias[mask, 0:logits.shape[1]] -+ temp = r.branches[0].temperature -+ if temp is not None: -+ sg.sampling_params.temperature = temp -+ logits_row_idx += 1 -+ else: -+ logits_row_idx += len(sg.seq_ids) -+ -+ return logits -+ -+ - def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -diff --git a/vllm/outputs.py b/vllm/outputs.py -index d01be0eb..7ab9793b 100644 ---- a/vllm/outputs.py -+++ b/vllm/outputs.py -@@ -26,6 +26,7 @@ class CompletionOutput: - - def __init__( - self, -+ seq_id: int, - index: int, - text: str, - token_ids: List[int], -@@ -35,6 +36,7 @@ class CompletionOutput: - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: -+ self.seq_id = seq_id - self.index = index - self.text = text - self.token_ids = token_ids -@@ -114,7 +116,7 @@ class RequestOutput: - include_logprobs = seq_group.sampling_params.logprobs is not None - text_buffer_length = seq_group.sampling_params.output_text_buffer_length - outputs = [ -- CompletionOutput(seqs.index(seq), -+ CompletionOutput(seq.seq_id, seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), -diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py -index 5fa94eb1..57739a51 100644 ---- a/vllm/sampling_params.py -+++ b/vllm/sampling_params.py -@@ -169,6 +169,7 @@ class SamplingParams: - self.spaces_between_special_tokens = spaces_between_special_tokens - self.logits_processors = logits_processors - self.include_stop_str_in_output = include_stop_str_in_output -+ self.has_aici = False - self.truncate_prompt_tokens = truncate_prompt_tokens - # Number of characters to hold back for stop string evaluation - # until sequence is finished. -@@ -277,6 +278,9 @@ class SamplingParams: - def update_from_generation_config( - self, generation_config: Dict[str, Any]) -> None: - """Update if there are non-default values from generation_config""" -+ # For AICI, we want the controller to control stopping. -+ if self.has_aici: -+ return - # Update eos_token_id for generation - if (not self.ignore_eos) and (eos_ids := - generation_config.get("eos_token_id")): -diff --git a/vllm/sequence.py b/vllm/sequence.py -index 3cebb85b..9a07748f 100644 ---- a/vllm/sequence.py -+++ b/vllm/sequence.py -@@ -221,6 +221,8 @@ class Sequence: - self.data: SequenceData = SequenceData(prompt_token_ids) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" -+ self.has_aici = False -+ self.backtrack = 0 - - self.logical_token_blocks: List[LogicalTokenBlock] = [] - # Initialize the logical token blocks with the prompt token ids. -@@ -294,6 +296,58 @@ class Sequence: - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) - -+ def splice_tokens(self, backtrack: int, token_ids: List[int]): -+ assert self.backtrack == 0 -+ -+ data = self.data -+ -+ if not token_ids: -+ # we need at least one token in forward step, -+ # so we pretend we're backtracking one token more -+ # and repeat the token that was there -+ # otherwise, the _num_comptued_tokens gets out of sync -+ backtrack += 1 -+ if backtrack <= len(data.output_token_ids): -+ token_ids = [data.output_token_ids[-backtrack]] -+ else: -+ off = backtrack - len(data.output_token_ids) -+ if off <= len(data.prompt_token_ids): -+ token_ids = [data.prompt_token_ids[-off]] -+ else: -+ token_ids = [1] -+ -+ backtrack = min(backtrack, data.get_len()) -+ self.backtrack = backtrack -+ -+ if backtrack > 0: -+ prompt_backtrack = 0 -+ output_len = len(data.output_token_ids) -+ if backtrack > output_len: -+ prompt_backtrack = backtrack - output_len -+ backtrack = output_len -+ del data.output_token_ids[-backtrack:] -+ del self.output_logprobs[-backtrack:] -+ data._num_computed_tokens = min(data._num_computed_tokens, -+ len(data.output_token_ids)) -+ if prompt_backtrack > 0: -+ assert not data.output_token_ids -+ del data.prompt_token_ids[-prompt_backtrack:] -+ needed_blocks = \ -+ (self.get_len() + self.block_size - 1) // self.block_size -+ if len(self.logical_token_blocks) > needed_blocks: -+ del self.logical_token_blocks[needed_blocks:] -+ if needed_blocks > 0: -+ last_block = self.logical_token_blocks[-1] -+ last_num_tokens = self.get_len() % self.block_size -+ if last_num_tokens == 0: -+ last_num_tokens = self.block_size -+ last_block.num_tokens = last_num_tokens -+ -+ for t in token_ids: -+ self.append_token_id(t, {t: Logprob(logprob=0.0)}) -+ if data.get_num_uncomputed_tokens() > 1: -+ data._stage = SequenceStage.PREFILL -+ - def get_len(self) -> int: - return self.data.get_len() - -@@ -445,7 +499,8 @@ class SequenceGroup: - def get_last_latency(self, now: float) -> Optional[float]: - """Sets the last token time for Request level timings.""" - # If still in prefill phase, raise Error. -- if self.is_prefill(): -+ # With AICI, the request may go from decode to prefill, so ignore. -+ if self.is_prefill() and not self.sampling_params.has_aici: - raise ValueError( - "seq_group.get_last_latency() should not be called " - "if the seq_group is in prefill phase.") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a6cce28263da0..33fd709532e80 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,8 +6,6 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union -from pyaici.comms import AiciRunner - from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.policy import Policy, PolicyFactory @@ -276,7 +274,6 @@ def __init__( version="v2" if self.scheduler_config. use_v2_block_manager else "v1") - self.aici_runner: AiciRunner = None # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( block_size=self.cache_config.block_size, @@ -319,11 +316,6 @@ def num_decoding_tokens_per_seq(self) -> int: return 1 def add_seq_group(self, seq_group: SequenceGroup) -> None: - if seq_group.sampling_params.has_aici: - seq = seq_group.get_seqs()[0] - seq.has_aici = True - self.aici_runner.assign_seq_id(seq_group.request_id, seq.seq_id) - # Add sequence groups to the waiting queue. self.waiting.append(seq_group) @@ -923,8 +915,6 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: ) def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: - runner = self.aici_runner - # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 6cd6ed0f25185..2f5595077d957 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -85,8 +85,6 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # List of (child, parent) child_seqs: List[Tuple[Sequence, Sequence]] = [] - aici_runner = self.scheduler.aici_runner - # Process the child samples for each parent sequence for parent in parent_seqs: child_samples: List[SequenceOutput] = parent_child_dict[ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cec4a910f1b08..362f28d05c3bb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -7,7 +7,6 @@ from typing import Optional, Set import fastapi -import pyaici import uvicorn from fastapi import Request from fastapi.exceptions import RequestValidationError @@ -23,9 +22,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, ErrorResponse, - RunRequest, SetTagsRequest) -from vllm.entrypoints.openai.serving_aici import AiciRunnerCompletion + CompletionRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.logger import init_logger @@ -35,8 +32,6 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion -pyaici_runner_completion: AiciRunnerCompletion - logger = init_logger(__name__) _running_tasks: Set[asyncio.Task] = set() @@ -63,7 +58,6 @@ async def _force_log(): def parse_args(): parser = make_arg_parser() - parser = pyaici.add_cli_args(parser) return parser.parse_args() @@ -129,51 +123,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def _no_aici(): - return JSONResponse({"error": "AICI runtime is not enabled"}, - status_code=501) - - -@app.post("/v1/controllers") -async def upload_aici_module(request: Request): - if not pyaici_runner_completion: - return _no_aici() - contents = await request.body() - return JSONResponse( - await - pyaici_runner_completion.aici_runner.upload_module_async(contents)) - - -@app.post("/v1/run") -async def aici_run(request: RunRequest, raw_request: Request): - if not pyaici_runner_completion: - return _no_aici() - request_id, inst_res = \ - await pyaici_runner_completion.prep_completion(request) - generator = pyaici_runner_completion.create_completion( - request_id, inst_res, request, raw_request) - return StreamingResponse(content=generator, media_type="text/event-stream") - - -@app.post("/v1/controllers/tags") -async def aici_set_tags(request: SetTagsRequest): - if not pyaici_runner_completion: - return _no_aici() - # non-admin users can only set tags that start with their username - auto_info = {"user": "vllm", "is_admin": True} - r = await pyaici_runner_completion.aici_runner.set_tags( - request.module_id, request.tags, auth_info=auto_info) - return JSONResponse(r) - - -@app.get("/v1/controllers/tags") -async def aici_get_tags(): - if not pyaici_runner_completion: - return _no_aici() - r = await pyaici_runner_completion.aici_runner.get_tags() - return JSONResponse(r) - - if __name__ == "__main__": args = parse_args() @@ -219,6 +168,7 @@ async def authentication(request: Request, call_next): engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + event_loop: Optional[asyncio.AbstractEventLoop] try: event_loop = asyncio.get_running_loop() @@ -233,15 +183,6 @@ async def authentication(request: Request, call_next): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) - if args.aici_rt: - config = asyncio.run(engine.get_model_config()) - dtype = str(config.dtype).replace("torch.", "").replace("float", "f") - pyaici_runner = pyaici.runner_from_cli(args, dtype=dtype) - pyaici_runner.fast_api() - assert len(served_model_names) == 1 - pyaici_runner_completion = AiciRunnerCompletion( - pyaici_runner, engine, model_config, served_model_names[0]) - openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, args.response_role, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 02c0c90701244..3cd9ddad3b7b7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -363,32 +363,6 @@ def check_guided_decoding_count(cls, data): return data -class RunRequest(BaseModel): - prompt: str - controller: str - controller_arg: Union[str, dict] - temperature: Optional[float] = 0.0 - top_p: Optional[float] = 1.0 - top_k: Optional[int] = -1 - max_tokens: Optional[int] = None - - def to_sampling_params(self): - r = SamplingParams( - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - max_tokens=self.max_tokens, - ignore_eos=True, - ) - r.has_aici = True - return r - - -class SetTagsRequest(BaseModel): - module_id: str - tags: List[str] - - class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_aici.py b/vllm/entrypoints/openai/serving_aici.py deleted file mode 100644 index 81be479f5166b..0000000000000 --- a/vllm/entrypoints/openai/serving_aici.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import List, Union - -from fastapi import Request -from pyaici.comms import AiciRunner - -from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.utils import random_uuid - -from .protocol import RunRequest - -# logger = init_logger(__name__) - - -class AiciRunnerCompletion(OpenAIServing): - - def __init__(self, aici_runner: AiciRunner, engine: AsyncLLMEngine, - model_config: ModelConfig, served_model_names: List[str]): - super().__init__(engine=engine, - model_config=model_config, - served_model_names=served_model_names, - lora_modules=None) - self.aici_runner = aici_runner - self.empty_prompt: List[int] = self.tokenizer("").input_ids - if not self.empty_prompt: - # if there's no start symbol, add a space, otherwise Engine - # gets stuck on empty prompt - self.empty_prompt = self.tokenizer(" ").input_ids - assert self.empty_prompt - # TODO: this is a hack: - engine.engine.scheduler.aici_runner = aici_runner - - # this is separate from create_completion() so fastapi exceptions - # from .instantiate_async() are properly sent to the user - async def prep_completion(self, request: RunRequest): - request_id = f"run-{random_uuid()}" - prompt = self.tokenizer(request.prompt).input_ids - inst_res = await self.aici_runner.instantiate_async( - request_id, prompt, request.controller, request.controller_arg) - return request_id, inst_res - - async def create_completion(self, request_id: str, inst_res: Union[dict, - list], - request: RunRequest, raw_request: Request): - """Completion API for AICI controllers. - - See https://github.com/microsoft/aici/blob/main/docs/REST.md - """ - runner = self.aici_runner - yield runner.data_line( - runner.initial_json(request_id, self.served_model_names[0])) - - if isinstance(inst_res, dict): - # error case - yield runner.data_line(inst_res) - yield runner.final_data() - return - - # Engine doesn't like prompts with no tokens - # self.empty_prompt is either start symbol or a single space - if len(inst_res) == 0: - inst_res = self.empty_prompt - - sampling_params = request.to_sampling_params() - sampling_params.stop_token_ids = [] - generator = self.engine.generate(prompt=None, - sampling_params=sampling_params, - request_id=request_id, - prompt_token_ids=inst_res) - - previous_texts: List[str] = [] - ff_tokens = len(inst_res) - sampled_tokens = 0 - - async for res in generator: - # Abort the request if the client disconnects. - if await raw_request.is_disconnected(): - await self.engine.abort(request_id) - raise StopAsyncIteration() - forks = [] - for output in res.outputs: - # TODO: - ff_tokens += 1 - sampled_tokens += 1 - - i = output.index - while len(previous_texts) <= i: - previous_texts.append("") - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - - fork_res = runner.seq_logs( - output.seq_id, - index=i, - text=delta_text, - finish_reason=output.finish_reason, - ) - forks.append(fork_res) - yield runner.data_line( - runner.run_json(forks, - runner.usage_json(ff_tokens, sampled_tokens))) - - yield runner.final_data() diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2b3992fa887f4..e52e350d2726f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -from pyaici.comms import AiciRunner from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.sampling_metadata import (SamplingMetadata, @@ -60,9 +59,6 @@ def forward( assert logits is not None _, vocab_size = logits.shape - # Start with constrained decoding - logits = _apply_aici_logit_bias(logits, sampling_metadata) - logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. @@ -153,41 +149,6 @@ def _get_bin_counts_and_mask( return bin_counts, mask -def _apply_aici_logit_bias( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -): - aici_runner = AiciRunner.instance - if not aici_runner: - return logits - mid_results, arr = aici_runner.recv_logit_bias_torch() - # logits.dtype should generally match arr.dtype - bias = arr.to(logits.device).to(logits.dtype) - if bias.shape[0] == 0: - return logits - - logits_row_idx = 0 - for sg in sampling_metadata.seq_groups: - if sg.sampling_params.has_aici: - for id in sg.seq_ids: - r = mid_results.get(id) - if r and len(r.branches) >= 1: - # this is actually also enforced by AICIrt since - # we don't pass --cap-fork - assert len(r.branches) <= 1, "Only one branch is supported" - mask = r.branches[0].mask - if mask is not None: - logits[logits_row_idx] += bias[mask, 0:logits.shape[1]] - temp = r.branches[0].temperature - if temp is not None: - sg.sampling_params.temperature = temp - logits_row_idx += 1 - else: - logits_row_idx += len(sg.seq_ids) - - return logits - - def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/outputs.py b/vllm/outputs.py index 7ab9793bb92d1..d01be0eb0efd2 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -26,7 +26,6 @@ class CompletionOutput: def __init__( self, - seq_id: int, index: int, text: str, token_ids: List[int], @@ -36,7 +35,6 @@ def __init__( stop_reason: Union[int, str, None] = None, lora_request: Optional[LoRARequest] = None, ) -> None: - self.seq_id = seq_id self.index = index self.text = text self.token_ids = token_ids @@ -116,7 +114,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": include_logprobs = seq_group.sampling_params.logprobs is not None text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seq.seq_id, seqs.index(seq), + CompletionOutput(seqs.index(seq), seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), From 4f8e4b0052612961a0c27a6b4c229f7f550e2f3e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 11 May 2024 22:07:21 +0000 Subject: [PATCH 6/9] introduce SequenceController interface --- vllm/core/scheduler.py | 17 ++++----- vllm/engine/llm_engine.py | 1 + vllm/engine/output_processor/single_step.py | 30 +++++++-------- vllm/sampling_params.py | 20 ++++++---- vllm/sequence.py | 9 +++-- vllm/sequence_controller.py | 42 +++++++++++++++++++++ 6 files changed, 82 insertions(+), 37 deletions(-) create mode 100644 vllm/sequence_controller.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 33fd709532e80..e053ea0d47ee6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.sequence_controller import SequenceController logger = init_logger(__name__) @@ -934,10 +935,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} + ctrl = seq_group.sampling_params.controller for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id - if seq_group.sampling_params.has_aici: - runner.add_mid(seq_id) + if ctrl: + ctrl.scheduled(seq) seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) @@ -983,11 +985,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: ) seq_group_metadata_list.append(seq_group_metadata) - if runner: - if scheduler_outputs.is_empty(): - assert not runner.needs_exec_mid() - else: - runner.exec_mid() + if not scheduler_outputs.is_empty(): + SequenceController.forward_started() # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. @@ -1004,8 +1003,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None: """Free a sequence from a block table.""" - if seq.has_aici: - self.aici_runner.seq_freed(seq.seq_id) + if seq.controller: + seq.controller.free(seq) self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b9938b045ba2b..72edae5c1b0c7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -431,6 +431,7 @@ def add_request( "not initialized") seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, eos_token_id, lora_request) + seq.controller = sampling_params.controller # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 2f5595077d957..03a82d68d54c7 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Set, Tuple, Union from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -85,6 +85,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # List of (child, parent) child_seqs: List[Tuple[Sequence, Sequence]] = [] + to_stop: Set[int] = set() + # Process the child samples for each parent sequence for parent in parent_seqs: child_samples: List[SequenceOutput] = parent_child_dict[ @@ -109,26 +111,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] child_seqs.append((parent, parent)) - if seq_group.sampling_params.has_aici: + ctrl = seq_group.sampling_params.controller + if ctrl: sid = parent.seq_id sampled_token = last_child_sample.output_token - r = aici_runner.mid_status(sid) - assert len(r.branches) <= 1 - if r.branches: - splice = r.branches[0].find_splice(sampled_token) - if splice: - parent.splice_tokens(splice.backtrack, - splice.ff_tokens) - aici_runner.tokens_generated( - sid, splice.ff_tokens, backtrack=splice.backtrack) - continue # don't call append_token_id() - else: - aici_runner.tokens_generated(sid, [sampled_token]) + backtrack, ff_tokens, should_stop = ctrl.sampled( + parent, sampled_token, last_child_sample.logprobs) + if should_stop: + to_stop.add(sid) + if backtrack != 0 or ff_tokens != [sampled_token]: + parent.splice_tokens(backtrack, ff_tokens) + continue # don't call append_token_id() parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs) - to_stop = aici_runner.get_seqs_to_stop() if aici_runner else set() - for seq, _ in child_seqs: if seq_group.sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( @@ -139,7 +135,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, seq_group.sampling_params) if seq.seq_id in to_stop: seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = "" + seq.stop_reason = "" # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 57739a5108377..422c2692181ef 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -7,6 +7,7 @@ import torch from pydantic import Field from typing_extensions import Annotated +from .sequence_controller import SequenceController _SAMPLING_EPS = 1e-5 @@ -169,7 +170,7 @@ def __init__( self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output - self.has_aici = False + self.controller: Optional[SequenceController] = None self.truncate_prompt_tokens = truncate_prompt_tokens # Number of characters to hold back for stop string evaluation # until sequence is finished. @@ -278,8 +279,8 @@ def _verify_greedy_sampling(self) -> None: def update_from_generation_config( self, generation_config: Dict[str, Any]) -> None: """Update if there are non-default values from generation_config""" - # For AICI, we want the controller to control stopping. - if self.has_aici: + # If present, we want the controller to control stopping. + if self.controller: return # Update eos_token_id for generation if (not self.ignore_eos) and (eos_ids := @@ -309,10 +310,15 @@ def clone(self) -> "SamplingParams": See https://github.com/vllm-project/vllm/issues/3087 """ - logit_processor_refs = None if self.logits_processors is None else { - id(lp): lp - for lp in self.logits_processors - } + logit_processor_refs: Optional[ + dict] = None if self.logits_processors is None else { + id(lp): lp + for lp in self.logits_processors + } + if self.controller: + if logit_processor_refs is None: + logit_processor_refs = {} + logit_processor_refs[id(self.controller)] = self.controller return copy.deepcopy(self, memo=logit_processor_refs) def __repr__(self) -> str: diff --git a/vllm/sequence.py b/vllm/sequence.py index 9a07748fd43c6..4dc32c562a39e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -7,6 +7,7 @@ from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams +from vllm.sequence_controller import SequenceController if TYPE_CHECKING: import torch @@ -221,8 +222,8 @@ def __init__( self.data: SequenceData = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" - self.has_aici = False self.backtrack = 0 + self.controller: Optional[SequenceController] = None self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. @@ -498,9 +499,9 @@ def lora_int_id(self) -> int: def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" - # If still in prefill phase, raise Error. - # With AICI, the request may go from decode to prefill, so ignore. - if self.is_prefill() and not self.sampling_params.has_aici: + # If still in prefill phase, raise Error (unless using controllers, + # where the request may go from decode to prefill). + if self.is_prefill() and not self.sampling_params.controller: raise ValueError( "seq_group.get_last_latency() should not be called " "if the seq_group is in prefill phase.") diff --git a/vllm/sequence_controller.py b/vllm/sequence_controller.py new file mode 100644 index 0000000000000..1cd62bf844e22 --- /dev/null +++ b/vllm/sequence_controller.py @@ -0,0 +1,42 @@ +from typing import Dict, List, Tuple +from .sequence import Logprob, Sequence + + +class SequenceController: + """Callback for generation control for a single sequence group. + + This can be part of SamplingParams and gets callbacks for various + steps. It is to be used together with LogitsProcessor. + """ + + def scheduled(self, seq: Sequence): + """ + Called whenever the current sequence is scheduled to be run + in the next step. + """ + pass + + @staticmethod + def forward_started(): + """ + Called when all sequences for the current step have been queued. + """ + pass + + def sampled(self, seq: Sequence, token_id: int, + logprobs: Dict[int, Logprob]) -> Tuple[int, List[int], bool]: + """ + Informs the controller a given token has been sampled. + Returns the number of tokens to backtrack, the tokens to append, + and whether to stop. + """ + if token_id == seq.eos_token_id: + return 0, [], True + return 0, [token_id], False + + def free(self, seq: Sequence): + """ + Called when the sequence is stopped, and deallocated. + .scheduled() will not be called again for this sequence. + """ + pass From 09947750e6e9ec406776b2501affb4551cd0c737 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 12 May 2024 18:31:54 +0000 Subject: [PATCH 7/9] formatting --- vllm/sampling_params.py | 1 + vllm/sequence_controller.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 422c2692181ef..9b1984f9d54b7 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -7,6 +7,7 @@ import torch from pydantic import Field from typing_extensions import Annotated + from .sequence_controller import SequenceController _SAMPLING_EPS = 1e-5 diff --git a/vllm/sequence_controller.py b/vllm/sequence_controller.py index 1cd62bf844e22..72b35bbd31cff 100644 --- a/vllm/sequence_controller.py +++ b/vllm/sequence_controller.py @@ -1,4 +1,5 @@ from typing import Dict, List, Tuple + from .sequence import Logprob, Sequence From 64560f7ac2ca97200e6bdd6bcc5e16a0a230f6b6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 12 May 2024 18:37:22 +0000 Subject: [PATCH 8/9] fix merge --- vllm/core/embedding_model_block_manager.py | 1 + vllm/engine/llm_engine.py | 5 +++-- vllm/sequence.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index a09d79ec3c420..e21f7c0f6847a 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -36,6 +36,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, + backtrack: int = 0, ) -> List[Tuple[int, int]]: return None # type: ignore diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a05b4a44a089e..6e85891447555 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -429,7 +429,6 @@ def add_request( "not initialized") seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, eos_token_id, lora_request) - seq.controller = sampling_params.controller # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -476,8 +475,10 @@ def _create_sequence_group_with_sampling( f"{max_logprobs} logprobs.") # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects + # this doesn't deep-copy LogitsProcessor or SequenceController objects sampling_params = sampling_params.clone() + # Link controller to sequence. + seq.controller = sampling_params.controller # Add the eos token id into the sampling_params to support min_tokens # processing if seq.eos_token_id is not None: diff --git a/vllm/sequence.py b/vllm/sequence.py index cb5d97f380adb..27d89ed17126c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -511,7 +511,8 @@ def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error (unless using controllers, # where the request may go from decode to prefill). - if self.is_prefill() and not self.sampling_params.controller: + if self.is_prefill() and not (self.sampling_params + and self.sampling_params.controller): raise ValueError( "seq_group.get_last_latency() should not be called " "if the seq_group is in prefill phase.") From 77ffb6fb4ea04eed52625b2f0289fa1c34251ce1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 24 May 2024 15:56:47 -0700 Subject: [PATCH 9/9] resolve circular deps --- vllm/sequence_controller.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/sequence_controller.py b/vllm/sequence_controller.py index 72b35bbd31cff..9449211eb7245 100644 --- a/vllm/sequence_controller.py +++ b/vllm/sequence_controller.py @@ -1,6 +1,7 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, TYPE_CHECKING -from .sequence import Logprob, Sequence +if TYPE_CHECKING: + from .sequence import Logprob, Sequence class SequenceController: @@ -10,7 +11,7 @@ class SequenceController: steps. It is to be used together with LogitsProcessor. """ - def scheduled(self, seq: Sequence): + def scheduled(self, seq: 'Sequence'): """ Called whenever the current sequence is scheduled to be run in the next step. @@ -24,8 +25,8 @@ def forward_started(): """ pass - def sampled(self, seq: Sequence, token_id: int, - logprobs: Dict[int, Logprob]) -> Tuple[int, List[int], bool]: + def sampled(self, seq: 'Sequence', token_id: int, + logprobs: Dict[int, 'Logprob']) -> Tuple[int, List[int], bool]: """ Informs the controller a given token has been sampled. Returns the number of tokens to backtrack, the tokens to append, @@ -35,7 +36,7 @@ def sampled(self, seq: Sequence, token_id: int, return 0, [], True return 0, [token_id], False - def free(self, seq: Sequence): + def free(self, seq: 'Sequence'): """ Called when the sequence is stopped, and deallocated. .scheduled() will not be called again for this sequence.