Skip to content

Commit

Permalink
support export more models
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed May 16, 2024
1 parent 8c2b787 commit 3db832a
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 5 deletions.
49 changes: 47 additions & 2 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@
from optimum.utils.normalized_config import NormalizedTextConfig

from .model_patcher import (
AquilaModelPatcher,
BaichuanModelPatcher,
ChatGLMModelPatcher,
GemmaModelPatcher,
InternLMPatcher,
InternLM2Patcher,
InternLMModelPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
XverseModelPatcher,
)


Expand Down Expand Up @@ -445,7 +448,7 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return InternLMPatcher(self, model, model_kwargs=model_kwargs)
return InternLM2Patcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
Expand Down Expand Up @@ -653,3 +656,45 @@ class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="attention_heads", hidden_size="d_model"
)


@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers")
class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return AquilaModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers")
class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return XverseModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return InternLMModelPatcher(self, model, model_kwargs=model_kwargs)
274 changes: 271 additions & 3 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def __exit__(self, exc_type, exc_value, traceback):
block.attn.forward = block.attn._orig_forward


def _internlm_attention_forward(
def _internlm2_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -935,14 +935,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return attn_output, attn_weights, past_key_value


class InternLMPatcher(DecoderModelPatcher):
class InternLM2Patcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

if is_torch_version(">=", "2.1.0"):
for block in self._model.model.layers:
block.attention._orig_forward = block.attention.forward
block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention)
block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
Expand Down Expand Up @@ -1055,3 +1055,271 @@ def __exit__(self, exc_type, exc_value, traceback):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def _aquila_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)

return attn_output, attn_weights, past_key_value


class AquilaModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def _xverse_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

return attn_output, attn_weights, past_key_value


def _internlm_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)

bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value


class XverseModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


class InternLMModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
Loading

0 comments on commit 3db832a

Please sign in to comment.