Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Flash Attention #1065

Merged
merged 24 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.block_size = 64
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -27,6 +27,7 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
Expand Down Expand Up @@ -111,6 +112,7 @@ def _patch_gpt2_model(model):
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
return model


Expand Down
200 changes: 170 additions & 30 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import torch
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version
from optimum.intel.utils.modeling_utils import _setattr_from_module

from .cache_utils import IPEXPagedCache
Expand All @@ -43,6 +46,7 @@
LinearAdd,
LinearAddAdd,
LinearGelu,
LinearNewGelu,
PagedAttention,
)

Expand Down Expand Up @@ -194,7 +198,10 @@ def _llama_model_forward(
next_decoder_cache = () if use_cache else None

position_embeddings = self.rotary_emb(hidden_states, position_ids)
if past_key_values_length == 0:

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand All @@ -207,7 +214,13 @@ def _llama_model_forward(
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
Expand Down Expand Up @@ -309,7 +322,9 @@ def _falcon_model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

if past_key_values_length == 0:
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand All @@ -321,7 +336,14 @@ def _falcon_model_forward(
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

next_decoder_cache = None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -436,15 +458,23 @@ def _gpt2_model_forward(

hidden_states = self.drop(hidden_states)

if past_length == 0:
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)

presents = None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -528,7 +558,10 @@ def _gpt2_block_forward(
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
if hasattr(self.attn, "linear_add"):
hidden_states = self.attn.linear_add(attn_output, residual)
else:
hidden_states = attn_output + residual

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
Expand Down Expand Up @@ -557,7 +590,10 @@ def _gpt2_block_forward(
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
if hasattr(self.mlp, "linear_add"):
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
else:
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
Expand All @@ -577,6 +613,7 @@ def __init__(self, module, config) -> None:
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
).repeat_interleave(self.num_groups)
self.use_sdpa = False

def qkv_gemm(self, hidden_states):
raise NotImplementedError("Need to implement in specific model class")
Expand All @@ -585,9 +622,75 @@ def rope(self, *args, **kwargs):
raise NotImplementedError("Need to implement in specific model class")

def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
return attn_output

# Maybe removed after torch 2.6 released
def has_flash_attn(self, query):
if query.device.type == "cpu":
return is_torch_version(">", "2.4.99")
elif query.device.type == "xpu":
return is_torch_version(">", "2.5.99")

def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens):
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query):
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
key_cache,
value_cache,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
None,
)
else:
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
attn_output,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
0.0,
1.0 / math.sqrt(self.head_dim),
False,
True,
False,
None,
)

return attn_output

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -607,31 +710,18 @@ def forward(
query, key, value = self.qkv_gemm(hidden_states)
query, key = self.rope(query, key, **kwargs)

key_cache, value_cache = None, None
if past_key_value is not None:
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)

attn_output = torch.empty_like(query)
if past_len == 0:
# prefill, remove padding
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
attn_output,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
0.0,
1.0 / math.sqrt(self.head_dim),
False,
True,
False,
None,
# prefill
attn_output = self.prefill_attn(
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens
)
else:
# decode
attn_output = torch.empty_like(query)
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
PagedAttention.single_query_cached_kv_attention(
attn_output,
query,
Expand Down Expand Up @@ -720,9 +810,23 @@ class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, config)
_setattr_from_module(self, module)
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
Expand All @@ -732,9 +836,11 @@ def rope(self, query, key, *args, **kwargs):
return query, key

def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(self, "linear_add"):
attn_output = self.c_proj(attn_output)
return attn_output


Expand Down Expand Up @@ -805,6 +911,40 @@ def forward(
return output


class _IPEXGPT2MLP(nn.Module):
def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)

if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
if hasattr(self, "linear_new_gelu"):
hidden_states = self.linear_new_gelu(hidden_states)
else:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
if not hasattr(self, "linear_add"):
hidden_states = self.c_proj(hidden_states)
return hidden_states


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config):
Expand Down
4 changes: 2 additions & 2 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
# Patch functions to support ipex_paged cache
if self._add_patch:
Expand Down