Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand MLA to support most types of quantization #13181

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 25 additions & 45 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)

Expand Down Expand Up @@ -220,16 +220,6 @@ def _q_proj_and_k_up_proj(self, x):
.view(-1, self.num_heads, self.kv_lora_rank)

def process_weights_after_loading(self, act_dtype: torch.dtype):

def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))

def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)

# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
Expand All @@ -239,7 +229,7 @@ def quantization_scheme_supported(layer: LinearBase) -> bool:
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
Expand Down Expand Up @@ -267,41 +257,31 @@ def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)

def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale

def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)

return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
def get_layer_weight(layer):
if hasattr(layer, "weight"):
return layer.weight
elif hasattr(layer, "qweight"):
return layer.qweight
else:
raise AttributeError(
f"Layer '{layer}' has neither weight nor qweight")

if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")

weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight

weight_dtype = get_layer_weight(self.kv_b_proj).dtype
assert get_layer_weight(self.o_proj).dtype == weight_dtype
assert get_layer_weight(self.q_proj).dtype == weight_dtype

kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
Expand Down
32 changes: 1 addition & 31 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,37 +986,7 @@ def is_cross_encoder(self) -> bool:

@property
def use_mla(self) -> bool:
if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE:
return False

if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False

# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups", {
"": {}
}).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False

return True
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE

@property
def supported_runner_types(self) -> Set[RunnerType]:
Expand Down
90 changes: 34 additions & 56 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ def _initialize_model(
return model_class(**kwargs)


def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)

# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)


class BaseModelLoader(ABC):
"""Base class for model loaders."""

Expand Down Expand Up @@ -376,7 +400,6 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
Expand All @@ -394,23 +417,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)

return model.eval()


Expand All @@ -429,29 +437,15 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()


Expand Down Expand Up @@ -632,6 +626,7 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank
Expand All @@ -640,18 +635,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config.revision)

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
_process_weights_after_loading(model, model_config,
target_device)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
Expand Down Expand Up @@ -1401,16 +1388,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
self._get_weights_iterator(model_weights,
model_config.revision))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()


Expand Down
Loading