Skip to content

Commit

Permalink
force attn model
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 6, 2025
1 parent 45133cb commit 35c47a2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 41 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)


FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"}

if TYPE_CHECKING:
from optimum.intel.openvino.configuration import OVConfig
Expand Down
46 changes: 6 additions & 40 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -2710,40 +2710,6 @@ def patched_forward(*args, **kwargs):

self.patched_forward = patched_forward

def __enter__(self):
super().__enter__()

if is_transformers_version(">=", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION

GEMMA2_ATTENTION_FUNCTION["original_eager"] = GEMMA2_ATTENTION_FUNCTION["eager"]
GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["sdpa"]

elif is_transformers_version(">=", "4.45.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES

sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
eager_attn = GEMMA2_ATTENTION_CLASSES["eager"]

for layer in self._model.model.layers:
if isinstance(layer.self_attn, eager_attn):
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

if is_transformers_version(">=", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION

GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["original_eager"]
del GEMMA2_ATTENTION_FUNCTION["original_eager"]

elif is_transformers_version(">=", "4.45.0"):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def _decilm_attn_forward(
self,
Expand Down

0 comments on commit 35c47a2

Please sign in to comment.