diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 9d0331781c7a8..015e85b4ca81d 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,11 +7,14 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +# Some triton sampler related code is guarded before it is ready. +_USE_TRITON_SAMPLER = False @dataclass @@ -350,9 +353,12 @@ def from_sampling_metadata( do_top_p_top_k = False do_min_p = False - # # We need one base seed per Triton slice. - # seeds_to_generate = (extra_seeds_to_generate + - # get_num_triton_sampler_splits(vocab_size)) + if _USE_TRITON_SAMPLER: + prompt_best_of: List[int] = [] + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: @@ -409,26 +415,27 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - # The following code is for a Triton-based sampler - # that is not enabled. - - # if is_prompt: - # prompt_best_of.append(sampling_params.best_of) - # query_len = seq_group.query_len - # assert query_len is not None - - # for seq_id in seq_ids: - # seq_data = seq_group.seq_data[seq_id] - # extra_entropy = extra_entropy or () - # seq_seeds = cls._get_sequence_seeds( - # seed, - # seq_data.get_len(), - # *extra_entropy, - # seq_id, - # seeds_to_generate=seeds_to_generate, - # is_greedy=is_greedy) - # sampling_seeds.append(seq_seeds) - # sample_indices.extend(seq_group.sample_indices) + if _USE_TRITON_SAMPLER: + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + query_len = seq_group.query_len + assert query_len is not None + + seed = sampling_params.seed + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.extend(seq_group.sample_indices) if do_penalties: for seq_group in sampling_metadata.seq_groups: