Skip to content

Commit

Permalink
test: rewrite test_running_prefill_prioritized_over_swap
Browse files Browse the repository at this point in the history
  • Loading branch information
toslunar committed Jul 6, 2024
1 parent 791a238 commit 231ec60
Showing 1 changed file with 38 additions and 23 deletions.
61 changes: 38 additions & 23 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,12 @@ def test_running_prefill_prioritized_over_swap():
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)

_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
# Artificial priority is needed for testing with the fcfs policy in
# _schedule_running. This seq will be prioritized more among running
# seqs but less among waiting seqs.
_, seq_group2 = create_dummy_prompt("2", prompt_length=20 + 30 + 30 + 30)

_, seq_group = create_dummy_prompt("1", prompt_length=30 + 10, best_of=2)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
Expand All @@ -393,7 +398,27 @@ def test_running_prefill_prioritized_over_swap():
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens

# The request should be swapped out.
# Add 1 more task.
scheduler.add_seq_group(seq_group2)
_, out = schedule_and_update_computed_tokens(scheduler)
# task 1 finished the last 10 tokens of prefill.
# task 2 started the first 20 tokens of prefill.
assert len(out.scheduled_seq_groups) == 2
assert out.num_prefill_groups == 2
assert not seq_group.is_prefill()
assert seq_group2.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens

# seq_group starts decoding with best_of=2
# see vllm/engine/output_processor/single_step.py
seq = seq_group.seqs_dict[1]
new_seq_id = 3
new_seq = seq.fork(new_seq_id)
seq_group.add(new_seq)
scheduler.fork_seq(seq, new_seq)
append_new_token(seq_group, 1)

# The first request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()

def cannot_append_second_group(seq_group, num_lookahead_slots):
Expand All @@ -402,32 +427,29 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)

# The running prefill is now swapped.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != []
assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out != []
assert out.scheduled_seq_groups[0].seq_group == seq_group2

# Add 1 more task. Swap is not possible, so prefill is running.
# Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER

_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
scheduler.add_seq_group(seq_group2)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2

# Now although swap is possible, running prefill is prioritized.
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
Expand All @@ -437,23 +459,16 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):

# Decoding is prioritized.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert len(out.scheduled_seq_groups) == 2
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 1
assert out.blocks_to_swap_in == []
assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
assert not seq_group.is_prefill()
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group, 1)
append_new_token(seq_group2, 1)

# Since we abort the sequence group, we can finally swap.
scheduler.abort_seq_group(seq_group2.request_id)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []


def test_chunked_prefill_preempt():
"""Verify preempt works with chunked prefill requests"""
Expand Down

0 comments on commit 231ec60

Please sign in to comment.