Skip to content

Commit

Permalink
[Model][Quant] Fix GLM, Fix fused module mappings for quantization (v…
Browse files Browse the repository at this point in the history
…llm-project#12634)

Signed-off-by: mgoin <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Felix Marty <[email protected]>
  • Loading branch information
2 people authored and fxmarty-amd committed Feb 7, 2025
1 parent 6e0ca10 commit 344c409
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 151 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Mapping, Optional, Type

import torch
from torch import nn
Expand Down Expand Up @@ -59,6 +59,7 @@ def method_has_implemented_embedding(

class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()

@abstractmethod
def get_name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def get_quant_method(

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
Expand Down Expand Up @@ -379,34 +381,29 @@ def get_scheme(self,

# Will be empty for models with only sparsity
weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)

scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")

if self.sparsity_scheme_map:
is_ignored = False
with suppress(ValueError):
is_ignored = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_ignore_list)

# if the layer is in the sparsity ignore list,
# we should not apply any sparsity scheme

if not is_ignored:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
# Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
sparsity_scheme = self.sparsity_scheme_map[matched_target]

if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
Expand Down
140 changes: 55 additions & 85 deletions vllm/model_executor/layers/quantization/compressed_tensors/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

import re
from typing import Iterable, Optional
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional

from compressed_tensors import CompressionFormat
from torch.nn import Module

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)


def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
Expand All @@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

Expand All @@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down Expand Up @@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
return False


def _handle_fused_layers(func):
"""
Decorator to handle fused layers by mapping vllm fused layer names
to their corresponding unfused layer names for quantization/pruning schemes.
"""
# fused_layer_name -> unfused_layer_name
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}

def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)

return fused_layer_handler


@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Expand All @@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""

if layer_name is None:
layer_name = ""

matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True)
or _match_fused_layer(layer_name, targets))
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping))

if matched_target is None:
raise ValueError(
Expand Down Expand Up @@ -205,39 +169,45 @@ def _is_equal_or_regex_match(value: str,
return False


def _match_fused_layer(layer_name: str,
target_layers: Iterable[str]) -> Optional[str]:
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers.
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
target_layers = ["model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# Split into parent path and layer type
# e.g., "model.layers.0.self_attn" and "qkv_proj"
parent_path = ".".join(layer_name.split(".")[:-1])
layer_type = layer_name.split(".")[-1]

if layer_type not in FUSED_LAYER_NAME_MAPPING:
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
None)
if fused is None:
return None

possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]

# Look for a target layer that:
# 1. Has the same parent path
# 2. Ends with one of the possible individual layer types
for target in target_layers:
is_same_parent = parent_path in target
is_matching_type = any(type_suffix in target
for type_suffix in possible_layer_types)

if is_same_parent and is_matching_type and all(
(f"{parent_path}.{type_suffix}" in target_layers)
for type_suffix in possible_layer_types):
return target
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]

return None
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)

return unfused_matches[0] if all(unfused_matches) else None
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform

__all__ = ["QuarkLinearMethod"]
Expand Down Expand Up @@ -58,7 +56,9 @@ def get_quant_method(self, layer: torch.nn.Module,

# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers):
if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
Expand Down Expand Up @@ -201,8 +201,8 @@ def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:

proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

import re
from typing import Any, Iterable, Optional

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional


def deep_compare(dict1: Any, dict2: Any) -> bool:
Expand All @@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

Expand All @@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping:
shard_proj_names = fused_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
Loading

0 comments on commit 344c409

Please sign in to comment.