diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 8c5ef5030..c171aef29 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -20,6 +20,11 @@ LlamaModel, LlamaRMSNorm, ) +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2DecoderLayer, + Qwen2Model, + Qwen2RMSNorm, +) from transformers.models.vit.modeling_vit import ViTIntermediate from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version @@ -36,7 +41,9 @@ _IPEXGPT2Attention, _IPEXIntermediate, _IPEXLlamaDecoderLayer, + _IPEXQwen2DecoderLayer, _llama_model_forward, + _qwen2_model_forward, ) @@ -116,6 +123,18 @@ def _patch_gpt2_model(model): return model +def _patch_qwen2_model(model): + """ + Patch qwen2 model: + 1. Use IPEX rope and paged cache + 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) + """ + convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward) + convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward) + convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config) + return model + + def _patch_bert_model(model): """ Patch bert model: @@ -149,6 +168,8 @@ def _patch_model(model): model = _patch_falcon_model(model) elif model.config.model_type == "gpt2": model = _patch_gpt2_model(model) + elif model.config.model_type == "qwen2": + model = _patch_qwen2_model(model) elif model.config.model_type == "bert": model = _patch_bert_model(model) elif model.config.model_type == "vit": diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 41dd5693d..2b440aa91 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -603,6 +603,125 @@ def _gpt2_block_forward( return outputs # hidden_states, present, (attentions, cross_attentions) +# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499 +def _qwen2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = inputs_embeds.shape[:2] + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values_length == 0 and past_key_values is not None: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if past_key_values is None: + attention_mask = causal_mask + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + input_lens=input_lens, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + class _IPEXAttention(nn.Module): def __init__(self, module, config) -> None: super().__init__() @@ -618,8 +737,10 @@ def __init__(self, module, config) -> None: def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") - def rope(self, *args, **kwargs): - raise NotImplementedError("Need to implement in specific model class") + def rope(self, query, key, **kwargs): + position_embeddings = kwargs.pop("position_embeddings", None) + rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) + return query, key def postprocess_attention_output(self, attn_output): if self.use_sdpa: @@ -748,13 +869,13 @@ class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() - bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None] use_bias = bias_list != [] self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) self.concat_qkv.weight = nn.Parameter(concat_weight) if use_bias: concat_bias = torch.concat(bias_list, 0).contiguous() - self.concat_linear.bias = nn.Parameter(concat_bias) + self.concat_qkv.bias = nn.Parameter(concat_bias) self.q_slice = self.q_proj.weight.shape[0] self.k_slice = self.q_slice + self.k_proj.weight.shape[0] self.v_slice = self.k_slice + self.v_proj.weight.shape[0] @@ -774,11 +895,6 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, **kwargs): - position_embeddings = kwargs.pop("position_embeddings", None) - rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) - return query, key - class _IPEXFalconAttention(_IPEXAttention): def __init__(self, module, config): @@ -801,11 +917,6 @@ def qkv_gemm(self, hidden_states): value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value - def rope(self, query, key, **kwargs): - position_embeddings = kwargs.pop("position_embeddings", None) - rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) - return query, key - class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: @@ -1006,6 +1117,12 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return outputs +# Currently can just apply llama decoder layer. +class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524 class _IPEXIntermediate(nn.Module): def __init__(self, module, config): diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 3263e31db..76746c171 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -58,11 +58,11 @@ logger = logging.getLogger(__name__) -_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") +_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" # TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2") +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2") def _is_patched_with_ipex(model, task, use_cache: bool = True): diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 419e1bb42..a91d08ee0 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -233,8 +233,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "distilgpt2", "mpt", "opt", + "qwen2", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2") + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2", "qwen2") GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index f376c6050..62c3877b5 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -66,6 +66,7 @@ class PipelinesIntegrationTest(unittest.TestCase): "mistral", "mpt", "opt", + "qwen2", ) QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = ( "bert", @@ -144,10 +145,11 @@ def test_text_generation_pipeline_inference(self, model_arch): "text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE ) inputs = "Describe a real-world application of AI." + max_new_tokens = 10 if model_arch != "qwen2" else 2 with torch.inference_mode(): - transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens) with torch.inference_mode(): - ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py index 8cd93516d..72f407cc1 100644 --- a/tests/ipex/utils_tests.py +++ b/tests/ipex/utils_tests.py @@ -50,6 +50,7 @@ "mt5": "stas/mt5-tiny-random", "opt": "hf-internal-testing/tiny-random-OPTModel", "phi": "echarlaix/tiny-random-PhiForCausalLM", + "qwen2": "Jiqing/tiny-random-Qwen2", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", @@ -64,4 +65,5 @@ "patched_falcon": "Intel/tiny-random-falcon_ipex_model", "patched_gpt2": "Intel/tiny-random-gpt2_ipex_model", "patched_llama2": "Intel/tiny-random-llama2_ipex_model", + "patched_qwen2": "Jiqing/tiny-random-Qwen2_ipex_model", }