-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement preemption via recomputation & Refactor scheduling logic #12
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Left some small comments.
cacheflow/master/scheduler.py
Outdated
# sequences, we only support swapping. | ||
# TODO(woosuk): Support recomputation for sequence groups with multiple | ||
# sequences. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add different preemption methods as options? For example, add a preempt_method
function argument and can pick between swapping and recomputation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added PreemptionMode
and allowed the caller of _preempt
to specify the mode. If the mode is not specified, we use recomputation for single-output requests and swapping for multi-output requests.
class PolicyFactory: | ||
|
||
_POLICY_REGISTRY = { | ||
'fcfs': FCFS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will we add SSF in another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. In this PR, I tried to make minimal changes.
# No other sequence groups can be swapped out. | ||
if self.running: | ||
# Preempt the lowest-priority sequence groups. | ||
victim_seq_group = self.running.pop(-1) | ||
self._preempt(victim_seq_group, blocks_to_swap_out) | ||
preempted.append(victim_seq_group) | ||
else: | ||
# No other sequence groups can be preempted. | ||
# Preempt the current sequence group. | ||
self._preempt(seq_group, blocks_to_swap_out) | ||
preempted.append(seq_group) | ||
break | ||
else: | ||
# Append new slots to the sequence group. | ||
self._append(seq_group, blocks_to_copy) | ||
self.running = self.running[:victim_idx + 1] | ||
|
||
# 2. Swap in the swapped sequences if possible. | ||
# NOTE: Here we implicitly assume FCFS scheduling. | ||
# The swapped sequences are in LIFO order. | ||
for i, seq_group in enumerate(reversed(self.swapped)): | ||
if self.block_manager.can_swap_in(seq_group): | ||
self._swap_in(seq_group, blocks_to_swap_in) | ||
self._append(seq_group, blocks_to_copy) | ||
else: | ||
# OOM. Stop swapping. | ||
self.swapped = self.swapped[:len(self.swapped) - i] | ||
running.append(seq_group) | ||
self.running = running | ||
|
||
# Swap in the sequence groups in the SWAPPED state if possible. | ||
self.swapped = self.policy.sort_by_priority(now, self.swapped) | ||
while self.swapped: | ||
seq_group = self.swapped[0] | ||
# If the sequence group has been preempted in this step, stop. | ||
if seq_group in preempted: | ||
break | ||
# If the sequence group cannot be swapped in, stop. | ||
if not self.block_manager.can_swap_in(seq_group): | ||
break | ||
else: | ||
# All swapped sequences are swapped in. | ||
self.swapped.clear() | ||
|
||
# Ensure that swap-in and swap-out never happen at the same timestep. | ||
if blocks_to_swap_in: | ||
assert not blocks_to_swap_out | ||
seq_group = self.swapped.pop(0) | ||
self._swap_in(seq_group, blocks_to_swap_in) | ||
self._append(seq_group, blocks_to_copy) | ||
self.running.append(seq_group) | ||
|
||
num_batched_tokens = sum( | ||
seq_group.num_seqs(status=SequenceStatus.RUNNING) | ||
for seq_group in self.running | ||
) | ||
|
||
# 3. Join new sequences if possible. | ||
# NOTE: Here we implicitly assume FCFS scheduling. | ||
# TODO(woosuk): Add a batching policy to control the batch size. | ||
# Join waiting sequences if possible. | ||
prompt_group_ids: List[int] = [] | ||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly | ||
# prioritized over the sequence groups in the WAITING state. | ||
# This is because we want to bound the amount of CPU memory taken by | ||
# the swapped sequence groups. | ||
if not self.swapped: | ||
for i, seq_group in enumerate(self.pending): | ||
self.waiting = self.policy.sort_by_priority(now, self.waiting) | ||
while self.waiting: | ||
seq_group = self.waiting[0] | ||
# If the sequence group has been preempted in this step, stop. | ||
if seq_group in preempted: | ||
break | ||
# If the sequence group cannot be allocated, stop. | ||
if not self.block_manager.can_allocate(seq_group): | ||
break | ||
|
||
# If the number of batched tokens exceeds the limit, stop. | ||
num_prompt_tokens = seq_group.seqs[0].get_len() | ||
if self.block_manager.can_allocate(seq_group): | ||
if (num_batched_tokens + num_prompt_tokens | ||
<= self.max_num_batched_tokens): | ||
self._allocate(seq_group) | ||
num_batched_tokens += num_prompt_tokens | ||
continue | ||
|
||
self.pending = self.pending[i:] | ||
break | ||
else: | ||
self.pending.clear() | ||
if (num_batched_tokens + num_prompt_tokens | ||
> self.max_num_batched_tokens): | ||
break | ||
|
||
seq_group = self.waiting.pop(0) | ||
self._allocate(seq_group) | ||
self.running.append(seq_group) | ||
num_batched_tokens += num_prompt_tokens | ||
prompt_group_ids.append(seq_group.group_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this part to a new function dedicated to swapping and finding which sequences to run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I moved the scheduling logic to a new function _schedule
.
@@ -76,7 +76,8 @@ def __init__( | |||
self.block_tables: Dict[int, BlockTable] = {} | |||
|
|||
def can_allocate(self, seq_group: SequenceGroup) -> bool: | |||
# NOTE: Here we assume that all sequences in the group have the same prompt. | |||
# FIXME(woosuk): Here we assume that all sequences in the group share | |||
# the same prompt. This may not be true for preempted sequences. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, is this function only wrong when we use recomputation preemption for parallel decoding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and for beam search as well.
Hi @WoosukKwon, if we had a kernel that can do one of the followings
I think we can solve the problem of preempt by recompute for multi-sequence requests. Do you agree with this? We first run the normal prefill on the shared prompt tokens, followed by necessary copying of partially shared blocks. |
Transformers 4.39
* sharded prequantized checkpoints * update --------- Co-authored-by: Hao Zhang <[email protected]>
…ble_ROCm6.1 Bump Docker to ROCm 6.1, add gradlib for tuned gemm, include RCCL fixes
Add TP2 config for H100
Summary: Initial integration for the sparse-fused gemm. To achieve this, we need to ensure that we compress the weight matrix only once and never decompress it, as decompression is currently unsupported. Before this change, using `SparseParameter(SparseTensor)` meant that in `MergedColumnParallelLinear` and `QKVParallelLinear` every time a new shard was loaded by the `weight_loader` (e.g., the "q" portion of `QKVParallelLinear`), we would decompress the tensor in-order to use narrow to update the appropriate section of the weight tensor. With this change, `SparseParameter(SparseTensor)` is replaced with `LazyCompressedParameter`, which allows us to operate on `uncompressed_data` until we explicitly compress it. At that point, the `uncompressed_data` is compressed into `compressed_data` and freed. Currently, the detection of when to call compress is somewhat hacky. For `QKVParallelLinear`, we compress only after inserting "q", "k", and "v" shard ids, and for `MergedColumnParallelLinear`, we compress once we've inserted the same number of shards as outputs (determined by `len(output_sizes)`), which implicitly assumes one shard per output. Moving away from `SparseParameter(SparseTensor)` means that `SparseTensor` no longer handles dispatching to the custom ops; instead, this is handled by `SparseW16A16LinearMethod`. I believe this is a positive change overall. `SparseTensor` was an unnecessary extra layer of abstraction/indirection originally designed for the SLoRA work, not vLLM. This did result in the 2:4 sparse implementation breaking. However, it turns out it was already broken (i.e., it was decompressing and running dense within `SparseTensor`), so we "disable" it for now ("disable" meaning decompress and run dense instead). We should revisit all of this infrastructure post-MVP. --------- Co-authored-by: Andrew Feldman <[email protected]>
Summary: Initial integration for the sparse-fused gemm. To achieve this, we need to ensure that we compress the weight matrix only once and never decompress it, as decompression is currently unsupported. Before this change, using `SparseParameter(SparseTensor)` meant that in `MergedColumnParallelLinear` and `QKVParallelLinear` every time a new shard was loaded by the `weight_loader` (e.g., the "q" portion of `QKVParallelLinear`), we would decompress the tensor in-order to use narrow to update the appropriate section of the weight tensor. With this change, `SparseParameter(SparseTensor)` is replaced with `LazyCompressedParameter`, which allows us to operate on `uncompressed_data` until we explicitly compress it. At that point, the `uncompressed_data` is compressed into `compressed_data` and freed. Currently, the detection of when to call compress is somewhat hacky. For `QKVParallelLinear`, we compress only after inserting "q", "k", and "v" shard ids, and for `MergedColumnParallelLinear`, we compress once we've inserted the same number of shards as outputs (determined by `len(output_sizes)`), which implicitly assumes one shard per output. Moving away from `SparseParameter(SparseTensor)` means that `SparseTensor` no longer handles dispatching to the custom ops; instead, this is handled by `SparseW16A16LinearMethod`. I believe this is a positive change overall. `SparseTensor` was an unnecessary extra layer of abstraction/indirection originally designed for the SLoRA work, not vLLM. This did result in the 2:4 sparse implementation breaking. However, it turns out it was already broken (i.e., it was decompressing and running dense within `SparseTensor`), so we "disable" it for now ("disable" meaning decompress and run dense instead). We should revisit all of this infrastructure post-MVP. --------- Co-authored-by: Andrew Feldman <[email protected]>
…x-633313fb5af9953589a88bc244a2a983 [Snyk] Security upgrade starlette from 0.38.6 to 0.40.0
…ctx-fix fixed phi3longrope rotary dim
…_version Update Habana UBI image to fix CVE, GRPC issue and WARMUP issue
This PR implements a new preemption (eviction) mechanism "recomputation". In our benchmark results, recomputation is more efficient than swapping, because swapping incurs significant overheads due to numerous small data transfers between CPU and GPU. Thus, we use recomputation for our default preemption mechanism.
However, currently we do not support recomputation for sequence groups with multiple sequences. This is because when token blocks are shared, the recomputation logic becomes very complex and we do not have CUDA kernels to efficiently support it. We will use swapping for this case despite its overheads.
Besides, this PR also refactors the scheduling logic to be easier to understand.