Skip to content

Commit

Permalink
Configure pooler through PoolerConfig
Browse files Browse the repository at this point in the history
Signed-off-by: Went-Liang <[email protected]>
  • Loading branch information
Went-Liang committed Oct 28, 2024
1 parent 0d80f63 commit 3e2c7f4
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 47 deletions.
111 changes: 81 additions & 30 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,38 +112,58 @@ class ModelConfig:
Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
model.
pooling_norm: Used to determine whether to normalize the pooled
data in the embedding model.
pooling_softmax: Used to determine whether to softmax the pooled
data in the embedding model.
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
that the score corresponding to the pooling_step_tag_id in the
generated sentence should be returned. Otherwise, it returns
the scores for all tokens.
pooling_returned_token_ids: pooling_returned_token_ids represents a
list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of good_token and bad_token in the
math-shepherd-mistral-7b-prm model.
"""

def __init__(self,
model: str,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
model: str,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: bool = False,
pooling_softmax: bool = False,
pooling_step_tag_id: int = -1,
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -224,6 +244,13 @@ def __init__(self,
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self.pooler_config = self._init_pooler_config(
pooling_type,
pooling_norm,
pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids,
)

self._verify_quantization()
self._verify_cuda_graph()
Expand All @@ -242,6 +269,19 @@ def _init_multimodal_config(

return None

def _init_pooler_config(
self, pooling_type, pooling_norm, pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids) -> Optional["PoolerConfig"]:
if self.task == "embedding":
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
return None

def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures)
Expand Down Expand Up @@ -1647,6 +1687,17 @@ class MultiModalConfig:
# TODO: Add configs to init vision tower or not.


@dataclass
class PoolerConfig:
"""Controls the behavior of pooler in embedding model"""

pooling_type: Optional[str] = None
pooling_norm: bool = False
pooling_softmax: bool = False
pooling_step_tag_id: int = -1
pooling_returned_token_ids: Optional[List[int]] = None


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
Expand Down
48 changes: 48 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"

# Pooling configuration.
pooling_type: Optional[str] = None
pooling_norm: bool = False
pooling_softmax: bool = False
pooling_step_tag_id: int = -1
pooling_returned_token_ids: Optional[List[int]] = None

def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
Expand Down Expand Up @@ -850,6 +857,42 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).')

parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
default="LAST",
help='Used to configure the pooling method in the embedding model.'
)

parser.add_argument('--pooling-norm',
action='store_true',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")

parser.add_argument('--pooling-softmax',
action='store_true',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")

parser.add_argument(
'--pooling-step-tag-id',
type=int,
default=-1,
help="When pooling-step-tag-id is not -1, it indicates "
"that the score corresponding to the step-tag-ids in the "
"generated sentence should be returned. Otherwise, it "
"returns the scores for all tokens.")

parser.add_argument(
'--pooling-returned-token-ids',
nargs='+',
type=int,
default=None,
help="pooling-returned-token-ids represents a list of "
"indices for the vocabulary dimensions to be extracted, "
"such as the token IDs of good_token and bad_token in "
"the math-shepherd-mistral-7b-prm model.")

return parser

@classmethod
Expand Down Expand Up @@ -891,6 +934,11 @@ def create_model_config(self) -> ModelConfig:
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
pooling_type=self.pooling_type,
pooling_norm=self.pooling_norm,
pooling_softmax=self.pooling_softmax,
pooling_step_tag_id=self.pooling_step_tag_id,
pooling_returned_token_ids=self.pooling_returned_token_ids,
)

def create_load_config(self) -> LoadConfig:
Expand Down
9 changes: 8 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
"chat_template_text_format=%s, mm_processor_kwargs=%s, "
"pooling_type=%s, pooling_norm=%s, pooling_softmax=%s, "
"pooling_step_tag_id=%s, pooling_returned_token_ids=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -293,6 +295,11 @@ def __init__(
use_cached_outputs,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,
model_config.pooler_config.pooling_type,
model_config.pooler_config.pooling_norm,
model_config.pooler_config.pooling_softmax,
model_config.pooler_config.pooling_step_tag_id,
model_config.pooler_config.pooling_returned_token_ids,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def __init__(
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
pooling_type: Optional[str] = None,
pooling_norm: bool = False,
pooling_softmax: bool = False,
pooling_step_tag_id: int = -1,
pooling_returned_token_ids: Optional[List[int]] = None,
**kwargs,
) -> None:
'''
Expand Down Expand Up @@ -193,6 +198,11 @@ def __init__(
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
Expand Down
26 changes: 17 additions & 9 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from enum import IntEnum
from typing import List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -30,18 +30,21 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""

def __init__(self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool = False):
def __init__(
self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool = False,
step_tag_id: int = -1,
returned_token_ids: Optional[List[int]] = None,
):
super().__init__()

self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax
returned_token_ids = os.environ.get('RETURNED_TOKEN_IDS', '648,387')
self.returned_token_ids = list(map(int, returned_token_ids.split(",")))
self.step_tag_id = int(os.environ.get('STEP_TOKEN_ID', -1))
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids

def forward(
self,
Expand All @@ -68,7 +71,12 @@ def forward(
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.STEP:
logits = hidden_states[:, self.returned_token_ids].softmax(dim=-1)
if self.returned_token_ids is not None and len(
self.returned_token_ids) > 0:
logits = hidden_states[:,
self.returned_token_ids].softmax(dim=-1)
else:
logits = hidden_states.softmax(dim=-1)
offset = 0
pooled_data = []
for prompt_len, seq_data_i in zip(
Expand Down
15 changes: 10 additions & 5 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, PoolerConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
Expand Down Expand Up @@ -122,7 +122,8 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
scheduler_config: Optional[SchedulerConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand All @@ -143,7 +144,8 @@ def _get_model_initialization_kwargs(

if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config

if pooler_config:
extra_kwargs["pooler_config"] = pooler_config
return extra_kwargs


Expand All @@ -152,10 +154,12 @@ def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig], *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
scheduler_config: Optional[SchedulerConfig],
pooler_config: Optional[PoolerConfig]) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
scheduler_config,
pooler_config)

return model_class(config=hf_config,
cache_config=cache_config,
Expand All @@ -180,6 +184,7 @@ def _initialize_model(
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooler_config=model_config.pooler_config,
)


Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -500,6 +500,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -540,7 +541,13 @@ def __init__(
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self._pooler = Pooler(pooling_type=PoolingType.STEP, normalize=False)
self._pooler = Pooler(
pooling_type=PoolingType[pooler_config.pooling_type],
normalize=pooler_config.pooling_norm,
softmax=pooler_config.pooling_softmax,
step_tag_id=pooler_config.pooling_step_tag_id,
returned_token_ids=pooler_config.pooling_returned_token_ids,
)

def forward(
self,
Expand Down

0 comments on commit 3e2c7f4

Please sign in to comment.