Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for flash decoding on xpu #1118

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
69 changes: 48 additions & 21 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig

from optimum.intel.utils.import_utils import is_ipex_version


class IPEXPagedCache(Cache):
"""
Expand Down Expand Up @@ -43,10 +45,14 @@ def __init__(
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.device = device
self._supports_flash_decoding = (
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
)
# Used in `generate` to keep tally of how many tokens the cache has seen

self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
default_block_size = 16 if device.type == "cpu" else 64
default_block_size = 16
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
Expand All @@ -70,14 +76,44 @@ def __init__(
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
if self._supports_flash_decoding:
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
else:
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def reshape_and_cache(
self,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
if self.device.type == "xpu" and self._supports_flash_decoding:
PagedAttention.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slots,
)
else:
PagedAttention.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
)

def update_for_prefill(
self,
key_states: torch.Tensor,
Expand All @@ -95,7 +131,7 @@ def update_for_prefill(
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
self.block_tables[i][0:nb] = block_table
self.free_blocks[block_table] = 0
slots_range = torch.arange(input_lens[i], device=key_states.device)
slots_range = torch.arange(input_lens[i], device=self.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
Expand All @@ -105,12 +141,8 @@ def update_for_prefill(
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
self.reshape_and_cache(
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
)

# Update the number of seen tokens
Expand All @@ -128,7 +160,7 @@ def update_for_decode(
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32)
for i in range(batch_size):
if slot_offset_in_block[i] == 0:
# need a new block:
Expand All @@ -139,12 +171,8 @@ def update_for_decode(
self.free_blocks[self.block_tables[i][b_idx]] = 0
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
self.reshape_and_cache(
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
)

# Update the number of seen tokens
Expand Down Expand Up @@ -194,16 +222,15 @@ def get_max_length(self) -> Optional[int]:

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.block_tables.device)
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.block_tables.device
origin_table = self.block_tables.clone()
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = torch.zeros_like(beam_idx)
Expand Down
82 changes: 65 additions & 17 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def _llama_model_forward(

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0

device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
Expand All @@ -200,6 +200,9 @@ def _llama_model_forward(
position_embeddings = self.rotary_emb(hidden_states, position_ids)

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -235,6 +238,9 @@ def _llama_model_forward(
use_cache=use_cache,
position_embeddings=position_embeddings,
input_lens=input_lens,
max_input_lens=max_input_lens,
seq_len_tensor=seq_len_tensor,
query_len_tensor=query_len_tensor,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -303,11 +309,10 @@ def _falcon_model_forward(

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
batch_size, seq_length, _ = inputs_embeds.shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)

if position_ids is None:
position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
Expand All @@ -323,6 +328,9 @@ def _falcon_model_forward(
position_embeddings = self.rotary_emb(hidden_states, position_ids)

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -365,6 +373,9 @@ def _falcon_model_forward(
cache_position=cache_position,
position_embeddings=position_embeddings,
input_lens=input_lens,
max_input_lens=max_input_lens,
seq_len_tensor=seq_len_tensor,
query_len_tensor=query_len_tensor,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -459,6 +470,9 @@ def _gpt2_model_forward(
hidden_states = self.drop(hidden_states)

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()

if past_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -494,6 +508,9 @@ def _gpt2_model_forward(
use_cache=use_cache,
output_attentions=output_attentions,
input_lens=input_lens,
max_input_lens=max_input_lens,
seq_len_tensor=seq_len_tensor,
query_len_tensor=query_len_tensor,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -636,6 +653,7 @@ def _qwen2_model_forward(
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = inputs_embeds.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device

past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
Expand All @@ -660,6 +678,9 @@ def _qwen2_model_forward(
position_embeddings = self.rotary_emb(hidden_states, position_ids)

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -695,6 +716,9 @@ def _qwen2_model_forward(
cache_position=cache_position,
position_embeddings=position_embeddings,
input_lens=input_lens,
max_input_lens=max_input_lens,
seq_len_tensor=seq_len_tensor,
query_len_tensor=query_len_tensor,
**kwargs,
)

Expand Down Expand Up @@ -749,14 +773,26 @@ def postprocess_attention_output(self, attn_output):
return attn_output

# Maybe removed after torch 2.6 released
def has_flash_attn(self, query):
if query.device.type == "cpu":
def has_flash_attn(self):
if self.module_device.type == "cpu":
return is_torch_version(">", "2.4.99")
elif query.device.type == "xpu":
elif self.module_device.type == "xpu":
return is_torch_version(">", "2.5.99")

def attention_interface(
self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
self,
query,
key_cache,
value_cache,
key,
value,
past_key_value,
attention_mask,
input_lens,
past_len,
seq_len_tensor,
query_len_tensor,
max_input_lens,
):
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
Expand All @@ -773,20 +809,19 @@ def attention_interface(
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query):
elif self.has_flash_attn():
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
query_max_len = input_lens.max() if past_len == 0 else 1
query_len_tensor = seq_len_tensor if past_len == 0 else query_len_tensor
query_max_len = max_input_lens if past_len == 0 else 1
PagedAttention.flash_attn_varlen_func(
attn_output,
query.contiguous() if query.device.type == "xpu" else query,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
key_cache,
value_cache,
query_len_tensor,
seq_len_tensor,
query_max_len,
input_lens.max(),
max_input_lens,
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
Expand All @@ -795,7 +830,6 @@ def attention_interface(
elif past_len == 0:
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
Expand Down Expand Up @@ -844,6 +878,9 @@ def forward(
if past_key_value is None and kwargs.get("layer_past", None) is not None:
past_key_value = kwargs.pop("layer_past", None)
input_lens = kwargs.pop("input_lens", None)
seq_len_tensor = kwargs.pop("seq_len_tensor", None)
query_len_tensor = kwargs.pop("query_len_tensor", None)
max_input_lens = kwargs.pop("max_input_lens", 0)
past_len = 0
if past_key_value is not None:
past_len = past_key_value.get_seq_length()
Expand All @@ -855,7 +892,18 @@ def forward(
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)

attn_output = self.attention_interface(
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
query,
key_cache,
value_cache,
key,
value,
past_key_value,
attention_mask,
input_lens,
past_len,
seq_len_tensor,
query_len_tensor,
max_input_lens,
)

attn_output = self.postprocess_attention_output(attn_output)
Expand Down
2 changes: 1 addition & 1 deletion tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_forward(self, model_arch):
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long)
input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long).to(DEVICE)
outputs = ipex_model(input_ids)

self.assertIsInstance(outputs.logits, torch.Tensor)
Expand Down