Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable qwen2 model #1107

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6d21075
use real varlen attn
jiqing-feng Dec 11, 2024
b792875
optimize gpt2 by using linear instead of conv1D
jiqing-feng Dec 12, 2024
422134f
Merge branch 'huggingface:main' into varlen
jiqing-feng Dec 12, 2024
36884cb
fix usage without pkv
jiqing-feng Dec 12, 2024
d061e69
use sdpa for no cache forward
jiqing-feng Dec 12, 2024
31c635a
fix format
jiqing-feng Dec 12, 2024
73a5ef7
fix sdpa
jiqing-feng Dec 12, 2024
f9c021b
revert shape for sdpa
jiqing-feng Dec 12, 2024
d069407
fix sdpa precision, still have error
jiqing-feng Dec 12, 2024
2c54045
fix sdpa shape
jiqing-feng Dec 13, 2024
bce9aa9
upgrad minimum torch version to 2.5
jiqing-feng Dec 13, 2024
72ac9e6
rm pdb
jiqing-feng Dec 13, 2024
3fdb3a5
fix non patch path
jiqing-feng Dec 16, 2024
7e20b86
Merge branch 'main' into varlen
jiqing-feng Dec 18, 2024
c1bd7f7
Merge branch 'huggingface:main' into varlen
jiqing-feng Dec 25, 2024
fb71c2e
Merge branch 'huggingface:main' into varlen
jiqing-feng Jan 13, 2025
6186aaf
use varlen if flash attn not available
jiqing-feng Jan 14, 2025
cbc232b
revert ipex version change
jiqing-feng Jan 14, 2025
4dd2e44
fix flash attn check
jiqing-feng Jan 14, 2025
372d3f8
prefill attn
jiqing-feng Jan 14, 2025
daddabf
fix cache
jiqing-feng Jan 14, 2025
8e8c95f
qwen2 model forward
jiqing-feng Jan 14, 2025
95b7043
refactor attention
jiqing-feng Jan 14, 2025
71aa6b0
use flash attn for decode
jiqing-feng Jan 14, 2025
9211803
fix dtype
jiqing-feng Jan 14, 2025
333bd86
Merge branch 'varlen' into qwen
jiqing-feng Jan 14, 2025
d3fbd65
enable qwen2 model
jiqing-feng Jan 14, 2025
06798e2
enable qwen2 test
jiqing-feng Jan 14, 2025
12dd802
set default block size
jiqing-feng Jan 15, 2025
c6d2d0f
decoding use single query
jiqing-feng Jan 15, 2025
00e6bf3
rebase
jiqing-feng Jan 15, 2025
acfd0ce
fix position_id init for qwen2
jiqing-feng Jan 15, 2025
ccbe97a
add patched qwen2 test
jiqing-feng Jan 15, 2025
ee7dd81
fix format
jiqing-feng Jan 15, 2025
c86fd1c
fix pipeline test
jiqing-feng Jan 15, 2025
5b93036
set block size as a env parameter
jiqing-feng Jan 16, 2025
31accd2
set different default value for block size based on device
jiqing-feng Jan 16, 2025
e75b45b
Merge branch 'block_size' into qwen
jiqing-feng Jan 16, 2025
8656c26
Merge branch 'huggingface:main' into qwen
jiqing-feng Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer,
Qwen2Model,
Qwen2RMSNorm,
)
from transformers.models.vit.modeling_vit import ViTIntermediate

from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version
Expand All @@ -36,7 +41,9 @@
_IPEXGPT2Attention,
_IPEXIntermediate,
_IPEXLlamaDecoderLayer,
_IPEXQwen2DecoderLayer,
_llama_model_forward,
_qwen2_model_forward,
)


Expand Down Expand Up @@ -116,6 +123,18 @@ def _patch_gpt2_model(model):
return model


def _patch_qwen2_model(model):
"""
Patch qwen2 model:
1. Use IPEX rope and paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config)
return model


def _patch_bert_model(model):
"""
Patch bert model:
Expand Down Expand Up @@ -149,6 +168,8 @@ def _patch_model(model):
model = _patch_falcon_model(model)
elif model.config.model_type == "gpt2":
model = _patch_gpt2_model(model)
elif model.config.model_type == "qwen2":
model = _patch_qwen2_model(model)
elif model.config.model_type == "bert":
model = _patch_bert_model(model)
elif model.config.model_type == "vit":
Expand Down
145 changes: 131 additions & 14 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,125 @@ def _gpt2_block_forward(
return outputs # hidden_states, present, (attentions, cross_attentions)


# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499
def _qwen2_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = inputs_embeds.shape[:2]

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
cos = position_embeddings[0]
sin = position_embeddings[1]
cos = (cos.reshape(-1, cos.shape[-1]))[index]
sin = (sin.reshape(-1, sin.shape[-1]))[index]
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

if past_key_values is None:
attention_mask = causal_mask

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
input_lens=input_lens,
**kwargs,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

if hidden_states.shape[0] != batch_size * seq_length:
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
hidden_states = hidden_states_copy
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()


class _IPEXAttention(nn.Module):
def __init__(self, module, config) -> None:
super().__init__()
Expand All @@ -618,8 +737,10 @@ def __init__(self, module, config) -> None:
def qkv_gemm(self, hidden_states):
raise NotImplementedError("Need to implement in specific model class")

def rope(self, *args, **kwargs):
raise NotImplementedError("Need to implement in specific model class")
def rope(self, query, key, **kwargs):
position_embeddings = kwargs.pop("position_embeddings", None)
rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
return query, key

def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
Expand Down Expand Up @@ -748,13 +869,13 @@ class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, config) -> None:
super().__init__(module, config)
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_linear.bias = nn.Parameter(concat_bias)
self.concat_qkv.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
Expand All @@ -774,11 +895,6 @@ def qkv_gemm(self, hidden_states):

return query, key, value

def rope(self, query, key, **kwargs):
position_embeddings = kwargs.pop("position_embeddings", None)
rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
return query, key


class _IPEXFalconAttention(_IPEXAttention):
def __init__(self, module, config):
Expand All @@ -801,11 +917,6 @@ def qkv_gemm(self, hidden_states):
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
return query, key, value

def rope(self, query, key, **kwargs):
position_embeddings = kwargs.pop("position_embeddings", None)
rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
return query, key


class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, config) -> None:
Expand Down Expand Up @@ -1006,6 +1117,12 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
return outputs


# Currently can just apply llama decoder layer.
class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
class _IPEXIntermediate(nn.Module):
def __init__(self, module, config):
Expand Down
4 changes: 2 additions & 2 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@
logger = logging.getLogger(__name__)


_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2")
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2")
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2")


def _is_patched_with_ipex(model, task, use_cache: bool = True):
Expand Down
3 changes: 2 additions & 1 deletion tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
"distilgpt2",
"mpt",
"opt",
"qwen2",
)
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2")
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2", "qwen2")
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.0

Expand Down
6 changes: 4 additions & 2 deletions tests/ipex/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class PipelinesIntegrationTest(unittest.TestCase):
"mistral",
"mpt",
"opt",
"qwen2",
)
QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = (
"bert",
Expand Down Expand Up @@ -144,10 +145,11 @@ def test_text_generation_pipeline_inference(self, model_arch):
"text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE
)
inputs = "Describe a real-world application of AI."
max_new_tokens = 10 if model_arch != "qwen2" else 2
with torch.inference_mode():
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
with torch.inference_mode():
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM))
self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])

Expand Down
2 changes: 2 additions & 0 deletions tests/ipex/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"mt5": "stas/mt5-tiny-random",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"qwen2": "Jiqing/tiny-random-Qwen2",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-roberta",
"roformer": "hf-internal-testing/tiny-random-roformer",
Expand All @@ -64,4 +65,5 @@
"patched_falcon": "Intel/tiny-random-falcon_ipex_model",
"patched_gpt2": "Intel/tiny-random-gpt2_ipex_model",
"patched_llama2": "Intel/tiny-random-llama2_ipex_model",
"patched_qwen2": "Jiqing/tiny-random-Qwen2_ipex_model",
}