diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index dc13512111..c073a0dd38 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -41,15 +41,18 @@ from optimum.utils.normalized_config import NormalizedTextConfig from .model_patcher import ( + AquilaModelPatcher, BaichuanModelPatcher, ChatGLMModelPatcher, GemmaModelPatcher, - InternLMPatcher, + InternLM2Patcher, + InternLMModelPatcher, LlamaModelPatcher, MixtralModelPatcher, MPTModelPatcher, Phi3ModelPatcher, QwenModelPatcher, + XverseModelPatcher, ) @@ -445,7 +448,7 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": - return InternLMPatcher(self, model, model_kwargs=model_kwargs) + return InternLM2Patcher(self, model, model_kwargs=model_kwargs) @register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers") @@ -653,3 +656,45 @@ class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( num_attention_heads="attention_heads", hidden_size="d_model" ) + + +@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers") +class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return AquilaModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers") +class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return XverseModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return InternLMModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 55afb0ffea..a6d2bc3831 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -844,7 +844,7 @@ def __exit__(self, exc_type, exc_value, traceback): block.attn.forward = block.attn._orig_forward -def _internlm_attention_forward( +def _internlm2_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -935,14 +935,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return attn_output, attn_weights, past_key_value -class InternLMPatcher(DecoderModelPatcher): +class InternLM2Patcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_torch_version(">=", "2.1.0"): for block in self._model.model.layers: block.attention._orig_forward = block.attention.forward - block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention) + block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) @@ -1055,3 +1055,271 @@ def __exit__(self, exc_type, exc_value, traceback): for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward + + +def _aquila_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class AquilaModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _xverse_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +def _internlm_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class XverseModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +class InternLMModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 0a0b66b86d..75d76ecae2 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -557,6 +557,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neox_japanese", "cohere", "xglm", + "aquila", + "xverse", + "internlm", ) GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( @@ -569,6 +572,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "internlm2", "orion", "phi3", + "aquila", + "xverse", + "internlm", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -596,6 +602,7 @@ def test_compare_to_transformers(self, model_arch): self.assertEqual(ov_model.stateful, ov_model.config.model_type not in not_stateful) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) tokens = tokenizer("This is a sample output", return_tensors="pt") + tokens.pop("token_type_ids", None) ov_outputs = ov_model(**tokens) self.assertTrue("logits" in ov_outputs) @@ -630,6 +637,7 @@ def test_compare_to_transformers(self, model_arch): # Compare batched generation tokenizer.padding_side = "left" tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens.pop("token_type_ids", None) ov_model.generation_config.eos_token_id = None transformers_model.generation_config.eos_token_id = None ov_model.config.eos_token_id = None @@ -853,6 +861,7 @@ def test_beam_search(self, model_arch): transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) tokenizer.pad_token_id = tokenizer.eos_token_id tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens.pop("token_type_ids", None) ov_model_stateful.generation_config.eos_token_id = None ov_model_stateless.generation_config.eos_token_id = None transformers_model.generation_config.eos_token_id = None diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index aa3ea5f33a..5c10eba1f5 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -18,6 +18,7 @@ MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-albert", + "aquila": "katuni4ka/tiny-random-aquila", "audio_spectrogram_transformer": "Ericwang/tiny-random-ast", "bge": "BAAI/bge-small-en-v1.5", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", @@ -57,6 +58,7 @@ "gptj": "hf-internal-testing/tiny-random-GPTJModel", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-ibert", + "internlm": "katuni4ka/tiny-random-internlm", "internlm2": "katuni4ka/tiny-random-internlm2", "levit": "hf-internal-testing/tiny-random-LevitModel", "longt5": "hf-internal-testing/tiny-random-longt5", @@ -120,6 +122,7 @@ "xlm": "hf-internal-testing/tiny-random-xlm", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", "xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM", + "xverse": "katuni4ka/tiny-random-xverse", }