From fbaf889bd5c8f9cb7d2dad89384a79ab2ff3c1b7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 16 Aug 2023 10:25:04 +0100 Subject: [PATCH 1/2] fix pylint errors --- vllm/awq_quantization/qmodule.py | 77 +++++++++++++++++++++++------ vllm/config.py | 5 +- vllm/engine/arg_utils.py | 6 ++- vllm/model_executor/model_loader.py | 5 +- vllm/model_executor/models/llama.py | 44 +++++++++++------ 5 files changed, 104 insertions(+), 33 deletions(-) diff --git a/vllm/awq_quantization/qmodule.py b/vllm/awq_quantization/qmodule.py index 4fe7c3a8b5c66..9b6f72404a725 100644 --- a/vllm/awq_quantization/qmodule.py +++ b/vllm/awq_quantization/qmodule.py @@ -1,14 +1,14 @@ # adapted from llm-awq: https://github.com/mit-han-lab/llm-awq -import math import torch import torch.nn as nn try: import awq_inference_engine # with CUDA kernels except ImportError as ex: - msg = "Unable to import awq_inference_engine: run setup.py to install CUDA kernels" - raise ImportError(msg) + raise ImportError( + "Unable to import awq_inference_engine: run setup.py" + " to install AWQ CUDA kernels") class ScaledActivation(nn.Module): @@ -16,18 +16,26 @@ def __init__(self, module, scales): super().__init__() self.act = module self.scales = nn.Parameter(scales.data) - + def forward(self, x): return self.act(x) / self.scales.view(1, 1, -1).to(x.device) class WQLinear(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + def __init__( + self, + w_bit, + group_size, + in_features, + out_features, + bias, + dev + ): super().__init__() - + if w_bit not in [4]: raise NotImplementedError("Only 4-bit are supported for now.") - + self.in_features = in_features self.out_features = out_features self.w_bit = w_bit @@ -37,23 +45,62 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 - self.register_buffer('qweight', torch.empty((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('qzeros', torch.empty((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('scales', torch.empty((in_features // self.group_size, out_features), dtype=torch.float16, device=dev)) + qweight_buffer = torch.empty( + (in_features, out_features // (32 // self.w_bit)), + dtype=torch.int32, + device=dev + ) + self.register_buffer("qweight", qweight_buffer) + + qzeros_buffer = torch.empty( + ( + in_features // self.group_size, + out_features // (32 // self.w_bit) + ), + dtype=torch.int32, + device=dev + ) + self.register_buffer("qzeros", qzeros_buffer) + + scales_buffer = torch.empty( + (in_features // self.group_size, out_features), + dtype=torch.float16, + device=dev + ) + self.register_buffer("scales", scales_buffer) if bias: - self.register_buffer('bias', torch.empty((out_features), dtype=torch.float16, device=dev)) + bias_buffer = torch.empty( + (out_features), + dtype=torch.float16, + device=dev + ) + self.register_buffer("bias", bias_buffer) else: self.bias = None @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features, ) - out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) + + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + 8 + ) + out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) - + def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( - self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size + str_repr = "in_features={}, out_features={}, " \ + "bias={}, w_bit={}, group_size={}" + return str_repr.format( + self.in_features, + self.out_features, + self.bias is not None, + self.w_bit, + self.group_size ) diff --git a/vllm/config.py b/vllm/config.py index 7cfec13f409dc..d9e857185f8af 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,7 +33,7 @@ def __init__( self._verify() def _verify(self) -> None: - allowed_methods = ['awq'] + allowed_methods = ["awq"] if self.method not in allowed_methods: raise ValueError( f"Unknown quantization method ({self.method})" @@ -118,7 +118,8 @@ def verify_with_parallel_config( f"({pipeline_parallel_size}).") if self.quantization_config and tensor_parallel_size > 1: - raise NotImplementedError("Quantization does not currently support tensor parallelism") + raise NotImplementedError( + "Quantization does not currently support tensor parallelism") def get_hidden_size(self) -> int: return self.hf_config.hidden_size diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index af005bb5e22e7..3d7d20ef99fe3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -152,7 +152,11 @@ def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: # Initialize the configs. - quantization_config = QuantizationConfig(self.quantization) if self.quantization else None + if self.quantization is not None: + quantization_config = QuantizationConfig(self.quantization) + else: + quantization_config = None + model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.use_np_weights, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 8ccc8bad96f5a..aa897acf3beca 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -51,7 +51,10 @@ def get_model(model_config: ModelConfig) -> nn.Module: # The weights will be initialized as empty tensors. if _supports_quantization(model_class): - model = model_class(model_config.hf_config, model_config.quantization_config) + model = model_class( + model_config.hf_config, + model_config.quantization_config + ) else: model = model_class(model_config.hf_config) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f4aea9366f13f..dc36dacde6d85 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -164,7 +164,7 @@ def __init__( super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() - assert tp_size == 1, 'quantization does not support TP' + assert tp_size == 1, "quantization does not support TP" self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -178,7 +178,7 @@ def __init__( self.qkv_proj = get_quantized_layer( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + self.q_size + 2 * self.kv_size, quant_config ) @@ -220,8 +220,17 @@ def __init__( quant_config: QuantizationConfig ): super().__init__() - self.gate_up_proj = get_quantized_layer(hidden_size, 2 * intermediate_size, quant_config) - self.down_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config) + + self.gate_up_proj = get_quantized_layer( + hidden_size, + 2 * intermediate_size, quant_config + ) + + self.down_proj = get_quantized_layer( + intermediate_size, + hidden_size, + quant_config + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -313,9 +322,12 @@ def __init__(self, config: LlamaConfig, quant_config: QuantizationConfig): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) + self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config, quant_config) + for _ in range(config.num_hidden_layers) ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -414,10 +426,8 @@ def load_weights(self, extra_rows = extra_rows.to(loaded_weight) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - is_quantized = self.quant_config is not None and self.quant_config.method is not None - # merge linear layers - if not is_quantized: + if self.quant_config is not None: is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: @@ -454,17 +464,21 @@ def load_weights(self, if is_gate_up_weight: continue else: - # TODO: improve this block of code (not DRY, hacky, specific to AWQ) + # TODO: improve this block of code is_attention_weight = False - for stride_id, (weight_name, shard_size, offset) in enumerate(attention_weight_specs): + for stride_id, weight_spec in enumerate(attention_weight_specs): + weight_name, shard_size, offset = weight_spec + if weight_name not in name: continue + param = state_dict[name.replace(weight_name, "qkv_proj")] # TODO: this is specific to AWQ (should be more general) - if 'qweight' in name or 'qzeros' in name: - shard_size = int(shard_size // (32 / self.quant_config.bits)) - offset = int(offset // (32 / self.quant_config.bits)) + if "qweight" in name or "qzeros" in name: + adjustment = 32 / self.quant_config.bits + shard_size = int(shard_size // adjustment) + offset = int(offset // adjustment) param_slice = param.data[:, offset:offset + shard_size] assert param_slice.shape == loaded_weight.shape @@ -482,7 +496,9 @@ def load_weights(self, param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[1] // 2 - start, end = shard_size * stride_id, shard_size * (stride_id + 1) + start = shard_size * stride_id + end = shard_size * (stride_id + 1) + param_slice = param.data[:, start:end] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) From db4db0c2b054ae587aa20d8d6be606760f2a5ada Mon Sep 17 00:00:00 2001 From: root Date: Wed, 16 Aug 2023 11:33:06 +0100 Subject: [PATCH 2/2] improve the quant weight loaded code --- vllm/model_executor/models/llama.py | 98 +++++++++++------------------ 1 file changed, 36 insertions(+), 62 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index dc36dacde6d85..4fc443c515250 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -426,86 +426,60 @@ def load_weights(self, extra_rows = extra_rows.to(loaded_weight) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - # merge linear layers - if self.quant_config is not None: - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "qkv_proj")] + is_quantized = self.quant_config is not None - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: continue + param = state_dict[name.replace(weight_name, "qkv_proj")] - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 + if not is_quantized: loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - else: - # TODO: improve this block of code - is_attention_weight = False - for stride_id, weight_spec in enumerate(attention_weight_specs): - weight_name, shard_size, offset = weight_spec - - if weight_name not in name: - continue - - param = state_dict[name.replace(weight_name, "qkv_proj")] - - # TODO: this is specific to AWQ (should be more general) + param_slice = param.data[offset:offset + shard_size] + else: + # TODO: this is specific to AWQ if "qweight" in name or "qzeros" in name: adjustment = 32 / self.quant_config.bits shard_size = int(shard_size // adjustment) offset = int(offset // adjustment) - param_slice = param.data[:, offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] + if not is_quantized: + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + else: shard_size = param.shape[1] // 2 - start = shard_size * stride_id end = shard_size * (stride_id + 1) - param_slice = param.data[:, start:end] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue + + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue param = state_dict[name] load_tensor_parallel_weights(param, loaded_weight, name,