From e51331cc7d4a4b26fc4bef921022f2856c267238 Mon Sep 17 00:00:00 2001 From: Chongming Ni Date: Mon, 23 Sep 2024 21:29:21 +0000 Subject: [PATCH] Fix formatting. --- vllm/model_executor/model_loader/neuron.py | 21 ++-- vllm/worker/neuron_model_runner.py | 135 +++++++++++++-------- 2 files changed, 95 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 21c1b0d996686..d156684b8209d 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,8 +1,8 @@ """Utilities for selecting and loading neuron models.""" +import copy import importlib import os from typing import Dict, List, Optional, Tuple -import copy import torch import torch.nn as nn @@ -14,8 +14,8 @@ from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SequenceOutput, CompletionSequenceGroupOutput, Logprob -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) TORCH_DTYPE_TO_NEURON_AMP = { "auto": "f32", @@ -81,11 +81,16 @@ def sample( samples = [] for seq_id in seq_group.seq_ids: token_id = hidden_states[sample_idx].item() - samples.append(SequenceOutput(parent_seq_id=seq_id, output_token=token_id, - logprobs={token_id: Logprob(token_id)})) + samples.append( + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) sample_idx += 1 - next_tokens.append(CompletionSequenceGroupOutput(samples=samples, prompt_logprobs=None)) - return next_tokens + next_tokens.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + + return SamplerOutput(outputs=next_tokens) def load_weights(self, model_name_or_path: str, **kwargs): arch = _get_model_architecture(self.config) @@ -171,7 +176,7 @@ def _get_default_neuron_config(model_config: ModelConfig, if model_config.quantization else None, continuous_batching=continuous_batching_config, weight_tiling=bool(model_config.quantization), - on_device_generation = copy.deepcopy(model_config.generation_config)) + on_device_generation=copy.deepcopy(model_config.generation_config)) return default_neuron_args diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f4a2ebd6613bc..af8045bcd47e7 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -5,6 +5,7 @@ import torch from torch import nn +from transformers_neuronx.config import GenerationConfig from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -17,7 +18,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase -from transformers_neuronx.config import GenerationConfig + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -29,6 +30,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): """ Used by the NeuronModelRunner. """ + input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None @@ -36,7 +38,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): multi_modal_kwargs: Optional[BatchedTensorInputs] = None def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: + self, ) -> Dict[str, Union[int, torch.Tensor]]: raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") @classmethod @@ -71,28 +73,29 @@ def __init__( self.pin_memory = is_pin_memory_available() # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY.create_input_mapper( + self.model_config) # Lazy initialization. self.model: nn.Module # initialize after load_model. self.model_config.generation_config = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k = [1] * self.scheduler_config.max_num_seqs, - top_p = [1] * self.scheduler_config.max_num_seqs, - temperature = [1] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=256 - ) + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[1] * self.scheduler_config.max_num_seqs, + top_p=[1] * self.scheduler_config.max_num_seqs, + temperature=[1] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=256, + ) def load_model(self) -> None: if find_spec("transformers_neuronx") is not None: self.model = get_neuron_model( self.model_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + ) else: raise NotImplementedError( "Supports only Transformer-NeuronX based models.") @@ -136,24 +139,33 @@ def _prepare_prompt( max_seq_len = max(seq_lens) assert max_seq_len > 0 - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) + input_tokens = make_tensor_with_pad( + input_tokens, + pad=0, + max_len=max_seq_len, + dtype=torch.long, + device=self.device, + ) + input_positions = make_tensor_with_pad( + input_positions, + pad=0, + max_len=max_seq_len, + dtype=torch.long, + device=self.device, + ) input_block_ids = torch.tensor(input_block_ids, dtype=torch.long, device=self.device) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - return (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs) + return ( + input_tokens, + input_positions, + input_block_ids, + seq_lens, + multi_modal_kwargs, + ) def _prepare_decode( self, @@ -190,11 +202,13 @@ def _prepare_decode( max_len=1, dtype=torch.long, device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) + input_positions = make_tensor_with_pad( + input_positions, + pad=0, + max_len=1, + dtype=torch.long, + device=self.device, + ) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -212,7 +226,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> ModelInputForNeuron: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or @@ -220,12 +234,16 @@ def prepare_model_input( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) + ( + input_tokens, + input_positions, + input_block_ids, + seq_lens, + multi_modal_kwargs, + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) + input_block_ids) = (self._prepare_decode(seq_group_metadata_list)) seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -236,37 +254,48 @@ def prepare_model_input( seq_lens, self.device, self.pin_memory, - generators=self.get_generators(finished_requests_ids)) + generators=self.get_generators(finished_requests_ids), + ) - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs) + return ModelInputForNeuron( + input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata, + multi_modal_kwargs=multi_modal_kwargs, + ) def _update_neuron_generation_config(self, sampling_metadata): # Update Neuron Genetation Config - assert self.model_config.generation_config is not None, f"Failed to update generation_config, \ - current generation config is {self.model_config.generation_config}" - current_generation_config = self.model_config.generation_config + assert current_generation_config is not None, ( + f"Failed to update generation_config, " + f"current generation config is {current_generation_config}") + top_k = current_generation_config.top_k.copy() top_p = current_generation_config.top_p.copy() temperature = current_generation_config.temperature.copy() - for index, sequence_group_to_sample in enumerate(sampling_metadata.seq_groups): + for index, sequence_group_to_sample in enumerate( + sampling_metadata.seq_groups): top_k[index] = sequence_group_to_sample.sampling_params.top_k top_p[index] = sequence_group_to_sample.sampling_params.top_p - temperature[index] = sequence_group_to_sample.sampling_params.temperature - - # We only call update the generation config is the new config is different - # This avoids calling update_generation_config for every token within the same sequence - if top_k != current_generation_config.top_k or top_p != current_generation_config.top_p or temperature != current_generation_config.temperature: + temperature[index] = ( + sequence_group_to_sample.sampling_params.temperature) + + # Only call update the generation config is the new config is different. + # This avoids calling update_generation_config for every token within + # the same sequence. + if (top_k != current_generation_config.top_k + or top_p != current_generation_config.top_p + or temperature != current_generation_config.temperature): current_generation_config.top_k = top_k current_generation_config.top_p = top_p current_generation_config.temperature = temperature - self.model_config.generation_config = copy.deepcopy(current_generation_config) + self.model_config.generation_config = copy.deepcopy( + current_generation_config) - self.model.model.update_generation_config(current_generation_config) + self.model.model.update_generation_config( + current_generation_config) @torch.inference_mode() def execute_model(