-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. #3951
[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. #3951
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts before discussing this PR. Skip sampler & tests.
speculative_max_model_len is mainly used for testing that sequences can | ||
skip speculation. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add a check to make sure speculative_max_model_len
< min( draft_max_model_len, target_max_model_len, )
in case user sets speculative_max_model_len inappropriately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Cade to fix
# process the output tokens. Otherwise, they are (chunked) prefill | ||
# samples and should not be processed. | ||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()] | ||
if all(stage == SequenceStage.DECODE for stage in stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- A bit concern here:
From this, it seems that we assume DECODE stage only has 1 new token? - I assume we cannot have chunked prefill and speculative decoding cannot be turned on the same time? Did we explicitly check or document that somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- (Cade fill out answer)
- Cade to verify args and raise if chunked prefill enabled while spec decode enabled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answer for future readers:
- Chunked prefill and speculative decoding from a systems perspective are compatible, however the current vLLM implementations need work to be enabled together. I'll add a validation check which raises if both are enabled.
- The DECODE stage currently only reports 1 new token. This is used by the scheduler to prevent a batch from becoming compute-bound using the token budget. When chunked prefill is enabled, we will need to adjust this to take into account the "new tokens" computed during speculative verification and modify this value. When chunked prefill is disabled, the new token budget is max_num_batched_tokens, and we are OK with the fact that the budget system doesn't take speculative decoding into account.
- I'll make an issue for integrating chunked prefill with spec decode soon!
# process the output tokens. Otherwise, they are (chunked) prefill | ||
# samples and should not be processed. | ||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()] | ||
if all(stage == SequenceStage.DECODE for stage in stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- (Cade fill out answer)
- Cade to verify args and raise if chunked prefill enabled while spec decode enabled
speculative_max_model_len is mainly used for testing that sequences can | ||
skip speculation. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Cade to fix
@@ -680,12 +760,36 @@ def _get_logprobs( | |||
return result_prompt_logprobs, result_sample_logprobs | |||
|
|||
|
|||
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cade to list how this fits into sampler overall
- Why not use very small temperature instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed me concerns after discussion, please add some doc to clarify, thanks!
Applied feedback @LiuXiaoxuanPKU . I will enable auto-merge with your approval; if you have any more comments happy to take them in a future PR. |
main branch was broken, merging again to get #4271 |
This PR adds e2e correctness tests for speculative decoding. It is PR 7/9 in the speculative decoding open sourcing plan.
The E2E correctness tests verify that the generated output of a sequence with speculative decoding is equal to the generated output without speculative decoding when temperature is 0. We test various batch sizes, models, speculative lens, block sizes, num_gpu_blocks (& preemption), and max_model_lens (& skipping speculation for some/all sequences) and verify that this core greedy equality property holds.
See test_correctness.py for more details on test methodology.
Bugfixes
To make the tests pass, this PR introduces several fixes that are listed in order of notoriety:
system efficiency
calculation fixed.Minor feature additions
The following features were added:
max_model_len
for use in testing. This was required to test preemption.