Skip to content

Commit

Permalink
Comment out unused code in sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
peng1999 committed Aug 1, 2024
1 parent c32ab8b commit f6c9233
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)

Expand Down Expand Up @@ -347,14 +346,13 @@ def from_sampling_metadata(
repetition_penalties: List[float] = []
sampling_seeds: List[int] = []
sample_indices: List[int] = []
prompt_best_of: List[int] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False

# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
# # We need one base seed per Triton slice.
# seeds_to_generate = (extra_seeds_to_generate +
# get_num_triton_sampler_splits(vocab_size))

assert sampling_metadata.seq_groups is not None
for seq_group in sampling_metadata.seq_groups:
Expand All @@ -366,9 +364,6 @@ def from_sampling_metadata(
r = sampling_params.repetition_penalty
top_p = sampling_params.top_p
min_p = sampling_params.min_p
seed = sampling_params.seed

is_greedy = sampling_params.sampling_type == SamplingType.GREEDY

# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
Expand All @@ -389,8 +384,7 @@ def from_sampling_metadata(
do_penalties = True

is_prompt = seq_group.is_prompt
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
if (is_prompt and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
query_len = seq_group.query_len
Expand All @@ -415,23 +409,26 @@ def from_sampling_metadata(
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)

if is_prompt:
prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len
assert query_len is not None

for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)
# The following code is for a Triton-based sampler
# that is not enabled.

# if is_prompt:
# prompt_best_of.append(sampling_params.best_of)
# query_len = seq_group.query_len
# assert query_len is not None

# for seq_id in seq_ids:
# seq_data = seq_group.seq_data[seq_id]
# extra_entropy = extra_entropy or ()
# seq_seeds = cls._get_sequence_seeds(
# seed,
# seq_data.get_len(),
# *extra_entropy,
# seq_id,
# seeds_to_generate=seeds_to_generate,
# is_greedy=is_greedy)
# sampling_seeds.append(seq_seeds)
# sample_indices.extend(seq_group.sample_indices)

if do_penalties:
for seq_group in sampling_metadata.seq_groups:
Expand Down Expand Up @@ -549,7 +546,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).T.contiguous()
).t().contiguous()

# Because the memory is pinned, we can do non-blocking
# transfer to device.
Expand Down

0 comments on commit f6c9233

Please sign in to comment.