Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do KV Cache with FlexAttention and BlockMask by slicing? #60

Open
Leo-T-Zang opened this issue Oct 21, 2024 · 4 comments
Open

How to do KV Cache with FlexAttention and BlockMask by slicing? #60

Leo-T-Zang opened this issue Oct 21, 2024 · 4 comments

Comments

@Leo-T-Zang
Copy link

Leo-T-Zang commented Oct 21, 2024

Is there any example code to do this? Should I generate new BlockMask everytime?

Thanks!


Essentially, I have problem of slicing BlockMask. For exmaple, if we have a prompt token of length 1000 (Prefill stage), I have the following codes for attention, which can be wrong. But, my question is if I need to generate 1001th token (one single token for Q), how do I slice the exact position in the BlockMask for it?

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask
torch.set_default_device('cuda')

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)

B, H, S = 1, 2, 5000 #5000 is the max model length

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def causal_mask(b, h, q_idx, kv_idx):

    causal_mask = q_idx >= kv_idx

    return causal_mask

causal = causal_mask

block_mask = create_block_mask(causal, B=B, H=H, Q_LEN=S, KV_LEN=S)
print('Mask Shape (Max Length): ', block_mask.shape)

# Input Prompt length is 1000, 3500 is the max token length we want
query = torch.randn(B, H, 1000, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 3500, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 3500, 64, device="cuda", dtype=torch.float32)

# slice block mask
q_slice = torch.arange(0, 3500//128 + 1).to(device)
block_mask = block_mask[:, :, q_slice]
print('Mask Shape (Sliced): ', block_mask.shape)

print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)

out = flex_attention(query, key, value, block_mask=block_mask)

print('Attention Output Shape:', out.shape)

block_mask = block_mask.to_string(limit=32,)
print(block_mask)

# Generate new token, then q length is 1, and position is 1000
q = torch.randn(1, H, 1, 64, device="cuda", dtype=torch.float32)
# The problem is how to select 1000 position at mask

Another question is that if I use Prefix Mask for token prompts, when I set H=None, it works; when I set H=H, it has errors.

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)

B, H, S = 1, 2, 500

device = 'cuda' if torch.cuda.is_available() else 'cpu'

full_attention_idx = torch.tensor([[0, 100]], dtype=torch.long).to(device)

def prefix_lm_causal_mask(b, h, q_idx, kv_idx):

    full_mask = (kv_idx <= full_attention_idx[b][1]) & (kv_idx >= full_attention_idx[b][0])
    causal_mask = q_idx >= kv_idx

    return (full_mask | causal_mask)

prefix_lm_causal = prefix_lm_causal_mask
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, Q_LEN=S, KV_LEN=S)
print(block_mask.shape)

query = torch.randn(B, H, 100, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)

# slice block mask
q_slice = torch.arange(0, 350//128 + 1)
block_mask = block_mask[:, :, q_slice]
print(block_mask.shape)

print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)

out = flex_attention(query, key, value, block_mask=block_mask)

print('Attention Output Shape:', out.shape)

block_mask = block_mask.to_string(limit=32,)
print(block_mask)

When H=H

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)

B, H, S = 1, 2, 500

device = 'cuda' if torch.cuda.is_available() else 'cpu'

full_attention_idx = torch.tensor([[0, 100]], dtype=torch.long).to(device)

def prefix_lm_causal_mask(b, h, q_idx, kv_idx):

    full_mask = (kv_idx <= full_attention_idx[b][1]) & (kv_idx >= full_attention_idx[b][0])
    causal_mask = q_idx >= kv_idx

    return (full_mask | causal_mask)

prefix_lm_causal = prefix_lm_causal_mask
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=H, Q_LEN=S, KV_LEN=S)
print(block_mask.shape)

query = torch.randn(B, H, 100, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)

# slice block mask
q_slice = torch.arange(0, 350//128 + 1)
block_mask = block_mask[:, :, q_slice]
print(block_mask.shape)

print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)

out = flex_attention(query, key, value, block_mask=block_mask)

print('Attention Output Shape:', out.shape)

block_mask = block_mask.to_string(limit=32,)
print(block_mask)

Errors

/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [0,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [1,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [2,0,0] Assertion `` failed.
...
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [55,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [56,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [57,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [58,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [59,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [60,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [61,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [62,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [63,0,0] Assertion `` failed.
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/tc289/project/p3i/test/attention.py", line 208, in <module>
    block_mask = block_mask.to_string(limit=32,)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 513, in to_string
    dense_mask = self.to_dense()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 500, in to_dense
    partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 173, in _ordered_to_dense
    out = create_dense_batched(num_blocks_in_row, col_indices)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 166, in create_dense_one
    dense_mask[row_indices, valid_indices] = 1
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


@Leo-T-Zang Leo-T-Zang changed the title How to do KV Cache with FlexAttention and BlockMask? How to do KV Cache with FlexAttention and BlockMask by slicing? Oct 21, 2024
@Leo-T-Zang
Copy link
Author

Is there any suggestion?
Really appreciate it if you guys have any idea how to solve this problem (@drisspg @Chillee)

@Chillee
Copy link
Contributor

Chillee commented Oct 23, 2024

For prefill it might be worth just regenerating the block mask. But in general, indexing at a position just gives you the block mask corresponding to that position. I believe we also support using a sequence smaller than the BlockMask with the BlockMask. So if you generate a BlockMask with S=2048, for example, you can pass in a sequence of length 1001.

@Leo-T-Zang
Copy link
Author

Leo-T-Zang commented Oct 23, 2024

For prefill it might be worth just regenerating the block mask. But in general, indexing at a position just gives you the block mask corresponding to that position. I believe we also support using a sequence smaller than the BlockMask with the BlockMask. So if you generate a BlockMask with S=2048, for example, you can pass in a sequence of length 1001.

Thank you so much for your reply!

I understand that with BlockMask (S=2048), we can process sequences up to length 1000 for Prefill Stage. I am sorry if I did not make it clear in the original comment. I'm facing an issue during decoding stage:

  1. During autoregressive decoding, we typically generate one token at a time and need to compute attention scores for a single position (e.g., the 1001th token).

  2. To do this efficiently with KV Cache, we need to slice the BlockMask to get the exact row corresponding to the current position ID to compute attention with cached K and V.

  3. The challenge is that BlockMask can only be sliced in blocks of 128 tokens, not individual rows:

q_slice = torch.arange(0, position_id//128 + 1)
block_mask = block_mask[:, :, q_slice]

# I can't do the following, it gives error
block_mask = block_mask[:, :, position_id]

How can we implement KV Cache in this scenario where we can't slice individual rows from the BlockMask?
Additionally, could you help the error between setting H=H versus H=None ?
Thanks!

@joydddd
Copy link

joydddd commented Nov 13, 2024

We integrated flexattention in gpt-fast for decoding: pytorch-labs/gpt-fast#196

You only need to build a new BLockMask every 128 tokens generated. BlockMask stays the same for tokens #1024 - #1024 + 128.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants