From 6bf3b8b5e547bdd4b13851fe5927527d915d3c3a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Jan 2025 16:00:10 +0000 Subject: [PATCH] fix gpt2 quant model Signed-off-by: jiqing-feng --- optimum/exporters/ipex/model_patcher.py | 11 +- optimum/exporters/ipex/modeling_utils.py | 225 ++++++++++++----------- 2 files changed, 122 insertions(+), 114 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 9b38abad6..ee6082d1a 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -14,7 +14,7 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel -from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,13 +27,11 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, - _IPEXGPT2MLP, _falcon_model_forward, - _gpt2_block_forward, _gpt2_model_forward, _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, - _IPEXGPT2Attention, + _IPEXGPT2Block, _IPEXIntermediate, _IPEXLlamaDecoderLayer, _llama_model_forward, @@ -106,13 +104,12 @@ def _patch_gpt2_model(model): """ Patch gpt2 model: 1. Use IPEX paged attention + 2. Linear fusion with (Linear + Add) """ num_key_value_heads = model.config.num_attention_heads setattr(model.config, "num_key_value_heads", num_key_value_heads) convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) - convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) - convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.device, model.config) - convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.device, model.config) + convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index cd85c2531..1d9ef4b08 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -558,78 +558,6 @@ def _gpt2_model_forward( ) -# To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602 -def _gpt2_block_forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs, -) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs, - ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - # residual connection - if hasattr(self.attn, "linear_add"): - hidden_states = self.attn.linear_add(attn_output, residual) - else: - hidden_states = attn_output + residual - - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( - hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - **kwargs, - ) - attn_output = cross_attn_outputs[0] - # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - if hasattr(self.mlp, "linear_add"): - hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) - else: - hidden_states = residual + feed_forward_hidden_states - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) - - class _IPEXAttention(nn.Module): def __init__(self, module, device, config) -> None: super().__init__() @@ -844,26 +772,27 @@ class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, device, config) -> None: self.num_key_value_heads = config.num_key_value_heads super().__init__(module, device, config) - if getattr(config, "quantization_config", None): - _remove_hooks_for_ipex(self, True) - _setattr_from_module(self, module) - self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) - self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) - self.c_attn_linear.bias = self.c_attn.bias - self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) - self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) - self.c_proj_linear.bias = self.c_proj.bias - if self.module_device.type == "cpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = LinearAdd(self.c_proj_linear) - - elif self.module_device.type == "xpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = XPULinearAdd(self.c_proj_linear) + if getattr(config, "quantization_config", None) is None: + self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) + self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) + self.c_attn_linear.bias = self.c_attn.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) + if hasattr(self, "c_attn_linear"): + query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_heads, self.head_dim) key = key.view(-1, self.num_heads, self.head_dim) value = value.view(-1, self.num_heads, self.head_dim) @@ -951,27 +880,29 @@ def forward( class _IPEXGPT2MLP(nn.Module): - def __init__(self, module, config) -> None: + def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device - self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) - self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) - self.c_fc_linear.bias = self.c_fc.bias - self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) - self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) - self.c_proj_linear.bias = self.c_proj.bias - if self.module_device.type == "cpu": - self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) - - if self.module_device.type == "cpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = LinearAdd(self.c_proj_linear) - - elif self.module_device.type == "xpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = XPULinearAdd(self.c_proj_linear) + self.module_device = device + + if getattr(config, "quantization_config", None) is None: + self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) + self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) + self.c_fc_linear.bias = self.c_fc.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) + + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: if hasattr(self, "linear_new_gelu"): @@ -1048,6 +979,86 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return outputs +class _IPEXGPT2Block(nn.Module): + def __init__(self, module, device, config): + super().__init__() + _setattr_from_module(self, module) + self.attn = _IPEXGPT2Attention(module.attn, device, config) + self.mlp = _IPEXGPT2MLP(module.mlp, device, config) + if getattr(config, "quantization_config", None): + _remove_hooks_for_ipex(self, True) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + if hasattr(self.attn, "linear_add"): + hidden_states = self.attn.linear_add(attn_output, residual) + else: + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + if hasattr(self.mlp, "linear_add"): + hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) + else: + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524 class _IPEXIntermediate(nn.Module): def __init__(self, module, device, config):