Skip to content

Commit

Permalink
Clean unused KVCache after usage (vllm-project#10)
Browse files Browse the repository at this point in the history
* Add underlying functions

* tests done
  • Loading branch information
gc-fu authored Oct 25, 2023
1 parent a8561b8 commit fbcfee9
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 8 deletions.
2 changes: 2 additions & 0 deletions tests/under_models/send_mock_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def step_async(self) -> List[RequestOutput]:
blocks_to_swap_in={},
blocks_to_swap_out={},
blocks_to_copy={},
finished_seqs=[],
)
print(output)

Expand All @@ -68,6 +69,7 @@ async def step_async_multiple(self) -> List[RequestOutput]:
blocks_to_swap_in={},
blocks_to_swap_out={},
blocks_to_copy={},
finished_seqs=[],
)

# TODO: change this to real one
Expand Down
13 changes: 11 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
finished_seqs: List[int],
) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
Expand All @@ -46,11 +47,13 @@ def __init__(
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
self.finished_seqs = finished_seqs

def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
and not self.blocks_to_swap_out and not self.blocks_to_copy
and not self.finished_seqs)


class Scheduler:
Expand Down Expand Up @@ -417,6 +420,7 @@ def __init__(
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
self.cleaned: List[int] = []

def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
Expand Down Expand Up @@ -456,6 +460,8 @@ def _schedule(self) -> SchedulerOutputs:

ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = []
finished_seqs: List[int] = self.cleaned.copy()
self.cleaned=[]
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
Expand Down Expand Up @@ -518,6 +524,7 @@ def _schedule(self) -> SchedulerOutputs:
blocks_to_swap_out={},
blocks_to_copy={},
ignored_seq_groups=ignored_seq_groups,
finished_seqs=finished_seqs,
)
return scheduler_outputs

Expand All @@ -539,6 +546,7 @@ def _schedule(self) -> SchedulerOutputs:
blocks_to_swap_out={},
blocks_to_copy={},
ignored_seq_groups=[],
finished_seqs=finished_seqs,
)
return scheduler_outputs

Expand Down Expand Up @@ -576,7 +584,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)

def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)
#self.block_manager.free(seq)
self.cleaned.append(seq.seq_id)

def free_finished_seq_groups(self) -> None:
for seq_group in self.running:
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,15 @@ async def step_async(self) -> List[RequestOutput]:
return ignored

# Execute the model.
# Co(gc): Now that we do not have page table support, we need to pass the
# list of sequences that have been finished so that we can clean the KVCache.
output = await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
finished_seqs=scheduler_outputs.finished_seqs,
)
print("We finished model_execution")
return self._process_model_outputs(output, scheduler_outputs) + ignored
Expand Down
10 changes: 4 additions & 6 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def _process_sequence_group_samples(
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
# TODO(gc): Should we do anything special in this case?
# self.scheduler.free_seq(parent)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
# The outputs diverges, we need to fork the requests
Expand Down Expand Up @@ -425,7 +424,7 @@ def _process_sequence_group_samples(
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
#self.scheduler.free_seq(seq)
self.scheduler.free_seq(seq)
pass
return

Expand Down Expand Up @@ -523,8 +522,7 @@ def _process_sequence_group_samples(
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
#self.scheduler.free_seq(seq)
pass
self.scheduler.free_seq(seq)

# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
Expand All @@ -533,7 +531,7 @@ def _process_sequence_group_samples(
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
#self.scheduler.free_seq(seq)
self.scheduler.free_seq(seq)

def _process_model_outputs(
self, output: SamplerOutput,
Expand Down
19 changes: 19 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def __init__(

self.kv_cache = dict()

def clean_finished_seqs(
self,
finished_seqs: List[int]
):
"""
This function cleans the finished sequences and their KVCache in self.kv_cache
"""
for seq_id in finished_seqs:
if seq_id not in self.kv_cache.keys():
raise ValueError(
f"Duplicate key {seq_id} received during clean worker's KVCache"
)
del self.kv_cache[seq_id]


def init_model(self):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
Expand Down Expand Up @@ -293,6 +308,7 @@ def execute_model(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
finished_seqs: List[int],
) -> SamplerOutput:
# Issue cache operations.
# issued_cache_op = False
Expand All @@ -310,6 +326,9 @@ def execute_model(
# cache_events = self.cache_events
# else:
# cache_events = None
if finished_seqs:
self.clean_finished_seqs(finished_seqs)

cache_events = None
# If there is no input, we don't need to execute the model.
if not seq_group_metadata_list:
Expand Down

0 comments on commit fbcfee9

Please sign in to comment.