Skip to content

Commit

Permalink
[Flex Attention] update __getitem__ without tree_map_only to support …
Browse files Browse the repository at this point in the history
…compile (pytorch#134627)

Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196).

Pull Request resolved: pytorch#134627
Approved by: https://github.com/Chillee
  • Loading branch information
BoyuanFeng authored and Chao1Han committed Sep 20, 2024
1 parent 10c1988 commit d751d41
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 5 deletions.
57 changes: 57 additions & 0 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,63 @@ def causal_mask(b, h, q, kv):
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())

@supported_platform
def test_getitem(self):
offset = torch.zeros(8, device="cuda")

def causal_mask(b, h, q, kv):
return (q + (offset[b] * 128)) >= kv

block_mask = create_block_mask(causal_mask, 4, 2, 512, 512)
assert block_mask.kv_num_blocks.shape == (4, 2, 4)
assert block_mask.kv_indices.shape == (4, 2, 4, 4)

# Index on batch dimension
new_block_mask = block_mask[0]
assert new_block_mask.kv_num_blocks.shape == (2, 4)
assert new_block_mask.kv_indices.shape == (2, 4, 4)

# Index on batch and head dimension
new_block_mask = block_mask[0, 1]
assert new_block_mask.kv_num_blocks.shape == (4,)
assert new_block_mask.kv_indices.shape == (4, 4)

# slicing on batch and head dimension
new_block_mask = block_mask[0:2, 1:2]
assert new_block_mask.kv_num_blocks.shape == (2, 1, 4)
assert new_block_mask.kv_indices.shape == (2, 1, 4, 4)

# slicing on batch, head, and query dimension
new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)]
assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)

# slicing on batch, head, and query dimension
q_index = torch.tensor([0], dtype=torch.int32)
new_block_mask = block_mask[:, :, q_index]

self.assertEqual(new_block_mask.kv_num_blocks.ndim, 3)
self.assertEqual(new_block_mask.kv_indices.ndim, 4)
torch.testing.assert_close(
new_block_mask.kv_num_blocks,
block_mask.kv_num_blocks[:, :, q_index],
)
torch.testing.assert_close(
new_block_mask.kv_indices, block_mask.kv_indices[:, :, q_index, :]
)

if block_mask.full_kv_num_blocks is not None:
assert new_block_mask.full_kv_num_blocks is not None
assert new_block_mask.full_kv_indices is not None
torch.testing.assert_close(
new_block_mask.full_kv_num_blocks,
block_mask.full_kv_num_blocks[:, :, q_index],
)
torch.testing.assert_close(
new_block_mask.full_kv_indices,
block_mask.full_kv_indices[:, :, q_index, :],
)

@supported_platform
def test_block_mask_device_change(self):
offset = torch.zeros(8, device="cuda")
Expand Down
57 changes: 52 additions & 5 deletions torch/nn/attention/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,59 @@ def __str__(self):
return s

def __getitem__(self, index) -> "BlockMask":
mapped_attributes = tree_map_only(
torch.Tensor,
lambda x: x[index],
self.as_tuple(flatten=False),
"""
Returns a new BlockMask instance by getting the mask for the given index position.
Args:
index: Index to apply to all attributes.
Example Usage:
.. code-block:: python
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda")
assert block_mask.kv_num_blocks.shape == (4,2,4)
assert block_mask.kv_indices.shape == (4,2,4,4)
# Index on batch dimension
new_block_mask = block_mask[0]
assert new_block_mask.kv_num_blocks.shape == (2,4)
assert new_block_mask.kv_indices.shape == (2,4,4)
# Index on batch and head dimension
new_block_mask = block_mask[0, 1]
assert new_block_mask.kv_num_blocks.shape == (4,)
assert new_block_mask.kv_indices.shape == (4,4)
# slicing on batch and head dimension
new_block_mask = block_mask[0:2, 1:2]
assert new_block_mask.kv_num_blocks.shape == (2,1,4)
assert new_block_mask.kv_indices.shape == (2,1,4,4)
# slicing on batch, head, and query dimension
new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)]
assert new_block_mask.kv_num_blocks.shape == (2,1,1)
assert new_block_mask.kv_indices.shape == (2,1,1,4)
"""
new_kv_num_blocks = self.kv_num_blocks[index]
new_kv_indices = self.kv_indices[index]
if self.full_kv_num_blocks is not None:
assert self.full_kv_indices is not None
new_full_kv_num_blocks = self.full_kv_num_blocks[index]
new_full_kv_indices = self.full_kv_indices[index]
else:
new_full_kv_num_blocks = None
new_full_kv_indices = None
return BlockMask.from_kv_blocks(
new_kv_num_blocks,
new_kv_indices,
new_full_kv_num_blocks,
new_full_kv_indices,
BLOCK_SIZE=self.BLOCK_SIZE,
mask_mod=None,
)
return BlockMask(*mapped_attributes)

def __repr__(self):
def shape_or_none(x: Optional[torch.Tensor]):
Expand Down

0 comments on commit d751d41

Please sign in to comment.