Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
chongmni-aws committed Sep 23, 2024
1 parent 5b15061 commit e51331c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 61 deletions.
21 changes: 13 additions & 8 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utilities for selecting and loading neuron models."""
import copy
import importlib
import os
from typing import Dict, List, Optional, Tuple
import copy

import torch
import torch.nn as nn
Expand All @@ -14,8 +14,8 @@
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SequenceOutput, CompletionSequenceGroupOutput, Logprob
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)

TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
Expand Down Expand Up @@ -81,11 +81,16 @@ def sample(
samples = []
for seq_id in seq_group.seq_ids:
token_id = hidden_states[sample_idx].item()
samples.append(SequenceOutput(parent_seq_id=seq_id, output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(CompletionSequenceGroupOutput(samples=samples, prompt_logprobs=None))
return next_tokens
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))

return SamplerOutput(outputs=next_tokens)

def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
Expand Down Expand Up @@ -171,7 +176,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization),
on_device_generation = copy.deepcopy(model_config.generation_config))
on_device_generation=copy.deepcopy(model_config.generation_config))
return default_neuron_args


Expand Down
135 changes: 82 additions & 53 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch import nn
from transformers_neuronx.config import GenerationConfig

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
Expand All @@ -17,7 +18,7 @@
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
from transformers_neuronx.config import GenerationConfig

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

Expand All @@ -29,14 +30,15 @@ class ModelInputForNeuron(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""

input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
self, ) -> Dict[str, Union[int, torch.Tensor]]:
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")

@classmethod
Expand Down Expand Up @@ -71,28 +73,29 @@ def __init__(
self.pin_memory = is_pin_memory_available()

# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY.create_input_mapper(
self.model_config)

# Lazy initialization.
self.model: nn.Module # initialize after load_model.
self.model_config.generation_config = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k = [1] * self.scheduler_config.max_num_seqs,
top_p = [1] * self.scheduler_config.max_num_seqs,
temperature = [1] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=256
)
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[1] * self.scheduler_config.max_num_seqs,
top_p=[1] * self.scheduler_config.max_num_seqs,
temperature=[1] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=256,
)

def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None:
self.model = get_neuron_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
scheduler_config=self.scheduler_config,
)
else:
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
Expand Down Expand Up @@ -136,24 +139,33 @@ def _prepare_prompt(

max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_tokens = make_tensor_with_pad(
input_tokens,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device,
)
input_positions = make_tensor_with_pad(
input_positions,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device,
)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)

multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
return (
input_tokens,
input_positions,
input_block_ids,
seq_lens,
multi_modal_kwargs,
)

def _prepare_decode(
self,
Expand Down Expand Up @@ -190,11 +202,13 @@ def _prepare_decode(
max_len=1,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(
input_positions,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device,
)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
Expand All @@ -212,20 +226,24 @@ def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
(
input_tokens,
input_positions,
input_block_ids,
seq_lens,
multi_modal_kwargs,
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
input_block_ids) = (self._prepare_decode(seq_group_metadata_list))
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
Expand All @@ -236,37 +254,48 @@ def prepare_model_input(
seq_lens,
self.device,
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
generators=self.get_generators(finished_requests_ids),
)

return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)
return ModelInputForNeuron(
input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs,
)

def _update_neuron_generation_config(self, sampling_metadata):
# Update Neuron Genetation Config
assert self.model_config.generation_config is not None, f"Failed to update generation_config, \
current generation config is {self.model_config.generation_config}"

current_generation_config = self.model_config.generation_config
assert current_generation_config is not None, (
f"Failed to update generation_config, "
f"current generation config is {current_generation_config}")

top_k = current_generation_config.top_k.copy()
top_p = current_generation_config.top_p.copy()
temperature = current_generation_config.temperature.copy()
for index, sequence_group_to_sample in enumerate(sampling_metadata.seq_groups):
for index, sequence_group_to_sample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = sequence_group_to_sample.sampling_params.top_k
top_p[index] = sequence_group_to_sample.sampling_params.top_p
temperature[index] = sequence_group_to_sample.sampling_params.temperature

# We only call update the generation config is the new config is different
# This avoids calling update_generation_config for every token within the same sequence
if top_k != current_generation_config.top_k or top_p != current_generation_config.top_p or temperature != current_generation_config.temperature:
temperature[index] = (
sequence_group_to_sample.sampling_params.temperature)

# Only call update the generation config is the new config is different.
# This avoids calling update_generation_config for every token within
# the same sequence.
if (top_k != current_generation_config.top_k
or top_p != current_generation_config.top_p
or temperature != current_generation_config.temperature):
current_generation_config.top_k = top_k
current_generation_config.top_p = top_p
current_generation_config.temperature = temperature
self.model_config.generation_config = copy.deepcopy(current_generation_config)
self.model_config.generation_config = copy.deepcopy(
current_generation_config)

self.model.model.update_generation_config(current_generation_config)
self.model.model.update_generation_config(
current_generation_config)

@torch.inference_mode()
def execute_model(
Expand Down

0 comments on commit e51331c

Please sign in to comment.