From f6c923382b8d3c7a05935ca95ec9a13f99ee4afe Mon Sep 17 00:00:00 2001 From: Peng Guanwen Date: Thu, 1 Aug 2024 15:58:33 +0800 Subject: [PATCH 1/2] Comment out unused code in sampler --- vllm/model_executor/sampling_metadata.py | 53 +++++++++++------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 59cfec9ec8934..9d0331781c7a8 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,7 +7,6 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata -from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) @@ -347,14 +346,13 @@ def from_sampling_metadata( repetition_penalties: List[float] = [] sampling_seeds: List[int] = [] sample_indices: List[int] = [] - prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) + # # We need one base seed per Triton slice. + # seeds_to_generate = (extra_seeds_to_generate + + # get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: @@ -366,9 +364,6 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p - seed = sampling_params.seed - - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) @@ -389,8 +384,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): + if (is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -415,23 +409,26 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) + # The following code is for a Triton-based sampler + # that is not enabled. + + # if is_prompt: + # prompt_best_of.append(sampling_params.best_of) + # query_len = seq_group.query_len + # assert query_len is not None + + # for seq_id in seq_ids: + # seq_data = seq_group.seq_data[seq_id] + # extra_entropy = extra_entropy or () + # seq_seeds = cls._get_sequence_seeds( + # seed, + # seq_data.get_len(), + # *extra_entropy, + # seq_id, + # seeds_to_generate=seeds_to_generate, + # is_greedy=is_greedy) + # sampling_seeds.append(seq_seeds) + # sample_indices.extend(seq_group.sample_indices) if do_penalties: for seq_group in sampling_metadata.seq_groups: @@ -549,7 +546,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], device="cpu", dtype=torch.long, pin_memory=pin_memory, - ).T.contiguous() + ).t().contiguous() # Because the memory is pinned, we can do non-blocking # transfer to device. From 449ce72c610e992d39da91134455a815663ee303 Mon Sep 17 00:00:00 2001 From: Peng Guanwen Date: Fri, 2 Aug 2024 14:43:44 +0800 Subject: [PATCH 2/2] use flag guards instead of commenting out --- vllm/model_executor/sampling_metadata.py | 53 ++++++++++++++---------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 9d0331781c7a8..015e85b4ca81d 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,11 +7,14 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +# Some triton sampler related code is guarded before it is ready. +_USE_TRITON_SAMPLER = False @dataclass @@ -350,9 +353,12 @@ def from_sampling_metadata( do_top_p_top_k = False do_min_p = False - # # We need one base seed per Triton slice. - # seeds_to_generate = (extra_seeds_to_generate + - # get_num_triton_sampler_splits(vocab_size)) + if _USE_TRITON_SAMPLER: + prompt_best_of: List[int] = [] + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: @@ -409,26 +415,27 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - # The following code is for a Triton-based sampler - # that is not enabled. - - # if is_prompt: - # prompt_best_of.append(sampling_params.best_of) - # query_len = seq_group.query_len - # assert query_len is not None - - # for seq_id in seq_ids: - # seq_data = seq_group.seq_data[seq_id] - # extra_entropy = extra_entropy or () - # seq_seeds = cls._get_sequence_seeds( - # seed, - # seq_data.get_len(), - # *extra_entropy, - # seq_id, - # seeds_to_generate=seeds_to_generate, - # is_greedy=is_greedy) - # sampling_seeds.append(seq_seeds) - # sample_indices.extend(seq_group.sample_indices) + if _USE_TRITON_SAMPLER: + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + query_len = seq_group.query_len + assert query_len is not None + + seed = sampling_params.seed + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.extend(seq_group.sample_indices) if do_penalties: for seq_group in sampling_metadata.seq_groups: