Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][BugFix] Fix edge case in VLM scheduling #12065

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,22 @@ def _try_schedule_encoder_inputs(
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
if not self.encoder_cache_manager.can_allocate(request, i):
# The encoder cache is full. We can only schedule the decoder
# tokens just before the encoder input.
num_new_tokens = start_pos - num_computed_tokens
break
if num_encoder_tokens > encoder_budget:
# The encoder budget is exhausted. We can only schedule the
# decoder tokens up until the encoder input.
# NOTE(woosuk): We assume that the encoder tokens should be
# processed altogether, as the encoder usually uses
if (not self.encoder_cache_manager.can_allocate(request, i)
or num_encoder_tokens > encoder_budget):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses
# bidirectional attention.
num_new_tokens = start_pos - num_computed_tokens
if num_computed_tokens < start_pos:
# We only schedule the decoder tokens just before the
# encoder input.
num_new_tokens = start_pos - num_computed_tokens
else:
# Because of prefix caching, num_computed_tokens is greater
# than start_pos even though its encoder input is not
# available. In this case, we can't schedule any token for
# the request in this step.
num_new_tokens = 0
break

encoder_budget -= num_encoder_tokens
Expand Down
Loading