Skip to content

Commit

Permalink
Disable flash_attn during export internvl2 (#1105)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Jan 21, 2025
1 parent e465c7f commit 2b0d642
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
7 changes: 6 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
InputEmbeddingPatcher,
InternLM2Patcher,
InternLMModelPatcher,
InternVL2ChatLangModelPatcher,
InternVLChatImageEmbeddingModelPatcher,
JaisModelPatcher,
LlamaModelPatcher,
Expand Down Expand Up @@ -1642,7 +1643,11 @@ def with_behavior(
if behavior == InternVLChatConfigBehavior.LANGUAGE:
model_type = self._orig_config.llm_config.model_type
return get_vlm_text_generation_config(
model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype
model_type,
self._orig_config.llm_config,
self.int_dtype,
self.float_dtype,
InternVL2ChatLangModelPatcher,
)

if behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS:
Expand Down
81 changes: 81 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

Expand Down Expand Up @@ -2992,11 +2993,91 @@ def __init__(
model.__orig_forward = model.forward
model.forward = model.extract_feature

if model.vision_model.encoder.layers[0].attn.use_flash_attn:
for layer in model.vision_model.encoder.layers:
layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn
layer.attn.use_flash_attn = False

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"):
for layer in self._model.vision_model.encoder.layers:
layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn


class InternVL2ChatLangModelPatcher(DecoderModelPatcher):
def __init__(
self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any]
):
model_type = model.config.model_type
patcher_for_model_type = {
"llama": LlamaModelPatcher,
"qwen2": UpdateCausalMaskModelPatcher,
"phi3": Phi3ModelPatcher,
"internlm2": InternLM2Patcher,
}
self._internal_patcher = None
self._patched_forward = None
internal_patcher_cls = patcher_for_model_type.get(model_type)
if internal_patcher_cls is not None:
self._internal_patcher = internal_patcher_cls(config, model, model_kwargs)
self._patched_forward = self._internal_patcher.patched_forward
super().__init__(config, model, model_kwargs)

@property
def patched_forward(self):
if self._internal_patcher is not None:
return self._internal_patcher.patched_forward
return self._patched_forward

@patched_forward.setter
def patched_forward(self, fn):
self._patched_forward = fn
if self._internal_patcher is not None:
self._internal_patcher.patched_forward = fn

def __enter__(self):
if is_torch_version(">=", "2.1.0"):
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES

sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"

for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if is_transformers_version("<", "4.47"):
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES

sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

if self._internal_patcher is not None:
return self._internal_patcher.__enter__()
return super().__enter__()

def __exit__(self, exc_type, exc_value, traceback):
if self._internal_patcher:
self._internal_patcher.__exit__(exc_type, exc_value, traceback)
else:
super().__exit__(exc_type, exc_value, traceback)

if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def llava_vision_embed_forward(self, pixel_values):
Expand Down

0 comments on commit 2b0d642

Please sign in to comment.