Skip to content

Commit

Permalink
fix gpt2 quant model
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 21, 2025
1 parent dab4a78 commit 6bf3b8b
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 114 deletions.
11 changes: 4 additions & 7 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -27,13 +27,11 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
_IPEXGPT2Block,
_IPEXIntermediate,
_IPEXLlamaDecoderLayer,
_llama_model_forward,
Expand Down Expand Up @@ -106,13 +104,12 @@ def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Use IPEX paged attention
2. Linear fusion with (Linear + Add)
"""
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.device, model.config)
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.device, model.config)
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
return model


Expand Down
225 changes: 118 additions & 107 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,78 +558,6 @@ def _gpt2_model_forward(
)


# To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602
def _gpt2_block_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
if hasattr(self.attn, "linear_add"):
hidden_states = self.attn.linear_add(attn_output, residual)
else:
hidden_states = attn_output + residual

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
if hasattr(self.mlp, "linear_add"):
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
else:
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]

return outputs # hidden_states, present, (attentions, cross_attentions)


class _IPEXAttention(nn.Module):
def __init__(self, module, device, config) -> None:
super().__init__()
Expand Down Expand Up @@ -844,26 +772,27 @@ class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, device, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

_setattr_from_module(self, module)
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)
if getattr(config, "quantization_config", None) is None:
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
if hasattr(self, "c_attn_linear"):
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
Expand Down Expand Up @@ -951,27 +880,29 @@ def forward(


class _IPEXGPT2MLP(nn.Module):
def __init__(self, module, config) -> None:
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)

if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)
self.module_device = device

if getattr(config, "quantization_config", None) is None:
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)

if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
if hasattr(self, "linear_new_gelu"):
Expand Down Expand Up @@ -1048,6 +979,86 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
return outputs


class _IPEXGPT2Block(nn.Module):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.attn = _IPEXGPT2Attention(module.attn, device, config)
self.mlp = _IPEXGPT2MLP(module.mlp, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
if hasattr(self.attn, "linear_add"):
hidden_states = self.attn.linear_add(attn_output, residual)
else:
hidden_states = attn_output + residual

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
if hasattr(self.mlp, "linear_add"):
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
else:
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]

return outputs # hidden_states, present, (attentions, cross_attentions)


# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
class _IPEXIntermediate(nn.Module):
def __init__(self, module, device, config):
Expand Down

0 comments on commit 6bf3b8b

Please sign in to comment.