From f84ff177751a36ea03cbcb631cf6e131fc13b606 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Thu, 1 Aug 2024 09:27:28 +0800 Subject: [PATCH 01/18] add cfg worker --- pyproject.toml | 3 +- requirements-build.txt | 2 +- requirements-cuda.txt | 8 +- .../cfg_model_runner.py | 157 +++++++++++++++++ vllm/classifier_free_guidance/cfg_worker.py | 159 ++++++++++++++++++ .../separated_worker.py | 116 +++++++++++++ vllm/config.py | 39 +++++ vllm/core/scheduler.py | 12 ++ vllm/engine/arg_utils.py | 19 ++- vllm/engine/llm_engine.py | 25 ++- vllm/executor/executor_base.py | 4 +- vllm/executor/gpu_executor.py | 128 +++++++++++++- vllm/inputs/data.py | 19 ++- vllm/model_executor/models/llama.py | 27 ++- vllm/model_executor/models/opt.py | 24 ++- vllm/sampling_params.py | 5 +- vllm/sequence.py | 38 ++++- vllm/spec_decode/spec_decode_worker.py | 9 + vllm/worker/model_runner.py | 11 ++ vllm/worker/model_runner_base.py | 1 + vllm/worker/worker.py | 3 +- vllm/worker/worker_base.py | 11 +- 22 files changed, 791 insertions(+), 29 deletions(-) create mode 100644 vllm/classifier_free_guidance/cfg_model_runner.py create mode 100644 vllm/classifier_free_guidance/cfg_worker.py create mode 100644 vllm/classifier_free_guidance/separated_worker.py diff --git a/pyproject.toml b/pyproject.toml index 1ba1eacd90084..4e7c78b1fa468 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,8 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.3.1", + # "torch == 2.3.1", + "torch", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index b05f38a0ed919..1ade34b94a7b4 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.3.1 +# torch==2.3.1 wheel diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 3eb91212e976e..7b99e0002696c 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,8 +4,8 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 nvidia-ml-py # for pynvml package -torch == 2.3.1 +# torch == 2.3.1 # These must be updated alongside torch -torchvision == 0.18.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -xformers == 0.0.27 # Requires PyTorch 2.3.1 -vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1 +# torchvision == 0.18.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# xformers == 0.0.27 # Requires PyTorch 2.3.1 +# vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1 diff --git a/vllm/classifier_free_guidance/cfg_model_runner.py b/vllm/classifier_free_guidance/cfg_model_runner.py new file mode 100644 index 0000000000000..b015c7bc3998d --- /dev/null +++ b/vllm/classifier_free_guidance/cfg_model_runner.py @@ -0,0 +1,157 @@ +from typing import List, Optional, Union + +import torch + +from vllm.distributed import get_pp_group +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import (ModelRunner, ModelInputForGPUWithSamplingMetadata, + FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper) + + +class CFGModelRunner(ModelRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def model_execute( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> torch.Tensor: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = self.graph_runners[ + model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multi_modal_kwargs, + **seqlen_agnostic_kwargs) + + return hidden_or_intermediate_states + + @torch.inference_mode() + def get_logits( + self, + hidden_or_intermediate_states: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model._get_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + @torch.inference_mode() + def compute_logits( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model.compute_logits(logits, + model_input.sampling_metadata) + + @torch.inference_mode() + def do_sample( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ): + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + if self.return_hidden_states: + raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner") + + return [output] + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + + hidden_or_intermediate_states = self.model_execute(model_input, kv_caches, intermediate_tensors, num_steps) + + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + hidden_or_intermediate_states = self.get_logits(hidden_or_intermediate_states, model_input) + logits = self.compute_logits(hidden_or_intermediate_states, model_input) + + return self.do_sample(logits, model_input) diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py new file mode 100644 index 0000000000000..f3220e5d1c0d4 --- /dev/null +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -0,0 +1,159 @@ +import copy + +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.config import ParallelConfig, ClassifierFreeGuidanceConfig +from vllm.logger import init_logger +from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata, SequenceData + +from vllm.classifier_free_guidance.cfg_model_runner import CFGModelRunner +from vllm.classifier_free_guidance.separated_worker import SeparatedWorker + +logger = init_logger(__name__) + + +def create_cfg_worker(*args, **kwargs) -> "CFGWorker": + + assert "classifier_free_guidance_config" in kwargs + classifier_free_guidance_config: ClassifierFreeGuidanceConfig = kwargs.get("classifier_free_guidance_config") + assert classifier_free_guidance_config is not None + kwargs.pop("classifier_free_guidance_config") + + kwargs["model_runner_cls"] = CFGModelRunner + root_worker = SeparatedWorker(*args, **kwargs) + + print("create_cfg_worker") + print("args", args) + print("kwargs", kwargs) + + guidance_model_config = classifier_free_guidance_config.guidance_model_config + guidance_parallel_config = classifier_free_guidance_config.guidance_parallel_config + kwargs.update( + model_config=guidance_model_config, + parallel_config=guidance_parallel_config, + ) + guidance_worker = SeparatedWorker(*args, **kwargs) + + return CFGWorker( + root_worker=root_worker, + guidance_worker=guidance_worker, + ) + + +class CFGWorker(LoraNotSupportedWorkerBase): + def __init__( + self, + root_worker: WorkerBase, + guidance_worker: WorkerBase, + ): + self.root_worker = root_worker + self.guidance_worker = guidance_worker + + def init_device(self): + self.root_worker.init_device() + self.guidance_worker.init_device() + + def load_model(self): + self.root_worker.load_model() + # TODO(zhaoyinglia): guidance_worker shares weight with root_worker + self.guidance_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_gpu_blocks, num_cpu_blocks = ( + self.root_worker.determine_num_available_blocks()) + + root_cache_block_size_bytes = ( + self.root_worker.get_cache_block_size_bytes()) + guidance_cache_block_size_bytes = ( + self.guidance_worker.get_cache_block_size_bytes()) + + new_num_gpu_blocks = int( + num_gpu_blocks * root_cache_block_size_bytes / + (guidance_cache_block_size_bytes + root_cache_block_size_bytes)) + return new_num_gpu_blocks, num_cpu_blocks + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int + ): + print("num_gpu_blocks:", num_gpu_blocks) + self.root_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.guidance_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + + # print("==>[zyl] execute_model_req:", execute_model_req) + # for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # for seq_data in seq_group_metadata.seq_data.values(): + # seq_len = seq_data.get_len() + # print("[zyl] seq_len:", seq_len) + # print("[zyl] seq_data:", seq_data) + # print("[zyl] seq_data.prompt_token_ids:", seq_data.prompt_token_ids) + # print("[zyl] seq_data.negative_prompt_token_ids:", seq_data.negative_prompt_token_ids) + + # get root models's logits + scores = self.root_worker.execute_model_part(execute_model_req) + # prepare negative request with shallow copy + negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + negative_excute_model_req = execute_model_req.clone(negative_seq_group_metadata_list) + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + negative_seq_group_metadata = copy.copy(seq_group_metadata) + negative_seq_data: Dict[int, SequenceData] = {} + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + negative_seq_data[seq_id] = copy.copy(seq_data) + negative_seq_data[seq_id].prompt_token_ids = seq_data.negative_prompt_token_ids + negative_seq_data[seq_id].negative_prompt_token_ids = [] + negative_seq_data[seq_id].output_token_ids = seq_data.output_token_ids[:] + + negative_seq_group_metadata.seq_data = negative_seq_data + negative_seq_group_metadata_list.append(negative_seq_group_metadata) + negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list + # print("==>[zyl] negative_excute_model_req:", negative_excute_model_req) + # for seq_group_metadata in negative_excute_model_req.seq_group_metadata_list: + # for seq_data in seq_group_metadata.seq_data.values(): + # seq_len = seq_data.get_len() + # print("[zyl] seq_data:", seq_data) + # print("[zyl] seq_len:", seq_len) + # print("[zyl] seq_data.prompt_token_ids:", seq_data.prompt_token_ids) + # print("[zyl] seq_data.negative_prompt_token_ids:", seq_data.negative_prompt_token_ids) + + # get unconditional logits + unconditional_logits = self.guidance_worker.execute_model_part(negative_excute_model_req) + # print("unconditional_logits:", unconditional_logits.shape, unconditional_logits) + + # do logist_processor + scores = self.root_worker.compute_logits(scores) + # print("scores:", scores.shape, scores) + + # do classifier free guidance logist process + for seq_group in self.root_worker.model_input.sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + guidance_scale = seq_group.sampling_params.guidance_scale + # print("guidance_scale:", guidance_scale) + if guidance_scale == 1.0: + break + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + logits_row = torch.nn.functional.log_softmax(scores[logits_row_idx], dim=-1) + unconditional_logits_row = torch.nn.functional.log_softmax(unconditional_logits[logits_row_idx], dim=-1) + scores[logits_row_idx] = guidance_scale * (logits_row - unconditional_logits_row) + unconditional_logits_row + + # do sample + output = self.root_worker.do_sample(scores) + + # output is List[SamplerOutput] + return output + + def get_cache_block_size_bytes(self): + raise NotImplementedError diff --git a/vllm/classifier_free_guidance/separated_worker.py b/vllm/classifier_free_guidance/separated_worker.py new file mode 100644 index 0000000000000..5752cf493c582 --- /dev/null +++ b/vllm/classifier_free_guidance/separated_worker.py @@ -0,0 +1,116 @@ +from typing import List, Optional + +import torch + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) +from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerInput +from vllm.worker.model_runner_base import ModelRunnerInputBase + + +class SeparatedWorker(Worker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.model_input = None + + @torch.inference_mode() + def get_logits( + self, + hidden_or_intermediate_states: torch.Tensor, + ) -> torch.Tensor: + return self.model_runner.get_logits(hidden_or_intermediate_states, self.model_input) + + @torch.inference_mode() + def compute_logits( + self, + logits: torch.Tensor, + ) -> torch.Tensor: + return self.model_runner.compute_logits(logits, self.model_input) + + @torch.inference_mode() + def do_sample( + self, + logits: torch.Tensor, + ) -> List[SamplerOutput]: + return self.model_runner.do_sample(logits, self.model_input) + + @torch.inference_mode() + def execute_model_part( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + self.model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + num_steps = execute_model_req.num_steps + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + self.model_input.as_broadcastable_tensor_dict()) + broadcast_data["num_steps"] = num_steps + broadcast_tensor_dict(broadcast_data, src=0) + else: + assert self.do_metadata_broadcast + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + num_steps = broadcast_data.pop("num_steps") + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + self.model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict()) + + hidden_or_intermediate_states = self.model_runner.model_execute( + self.model_input, + self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors, + num_steps + ) + + logits = self.get_logits(hidden_or_intermediate_states) + # logits = self.compute_logits(logits, model_input) + # output = self.do_sample(logits) + + if not self.is_driver_worker: + return [] + + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + get_pp_group().send_tensor_dict(logits.tensors) + return [None] + + return logits diff --git a/vllm/config.py b/vllm/config.py index 6403a53f86281..66960604b6f45 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1271,6 +1271,44 @@ def __repr__(self) -> str: return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" +class ClassifierFreeGuidanceConfig: + + @staticmethod + def maybe_create_spec_config( + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + guidance_model: Optional[str], + ): + if guidance_model is None: + return None + + guidance_parallel_config = target_parallel_config + assert target_model_config.model == guidance_model + guidance_model_config = target_model_config + + return ClassifierFreeGuidanceConfig( + guidance_model_config, + guidance_parallel_config + ) + + def __init__( + self, + guidance_model_config: ModelConfig, + guidance_parallel_config: ParallelConfig, + ): + self.guidance_model_config = guidance_model_config + self.guidance_parallel_config = guidance_parallel_config + + def _verify_args(self) -> None: + if self.guidance_model_config: + self.guidance_model_config.verify_with_parallel_config( + self.guidance_parallel_config) + + def __repr__(self) -> str: + guidance_model = self.guidance_model_config.model + return f"ClassifierFreeGuidanceConfig({guidance_model=})" + + @dataclass class LoRAConfig: max_lora_rank: int @@ -1602,6 +1640,7 @@ class EngineConfig: decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] prompt_adapter_config: Optional[PromptAdapterConfig] + classifier_free_guidance_config: Optional[ClassifierFreeGuidanceConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e59c5e0f74f3..61461b1ad12a9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -769,6 +769,12 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ + # print("!!!!!!!!_schedule_default!!!!!!!!") + # print("self.waiting:", self.waiting) + # print("self.running:", self.running) + # print("self.swapped:", self.swapped) + # print("max_num_batched_tokens:", self.scheduler_config.max_num_batched_tokens) + # print("max_num_seqs:", self.scheduler_config.max_num_seqs) # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -790,6 +796,7 @@ def _schedule_default(self) -> SchedulerOutputs: remaining_swapped, swapped_in = ( self.swapped, SchedulerSwappedInOutputs.create_empty()) + # print("not self.swapped:", (not self.swapped)) # If any requests are swapped, prioritized swapped requests. if not self.swapped: remaining_waiting, prefills = self._schedule_prefills( @@ -799,6 +806,7 @@ def _schedule_default(self) -> SchedulerOutputs: # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. + # print("len(prefills.seq_groups):", len(prefills.seq_groups)) if len(prefills.seq_groups) == 0: remaining_running, running_scheduled = self._schedule_running( self.running, @@ -818,6 +826,10 @@ def _schedule_default(self) -> SchedulerOutputs: self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + # print("prefills:", prefills) + # print("running_scheduled:", running_scheduled) + # print("swapped_in:", swapped_in) + # Update waiting requests. self.waiting = remaining_waiting self.waiting.extendleft(running_scheduled.preempted) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cd64d3345b830..b466f2a57a7e7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,7 +8,7 @@ EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig) + SpeculativeConfig, TokenizerPoolConfig, ClassifierFreeGuidanceConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -118,6 +118,9 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None + # classifier free guidance configuration + classifier_free_guidance_model: Optional[str] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -661,6 +664,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help='Target URL to which OpenTelemetry traces will be sent.') + parser.add_argument( + '--classifier-free-guidance-model', + type=nullable_str, + default=EngineArgs.classifier_free_guidance_model, + help= + 'The name of the model to be used in classifier free guidance logistor.') + return parser @classmethod @@ -799,6 +809,12 @@ def create_engine_config(self, ) -> EngineConfig: disable_logprobs=self.disable_logprobs_during_spec_decoding, ) + classifier_free_guidance_config = ClassifierFreeGuidanceConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + guidance_model=self.classifier_free_guidance_model, + ) + scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, @@ -867,6 +883,7 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, + classifier_free_guidance_config=classifier_free_guidance_config, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 48d5305892219..74fdcc9bbf461 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -12,7 +12,7 @@ EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) + SpeculativeConfig, ClassifierFreeGuidanceConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -168,6 +168,7 @@ def __init__( decoding_config: Optional[DecodingConfig], observability_config: Optional[ObservabilityConfig], prompt_adapter_config: Optional[PromptAdapterConfig], + classifier_free_guidance_config: Optional[ClassifierFreeGuidanceConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -186,7 +187,7 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "enable_prefix_caching=%s)", + "enable_prefix_caching=%s, classifier_free_guidance_config=%r)", VLLM_VERSION, model_config.model, speculative_config, @@ -216,6 +217,7 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, cache_config.enable_prefix_caching, + classifier_free_guidance_config, ) # TODO(woosuk): Print more configs in debug mode. @@ -230,6 +232,7 @@ def __init__( self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() self.prompt_adapter_config = prompt_adapter_config + self.classifier_free_guidance_config = classifier_free_guidance_config self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats @@ -259,6 +262,7 @@ def __init__( speculative_config=speculative_config, load_config=load_config, prompt_adapter_config=prompt_adapter_config, + classifier_free_guidance_config=classifier_free_guidance_config ) if not self.model_config.embedding_mode: @@ -315,6 +319,8 @@ def __init__( for _ in range(parallel_config.pipeline_parallel_size) ] + print("scheduler_config:", scheduler_config) + # Metric Logging. if self.log_stats: if stat_loggers is not None: @@ -586,6 +592,17 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] + if "negative_prompt" in inputs and "negative_prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("negative prompts must be None if " + "skip_tokenizer_init is True") + negative_prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["negative_prompt"], + lora_request=lora_request) + elif "negative_prompt_token_ids" in inputs: + negative_prompt_token_ids = inputs.get("negative_prompt_token_ids") + else: + negative_prompt_token_ids = None + if prompt_adapter_request: prompt_token_ids = \ [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ @@ -593,7 +610,9 @@ def process_model_inputs( llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + multi_modal_data=inputs.get("multi_modal_data"), + negative_prompt_token_ids=negative_prompt_token_ids, + negative_prompt=inputs.get("negative_prompt")) return self.input_processor(llm_inputs) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a848bc70941c1..ba2f6f58d1d43 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,7 +4,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) + SpeculativeConfig, ClassifierFreeGuidanceConfig) from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -32,6 +32,7 @@ def __init__( multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], prompt_adapter_config: Optional[PromptAdapterConfig], + classifier_free_guidance_config: Optional[ClassifierFreeGuidanceConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -43,6 +44,7 @@ def __init__( self.multimodal_config = multimodal_config self.speculative_config = speculative_config self.prompt_adapter_config = prompt_adapter_config + self.classifier_free_guidance_config = classifier_free_guidance_config self._init_executor() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3e77af0e20323..bb85c3bd3e5ed 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,10 +1,11 @@ +import copy from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput, SequenceGroupMetadata, SequenceData from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -31,9 +32,21 @@ def _init_executor(self) -> None: assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") + print("[zyl] gpu executor _create_worker:") self.driver_worker = self._create_worker() + print("[zyl] gpu executor init_device:") self.driver_worker.init_device() + print("[zyl] gpu executor load_model:") self.driver_worker.load_model() + print("[zyl] driver_worker:", self.driver_worker) + + # print("[zyl] gpu executor _create_worker:") + # self.cfg_worker = self._create_worker() + # print("[zyl] gpu executor init_device:") + # self.cfg_worker.init_device() + # print("[zyl] gpu executor load_model:") + # self.cfg_worker.load_model() + # print("[zyl] cfg_worker:", self.cfg_worker) def _get_worker_kwargs( self, @@ -58,6 +71,7 @@ def _get_worker_kwargs( multimodal_config=self.multimodal_config, speculative_config=self.speculative_config, prompt_adapter_config=self.prompt_adapter_config, + classifier_free_guidance_config=self.classifier_free_guidance_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), ) @@ -69,13 +83,17 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.speculative_config is None: + if self.speculative_config is None and self.classifier_free_guidance_config is None: worker_kwargs.update(worker_module_name="vllm.worker.worker", worker_class_name="Worker") - else: + elif self.speculative_config is not None: worker_kwargs.update( worker_module_name="vllm.spec_decode.spec_decode_worker", worker_class_name="create_spec_worker") + else: + worker_kwargs.update( + worker_module_name="vllm.classifier_free_guidance.cfg_worker", + worker_class_name="create_cfg_worker") return worker_kwargs def _create_worker(self, @@ -92,6 +110,16 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: underlying worker. """ return self.driver_worker.determine_num_available_blocks() + # num_gpu_blocks, num_cpu_blocks = self.driver_worker.determine_num_available_blocks() + + # driver_cache_block_size_bytes = self.driver_worker.get_cache_block_size_bytes() + # cfg_cache_block_size_bytes = self.cfg_worker.get_cache_block_size_bytes() + + # new_num_gpu_blocks = int( + # num_gpu_blocks * driver_cache_block_size_bytes / + # (cfg_cache_block_size_bytes + driver_cache_block_size_bytes)) + + # return new_num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: """Initialize the KV cache by invoking the underlying worker. @@ -103,13 +131,107 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + # self.cfg_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + # print("[zyl] gpu_executor execute_model_req:", execute_model_req) + # for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): + # seq_data = next(iter(seq_group_metadata.seq_data.values())) + # seq_len = seq_data.get_len() + # print("[zyl] driver seq_data:", seq_data) + # print("[zyl] driver seq_len:", seq_len) + # print("[zyl] gpu executor driver_worker.execute_model:", self.driver_worker.execute_model) output = self.driver_worker.execute_model(execute_model_req) + # print("[zyl] gpu_executor output:", output) + # print("[zyl] gpu_executor sampled_token_ids:", output[0].sampled_token_ids) + # exit(0) return output + # def execute_model( + # self, execute_model_req: ExecuteModelRequest + # ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + # # print("[zyl] gpu_executor execute_model_req:", execute_model_req) + # # for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): + # # seq_data = next(iter(seq_group_metadata.seq_data.values())) + # # seq_len = seq_data.get_len() + # # print("[zyl] driver seq_data:", seq_data) + # # print("[zyl] driver seq_len:", seq_len) + # # print("[zyl] gpu_executor self.driver_worker:", self.driver_worker) + # logits_driver, model_input_driver = self.driver_worker.execute_model(execute_model_req, do_no_processor=False) + # # print("[zyl] gpu_executor logits_driver:", logits_driver) + + # # cfg_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + # # cfg_execute_model_req = execute_model_req.clone(cfg_seq_group_metadata_list) + # # for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # # new_seq_group_metadata = copy.copy(seq_group_metadata) + # # new_seq_data: Dict[int, SequenceData] = {} + # # for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + # # if len(old_seq_data.output_token_ids) == 0: + # # new_seq_data[seq_id] = copy.copy(old_seq_data) + # # new_seq_data[seq_id].prompt_token_ids = old_seq_data.prompt_token_ids[-1:] + # # new_seq_data[seq_id].output_token_ids = () + # # else: + # # new_seq_data[seq_id] = copy.copy(old_seq_data) + # # new_seq_data[seq_id].prompt_token_ids = old_seq_data.prompt_token_ids[-1:] + # # new_seq_data[seq_id].output_token_ids = old_seq_data.output_token_ids[:] + # # new_seq_group_metadata.seq_data = new_seq_data + # # cfg_seq_group_metadata_list.append(new_seq_group_metadata) + # # cfg_execute_model_req.seq_group_metadata_list = cfg_seq_group_metadata_list + + + # # cfg_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + # # cfg_execute_model_req = execute_model_req.clone(cfg_seq_group_metadata_list) + # # for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # # new_seq_group_metadata = copy.copy(seq_group_metadata) + # # new_seq_data: Dict[int, SequenceData] = {} + # # for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + # # # if seq_group_metadata.is_prompt: + # # new_seq_data[seq_id] = copy.copy(old_seq_data) + # # new_seq_data[seq_id].prompt_token_ids = old_seq_data.negative_prompt_token_ids + # # new_seq_data[seq_id].negative_prompt_token_ids = [] + # # new_seq_data[seq_id].output_token_ids = old_seq_data.output_token_ids[:] + + # # new_seq_group_metadata.seq_data = new_seq_data + # # cfg_seq_group_metadata_list.append(new_seq_group_metadata) + # # cfg_execute_model_req.seq_group_metadata_list = cfg_seq_group_metadata_list + + + # cfg_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + # cfg_execute_model_req = execute_model_req.clone(cfg_seq_group_metadata_list) + # for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # new_seq_group_metadata = copy.deepcopy(seq_group_metadata) + # new_seq_data: Dict[int, SequenceData] = {} + # for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + # # if seq_group_metadata.is_prompt: + # new_seq_data[seq_id] = copy.deepcopy(old_seq_data) + # new_seq_data[seq_id].prompt_token_ids = old_seq_data.negative_prompt_token_ids + # new_seq_data[seq_id].negative_prompt_token_ids = [] + # new_seq_data[seq_id].output_token_ids = old_seq_data.output_token_ids[:] + + # new_seq_group_metadata.seq_data = new_seq_data + # cfg_seq_group_metadata_list.append(new_seq_group_metadata) + # cfg_execute_model_req.seq_group_metadata_list = cfg_seq_group_metadata_list + + + # # print("[zyl] gpu_executor cfg_execute_model_req:", cfg_execute_model_req) + # # for i, seq_group_metadata in enumerate(cfg_execute_model_req.seq_group_metadata_list): + # # seq_data = next(iter(seq_group_metadata.seq_data.values())) + # # seq_len = seq_data.get_len() + # # print("[zyl] cfg seq_data:", seq_data) + # # print("[zyl] cfg seq_len:", seq_len) + # logits_cfg, _ = self.cfg_worker.execute_model(cfg_execute_model_req, do_no_processor=True) + # # print("[zyl] gpu_executor logits_cfg:", logits_cfg) + + # logits = logits_cfg + 5.0 * (logits_driver - logits_cfg) + # # print("[zyl] gpu_executor logits:", logits) + # output: SamplerOutput = self.driver_worker.model_runner.model.sample(logits=logits, sampling_metadata=model_input_driver.sampling_metadata) + # # print("[zyl] gpu_executor output:", output) + # # print("[zyl] gpu_executor sampled_token_ids:", output.sampled_token_ids) + # # exit(0) + # return [output] + def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self.driver_worker.add_lora(lora_request) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4443e6c70fe5b..651e16db97d81 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -91,8 +91,21 @@ class TokensPrompt(TypedDict): if the model supports it. """ +class NegativeTextPrompt(TypedDict): + """Schema for a text prompt.""" + + negative_prompt: str + """The input text to be tokenized before passing to the model.""" + + +class NegativeTokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" -PromptInputs = Union[str, TextPrompt, TokensPrompt] + negative_prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + +PromptInputs = Union[str, TextPrompt, TokensPrompt, NegativeTextPrompt, NegativeTokensPrompt] """ The inputs to the LLM, which can take one of the following forms: @@ -119,3 +132,7 @@ class LLMInputs(TypedDict): Optional multi-modal data to pass to the model, if the model supports it. """ + + negative_prompt_token_ids: NotRequired[Optional[List[int]]] + + negative_prompt: NotRequired[Optional[str]] diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..cd8a9066b110d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,13 +30,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, tensor_model_parallel_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -136,7 +136,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, - bias=bias, + bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) @@ -406,11 +406,30 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale) + logit_scale, + logits_as_input=True) self.sampler = Sampler() + self.org_vocab_size = config.vocab_size else: self.lm_head = PPMissingLayer() + def _get_logits(self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + # Get the logits for the next tokens. + logits = self.lm_head.linear_method.apply( + self.lm_head, + hidden_states, + bias=None, + ) + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index edc16710c0229..1e946dba2c6e3 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -25,13 +25,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_gather from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler @@ -295,8 +295,26 @@ def __init__( self.quant_config = quant_config self.model = OPTModel(config, cache_config, quant_config) self.lm_head = self.model.decoder.embed_tokens - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size, logits_as_input=True) self.sampler = Sampler() + self.org_vocab_size = config.vocab_size + + def _get_logits(self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + # Get the logits for the next tokens. + logits = self.lm_head.linear_method.apply( + self.lm_head, + hidden_states, + bias=None, + ) + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 638c870c04371..eaedbff7dadb7 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -138,6 +138,7 @@ def __init__( spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + guidance_scale: Optional[float] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -179,6 +180,7 @@ def __init__( self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + self.guidance_scale = guidance_scale # Number of characters to hold back for stop string evaluation # until sequence is finished. if self.stop and not include_stop_str_in_output: @@ -359,4 +361,5 @@ def __repr__(self) -> str: f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " - f"truncate_prompt_tokens={self.truncate_prompt_tokens})") + f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " + f"guidance_scale={self.guidance_scale})") diff --git a/vllm/sequence.py b/vllm/sequence.py index 0cd4c7e71d78d..e9635a07b55df 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -118,12 +118,17 @@ def __init__( self, prompt_token_ids: List[int], output_token_ids: Optional[List[int]] = None, + negative_prompt_token_ids: Optional[List[int]] = None, ) -> None: self._prompt_token_ids: List[int] = list(prompt_token_ids) self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self._output_token_ids: List[int] = ( list(output_token_ids) if output_token_ids is not None else []) + self._negative_prompt_token_ids: List[int] = ( + list(negative_prompt_token_ids) if negative_prompt_token_ids is not None else []) + self._negative_prompt_token_ids_tuple: Tuple[int, ...] = tuple(negative_prompt_token_ids or []) + self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 @@ -145,6 +150,15 @@ def prompt_token_ids(self, new_prompt_token_ids) -> None: self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._update_cached_all_tokens() + @property + def negative_prompt_token_ids(self) -> Tuple[int, ...]: + return tuple(self._negative_prompt_token_ids) + + @negative_prompt_token_ids.setter + def negative_prompt_token_ids(self, new_negative_prompt_token_ids) -> None: + self._negative_prompt_token_ids = list(new_negative_prompt_token_ids) + self._negative_prompt_token_ids_tuple = tuple(new_negative_prompt_token_ids) + @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @@ -228,6 +242,7 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " + f"negative_prompt_token_ids={self._negative_prompt_token_ids}, " f"output_token_ids={self._output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob})") @@ -261,7 +276,9 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData(self.prompt_token_ids) + self.data = SequenceData( + self.prompt_token_ids, + negative_prompt_token_ids=self.negative_prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -290,6 +307,14 @@ def prompt_token_ids(self) -> List[int]: def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.get("multi_modal_data") or {} + @property + def negative_prompt(self) -> Optional[str]: + return self.inputs.get("negative_prompt") + + @property + def negative_prompt_token_ids(self) -> List[int]: + return self.inputs.get("negative_prompt_token_ids") or [] + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -399,7 +424,8 @@ def is_prefill(self) -> bool: def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " f"status={self.status.name}, " - f"num_blocks={self.n_blocks}, ") + f"num_blocks={self.n_blocks}, " + f"data={self.data})") @dataclass @@ -478,6 +504,14 @@ def multi_modal_data(self) -> "MultiModalDataDict": # We use the multi-modal data of an arbitrary sequence. return self._first_seq.multi_modal_data + @property + def negative_prompt(self) -> Optional[str]: + return self._first_seq.negative_prompt + + @property + def negative_prompt_token_ids(self) -> List[int]: + return self._first_seq.negative_prompt_token_ids + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 98960b88f719f..46facea50e247 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -237,6 +237,8 @@ def __init__( # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs + print("[zyl] proposer_worker:", self.proposer_worker) + print("[zyl] scorer_worker:", self.scorer_worker) def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -323,6 +325,7 @@ def execute_model( ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ + print("[zyl] spec_decode_worker.execute_model") if self.rank != self._driver_rank: self._run_non_driver_rank() return [] @@ -450,9 +453,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ + print("[zyl] spec_decode_worker._run_no_spec") if not skip_proposer: + print("[zyl] self.proposer_worker.execute_model", self.proposer_worker.execute_model) self.proposer_worker.execute_model(execute_model_req) + print("[zyl] self.scorer_worker.execute_model:", self.scorer_worker.execute_model) sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -516,6 +522,7 @@ def _run_speculative_decoding_step( Returns a list of SamplerOutput, each containing a single token per sequence. """ + print("[zyl] _run_speculative_decoding_step") assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Pass last hidden states from target model to proposer @@ -523,6 +530,7 @@ def _run_speculative_decoding_step( self.previous_hidden_states = None # Generate proposals using draft worker. + print("[zyl] self.proposer_worker.get_spec_proposals:", self.proposer_worker.get_spec_proposals) proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) @@ -531,6 +539,7 @@ def _run_speculative_decoding_step( raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") + print("[zyl] self.scorer.score_proposals:", self.scorer.score_proposals) proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c36..cca6e7863cc12 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1283,7 +1283,9 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + # do_no_processor: bool = None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + # print("[zyl] model_runner execute_model") if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1364,9 +1366,18 @@ def execute_model( if not get_pp_group().is_last_rank: return hidden_or_intermediate_states + # hidden_or_intermediate_states = self.model._get_logits( + # hidden_or_intermediate_states, model_input.sampling_metadata + # ) + # if do_no_processor is not None and do_no_processor is True: + # return hidden_or_intermediate_states + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) + # if do_no_processor is not None and do_no_processor is False: + # return logits + if not self.is_driver_worker: return [] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 5fb97025af5c0..f46abd826dc34 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -171,6 +171,7 @@ def execute_model( kv_caches: Optional[List[torch.Tensor]], intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, + # do_no_processor: bool = None, ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f3c379d1aa34d..c64232da57fcf 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) + SpeculativeConfig, ClassifierFreeGuidanceConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -48,6 +48,7 @@ def __init__( multimodal_config: Optional[MultiModalConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None, + classifier_free_guidance_config: Optional[ClassifierFreeGuidanceConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 03e3857e23c4b..fdca87330ef6f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -69,7 +69,8 @@ def start_worker_execution_loop(self) -> None: @abstractmethod def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, + # do_no_processor: bool = None, ) -> Optional[List[SamplerOutput]]: raise NotImplementedError @@ -215,7 +216,8 @@ def execute_worker(self, worker_input: WorkerInput) -> None: def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, + # do_no_processor: bool = None, ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" @@ -272,13 +274,16 @@ def execute_model( output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, intermediate_tensors, - num_steps) + num_steps) # , do_no_processor if not get_pp_group().is_last_rank: # output is IntermediateTensors get_pp_group().send_tensor_dict(output.tensors) return [None] + # if do_no_processor is not None: + # return output, model_input + # output is List[SamplerOutput] return output From 546575da23ad06a3803b4e9f2cd92c6f5a463112 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Thu, 1 Aug 2024 11:02:29 +0800 Subject: [PATCH 02/18] rm comment --- pyproject.toml | 3 +- requirements-build.txt | 2 +- requirements-cuda.txt | 8 +- vllm/classifier_free_guidance/cfg_worker.py | 25 ---- vllm/core/scheduler.py | 12 -- vllm/engine/llm_engine.py | 2 - vllm/executor/gpu_executor.py | 119 +------------------- vllm/spec_decode/spec_decode_worker.py | 9 -- vllm/worker/model_runner.py | 11 -- vllm/worker/model_runner_base.py | 1 - vllm/worker/worker_base.py | 11 +- 11 files changed, 10 insertions(+), 193 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4e7c78b1fa468..1ba1eacd90084 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - # "torch == 2.3.1", - "torch", + "torch == 2.3.1", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index 1ade34b94a7b4..b05f38a0ed919 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -# torch==2.3.1 +torch==2.3.1 wheel diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 7b99e0002696c..3eb91212e976e 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,8 +4,8 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 nvidia-ml-py # for pynvml package -# torch == 2.3.1 +torch == 2.3.1 # These must be updated alongside torch -# torchvision == 0.18.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -# xformers == 0.0.27 # Requires PyTorch 2.3.1 -# vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1 +torchvision == 0.18.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +xformers == 0.0.27 # Requires PyTorch 2.3.1 +vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1 diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py index f3220e5d1c0d4..a75407700823a 100644 --- a/vllm/classifier_free_guidance/cfg_worker.py +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -27,10 +27,6 @@ def create_cfg_worker(*args, **kwargs) -> "CFGWorker": kwargs["model_runner_cls"] = CFGModelRunner root_worker = SeparatedWorker(*args, **kwargs) - print("create_cfg_worker") - print("args", args) - print("kwargs", kwargs) - guidance_model_config = classifier_free_guidance_config.guidance_model_config guidance_parallel_config = classifier_free_guidance_config.guidance_parallel_config kwargs.update( @@ -82,7 +78,6 @@ def initialize_cache( num_gpu_blocks: int, num_cpu_blocks: int ): - print("num_gpu_blocks:", num_gpu_blocks) self.root_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) self.guidance_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, @@ -94,15 +89,6 @@ def execute_model( execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - # print("==>[zyl] execute_model_req:", execute_model_req) - # for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # for seq_data in seq_group_metadata.seq_data.values(): - # seq_len = seq_data.get_len() - # print("[zyl] seq_len:", seq_len) - # print("[zyl] seq_data:", seq_data) - # print("[zyl] seq_data.prompt_token_ids:", seq_data.prompt_token_ids) - # print("[zyl] seq_data.negative_prompt_token_ids:", seq_data.negative_prompt_token_ids) - # get root models's logits scores = self.root_worker.execute_model_part(execute_model_req) # prepare negative request with shallow copy @@ -120,28 +106,17 @@ def execute_model( negative_seq_group_metadata.seq_data = negative_seq_data negative_seq_group_metadata_list.append(negative_seq_group_metadata) negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list - # print("==>[zyl] negative_excute_model_req:", negative_excute_model_req) - # for seq_group_metadata in negative_excute_model_req.seq_group_metadata_list: - # for seq_data in seq_group_metadata.seq_data.values(): - # seq_len = seq_data.get_len() - # print("[zyl] seq_data:", seq_data) - # print("[zyl] seq_len:", seq_len) - # print("[zyl] seq_data.prompt_token_ids:", seq_data.prompt_token_ids) - # print("[zyl] seq_data.negative_prompt_token_ids:", seq_data.negative_prompt_token_ids) # get unconditional logits unconditional_logits = self.guidance_worker.execute_model_part(negative_excute_model_req) - # print("unconditional_logits:", unconditional_logits.shape, unconditional_logits) # do logist_processor scores = self.root_worker.compute_logits(scores) - # print("scores:", scores.shape, scores) # do classifier free guidance logist process for seq_group in self.root_worker.model_input.sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids guidance_scale = seq_group.sampling_params.guidance_scale - # print("guidance_scale:", guidance_scale) if guidance_scale == 1.0: break for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 61461b1ad12a9..6e59c5e0f74f3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -769,12 +769,6 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ - # print("!!!!!!!!_schedule_default!!!!!!!!") - # print("self.waiting:", self.waiting) - # print("self.running:", self.running) - # print("self.swapped:", self.swapped) - # print("max_num_batched_tokens:", self.scheduler_config.max_num_batched_tokens) - # print("max_num_seqs:", self.scheduler_config.max_num_seqs) # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -796,7 +790,6 @@ def _schedule_default(self) -> SchedulerOutputs: remaining_swapped, swapped_in = ( self.swapped, SchedulerSwappedInOutputs.create_empty()) - # print("not self.swapped:", (not self.swapped)) # If any requests are swapped, prioritized swapped requests. if not self.swapped: remaining_waiting, prefills = self._schedule_prefills( @@ -806,7 +799,6 @@ def _schedule_default(self) -> SchedulerOutputs: # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. - # print("len(prefills.seq_groups):", len(prefills.seq_groups)) if len(prefills.seq_groups) == 0: remaining_running, running_scheduled = self._schedule_running( self.running, @@ -826,10 +818,6 @@ def _schedule_default(self) -> SchedulerOutputs: self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - # print("prefills:", prefills) - # print("running_scheduled:", running_scheduled) - # print("swapped_in:", swapped_in) - # Update waiting requests. self.waiting = remaining_waiting self.waiting.extendleft(running_scheduled.preempted) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74fdcc9bbf461..11d21287939e2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -319,8 +319,6 @@ def __init__( for _ in range(parallel_config.pipeline_parallel_size) ] - print("scheduler_config:", scheduler_config) - # Metric Logging. if self.log_stats: if stat_loggers is not None: diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index bb85c3bd3e5ed..44b9d35f76861 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,11 +1,10 @@ -import copy from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput, SequenceGroupMetadata, SequenceData +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -32,21 +31,9 @@ def _init_executor(self) -> None: assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") - print("[zyl] gpu executor _create_worker:") self.driver_worker = self._create_worker() - print("[zyl] gpu executor init_device:") self.driver_worker.init_device() - print("[zyl] gpu executor load_model:") self.driver_worker.load_model() - print("[zyl] driver_worker:", self.driver_worker) - - # print("[zyl] gpu executor _create_worker:") - # self.cfg_worker = self._create_worker() - # print("[zyl] gpu executor init_device:") - # self.cfg_worker.init_device() - # print("[zyl] gpu executor load_model:") - # self.cfg_worker.load_model() - # print("[zyl] cfg_worker:", self.cfg_worker) def _get_worker_kwargs( self, @@ -110,16 +97,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: underlying worker. """ return self.driver_worker.determine_num_available_blocks() - # num_gpu_blocks, num_cpu_blocks = self.driver_worker.determine_num_available_blocks() - - # driver_cache_block_size_bytes = self.driver_worker.get_cache_block_size_bytes() - # cfg_cache_block_size_bytes = self.cfg_worker.get_cache_block_size_bytes() - - # new_num_gpu_blocks = int( - # num_gpu_blocks * driver_cache_block_size_bytes / - # (cfg_cache_block_size_bytes + driver_cache_block_size_bytes)) - - # return new_num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: """Initialize the KV cache by invoking the underlying worker. @@ -131,107 +108,13 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - # self.cfg_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - # print("[zyl] gpu_executor execute_model_req:", execute_model_req) - # for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): - # seq_data = next(iter(seq_group_metadata.seq_data.values())) - # seq_len = seq_data.get_len() - # print("[zyl] driver seq_data:", seq_data) - # print("[zyl] driver seq_len:", seq_len) - # print("[zyl] gpu executor driver_worker.execute_model:", self.driver_worker.execute_model) output = self.driver_worker.execute_model(execute_model_req) - # print("[zyl] gpu_executor output:", output) - # print("[zyl] gpu_executor sampled_token_ids:", output[0].sampled_token_ids) - # exit(0) return output - # def execute_model( - # self, execute_model_req: ExecuteModelRequest - # ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - # # print("[zyl] gpu_executor execute_model_req:", execute_model_req) - # # for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): - # # seq_data = next(iter(seq_group_metadata.seq_data.values())) - # # seq_len = seq_data.get_len() - # # print("[zyl] driver seq_data:", seq_data) - # # print("[zyl] driver seq_len:", seq_len) - # # print("[zyl] gpu_executor self.driver_worker:", self.driver_worker) - # logits_driver, model_input_driver = self.driver_worker.execute_model(execute_model_req, do_no_processor=False) - # # print("[zyl] gpu_executor logits_driver:", logits_driver) - - # # cfg_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - # # cfg_execute_model_req = execute_model_req.clone(cfg_seq_group_metadata_list) - # # for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # # new_seq_group_metadata = copy.copy(seq_group_metadata) - # # new_seq_data: Dict[int, SequenceData] = {} - # # for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - # # if len(old_seq_data.output_token_ids) == 0: - # # new_seq_data[seq_id] = copy.copy(old_seq_data) - # # new_seq_data[seq_id].prompt_token_ids = old_seq_data.prompt_token_ids[-1:] - # # new_seq_data[seq_id].output_token_ids = () - # # else: - # # new_seq_data[seq_id] = copy.copy(old_seq_data) - # # new_seq_data[seq_id].prompt_token_ids = old_seq_data.prompt_token_ids[-1:] - # # new_seq_data[seq_id].output_token_ids = old_seq_data.output_token_ids[:] - # # new_seq_group_metadata.seq_data = new_seq_data - # # cfg_seq_group_metadata_list.append(new_seq_group_metadata) - # # cfg_execute_model_req.seq_group_metadata_list = cfg_seq_group_metadata_list - - - # # cfg_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - # # cfg_execute_model_req = execute_model_req.clone(cfg_seq_group_metadata_list) - # # for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # # new_seq_group_metadata = copy.copy(seq_group_metadata) - # # new_seq_data: Dict[int, SequenceData] = {} - # # for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - # # # if seq_group_metadata.is_prompt: - # # new_seq_data[seq_id] = copy.copy(old_seq_data) - # # new_seq_data[seq_id].prompt_token_ids = old_seq_data.negative_prompt_token_ids - # # new_seq_data[seq_id].negative_prompt_token_ids = [] - # # new_seq_data[seq_id].output_token_ids = old_seq_data.output_token_ids[:] - - # # new_seq_group_metadata.seq_data = new_seq_data - # # cfg_seq_group_metadata_list.append(new_seq_group_metadata) - # # cfg_execute_model_req.seq_group_metadata_list = cfg_seq_group_metadata_list - - - # cfg_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - # cfg_execute_model_req = execute_model_req.clone(cfg_seq_group_metadata_list) - # for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # new_seq_group_metadata = copy.deepcopy(seq_group_metadata) - # new_seq_data: Dict[int, SequenceData] = {} - # for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - # # if seq_group_metadata.is_prompt: - # new_seq_data[seq_id] = copy.deepcopy(old_seq_data) - # new_seq_data[seq_id].prompt_token_ids = old_seq_data.negative_prompt_token_ids - # new_seq_data[seq_id].negative_prompt_token_ids = [] - # new_seq_data[seq_id].output_token_ids = old_seq_data.output_token_ids[:] - - # new_seq_group_metadata.seq_data = new_seq_data - # cfg_seq_group_metadata_list.append(new_seq_group_metadata) - # cfg_execute_model_req.seq_group_metadata_list = cfg_seq_group_metadata_list - - - # # print("[zyl] gpu_executor cfg_execute_model_req:", cfg_execute_model_req) - # # for i, seq_group_metadata in enumerate(cfg_execute_model_req.seq_group_metadata_list): - # # seq_data = next(iter(seq_group_metadata.seq_data.values())) - # # seq_len = seq_data.get_len() - # # print("[zyl] cfg seq_data:", seq_data) - # # print("[zyl] cfg seq_len:", seq_len) - # logits_cfg, _ = self.cfg_worker.execute_model(cfg_execute_model_req, do_no_processor=True) - # # print("[zyl] gpu_executor logits_cfg:", logits_cfg) - - # logits = logits_cfg + 5.0 * (logits_driver - logits_cfg) - # # print("[zyl] gpu_executor logits:", logits) - # output: SamplerOutput = self.driver_worker.model_runner.model.sample(logits=logits, sampling_metadata=model_input_driver.sampling_metadata) - # # print("[zyl] gpu_executor output:", output) - # # print("[zyl] gpu_executor sampled_token_ids:", output.sampled_token_ids) - # # exit(0) - # return [output] - def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self.driver_worker.add_lora(lora_request) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 46facea50e247..98960b88f719f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -237,8 +237,6 @@ def __init__( # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs - print("[zyl] proposer_worker:", self.proposer_worker) - print("[zyl] scorer_worker:", self.scorer_worker) def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -325,7 +323,6 @@ def execute_model( ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - print("[zyl] spec_decode_worker.execute_model") if self.rank != self._driver_rank: self._run_non_driver_rank() return [] @@ -453,12 +450,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - print("[zyl] spec_decode_worker._run_no_spec") if not skip_proposer: - print("[zyl] self.proposer_worker.execute_model", self.proposer_worker.execute_model) self.proposer_worker.execute_model(execute_model_req) - print("[zyl] self.scorer_worker.execute_model:", self.scorer_worker.execute_model) sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -522,7 +516,6 @@ def _run_speculative_decoding_step( Returns a list of SamplerOutput, each containing a single token per sequence. """ - print("[zyl] _run_speculative_decoding_step") assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Pass last hidden states from target model to proposer @@ -530,7 +523,6 @@ def _run_speculative_decoding_step( self.previous_hidden_states = None # Generate proposals using draft worker. - print("[zyl] self.proposer_worker.get_spec_proposals:", self.proposer_worker.get_spec_proposals) proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) @@ -539,7 +531,6 @@ def _run_speculative_decoding_step( raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") - print("[zyl] self.scorer.score_proposals:", self.scorer.score_proposals) proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cca6e7863cc12..86d26b4a84c36 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1283,9 +1283,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, - # do_no_processor: bool = None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - # print("[zyl] model_runner execute_model") if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1366,18 +1364,9 @@ def execute_model( if not get_pp_group().is_last_rank: return hidden_or_intermediate_states - # hidden_or_intermediate_states = self.model._get_logits( - # hidden_or_intermediate_states, model_input.sampling_metadata - # ) - # if do_no_processor is not None and do_no_processor is True: - # return hidden_or_intermediate_states - logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) - # if do_no_processor is not None and do_no_processor is False: - # return logits - if not self.is_driver_worker: return [] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index f46abd826dc34..5fb97025af5c0 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -171,7 +171,6 @@ def execute_model( kv_caches: Optional[List[torch.Tensor]], intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, - # do_no_processor: bool = None, ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fdca87330ef6f..03e3857e23c4b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -69,8 +69,7 @@ def start_worker_execution_loop(self) -> None: @abstractmethod def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None, - # do_no_processor: bool = None, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: raise NotImplementedError @@ -216,8 +215,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None, - # do_no_processor: bool = None, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" @@ -274,16 +272,13 @@ def execute_model( output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, intermediate_tensors, - num_steps) # , do_no_processor + num_steps) if not get_pp_group().is_last_rank: # output is IntermediateTensors get_pp_group().send_tensor_dict(output.tensors) return [None] - # if do_no_processor is not None: - # return output, model_input - # output is List[SamplerOutput] return output From f18bd83bcf3af4a409b354d5475af27f57d75de5 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Thu, 1 Aug 2024 11:04:10 +0800 Subject: [PATCH 03/18] revert llama --- vllm/model_executor/models/llama.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index cd8a9066b110d..306d22e42ed1d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,13 +30,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, tensor_model_parallel_gather) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -136,7 +136,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) @@ -406,30 +406,11 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale, - logits_as_input=True) + logit_scale) self.sampler = Sampler() - self.org_vocab_size = config.vocab_size else: self.lm_head = PPMissingLayer() - def _get_logits(self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) - # Get the logits for the next tokens. - logits = self.lm_head.linear_method.apply( - self.lm_head, - hidden_states, - bias=None, - ) - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :self.org_vocab_size] - return logits - def forward( self, input_ids: torch.Tensor, From c66f6c64917c4f53c4127159e6acf36819707bb9 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Thu, 1 Aug 2024 11:08:36 +0800 Subject: [PATCH 04/18] add test --- tests/test_cfg_worker.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/test_cfg_worker.py diff --git a/tests/test_cfg_worker.py b/tests/test_cfg_worker.py new file mode 100644 index 0000000000000..ae11ce413a76b --- /dev/null +++ b/tests/test_cfg_worker.py @@ -0,0 +1,33 @@ + +from typing import List + +from vllm import LLM, SamplingParams +from vllm.inputs import PromptInputs + +llm = LLM( + model="facebook/opt-6.7b", + tensor_parallel_size=1, + use_v2_block_manager=True, + classifier_free_guidance_model="facebook/opt-6.7b" +) + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# inputs: List[PromptInputs]=[{"prompt": prompt, "negative_prompt": prompt[-1]} for prompt in prompts] +tokenizer = llm.get_tokenizer() +prompt_token_ids = [tokenizer.encode(text=prompt) for prompt in prompts] +inputs: List[PromptInputs]=[{"prompt_token_ids": token_ids, "negative_prompt_token_ids": token_ids[-1:]} for token_ids in prompt_token_ids] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, guidance_scale=5.0) +outputs = llm.generate(inputs, sampling_params) + +for i, output in enumerate(outputs): + # prompt = output.prompt + prompt = prompts[i] + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 565fa82a92199645d15c6ed8ffe46a2b054e3c46 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Mon, 5 Aug 2024 12:03:40 +0800 Subject: [PATCH 05/18] support tp&share weight --- vllm/classifier_free_guidance/cfg_worker.py | 85 ++++++++++++++----- .../separated_worker.py | 27 +----- vllm/worker/model_runner.py | 3 + vllm/worker/worker.py | 3 + 4 files changed, 74 insertions(+), 44 deletions(-) diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py index a75407700823a..f1fc508ea4897 100644 --- a/vllm/classifier_free_guidance/cfg_worker.py +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -38,6 +38,8 @@ def create_cfg_worker(*args, **kwargs) -> "CFGWorker": return CFGWorker( root_worker=root_worker, guidance_worker=guidance_worker, + is_driver_worker=kwargs["is_driver_worker"], + parallel_config=kwargs["parallel_config"], ) @@ -46,9 +48,14 @@ def __init__( self, root_worker: WorkerBase, guidance_worker: WorkerBase, + is_driver_worker: bool, + parallel_config: ParallelConfig, ): self.root_worker = root_worker self.guidance_worker = guidance_worker + self.is_driver_worker = is_driver_worker + self.parallel_config = parallel_config + assert self.parallel_config.pipeline_parallel_size == 1 def init_device(self): self.root_worker.init_device() @@ -56,8 +63,7 @@ def init_device(self): def load_model(self): self.root_worker.load_model() - # TODO(zhaoyinglia): guidance_worker shares weight with root_worker - self.guidance_worker.load_model() + self.guidance_worker.share_model(self.root_worker) def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks, num_cpu_blocks = ( @@ -83,36 +89,60 @@ def initialize_cache( self.guidance_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + @torch.inference_mode() def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - # get root models's logits - scores = self.root_worker.execute_model_part(execute_model_req) # prepare negative request with shallow copy - negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - negative_excute_model_req = execute_model_req.clone(negative_seq_group_metadata_list) - for seq_group_metadata in execute_model_req.seq_group_metadata_list: - negative_seq_group_metadata = copy.copy(seq_group_metadata) - negative_seq_data: Dict[int, SequenceData] = {} - for seq_id, seq_data in seq_group_metadata.seq_data.items(): - negative_seq_data[seq_id] = copy.copy(seq_data) - negative_seq_data[seq_id].prompt_token_ids = seq_data.negative_prompt_token_ids - negative_seq_data[seq_id].negative_prompt_token_ids = [] - negative_seq_data[seq_id].output_token_ids = seq_data.output_token_ids[:] - - negative_seq_group_metadata.seq_data = negative_seq_data - negative_seq_group_metadata_list.append(negative_seq_group_metadata) - negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list + if execute_model_req is not None: + negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + negative_excute_model_req = execute_model_req.clone(negative_seq_group_metadata_list) + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + negative_seq_group_metadata = copy.copy(seq_group_metadata) + negative_seq_data: Dict[int, SequenceData] = {} + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + negative_seq_data[seq_id] = copy.copy(seq_data) + negative_seq_data[seq_id].prompt_token_ids = seq_data.negative_prompt_token_ids + negative_seq_data[seq_id].negative_prompt_token_ids = [] + negative_seq_data[seq_id].output_token_ids = seq_data.output_token_ids[:] + + negative_seq_group_metadata.seq_data = negative_seq_data + negative_seq_group_metadata_list.append(negative_seq_group_metadata) + negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list + else: + negative_excute_model_req = None + + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + if self.do_metadata_broadcast: + broadcast_data = {"flag": 1} + broadcast_tensor_dict(broadcast_data, src=0) + else: + assert self.do_metadata_broadcast + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + # get root models's logits + scores = self.root_worker.execute_model_part(execute_model_req) # get unconditional logits unconditional_logits = self.guidance_worker.execute_model_part(negative_excute_model_req) - # do logist_processor - scores = self.root_worker.compute_logits(scores) - # do classifier free guidance logist process for seq_group in self.root_worker.model_input.sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -124,9 +154,22 @@ def execute_model( unconditional_logits_row = torch.nn.functional.log_softmax(unconditional_logits[logits_row_idx], dim=-1) scores[logits_row_idx] = guidance_scale * (logits_row - unconditional_logits_row) + unconditional_logits_row + # print("scores:", scores.shape, scores) + # exit(0) + + # do logist_processor + scores = self.root_worker.compute_logits(scores) + if not self.is_driver_worker: + return [] + # do sample output = self.root_worker.do_sample(scores) + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + get_pp_group().send_tensor_dict(output.tensors) + return [None] + # output is List[SamplerOutput] return output diff --git a/vllm/classifier_free_guidance/separated_worker.py b/vllm/classifier_free_guidance/separated_worker.py index 5752cf493c582..ff11cec3e7446 100644 --- a/vllm/classifier_free_guidance/separated_worker.py +++ b/vllm/classifier_free_guidance/separated_worker.py @@ -44,16 +44,6 @@ def execute_model_part( ) -> Optional[List[SamplerOutput]]: if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) self.model_input: ModelRunnerInputBase = ( @@ -72,9 +62,6 @@ def execute_model_part( else: assert self.do_metadata_broadcast broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - num_steps = broadcast_data.pop("num_steps") worker_input = WorkerInput.from_broadcasted_tensor_dict( broadcast_data) @@ -101,16 +88,10 @@ def execute_model_part( num_steps ) - logits = self.get_logits(hidden_or_intermediate_states) - # logits = self.compute_logits(logits, model_input) - # output = self.do_sample(logits) - - if not self.is_driver_worker: - return [] - + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: - # output is IntermediateTensors - get_pp_group().send_tensor_dict(logits.tensors) - return [None] + return hidden_or_intermediate_states + + logits = self.get_logits(hidden_or_intermediate_states) return logits diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c36..c4bf8807146b7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -787,6 +787,9 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + def share_model(self, model: nn.Module) -> None: + self.model = model + def save_sharded_state( self, path: str, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c64232da57fcf..bfaf6f91c1f21 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -139,6 +139,9 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def share_model(self, shared_worker) -> None: + self.model_runner.share_model(shared_worker.model_runner.model) + def save_sharded_state( self, path: str, From 6c57bcf1375b7ef2a565fc543adde9a895942d73 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Tue, 6 Aug 2024 17:51:03 +0800 Subject: [PATCH 06/18] fix for negative prompt logger than positive prompt --- vllm/classifier_free_guidance/cfg_worker.py | 3 +++ vllm/core/scheduler.py | 16 +++++++++++++--- vllm/sequence.py | 14 ++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py index f1fc508ea4897..ca7add71f0f59 100644 --- a/vllm/classifier_free_guidance/cfg_worker.py +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -112,6 +112,9 @@ def execute_model( negative_seq_data[seq_id].negative_prompt_token_ids = [] negative_seq_data[seq_id].output_token_ids = seq_data.output_token_ids[:] + if negative_seq_group_metadata.is_prompt: + negative_seq_group_metadata._token_chunk_size = list(negative_seq_data.values())[0].get_len() + negative_seq_group_metadata.seq_data = negative_seq_data negative_seq_group_metadata_list.append(negative_seq_group_metadata) negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e59c5e0f74f3..c35e64155b40c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -968,7 +968,7 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: return self.block_manager.can_append_slots( seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill, seq_group), ) def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: @@ -1087,7 +1087,7 @@ def _append_slots( the new source and destination block indices for the appended slots. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False, seq_group=seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) @@ -1199,7 +1199,7 @@ def _passed_delay(self, now: float) -> bool: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, seq_group: SequenceGroup=None) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. @@ -1210,6 +1210,16 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int: if is_prefill: return 0 + num_lookahead_slots = 0 + if seq_group is not None: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + if len(seq.get_negative_token_ids()) > len(seq.get_token_ids()): + addtional_length = len(seq.get_negative_token_ids()) - len(seq.get_token_ids()) + num_lookahead_slots = max(num_lookahead_slots, addtional_length) + + if num_lookahead_slots > 0: + return num_lookahead_slots + return self.scheduler_config.num_lookahead_slots def _get_num_new_tokens(self, seq_group: SequenceGroup, diff --git a/vllm/sequence.py b/vllm/sequence.py index e9635a07b55df..d012cda32a353 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -135,11 +135,16 @@ def __init__( self._stage: SequenceStage = SequenceStage.PREFILL self._update_cached_all_tokens() + self._update_cached_all_negative_tokens() def _update_cached_all_tokens(self): self._cached_all_token_ids: List[int] = (self._prompt_token_ids + self._output_token_ids) + def _update_cached_all_negative_tokens(self): + self._cached_all_negative_token_ids: List[int] = (self._negative_prompt_token_ids + + self._output_token_ids) + @property def prompt_token_ids(self) -> Tuple[int, ...]: return self._prompt_token_ids_tuple @@ -158,6 +163,7 @@ def negative_prompt_token_ids(self) -> Tuple[int, ...]: def negative_prompt_token_ids(self, new_negative_prompt_token_ids) -> None: self._negative_prompt_token_ids = list(new_negative_prompt_token_ids) self._negative_prompt_token_ids_tuple = tuple(new_negative_prompt_token_ids) + self._update_cached_all_negative_tokens() @property def output_token_ids(self) -> Tuple[int, ...]: @@ -167,10 +173,12 @@ def output_token_ids(self) -> Tuple[int, ...]: def output_token_ids(self, new_output_token_ids) -> None: self._output_token_ids = list(new_output_token_ids) self._update_cached_all_tokens() + self._update_cached_all_negative_tokens() def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._cached_all_token_ids.append(token_id) + self._cached_all_negative_token_ids.append(token_id) self.cumulative_logprob += logprob def get_len(self) -> int: @@ -185,6 +193,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self._cached_all_token_ids + def get_negative_token_ids(self) -> List[int]: + return self._cached_all_negative_token_ids + def get_prefix_token_ids( self, num_tokens: int ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: @@ -368,6 +379,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() + def get_negative_token_ids(self) -> List[int]: + return self.data.get_negative_token_ids() + def get_prompt_token_ids(self) -> Tuple[int, ...]: return self.data.get_prompt_token_ids() From 6b4ec015c8a33612bd2cfb93d993d527f0a97de4 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Wed, 21 Aug 2024 10:06:29 +0800 Subject: [PATCH 07/18] fix for confliction --- vllm/sampling_params.py | 2 +- vllm/sequence.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9f7984c256bfd..1ad8921bff57a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -147,7 +147,7 @@ class SamplingParams( logits_processors: Optional[Any] = None include_stop_str_in_output: bool = False truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None - guidance_scale: Optional[float] = None, + guidance_scale: Optional[float] = None # The below fields are not supposed to be used as an input. # They are set in post_init. diff --git a/vllm/sequence.py b/vllm/sequence.py index badca9e7794bd..57a5eed2b4a23 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -253,7 +253,7 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) self._cached_all_negative_token_ids.append(token_id) - self.cumulative_logprob += logprob + self._cumulative_logprob += logprob def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) From 8ae8f7e647333580c0cbb919cb3fc01b24fb0e20 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Wed, 21 Aug 2024 10:08:24 +0800 Subject: [PATCH 08/18] fix for confliction --- vllm/engine/llm_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f449a3e8ae70b..7db7af142c17a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -70,9 +70,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) PromptComponents = Tuple[Optional[str], List[int], - Optional[MultiModalDataDict]] + Optional[MultiModalDataDict], + Optional[str], List[int]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional[MultiModalDataDict]] + Optional[MultiModalDataDict], + Optional[str], Optional[List[int]]] class LLMEngine: From a2877170754ec4f7a06ea98b07c5acea7457d840 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Wed, 21 Aug 2024 10:26:08 +0800 Subject: [PATCH 09/18] fix encoder/decoder_comps for confliction --- vllm/engine/llm_engine.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7db7af142c17a..17efd290b4ef2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -71,10 +71,10 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: PromptComponents = Tuple[Optional[str], List[int], Optional[MultiModalDataDict], - Optional[str], List[int]] + Optional[None], List[None]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], Optional[MultiModalDataDict], - Optional[str], Optional[List[int]]] + Optional[None], Optional[None]] class LLMEngine: @@ -916,14 +916,17 @@ def _build_decoder_only_llm_inputs( prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps + prompt, prompt_token_ids, multi_modal_data, \ + negative_prompt, negative_prompt_token_ids = prompt_comps prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids, prompt_adapter_request=prompt_adapter_request) return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + negative_prompt_token_ids=negative_prompt_token_ids, + negative_prompt=negative_prompt) def _process_decoder_only_prompt( self, From fc00849074d7440c30b1954ae277778c7a52d444 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Wed, 21 Aug 2024 15:58:57 +0800 Subject: [PATCH 10/18] update cfg for latest vllm --- tests/test_cfg_worker.py | 2 +- .../cfg_model_runner.py | 8 +- vllm/classifier_free_guidance/cfg_worker.py | 85 ++++++++----------- .../separated_worker.py | 56 ++++-------- vllm/engine/llm_engine.py | 4 +- vllm/sequence.py | 11 ++- 6 files changed, 72 insertions(+), 94 deletions(-) diff --git a/tests/test_cfg_worker.py b/tests/test_cfg_worker.py index ae11ce413a76b..6255f239af7cf 100644 --- a/tests/test_cfg_worker.py +++ b/tests/test_cfg_worker.py @@ -6,7 +6,7 @@ llm = LLM( model="facebook/opt-6.7b", - tensor_parallel_size=1, + tensor_parallel_size=2, use_v2_block_manager=True, classifier_free_guidance_model="facebook/opt-6.7b" ) diff --git a/vllm/classifier_free_guidance/cfg_model_runner.py b/vllm/classifier_free_guidance/cfg_model_runner.py index b015c7bc3998d..ee020402ff21f 100644 --- a/vllm/classifier_free_guidance/cfg_model_runner.py +++ b/vllm/classifier_free_guidance/cfg_model_runner.py @@ -3,6 +3,7 @@ import torch from vllm.distributed import get_pp_group +from vllm.multimodal import MultiModalInputs from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.worker.model_runner import (ModelRunner, ModelInputForGPUWithSamplingMetadata, FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, @@ -88,13 +89,18 @@ def model_execute( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + raise NotImplementedError("") + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **multi_modal_kwargs, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), **seqlen_agnostic_kwargs) return hidden_or_intermediate_states diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py index ca7add71f0f59..6dcc64598cd31 100644 --- a/vllm/classifier_free_guidance/cfg_worker.py +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -1,13 +1,12 @@ import copy -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple import torch -from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.config import ParallelConfig, ClassifierFreeGuidanceConfig +from vllm.distributed import get_pp_group, get_tp_group from vllm.logger import init_logger -from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata, SequenceData @@ -79,11 +78,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: (guidance_cache_block_size_bytes + root_cache_block_size_bytes)) return new_num_gpu_blocks, num_cpu_blocks - def initialize_cache( - self, - num_gpu_blocks: int, - num_cpu_blocks: int - ): + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: self.root_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) self.guidance_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, @@ -107,13 +103,19 @@ def execute_model( negative_seq_group_metadata = copy.copy(seq_group_metadata) negative_seq_data: Dict[int, SequenceData] = {} for seq_id, seq_data in seq_group_metadata.seq_data.items(): - negative_seq_data[seq_id] = copy.copy(seq_data) - negative_seq_data[seq_id].prompt_token_ids = seq_data.negative_prompt_token_ids - negative_seq_data[seq_id].negative_prompt_token_ids = [] - negative_seq_data[seq_id].output_token_ids = seq_data.output_token_ids[:] + negative_seq_data[seq_id] = SequenceData( + seq_data._negative_prompt_token_ids, + _output_token_ids=seq_data._output_token_ids, + _cumulative_logprob=seq_data._cumulative_logprob, + _prompt_token_ids_tuple=seq_data._negative_prompt_token_ids, + _num_computed_tokens=seq_data._num_computed_tokens, + _stage=seq_data.stage, + _cached_all_token_ids=seq_data._cached_all_token_ids, + _new_appended_tokens=seq_data._new_appended_tokens, + ) if negative_seq_group_metadata.is_prompt: - negative_seq_group_metadata._token_chunk_size = list(negative_seq_data.values())[0].get_len() + negative_seq_group_metadata.token_chunk_size = list(negative_seq_data.values())[0].get_len() negative_seq_group_metadata.seq_data = negative_seq_data negative_seq_group_metadata_list.append(negative_seq_group_metadata) @@ -121,56 +123,41 @@ def execute_model( else: negative_excute_model_req = None - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - - if self.do_metadata_broadcast: - broadcast_data = {"flag": 1} - broadcast_tensor_dict(broadcast_data, src=0) - else: - assert self.do_metadata_broadcast - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None + inputs = self.root_worker.prepare_input(execute_model_req) + negative_inputs = self.guidance_worker.prepare_input(negative_excute_model_req) + if inputs is None: + assert negative_inputs is None + return None # get root models's logits - scores = self.root_worker.execute_model_part(execute_model_req) + condition_logits = self.root_worker.execute_model_part(inputs) # get unconditional logits - unconditional_logits = self.guidance_worker.execute_model_part(negative_excute_model_req) + unconditional_logits = self.guidance_worker.execute_model_part(negative_inputs) # do classifier free guidance logist process - for seq_group in self.root_worker.model_input.sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - guidance_scale = seq_group.sampling_params.guidance_scale - if guidance_scale == 1.0: - break - for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): - logits_row = torch.nn.functional.log_softmax(scores[logits_row_idx], dim=-1) - unconditional_logits_row = torch.nn.functional.log_softmax(unconditional_logits[logits_row_idx], dim=-1) - scores[logits_row_idx] = guidance_scale * (logits_row - unconditional_logits_row) + unconditional_logits_row - - # print("scores:", scores.shape, scores) - # exit(0) + model_input, _ = inputs + if condition_logits is not None: + for seq_group in model_input.sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + guidance_scale = seq_group.sampling_params.guidance_scale + if guidance_scale == 1.0: + break + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + logits_row = torch.nn.functional.log_softmax(condition_logits[logits_row_idx], dim=-1) + unconditional_logits_row = torch.nn.functional.log_softmax(unconditional_logits[logits_row_idx], dim=-1) + condition_logits[logits_row_idx] = guidance_scale * (logits_row - unconditional_logits_row) + unconditional_logits_row # do logist_processor - scores = self.root_worker.compute_logits(scores) + scores = self.root_worker.compute_logits(condition_logits, model_input) if not self.is_driver_worker: return [] # do sample - output = self.root_worker.do_sample(scores) + output = self.root_worker.do_sample(scores, model_input) if not get_pp_group().is_last_rank: # output is IntermediateTensors - get_pp_group().send_tensor_dict(output.tensors) + get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) return [None] # output is List[SamplerOutput] diff --git a/vllm/classifier_free_guidance/separated_worker.py b/vllm/classifier_free_guidance/separated_worker.py index ff11cec3e7446..17102ff8c625d 100644 --- a/vllm/classifier_free_guidance/separated_worker.py +++ b/vllm/classifier_free_guidance/separated_worker.py @@ -1,73 +1,51 @@ -from typing import List, Optional +from typing import List, Optional, Tuple import torch -from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.distributed import get_pp_group, get_tp_group +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerInput -from vllm.worker.model_runner_base import ModelRunnerInputBase +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata class SeparatedWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.model_input = None - @torch.inference_mode() def get_logits( self, hidden_or_intermediate_states: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, ) -> torch.Tensor: - return self.model_runner.get_logits(hidden_or_intermediate_states, self.model_input) + return self.model_runner.get_logits(hidden_or_intermediate_states, model_input) @torch.inference_mode() def compute_logits( self, logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, ) -> torch.Tensor: - return self.model_runner.compute_logits(logits, self.model_input) + return self.model_runner.compute_logits(logits, model_input) @torch.inference_mode() def do_sample( self, logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, ) -> List[SamplerOutput]: - return self.model_runner.do_sample(logits, self.model_input) + return self.model_runner.do_sample(logits, model_input) @torch.inference_mode() def execute_model_part( self, - execute_model_req: Optional[ExecuteModelRequest] = None, + inputs: Tuple[BroadcastableModelInput, WorkerInput], ) -> Optional[List[SamplerOutput]]: - if self.is_driver_worker: - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - self.model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - num_steps = execute_model_req.num_steps - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update( - self.model_input.as_broadcastable_tensor_dict()) - broadcast_data["num_steps"] = num_steps - broadcast_tensor_dict(broadcast_data, src=0) - else: - assert self.do_metadata_broadcast - broadcast_data = broadcast_tensor_dict(src=0) - num_steps = broadcast_data.pop("num_steps") - worker_input = WorkerInput.from_broadcasted_tensor_dict( - broadcast_data) - self.model_input = ( - self.model_runner. - make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + model_input, worker_input = inputs + num_steps = worker_input.num_steps self.execute_worker(worker_input) @@ -78,10 +56,10 @@ def execute_model_part( intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict()) + get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group())) hidden_or_intermediate_states = self.model_runner.model_execute( - self.model_input, + model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, intermediate_tensors, @@ -92,6 +70,6 @@ def execute_model_part( if not get_pp_group().is_last_rank: return hidden_or_intermediate_states - logits = self.get_logits(hidden_or_intermediate_states) + logits = self.get_logits(hidden_or_intermediate_states, model_input) return logits diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 17efd290b4ef2..9402e714fddd8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -736,6 +736,8 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = None + negative_prompt = None + negative_prompt_token_ids = None elif isinstance(inputs, dict): if "prompt_token_ids" in inputs: prompt = None @@ -748,7 +750,7 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) - + if "negative_prompt_token_ids" in inputs: negative_prompt = None negative_prompt_token_ids = inputs["negative_prompt_token_ids"] diff --git a/vllm/sequence.py b/vllm/sequence.py index 57a5eed2b4a23..4eac5410afcae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -152,9 +152,11 @@ class SequenceData(msgspec.Struct, _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - _negative_prompt_token_ids: array + _negative_prompt_token_ids: array = msgspec.field( + default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) _negative_prompt_token_ids_tuple: Tuple[int, ...] = msgspec.field(default_factory=tuple) + _cached_all_negative_token_ids: List[int] = msgspec.field(default_factory=list) ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 @@ -174,6 +176,9 @@ def __post_init__(self) -> None: assert self._output_token_ids.typecode == "l" self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( self._prompt_token_ids) + assert self._negative_prompt_token_ids.typecode == "l" + self._negative_prompt_token_ids_tuple: Tuple[int, ...] = tuple( + self._negative_prompt_token_ids) self._update_cached_all_tokens() self._update_cached_all_negative_tokens() @@ -425,8 +430,8 @@ def __init__( "encoder input prompt fields?") self.data = SequenceData( - prompt_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), - negative_prompt_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, self.negative_prompt_token_ids)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), + _negative_prompt_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, self.negative_prompt_token_ids)) self.output_logprobs: SampleLogprobs = [] self.output_text = "" From e9ddb627c1b00f7971d7b01e71018c8bbb847d7c Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Wed, 21 Aug 2024 17:27:10 +0800 Subject: [PATCH 11/18] tiny fix --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9402e714fddd8..2c1af08460281 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -71,7 +71,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: PromptComponents = Tuple[Optional[str], List[int], Optional[MultiModalDataDict], - Optional[None], List[None]] + Optional[None], Optional[None]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], Optional[MultiModalDataDict], Optional[None], Optional[None]] From ca52237f3ec218c501c330808567c8d55d77a1d3 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Fri, 23 Aug 2024 14:10:44 +0800 Subject: [PATCH 12/18] refactor negative sequence --- .../cfg_model_runner.py | 214 +++++++++++++++++- vllm/classifier_free_guidance/cfg_worker.py | 37 ++- vllm/core/block_manager_v2.py | 47 ++++ vllm/core/scheduler.py | 31 ++- vllm/engine/llm_engine.py | 21 +- vllm/engine/output_processor/single_step.py | 2 + vllm/model_executor/models/llama.py | 24 +- vllm/sequence.py | 126 ++++------- 8 files changed, 386 insertions(+), 116 deletions(-) diff --git a/vllm/classifier_free_guidance/cfg_model_runner.py b/vllm/classifier_free_guidance/cfg_model_runner.py index ee020402ff21f..32caa6905fe7b 100644 --- a/vllm/classifier_free_guidance/cfg_model_runner.py +++ b/vllm/classifier_free_guidance/cfg_model_runner.py @@ -1,13 +1,24 @@ +# import dataclasses from typing import List, Optional, Union import torch from vllm.distributed import get_pp_group from vllm.multimodal import MultiModalInputs -from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import (ModelRunner, ModelInputForGPUWithSamplingMetadata, - FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, +# from vllm.utils import make_tensor_with_pad +# from vllm.model_executor import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput #, SequenceGroupMetadata +from vllm.worker.model_runner import (ModelRunner, ModelInputForGPUWithSamplingMetadata, #GPUModelRunnerBase, ModelInputForGPUBuilder, + FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, # _PAD_SLOT_ID, BatchPrefillWithPagedKVCacheWrapper) +# from vllm.worker.model_runner_base import ( +# _add_attn_metadata_broadcastable_dict, +# _add_sampling_metadata_broadcastable_dict, +# _init_attn_metadata_from_tensor_dict, +# _init_sampling_metadata_from_tensor_dict) + +# if TYPE_CHECKING: +# from vllm.attention.backends.abstract import AttentionBackend class CFGModelRunner(ModelRunner): @@ -161,3 +172,200 @@ def execute_model( logits = self.compute_logits(hidden_or_intermediate_states, model_input) return self.do_sample(logits, model_input) + + +# @dataclasses.dataclass(frozen=True) +# class PositiveNegativeModelInput(ModelInputForGPUWithSamplingMetadata): +# """ +# Used by the ClassifierFreeGuidanceModelRunner. +# """ +# negative_input_tokens: Optional[torch.Tensor] = None +# negative_input_positions: Optional[torch.Tensor] = None + +# def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: +# tensor_dict = { +# "input_tokens": self.input_tokens, +# "input_positions": self.input_positions, +# "negative_input_tokens": self.negative_input_tokens, +# "negative_input_positions": self.negative_input_positions, +# "virtual_engine": self.virtual_engine, +# "request_ids_to_seq_ids": self.request_ids_to_seq_ids, +# "finished_requests_ids": self.finished_requests_ids, +# } +# _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) +# _add_sampling_metadata_broadcastable_dict(tensor_dict, +# self.sampling_metadata) +# return tensor_dict + +# @classmethod +# def from_broadcasted_tensor_dict( +# cls, +# tensor_dict: Dict[str, Any], +# attn_backend: Optional["AttentionBackend"] = None, +# ) -> "ModelInputForGPUWithSamplingMetadata": +# return cast( +# ModelInputForGPUWithSamplingMetadata, +# super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +# class ClassifierFreeGuidanceModelRunner(GPUModelRunnerBase[PositiveNegativeModelInput]): +# _model_input_cls: Type[PositiveNegativeModelInput] = ( +# PositiveNegativeModelInput) +# _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + +# @torch.inference_mode() +# def execute_model( +# self, +# model_input: PositiveNegativeModelInput, +# kv_caches: List[torch.Tensor], +# intermediate_tensors: Optional[IntermediateTensors] = None, +# num_steps: int = 1, +# ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + +# if num_steps > 1: +# raise ValueError("num_steps > 1 is not supported in ModelRunner") + + + + +# def prepare_model_input( +# self, +# seq_group_metadata_list: List[SequenceGroupMetadata], +# virtual_engine: int = 0, +# finished_requests_ids: Optional[List[str]] = None +# ) -> PositiveNegativeModelInput: + +# model_input = self._prepare_model_input_tensors( +# seq_group_metadata_list, finished_requests_ids) + +# ( +# attn_metadata, +# negative_input_tokens_tensor, +# negative_input_positions_tensor, +# ) = (self._prepare_model_negative_input_tensors(seq_group_metadata_list, +# model_input)) + +# model_input = dataclasses.replace( +# model_input, +# attn_metadata=attn_metadata, +# negative_input_tokens=negative_input_tokens_tensor, +# negative_input_positions=negative_input_positions_tensor, +# ) + +# sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, +# model_input.seq_lens, +# model_input.query_lens, +# self.device, +# self.pin_memory) +# is_prompt = (seq_group_metadata_list[0].is_prompt +# if seq_group_metadata_list else None) +# return dataclasses.replace(model_input, +# sampling_metadata=sampling_metadata, +# is_prompt=is_prompt, +# virtual_engine=virtual_engine) + +# def _prepare_model_negative_input_tensors( +# self, +# seq_group_metadata_list: List[SequenceGroupMetadata], +# model_input: PositiveNegativeModelInput, +# ): +# if len(seq_group_metadata_list) == 0: +# return (model_input.attn_metadata, None, None) + +# is_prompt = seq_group_metadata_list[0].is_prompt + +# negative_seq_lens: List[int] = [] +# if is_prompt: +# # Prefill phase. +# negative_block_tables = self._empty_int32_tensor().view( +# len(seq_group_metadata_list), -1) + +# ( +# negative_input_tokens, +# negative_input_positions, +# negative_slot_mapping, +# ) = ( +# [], +# [], +# [], +# ) + +# for seq_group_metadata in seq_group_metadata_list: +# seq_len = seq_group_metadata.negative_seq_data.get_len() +# token_ids = seq_group_metadata.negative_seq_data.get_token_ids() +# negative_seq_lens.append(seq_len) + +# is_profile_run = (seq_group_metadata.block_tables is None) +# if is_profile_run: +# negative_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) +# else: +# for i in range(0, seq_len): +# block_number = seq_group_metadata.negative_block_table[ +# i // self.block_size] +# block_offset = i % self.block_size +# slot = block_number * self.block_size + block_offset +# negative_slot_mapping.append(slot) + +# negative_input_tokens.extend(token_ids) +# negative_input_positions.extend(list(range(0, seq_len))) + +# negative_input_tokens_tensor = self._list_to_long_tensor( +# negative_input_tokens) +# negative_input_positions_tensor = self._list_to_long_tensor( +# negative_input_positions) +# negative_slot_mapping_tensor = self._list_to_long_tensor( +# negative_slot_mapping) +# else: +# # Decode phase. +# negative_input_tokens_tensor = self._empty_long_tensor() +# negative_input_positions_tensor = self._empty_long_tensor() +# negative_slot_mapping_tensor = self._empty_long_tensor() + +# negative_block_tables = [] +# for seq_group_metadata in seq_group_metadata_list: +# negative_seq_lens.append( +# seq_group_metadata.negative_seq_data.get_len()) +# negative_block_table = seq_group_metadata.negative_block_table +# negative_block_tables.append([] if ( +# negative_block_table is None) else negative_block_table) + +# negative_block_tables = make_tensor_with_pad( +# negative_block_tables, +# max_len=max( +# len(block_table) for block_table in negative_block_tables), +# pad=0, +# dtype=torch.int32, +# device=self.device, +# ) + +# max_negative_seq_len = max(negative_seq_lens, default=0) +# negative_seq_lens_tensor = self._list_to_int32_tensor(negative_seq_lens) +# negative_seq_start_loc = torch.zeros(negative_seq_lens_tensor.shape[0] + +# 1, +# dtype=torch.int32, +# device=self.device) +# torch.cumsum(negative_seq_lens_tensor, +# dim=0, +# dtype=negative_seq_start_loc.dtype, +# out=negative_seq_start_loc[1:]) + +# attn_metadata = model_input.attn_metadata +# assert attn_metadata is not None +# ( +# attn_metadata.num_negative_tokens, +# attn_metadata.negative_seq_lens, +# attn_metadata.negative_seq_lens_tensor, +# attn_metadata.max_negative_seq_len, +# attn_metadata.negative_slot_mapping, +# attn_metadata.negative_block_tables, +# ) = ( +# sum(negative_seq_lens), +# negative_seq_lens, +# negative_seq_lens_tensor, +# max_negative_seq_len, +# negative_slot_mapping_tensor, +# negative_block_tables, +# ) + +# return (attn_metadata, negative_input_tokens_tensor, +# negative_input_positions_tensor) diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py index 6dcc64598cd31..24c936dbecb13 100644 --- a/vllm/classifier_free_guidance/cfg_worker.py +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -95,6 +95,10 @@ def execute_model( execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + # print("==> request positive :") + # for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # print("seq_group_metadata:", seq_group_metadata) + # prepare negative request with shallow copy if execute_model_req is not None: negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] @@ -102,27 +106,38 @@ def execute_model( for seq_group_metadata in execute_model_req.seq_group_metadata_list: negative_seq_group_metadata = copy.copy(seq_group_metadata) negative_seq_data: Dict[int, SequenceData] = {} - for seq_id, seq_data in seq_group_metadata.seq_data.items(): - negative_seq_data[seq_id] = SequenceData( - seq_data._negative_prompt_token_ids, - _output_token_ids=seq_data._output_token_ids, - _cumulative_logprob=seq_data._cumulative_logprob, - _prompt_token_ids_tuple=seq_data._negative_prompt_token_ids, - _num_computed_tokens=seq_data._num_computed_tokens, - _stage=seq_data.stage, - _cached_all_token_ids=seq_data._cached_all_token_ids, - _new_appended_tokens=seq_data._new_appended_tokens, - ) + negative_block_tables: Dict[int, List[int]] = {} + assert len(seq_group_metadata.seq_data) == 1 + for seq_id in seq_group_metadata.seq_data.keys(): + negative_seq_data[seq_id] = seq_group_metadata.negative_seq_data + # negative_seq_data[seq_id] = SequenceData( + # _prompt_token_ids=seq_group_metadata.negative_seq_data.prompt_token_ids_array, + # _output_token_ids=seq_data._output_token_ids, + # _cumulative_logprob=seq_data._cumulative_logprob, + # _prompt_token_ids_tuple=seq_group_metadata.negative_seq_data.prompt_token_ids, + # _num_computed_tokens=seq_data._num_computed_tokens, + # _stage=seq_data.stage, + # _cached_all_token_ids=seq_data._cached_all_token_ids, + # _new_appended_tokens=seq_data._new_appended_tokens, + # ) + negative_block_tables[seq_id] = seq_group_metadata.negative_block_table if negative_seq_group_metadata.is_prompt: negative_seq_group_metadata.token_chunk_size = list(negative_seq_data.values())[0].get_len() negative_seq_group_metadata.seq_data = negative_seq_data + negative_seq_group_metadata.block_tables = negative_block_tables + negative_seq_group_metadata.negative_seq_data = None + negative_seq_group_metadata.negative_block_table = None negative_seq_group_metadata_list.append(negative_seq_group_metadata) negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list else: negative_excute_model_req = None + # print("==> request negative:") + # for seq_group_metadata in negative_excute_model_req.seq_group_metadata_list: + # print("seq_group_metadata:", seq_group_metadata) + inputs = self.root_worker.prepare_input(execute_model_req) negative_inputs = self.guidance_worker.prepare_input(negative_excute_model_req) if inputs is None: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b7d9451f18067..348f9853fd1d4 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -15,6 +15,7 @@ from vllm.utils import Device SeqId = int +NegativeSeqId = str EncoderSeqId = str @@ -100,6 +101,7 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.negative_block_tables: Dict[NegativeSeqId, BlockTable] = {} self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( @@ -125,6 +127,12 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) + if seq_group.has_negative_prompt(): + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.get_negative_seq().get_token_ids(), + block_size=self.block_size, + ) + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.max_block_sliding_window) @@ -185,6 +193,13 @@ def allocate(self, seq_group: SequenceGroup) -> None: assert (request_id not in self.cross_block_tables), \ "block table already exists" + assert (request_id + not in self.negative_block_tables), \ + "block table already exists" + + if seq_group.has_negative_prompt(): + block_table = self._allocate_sequence(seq_group.get_negative_seq()) + self.negative_block_tables[request_id] = block_table check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) @@ -216,6 +231,14 @@ def can_append_slots(self, seq_group: SequenceGroup, seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, )) + + negative_block_table = self.negative_block_tables[seq_group.request_id] + num_touched_blocks += ( + negative_block_table.get_num_blocks_touched_by_append_slots( + token_ids=negative_block_table.get_unseen_token_ids( + seq_group.get_negative_seq().get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( Device.GPU) @@ -225,6 +248,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, + seq_group: SequenceGroup, ) -> List[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] @@ -234,6 +258,15 @@ def append_slots( num_lookahead_slots=num_lookahead_slots, num_computed_slots=seq.data.get_num_computed_tokens(), ) + + negative_block_table = self.negative_block_tables[seq_group.request_id] + negative_seq = seq_group.negative_seq + negative_block_table.append_token_ids( + token_ids=negative_block_table.get_unseen_token_ids(negative_seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + num_computed_slots=negative_seq.data.get_num_computed_tokens(), + ) + # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() return new_cows @@ -265,6 +298,13 @@ def free_cross(self, seq_group: SequenceGroup) -> None: self.cross_block_tables[request_id].free() del self.cross_block_tables[request_id] + def free_negative(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.negative_block_tables: + return + self.negative_block_tables[request_id].free() + del self.negative_block_tables[request_id] + def get_block_table(self, seq: Sequence) -> List[int]: block_ids = self.block_tables[seq.seq_id].physical_block_ids return block_ids # type: ignore @@ -276,6 +316,13 @@ def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_negative_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.negative_block_tables + block_ids = self.negative_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids + def access_all_blocks_in_seq(self, seq: Sequence, now: float): if self.enable_caching: # Record the latest access time for the sequence. The actual update diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7ae7b3e4dcda3..ca2cfb3b841d9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -431,6 +431,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) self._free_seq_group_cross_attn_blocks(aborted_group) + self._free_seq_group_negative_blocks(aborted_group) def _free_seq_group_cross_attn_blocks( self, @@ -443,6 +444,13 @@ def _free_seq_group_cross_attn_blocks( if seq_group.is_encoder_decoder(): self.block_manager.free_cross(seq_group) + def _free_seq_group_negative_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + if seq_group.has_negative_prompt(): + self.block_manager.free_negative(seq_group) + def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 @@ -1066,6 +1074,15 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: encoder_seq_data = None cross_block_table = None + if seq_group.has_negative_prompt(): + negative_seq_data = seq_group.get_negative_seq().data + negative_block_table = self.block_manager.get_negative_block_table( + seq_group + ) + else: + negative_seq_data = None + negative_block_table = None + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data @@ -1113,6 +1130,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, + negative_seq_data=negative_seq_data, + negative_block_table=negative_block_table, state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. @@ -1206,7 +1225,8 @@ def _append_slots( seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) + cows = self.block_manager.append_slots(seq, num_lookahead_slots, seq_group) + assert len(cows) == 0 if len(cows) > 0: blocks_to_copy.extend(cows) @@ -1327,15 +1347,6 @@ def _get_num_lookahead_slots(self, is_prefill: bool, seq_group: SequenceGroup=No if is_prefill: return 0 - num_lookahead_slots = 0 - if seq_group is not None: - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - if len(seq.get_negative_token_ids()) > len(seq.get_token_ids()): - addtional_length = len(seq.get_negative_token_ids()) - len(seq.get_token_ids()) - num_lookahead_slots = max(num_lookahead_slots, addtional_length) - - if num_lookahead_slots > 0: - return num_lookahead_slots return self.scheduler_config.num_lookahead_slots diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2c1af08460281..7a4cb3c1a3c0a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -594,6 +594,16 @@ def _add_processed_request( seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) + negative_seq = None + if 'negative_prompt_token_ids' in processed_inputs: + negative_seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + prompt_adapter_request, + from_negative_prompt=True) + encoder_seq = None if 'encoder_prompt_token_ids' in processed_inputs: encoder_seq = Sequence(seq_id, @@ -614,7 +624,8 @@ def _add_processed_request( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + negative_seq=negative_seq) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -1079,6 +1090,7 @@ def _create_sequence_group_with_sampling( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + negative_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -1105,7 +1117,8 @@ def _create_sequence_group_with_sampling( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + negative_seq=negative_seq) return seq_group @@ -1118,6 +1131,7 @@ def _create_sequence_group_with_pooling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, + negative_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -1130,7 +1144,8 @@ def _create_sequence_group_with_pooling( lora_request=lora_request, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + negative_seq=negative_seq) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4a46c93f84256..8968665835189 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -88,6 +88,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # only have one sequence seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) + negative_seq = seq_group.negative_seq + negative_seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0c67a9b8e198b..f3aff37577779 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,13 +30,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, tensor_model_parallel_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -413,8 +413,9 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale) + logit_scale, logits_as_input=True) self.sampler = Sampler() + self.org_vocab_size = config.vocab_size else: self.lm_head = PPMissingLayer() @@ -430,6 +431,23 @@ def forward( attn_metadata, intermediate_tensors) return model_output + def _get_logits(self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + # Get the logits for the next tokens. + logits = self.lm_head.linear_method.apply( + self.lm_head, + hidden_states, + bias=None, + ) + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/sequence.py b/vllm/sequence.py index 4eac5410afcae..05d8fb0713863 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -152,12 +152,6 @@ class SequenceData(msgspec.Struct, _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - _negative_prompt_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - _negative_prompt_token_ids_tuple: Tuple[int, - ...] = msgspec.field(default_factory=tuple) - _cached_all_negative_token_ids: List[int] = msgspec.field(default_factory=list) - ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 _prompt_token_ids_tuple: Tuple[int, @@ -176,11 +170,7 @@ def __post_init__(self) -> None: assert self._output_token_ids.typecode == "l" self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( self._prompt_token_ids) - assert self._negative_prompt_token_ids.typecode == "l" - self._negative_prompt_token_ids_tuple: Tuple[int, ...] = tuple( - self._negative_prompt_token_ids) self._update_cached_all_tokens() - self._update_cached_all_negative_tokens() def _update_cached_all_tokens(self): assert isinstance(self._prompt_token_ids, array) @@ -192,12 +182,6 @@ def _update_cached_all_tokens(self): def cumulative_logprob(self) -> float: return self._cumulative_logprob - def _update_cached_all_negative_tokens(self): - assert isinstance(self._negative_prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - self._cached_all_negative_token_ids: List[int] = list(self._negative_prompt_token_ids + - self._output_token_ids) - @property def prompt_token_ids(self) -> Tuple[int, ...]: return self._prompt_token_ids_tuple @@ -215,23 +199,6 @@ def prompt_token_ids_array(self) -> array: """ return self._prompt_token_ids - @property - def negative_prompt_token_ids(self) -> Tuple[int, ...]: - return self._negative_prompt_token_ids_tuple - - @negative_prompt_token_ids.setter - def negative_prompt_token_ids(self, new_negative_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def negative_prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._negative_prompt_token_ids - @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @@ -241,7 +208,6 @@ def output_token_ids(self, new_output_token_ids: List[int]) -> None: self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, new_output_token_ids) self._update_cached_all_tokens() - self._update_cached_all_negative_tokens() @property def output_token_ids_array(self) -> array: @@ -257,7 +223,6 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) - self._cached_all_negative_token_ids.append(token_id) self._cumulative_logprob += logprob def get_len(self) -> int: @@ -272,9 +237,6 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self._cached_all_token_ids - def get_negative_token_ids(self) -> List[int]: - return self._cached_all_negative_token_ids - def get_prefix_token_ids( self, num_tokens: int ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: @@ -348,7 +310,6 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " - f"negative_prompt_token_ids={self.negative_prompt_token_ids}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()}") @@ -389,6 +350,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, from_decoder_prompt: bool = True, + from_negative_prompt: bool = False, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -397,10 +359,9 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request self.from_decoder_prompt = from_decoder_prompt + self.from_negative_prompt = from_negative_prompt self._prompt: Optional[str] = None self._prompt_token_ids: Optional[List[int]] = None - self._negative_prompt: Optional[str] = None - self._negative_prompt_token_ids: Optional[List[int]] = None # For decoder-only models, a Sequence is constructed # from an LLMInputs instance (the `inputs` arg.) @@ -430,8 +391,7 @@ def __init__( "encoder input prompt fields?") self.data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), - _negative_prompt_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, self.negative_prompt_token_ids)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids)) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -456,8 +416,12 @@ def prompt(self) -> Optional[str]: # Select decoder or encoder input prompt str, # as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") + prompt_key: str = "prompt" + if not self.from_decoder_prompt: + prompt_key = "encoder_prompt" + if self.from_negative_prompt: + assert self.from_decoder_prompt is True + prompt_key = "negative_prompt" # Cache prompt self._prompt = cast(Optional[str], self.inputs.get(prompt_key)) @@ -471,9 +435,12 @@ def prompt_token_ids(self) -> List[int]: # Select decoder or encoder input prompt # token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") + prompt_token_ids_key: str = "prompt_token_ids" + if not self.from_decoder_prompt: + "encoder_prompt_token_ids" + if self.from_negative_prompt: + assert self.from_decoder_prompt is True + prompt_token_ids_key = "negative_prompt_token_ids" # Cache computed prompt token ids self._prompt_token_ids = cast(List[int], @@ -484,35 +451,6 @@ def prompt_token_ids(self) -> List[int]: def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.get("multi_modal_data") or {} - @property - def negative_prompt(self) -> Optional[str]: - if self._negative_prompt is not None: - # Reuse precomputed prompt string - return self._negative_prompt - - # Select decoder or encoder input prompt str, - # as appropriate - assert self.from_decoder_prompt is True - negative_prompt_key: str = "negative_prompt" - - # Cache prompt - self._negative_prompt = cast(Optional[str], self.inputs.get(negative_prompt_key)) - return self._negative_prompt - - @property - def negative_prompt_token_ids(self) -> List[int]: - if self._negative_prompt_token_ids is not None: - # Reuse precomputed negative prompt token ids - return self._negative_prompt_token_ids - - # Select decoder or encoder input prompt - # token ids, as appropriate - assert self.from_decoder_prompt is True - negative_prompt_token_ids_key: str = "negative_prompt_token_ids" - self._negative_prompt_token_ids = cast(List[int], - self.inputs.get(negative_prompt_token_ids_key)) - return self._negative_prompt_token_ids - @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -671,6 +609,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, + negative_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: @@ -694,6 +633,9 @@ def __init__( self.encoder_seq = encoder_seq self.trace_headers = trace_headers + assert self.is_single_seq is True + self.negative_seq = negative_seq + @property def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. @@ -722,23 +664,27 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) - @property - def multi_modal_data(self) -> "MultiModalDataDict": - # All sequences in the group should have the same multi-modal data. - # We use the multi-modal data of an arbitrary sequence. - return self.seqs[0].multi_modal_data - @property def negative_prompt(self) -> Optional[str]: - # All sequences in the group should have the same prompt. + # There are either 0 or 1 negative sequences # We use the prompt of an arbitrary sequence. - return self.seqs[0].negative_prompt + assert self.is_single_seq is True + return (self.negative_seqs[0].prompt + if self.negative_seqs is not None else None) @property def negative_prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return self.seqs[0].negative_prompt_token_ids + assert self.is_single_seq is True + return (self.negative_seqs[0].prompt_token_ids + if self.negative_seqs is not None else None) + + @property + def multi_modal_data(self) -> "MultiModalDataDict": + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return self.seqs[0].multi_modal_data @property def lora_int_id(self) -> int: @@ -833,6 +779,12 @@ def is_encoder_decoder(self) -> bool: def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq + def has_negative_prompt(self) -> bool: + return self.negative_seq is not None + + def get_negative_seq(self) -> Optional[Sequence]: + return self.negative_seq + def get_unfinished_seqs(self) -> List[Sequence]: if self.is_single_seq: return self.seqs if not self.seqs[0].is_finished() else [] @@ -986,6 +938,8 @@ class SequenceGroupMetadata( multi_modal_data: Optional[Any] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None + negative_seq_data: Optional[SequenceData] = None + negative_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None From cbb7ea636c8f9cb59c6202e43c904e0b868c9417 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Fri, 23 Aug 2024 14:13:02 +0800 Subject: [PATCH 13/18] fix property --- vllm/sequence.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 05d8fb0713863..abc15c57646a0 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -669,16 +669,16 @@ def negative_prompt(self) -> Optional[str]: # There are either 0 or 1 negative sequences # We use the prompt of an arbitrary sequence. assert self.is_single_seq is True - return (self.negative_seqs[0].prompt - if self.negative_seqs is not None else None) + return (self.negative_seq.prompt + if self.negative_seq is not None else None) @property def negative_prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. assert self.is_single_seq is True - return (self.negative_seqs[0].prompt_token_ids - if self.negative_seqs is not None else None) + return (self.negative_seq.prompt_token_ids + if self.negative_seq is not None else None) @property def multi_modal_data(self) -> "MultiModalDataDict": From 9f152ea54b6a2ecd9db34c8ae7f66b5fe646e95f Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Fri, 23 Aug 2024 14:14:54 +0800 Subject: [PATCH 14/18] rm comment --- .../cfg_model_runner.py | 214 +----------------- vllm/classifier_free_guidance/cfg_worker.py | 18 -- 2 files changed, 3 insertions(+), 229 deletions(-) diff --git a/vllm/classifier_free_guidance/cfg_model_runner.py b/vllm/classifier_free_guidance/cfg_model_runner.py index 32caa6905fe7b..15b18ecdad041 100644 --- a/vllm/classifier_free_guidance/cfg_model_runner.py +++ b/vllm/classifier_free_guidance/cfg_model_runner.py @@ -1,24 +1,13 @@ -# import dataclasses from typing import List, Optional, Union import torch from vllm.distributed import get_pp_group from vllm.multimodal import MultiModalInputs -# from vllm.utils import make_tensor_with_pad -# from vllm.model_executor import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput #, SequenceGroupMetadata -from vllm.worker.model_runner import (ModelRunner, ModelInputForGPUWithSamplingMetadata, #GPUModelRunnerBase, ModelInputForGPUBuilder, - FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, # _PAD_SLOT_ID, +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import (ModelRunner, ModelInputForGPUWithSamplingMetadata, + FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper) -# from vllm.worker.model_runner_base import ( -# _add_attn_metadata_broadcastable_dict, -# _add_sampling_metadata_broadcastable_dict, -# _init_attn_metadata_from_tensor_dict, -# _init_sampling_metadata_from_tensor_dict) - -# if TYPE_CHECKING: -# from vllm.attention.backends.abstract import AttentionBackend class CFGModelRunner(ModelRunner): @@ -172,200 +161,3 @@ def execute_model( logits = self.compute_logits(hidden_or_intermediate_states, model_input) return self.do_sample(logits, model_input) - - -# @dataclasses.dataclass(frozen=True) -# class PositiveNegativeModelInput(ModelInputForGPUWithSamplingMetadata): -# """ -# Used by the ClassifierFreeGuidanceModelRunner. -# """ -# negative_input_tokens: Optional[torch.Tensor] = None -# negative_input_positions: Optional[torch.Tensor] = None - -# def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: -# tensor_dict = { -# "input_tokens": self.input_tokens, -# "input_positions": self.input_positions, -# "negative_input_tokens": self.negative_input_tokens, -# "negative_input_positions": self.negative_input_positions, -# "virtual_engine": self.virtual_engine, -# "request_ids_to_seq_ids": self.request_ids_to_seq_ids, -# "finished_requests_ids": self.finished_requests_ids, -# } -# _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) -# _add_sampling_metadata_broadcastable_dict(tensor_dict, -# self.sampling_metadata) -# return tensor_dict - -# @classmethod -# def from_broadcasted_tensor_dict( -# cls, -# tensor_dict: Dict[str, Any], -# attn_backend: Optional["AttentionBackend"] = None, -# ) -> "ModelInputForGPUWithSamplingMetadata": -# return cast( -# ModelInputForGPUWithSamplingMetadata, -# super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) - - -# class ClassifierFreeGuidanceModelRunner(GPUModelRunnerBase[PositiveNegativeModelInput]): -# _model_input_cls: Type[PositiveNegativeModelInput] = ( -# PositiveNegativeModelInput) -# _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) - -# @torch.inference_mode() -# def execute_model( -# self, -# model_input: PositiveNegativeModelInput, -# kv_caches: List[torch.Tensor], -# intermediate_tensors: Optional[IntermediateTensors] = None, -# num_steps: int = 1, -# ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - -# if num_steps > 1: -# raise ValueError("num_steps > 1 is not supported in ModelRunner") - - - - -# def prepare_model_input( -# self, -# seq_group_metadata_list: List[SequenceGroupMetadata], -# virtual_engine: int = 0, -# finished_requests_ids: Optional[List[str]] = None -# ) -> PositiveNegativeModelInput: - -# model_input = self._prepare_model_input_tensors( -# seq_group_metadata_list, finished_requests_ids) - -# ( -# attn_metadata, -# negative_input_tokens_tensor, -# negative_input_positions_tensor, -# ) = (self._prepare_model_negative_input_tensors(seq_group_metadata_list, -# model_input)) - -# model_input = dataclasses.replace( -# model_input, -# attn_metadata=attn_metadata, -# negative_input_tokens=negative_input_tokens_tensor, -# negative_input_positions=negative_input_positions_tensor, -# ) - -# sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, -# model_input.seq_lens, -# model_input.query_lens, -# self.device, -# self.pin_memory) -# is_prompt = (seq_group_metadata_list[0].is_prompt -# if seq_group_metadata_list else None) -# return dataclasses.replace(model_input, -# sampling_metadata=sampling_metadata, -# is_prompt=is_prompt, -# virtual_engine=virtual_engine) - -# def _prepare_model_negative_input_tensors( -# self, -# seq_group_metadata_list: List[SequenceGroupMetadata], -# model_input: PositiveNegativeModelInput, -# ): -# if len(seq_group_metadata_list) == 0: -# return (model_input.attn_metadata, None, None) - -# is_prompt = seq_group_metadata_list[0].is_prompt - -# negative_seq_lens: List[int] = [] -# if is_prompt: -# # Prefill phase. -# negative_block_tables = self._empty_int32_tensor().view( -# len(seq_group_metadata_list), -1) - -# ( -# negative_input_tokens, -# negative_input_positions, -# negative_slot_mapping, -# ) = ( -# [], -# [], -# [], -# ) - -# for seq_group_metadata in seq_group_metadata_list: -# seq_len = seq_group_metadata.negative_seq_data.get_len() -# token_ids = seq_group_metadata.negative_seq_data.get_token_ids() -# negative_seq_lens.append(seq_len) - -# is_profile_run = (seq_group_metadata.block_tables is None) -# if is_profile_run: -# negative_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) -# else: -# for i in range(0, seq_len): -# block_number = seq_group_metadata.negative_block_table[ -# i // self.block_size] -# block_offset = i % self.block_size -# slot = block_number * self.block_size + block_offset -# negative_slot_mapping.append(slot) - -# negative_input_tokens.extend(token_ids) -# negative_input_positions.extend(list(range(0, seq_len))) - -# negative_input_tokens_tensor = self._list_to_long_tensor( -# negative_input_tokens) -# negative_input_positions_tensor = self._list_to_long_tensor( -# negative_input_positions) -# negative_slot_mapping_tensor = self._list_to_long_tensor( -# negative_slot_mapping) -# else: -# # Decode phase. -# negative_input_tokens_tensor = self._empty_long_tensor() -# negative_input_positions_tensor = self._empty_long_tensor() -# negative_slot_mapping_tensor = self._empty_long_tensor() - -# negative_block_tables = [] -# for seq_group_metadata in seq_group_metadata_list: -# negative_seq_lens.append( -# seq_group_metadata.negative_seq_data.get_len()) -# negative_block_table = seq_group_metadata.negative_block_table -# negative_block_tables.append([] if ( -# negative_block_table is None) else negative_block_table) - -# negative_block_tables = make_tensor_with_pad( -# negative_block_tables, -# max_len=max( -# len(block_table) for block_table in negative_block_tables), -# pad=0, -# dtype=torch.int32, -# device=self.device, -# ) - -# max_negative_seq_len = max(negative_seq_lens, default=0) -# negative_seq_lens_tensor = self._list_to_int32_tensor(negative_seq_lens) -# negative_seq_start_loc = torch.zeros(negative_seq_lens_tensor.shape[0] + -# 1, -# dtype=torch.int32, -# device=self.device) -# torch.cumsum(negative_seq_lens_tensor, -# dim=0, -# dtype=negative_seq_start_loc.dtype, -# out=negative_seq_start_loc[1:]) - -# attn_metadata = model_input.attn_metadata -# assert attn_metadata is not None -# ( -# attn_metadata.num_negative_tokens, -# attn_metadata.negative_seq_lens, -# attn_metadata.negative_seq_lens_tensor, -# attn_metadata.max_negative_seq_len, -# attn_metadata.negative_slot_mapping, -# attn_metadata.negative_block_tables, -# ) = ( -# sum(negative_seq_lens), -# negative_seq_lens, -# negative_seq_lens_tensor, -# max_negative_seq_len, -# negative_slot_mapping_tensor, -# negative_block_tables, -# ) - -# return (attn_metadata, negative_input_tokens_tensor, -# negative_input_positions_tensor) diff --git a/vllm/classifier_free_guidance/cfg_worker.py b/vllm/classifier_free_guidance/cfg_worker.py index 24c936dbecb13..f36367ffc2348 100644 --- a/vllm/classifier_free_guidance/cfg_worker.py +++ b/vllm/classifier_free_guidance/cfg_worker.py @@ -95,10 +95,6 @@ def execute_model( execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - # print("==> request positive :") - # for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # print("seq_group_metadata:", seq_group_metadata) - # prepare negative request with shallow copy if execute_model_req is not None: negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] @@ -110,16 +106,6 @@ def execute_model( assert len(seq_group_metadata.seq_data) == 1 for seq_id in seq_group_metadata.seq_data.keys(): negative_seq_data[seq_id] = seq_group_metadata.negative_seq_data - # negative_seq_data[seq_id] = SequenceData( - # _prompt_token_ids=seq_group_metadata.negative_seq_data.prompt_token_ids_array, - # _output_token_ids=seq_data._output_token_ids, - # _cumulative_logprob=seq_data._cumulative_logprob, - # _prompt_token_ids_tuple=seq_group_metadata.negative_seq_data.prompt_token_ids, - # _num_computed_tokens=seq_data._num_computed_tokens, - # _stage=seq_data.stage, - # _cached_all_token_ids=seq_data._cached_all_token_ids, - # _new_appended_tokens=seq_data._new_appended_tokens, - # ) negative_block_tables[seq_id] = seq_group_metadata.negative_block_table if negative_seq_group_metadata.is_prompt: @@ -134,10 +120,6 @@ def execute_model( else: negative_excute_model_req = None - # print("==> request negative:") - # for seq_group_metadata in negative_excute_model_req.seq_group_metadata_list: - # print("seq_group_metadata:", seq_group_metadata) - inputs = self.root_worker.prepare_input(execute_model_req) negative_inputs = self.guidance_worker.prepare_input(negative_excute_model_req) if inputs is None: From 927673c4bc41efdba636ff438e510528f3872dd5 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Fri, 23 Aug 2024 14:16:07 +0800 Subject: [PATCH 15/18] revert llama --- vllm/model_executor/models/llama.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f3aff37577779..0c67a9b8e198b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,13 +30,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, tensor_model_parallel_gather) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -413,9 +413,8 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale, logits_as_input=True) + logit_scale) self.sampler = Sampler() - self.org_vocab_size = config.vocab_size else: self.lm_head = PPMissingLayer() @@ -431,23 +430,6 @@ def forward( attn_metadata, intermediate_tensors) return model_output - def _get_logits(self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) - # Get the logits for the next tokens. - logits = self.lm_head.linear_method.apply( - self.lm_head, - hidden_states, - bias=None, - ) - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :self.org_vocab_size] - return logits - def compute_logits( self, hidden_states: torch.Tensor, From a1be553ba7c4438d687d4e9a915de810f9c52187 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Fri, 23 Aug 2024 18:17:05 +0800 Subject: [PATCH 16/18] add llama --- vllm/model_executor/models/llama.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0c67a9b8e198b..f3aff37577779 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,13 +30,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, tensor_model_parallel_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -413,8 +413,9 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale) + logit_scale, logits_as_input=True) self.sampler = Sampler() + self.org_vocab_size = config.vocab_size else: self.lm_head = PPMissingLayer() @@ -430,6 +431,23 @@ def forward( attn_metadata, intermediate_tensors) return model_output + def _get_logits(self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + # Get the logits for the next tokens. + logits = self.lm_head.linear_method.apply( + self.lm_head, + hidden_states, + bias=None, + ) + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + def compute_logits( self, hidden_states: torch.Tensor, From 97d6219ed0230d414b29d3289ee2781a15cac827 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Mon, 26 Aug 2024 15:29:36 +0800 Subject: [PATCH 17/18] fix free block and add negative for async engine --- vllm/core/scheduler.py | 1 + vllm/engine/async_llm_engine.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca2cfb3b841d9..fd5a908dd67ef 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1192,6 +1192,7 @@ def free_finished_seq_groups(self) -> None: if seq_group.is_finished(): # Free cross-attention block table, if it exists self._free_seq_group_cross_attn_blocks(seq_group) + self._free_seq_group_negative_blocks(seq_group) # Add the finished requests to the finished requests list. # This list will be used to update the Mamba cache in the # next step. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6385d3ca2297e..6b2b40698f88c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -449,6 +449,7 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = None + negative_prompt = negative_prompt_token_ids = None elif isinstance(inputs, dict): if "prompt_token_ids" in inputs: prompt = None @@ -462,11 +463,24 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) + if "negative_prompt_token_ids" in inputs: + negative_prompt = None + negative_prompt_token_ids = inputs["negative_prompt_token_ids"] + elif "negative_prompt" in inputs: + negative_prompt = parsed_negative_prompt = inputs["negative_prompt"] + negative_prompt_token_ids = await self._tokenize_prompt_async( + parsed_negative_prompt, + request_id=request_id, + lora_request=lora_request, + ) + else: + negative_prompt = negative_prompt_token_ids = None + multi_modal_data = inputs.get("multi_modal_data") else: assert_never(inputs) - return prompt, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data, negative_prompt, negative_prompt_token_ids async def _process_encoder_decoder_prompt_async( self, From fc22dddfb1657d49cd852f0eb99c8ec84ed3911b Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Fri, 30 Aug 2024 10:11:41 +0800 Subject: [PATCH 18/18] add __init__ --- vllm/classifier_free_guidance/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/classifier_free_guidance/__init__.py diff --git a/vllm/classifier_free_guidance/__init__.py b/vllm/classifier_free_guidance/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d