Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Comment out unused code in sampler #7023

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER = False


@dataclass
Expand Down Expand Up @@ -347,14 +349,16 @@ def from_sampling_metadata(
repetition_penalties: List[float] = []
sampling_seeds: List[int] = []
sample_indices: List[int] = []
prompt_best_of: List[int] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False

# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
if _USE_TRITON_SAMPLER:
prompt_best_of: List[int] = []

# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))

assert sampling_metadata.seq_groups is not None
for seq_group in sampling_metadata.seq_groups:
Expand All @@ -366,9 +370,6 @@ def from_sampling_metadata(
r = sampling_params.repetition_penalty
top_p = sampling_params.top_p
min_p = sampling_params.min_p
seed = sampling_params.seed

is_greedy = sampling_params.sampling_type == SamplingType.GREEDY

# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
Expand All @@ -389,8 +390,7 @@ def from_sampling_metadata(
do_penalties = True

is_prompt = seq_group.is_prompt
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
if (is_prompt and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
query_len = seq_group.query_len
Expand All @@ -415,23 +415,27 @@ def from_sampling_metadata(
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)

if is_prompt:
prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len
assert query_len is not None

for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)
if _USE_TRITON_SAMPLER:
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len
assert query_len is not None

seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY

for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)

if do_penalties:
for seq_group in sampling_metadata.seq_groups:
Expand Down Expand Up @@ -549,7 +553,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).T.contiguous()
).t().contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to avoid pytorch warning because sampling_seeds_t becomes 1-dim here.


# Because the memory is pinned, we can do non-blocking
# transfer to device.
Expand Down
Loading