Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Reduce unnecessary compute when logprobs=None #6532

Merged
merged 15 commits into from
Jul 29, 2024
Merged
143 changes: 80 additions & 63 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs = 1
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False

# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
Expand Down Expand Up @@ -808,42 +811,49 @@ def _get_logprobs(
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)

use_beam_search = use_beam_search or sampling_params.use_beam_search

assert len(next_token_ids) == len(query_indices)

if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]

query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)

# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]

# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
else:
top_logprobs, top_token_ids = None, None
selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None

# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device)

# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]

# We need to compute top k only if there exists logprobs > 0.
if largest_num_logprobs > 0:
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')

selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')

# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
Expand Down Expand Up @@ -940,46 +950,53 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs or 0
num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result

if seq_group.do_sample:
assert len(next_token_ids) > 0
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
}
# Get top K logprobs.
if num_logprobs > 0:
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
if num_logprobs is None and not use_beam_search:
for next_token_id in next_token_ids:
# Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(0.0)})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 0 a value that makes sense for a logprob? We are still using real token ids here, so maybe 1.0 would be more representative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think using NaN here? The value will propagate to cumulative_logprobs. I think it's better to show a NaN than an arbitrary positive number here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah + 1 in nan

else:
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id, parent_id) in enumerate(
zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id:
(selected_logprob_items[idx], rank_items[idx])
}
if num_logprobs is not None and num_logprobs > 0:
# Get top K logprobs.
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[
top_logprob_idx + parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(
top_ids, top_probs, top_ranks)
})

sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})

sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})

# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
Expand Down
Loading