Skip to content

Commit

Permalink
Ensure that BlockMask length must always exactly match the sequence l…
Browse files Browse the repository at this point in the history
…ength in flex_attention (pytorch#141625)

Fixes pytorch#141435

Pull Request resolved: pytorch#141625
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#138788
  • Loading branch information
Chillee authored and pytorchmergebot committed Dec 2, 2024
1 parent 8eb259f commit 795f28a
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 118 deletions.
149 changes: 85 additions & 64 deletions test/inductor/test_flex_attention.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions test/inductor/test_flex_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,9 @@ def run_test_with_call_paged_attention(
ref_out = golden_call(q_ref, k_ref, v_ref)

if mask_mod is not None:
block_mask = create_block_mask(mask_mod, Q_B, 1, 1, S)
block_mask = create_block_mask(mask_mod, Q_B, 1, Q_S, KV_S)
else:
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, S)
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S)

compiled_out, _ = self.run_paged_attention(
score_mod, q, k, v, dtype, block_mask
Expand Down Expand Up @@ -682,7 +682,7 @@ def test_builtin_score_mods_different_block_size(
score_mod: Callable,
BLOCK_SIZE: Union[int, Tuple[int, int]],
):
block_mask = create_block_mask(noop_mask, B, 1, S, S, BLOCK_SIZE=BLOCK_SIZE)
block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE)
self.run_test(score_mod, dtype, block_mask=block_mask)

def input_strides_1(B, H, S, D):
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def scoremod_1(qk, b, h, q, kv):
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))

block_mask = create_block_mask(noop_mask, 1, 1, 1, S)
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024)

def f(q, k1, k2, v1, v2):
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1, block_mask=block_mask)
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def scoremod_1(qk, b, h, q, kv):
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))

block_mask = create_block_mask(noop_mask, 1, 1, 1, S)
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024)

attention1 = functools.partial(
flex_attention, score_mod=scoremod_1, block_mask=block_mask
Expand Down Expand Up @@ -1567,8 +1567,8 @@ def mask_mod(b, h, q_idx, kv_idx):
mask_mod=mask_mod,
B=2,
H=None,
Q_LEN=128,
KV_LEN=256,
Q_LEN=2,
KV_LEN=2,
device="cuda",
)

Expand Down
13 changes: 10 additions & 3 deletions torch/_higher_order_ops/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def forward(
value,
out,
logsumexp,
*block_mask[:10],
*block_mask[:-1],
*score_mod_other_buffers,
*mask_mod_other_buffers,
),
Expand All @@ -630,6 +630,8 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option
value,
out,
logsumexp,
query_lengths,
kv_lengths,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
Expand Down Expand Up @@ -672,6 +674,8 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option
fw_graph,
joint_graph,
(
query_lengths,
kv_lengths,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
Expand Down Expand Up @@ -708,7 +712,8 @@ def flex_attention_autograd(

with TransformGetItemToIndex():
input_requires_grad = any(
t.requires_grad for t in (query, key, value, *score_mod_other_buffers)
isinstance(t, torch.Tensor) and t.requires_grad
for t in (query, key, value, *score_mod_other_buffers)
)
if torch.is_grad_enabled() and input_requires_grad:
example_vals = (
Expand Down Expand Up @@ -1130,7 +1135,9 @@ def flex_attention_backward_fake_tensor_mode(
grad_value = torch.empty_like(value)
grad_score_mod_captured = tuple(
[
torch.empty_like(buffer) if buffer.requires_grad else None
torch.empty_like(buffer)
if isinstance(buffer, torch.Tensor) and buffer.requires_grad
else None
for buffer in score_mod_other_buffers
]
)
Expand Down
21 changes: 11 additions & 10 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@ def flex_attention(
mask_mod_other_buffers,
):
(
_, # q_length
_, # kv_length
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
Expand Down Expand Up @@ -968,12 +970,6 @@ def flex_attention(
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."

# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
Expand Down Expand Up @@ -1509,7 +1505,7 @@ def bwd_dq_block_mn(
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1541,7 +1537,7 @@ def bwd_dq_block_mn(
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1691,7 +1687,7 @@ def bwd_dkdv_block_mn(
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1749,7 +1745,7 @@ def bwd_dkdv_block_mn(
dsT = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1860,6 +1856,8 @@ def flex_attention_backward(*args, **kwargs):
mask_mod_other_buffers,
) = args
(
_, # q_length
_, # kv_length
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
Expand Down Expand Up @@ -2036,6 +2034,9 @@ def flex_attention_backward(*args, **kwargs):
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
):
continue
if num_warps == 8:
# Working around https://github.com/pytorch/pytorch/issues/141603
continue

# Performance tuning
cur_kernel_options = original_kernel_options.copy()
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/kernel/flex_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
mask_mod_other_buffers,
) = args
(
_, # q_length
_, # kv_length
kv_num_blocks,
kv_indices,
full_kv_num_blocks, # full_kv_num_blocks,
Expand Down
2 changes: 2 additions & 0 deletions torch/nn/attention/experimental/_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ def convert_logical_block_mask(

new_mask_mod = self.get_mask_mod(block_mask.mask_mod)

seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
return BlockMask.from_kv_blocks(
new_kv_num_blocks,
new_kv_indices,
new_full_kv_num_blocks,
new_full_kv_indices,
block_mask.BLOCK_SIZE,
new_mask_mod,
seq_lengths=seq_lengths,
)

def get_mask_mod(
Expand Down
88 changes: 54 additions & 34 deletions torch/nn/attention/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ class BlockMask:
the backwards pass. These are autogenerated from 2.
"""

seq_lengths: Tuple[int, int]
kv_num_blocks: Tensor
kv_indices: Tensor
full_kv_num_blocks: Optional[Tensor]
Expand All @@ -275,6 +276,7 @@ class BlockMask:

def __init__(
self,
seq_lengths: Tuple[int, int],
kv_num_blocks: Tensor,
kv_indices: Tensor,
full_kv_num_blocks: Optional[Tensor],
Expand All @@ -299,6 +301,7 @@ def __init__(
full_q_indices is None
), "full_q_num_blocks and full_q_indices must be both provided or omitted"

self.seq_lengths = seq_lengths
self.kv_num_blocks = kv_num_blocks
self.kv_indices = kv_indices
self.full_kv_num_blocks = full_kv_num_blocks
Expand All @@ -319,6 +322,7 @@ def from_kv_blocks(
full_kv_indices: Optional[Tensor] = None,
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
mask_mod: Optional[_mask_mod_signature] = None,
seq_lengths: Optional[Tuple[int, int]] = None,
):
"""
Creates a BlockMask instance from key-value block information.
Expand Down Expand Up @@ -359,8 +363,13 @@ def from_kv_blocks(
BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)

mask_mod = mask_mod if mask_mod is not None else noop_mask
if seq_lengths is None:
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
kv_length = q_indices.shape[-2] * BLOCK_SIZE[1]
seq_lengths = (q_length, kv_length)

return cls(
seq_lengths=seq_lengths,
kv_num_blocks=kv_num_blocks,
kv_indices=kv_indices,
full_kv_num_blocks=full_kv_num_blocks,
Expand All @@ -380,11 +389,15 @@ def as_tuple(self, flatten: bool = True):
Args:
flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
"""
block_size = (
(self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) if flatten else (self.BLOCK_SIZE,)
)
if flatten:
block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment]
seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment]
else:
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]

return (
*seq_lengths,
self.kv_num_blocks,
self.kv_indices,
self.full_kv_num_blocks,
Expand All @@ -397,6 +410,11 @@ def as_tuple(self, flatten: bool = True):
self.mask_mod,
)

@property
def shape(self):
*batch_dims, _, _ = self.kv_indices.shape
return tuple(batch_dims) + self.seq_lengths

def __str__(self):
s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
mask_str = self.to_string().strip()
Expand Down Expand Up @@ -457,6 +475,7 @@ def causal_mask(b, h, q_idx, kv_idx):
new_full_kv_indices,
BLOCK_SIZE=self.BLOCK_SIZE,
mask_mod=None,
seq_lengths=self.seq_lengths,
)

def __repr__(self):
Expand Down Expand Up @@ -509,14 +528,6 @@ def _adjust(self, new_q_len: int, new_kv_len: int):
self.mask_mod,
)

@property
def shape(self):
"""Returns the shape of the mask."""
*batch_dims, q_length, _ = self.kv_indices.shape
q_length = self.kv_indices.shape[-2] * self.BLOCK_SIZE[0]
kv_length = self.kv_indices.shape[-1] * self.BLOCK_SIZE[1]
return tuple(batch_dims + [q_length, kv_length])

def numel(self):
"""Returns the number of elements (not accounting for sparsity) in the mask."""
shape = self.shape
Expand Down Expand Up @@ -739,6 +750,7 @@ def _convert_block_mask_to_mask(
def _create_sparse_block_from_block_mask(
block_mask: Tuple[Tensor, Optional[Tensor]],
mask_mod: Optional[Callable],
seq_lengths: Tuple[int, int],
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
) -> BlockMask:
Expand All @@ -757,6 +769,7 @@ def _create_sparse_block_from_block_mask(
full_bm[1],
BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE),
mask_mod=mask_mod,
seq_lengths=seq_lengths,
)


Expand Down Expand Up @@ -878,7 +891,11 @@ def causal_mask(b, h, q_idx, kv_idx):
separate_full_blocks=True,
)
block_mask = _create_sparse_block_from_block_mask(
(partial_block_mask, full_block_mask), mask_mod, Q_BLOCK_SIZE, KV_BLOCK_SIZE
(partial_block_mask, full_block_mask),
mask_mod,
(Q_LEN, KV_LEN),
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
)
return block_mask

Expand All @@ -894,6 +911,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE,
seq_lengths=(1, 1),
)


Expand Down Expand Up @@ -1237,29 +1255,31 @@ def score_mod(

if block_mask is None:
block_mask = _create_empty_block_mask(query, key)
elif (
not query.is_nested
and (query.requires_grad or key.requires_grad or value.requires_grad)
and (
query.size(-2)
< block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
or key.size(-2) < block_mask.kv_indices.size(-1) * block_mask.BLOCK_SIZE[1]
)
):
new_q_len = _round_up_to_multiple(query.size(-2), block_mask.BLOCK_SIZE[0])
new_kv_len = _round_up_to_multiple(key.size(-2), block_mask.BLOCK_SIZE[1])
block_mask = block_mask._adjust(new_q_len, new_kv_len)
elif query.is_nested and (
block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
!= _round_up_to_multiple(
query._values.size(query._ragged_idx - 1), block_mask.BLOCK_SIZE[0] # type: ignore[attr-defined]
)

if (
block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE
):
# TODO: Maybe we want to auto-adjust for this case as well?
raise RuntimeError(
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
)
# This corresponds to the case where we essentially have a "no-op" block mask.
pass
else:
block_mask_q_len = block_mask.shape[-2]
block_mask_kv_len = block_mask.shape[-1]
if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len:
raise ValueError(
f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
"As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask."
)
elif (
query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len
) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len):
raise ValueError(
f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
"As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!"
)
assert query.size(-2) == block_mask_q_len
assert key.size(-2) == block_mask_kv_len

if scale is None:
scale = 1.0 / math.sqrt(query.size(-1))

Expand Down

0 comments on commit 795f28a

Please sign in to comment.