Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block Size when Q_LEN and KV_LEN are different #71

Open
johng149 opened this issue Nov 4, 2024 · 0 comments
Open

Block Size when Q_LEN and KV_LEN are different #71

johng149 opened this issue Nov 4, 2024 · 0 comments

Comments

@johng149
Copy link

johng149 commented Nov 4, 2024

I was playing around with Flex Attention, specifically Document Masking / Jagged Sequences, and I notice that when doing cross attention with Q_LEN being different from KV_LEN, trying to run create_block_mask with a block size that is not a common divisor of both Q_LEN and KV_LEN fails. Is this expected behavior?

device = "cpu" # happens on "cuda" as well

seq_len1a = 150
seq_len2a = 150
q_len = seq_len1a + seq_len2a
doc_idsa = torch.tensor([0]*seq_len1a + [1]*seq_len2a)

seq_len1b = 70
seq_len2b = 110
kv_len = seq_len1b + seq_len2b
doc_idsb = torch.tensor([0]*seq_len1b + [1]*seq_len2b)

doc_idsa = doc_idsa.to(device)
doc_idsb = doc_idsb.to(device)

def x_masking(b, h, q_idx, kv_idx):
    return doc_idsa[q_idx] == doc_idsb[kv_idx]
    
q = torch.rand(1, num_heads, q_len, embed_dim).to(device)
kv = torch.rand(1, num_heads, kv_len, embed_dim).to(device)

# Greatest common divisor with q_len = 300 and kv_len = 180 is 60
# any divisor of the two will work
x_mask = create_block_mask(x_masking, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=60, device=device) # works

# using a block size that is not divisor of q_len and kv_len fails
x_mask = create_block_mask(x_masking, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=8, device=device) # fails

The error when using a block size that is not a divisor is

IndexError                                Traceback (most recent call last)
Cell In[118], line 1
----> 1 x_mask = create_block_mask(x_masking, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=8, device=device)

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:850, in create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE, _compile)
    848     inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
    849 with TransformGetItemToIndex():
--> 850     partial_block_mask, full_block_mask = inner_func(
    851         mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE
    852     )
    853     block_mask = _create_sparse_block_from_block_mask(
    854         (partial_block_mask, full_block_mask), mask_mod
    855     )
    856 return block_mask

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:775, in _create_block_mask_inner(mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE)
    761 def _create_block_mask_inner(
    762     mask_mod: Callable,
    763     B: int,
   (...)
    769     Q_BLOCK_SIZE: int,
    770 ):
    771     r"""Work around for being unable to instantiate __torch_function__ mode under compile.
    772     `create_block_mask` will compile this inner function and wrap the call to this
    773     with the __torch_function__ mode.
    774     """
--> 775     mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
    776     partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
    777         mask_tensor,
    778         KV_BLOCK_SIZE=KV_BLOCK_SIZE,
    779         Q_BLOCK_SIZE=Q_BLOCK_SIZE,
    780         separate_full_blocks=True,
    781     )
    782     return partial_block_mask, full_block_mask

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:755, in create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device, _compile)
    753     mask_mod = mod_fn
    754     mask_mod = _vmap_for_bhqkv(mask_mod, prefix=())
--> 755     mask = mask_mod(b, h, m, n)
    756     return mask
    757 else:

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    202 def wrapped(*args, **kwargs):
--> 203     return vmap_impl(
    204         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    205     )

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    320     return _chunked_vmap(
    321         func,
    322         flat_in_dims,
   (...)
    327         **kwargs,
    328     )
    330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
    332     func,
    333     batch_size,
    334     flat_in_dims,
    335     flat_args,
    336     args_spec,
    337     out_dims,
    338     randomness,
    339     **kwargs,
    340 )

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
    476     batched_inputs = _create_batched_inputs(
    477         flat_in_dims, flat_args, vmap_level, args_spec
    478     )
--> 479     batched_outputs = func(*batched_inputs, **kwargs)
    480     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    202 def wrapped(*args, **kwargs):
--> 203     return vmap_impl(
    204         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    205     )

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    320     return _chunked_vmap(
    321         func,
    322         flat_in_dims,
   (...)
    327         **kwargs,
    328     )
    330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
    332     func,
    333     batch_size,
    334     flat_in_dims,
    335     flat_args,
    336     args_spec,
    337     out_dims,
    338     randomness,
    339     **kwargs,
    340 )

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
    476     batched_inputs = _create_batched_inputs(
    477         flat_in_dims, flat_args, vmap_level, args_spec
    478     )
--> 479     batched_outputs = func(*batched_inputs, **kwargs)
    480     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

    [... skipping similar frames: _flat_vmap at line 479 (1 times), vmap_impl at line 331 (1 times), vmap.<locals>.wrapped at line 203 (1 times)]

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    202 def wrapped(*args, **kwargs):
--> 203     return vmap_impl(
    204         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    205     )

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    320     return _chunked_vmap(
    321         func,
    322         flat_in_dims,
   (...)
    327         **kwargs,
    328     )
    330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
    332     func,
    333     batch_size,
    334     flat_in_dims,
    335     flat_args,
    336     args_spec,
    337     out_dims,
    338     randomness,
    339     **kwargs,
    340 )

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
    476     batched_inputs = _create_batched_inputs(
    477         flat_in_dims, flat_args, vmap_level, args_spec
    478     )
--> 479     batched_outputs = func(*batched_inputs, **kwargs)
    480     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

Cell In[111], line 2
      1 def x_masking(b, h, q_idx, kv_idx):
----> 2     return doc_idsa[q_idx] == doc_idsb[kv_idx]

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_higher_order_ops/flex_attention.py:84, in TransformGetItemToIndex.__torch_function__(self, func, types, args, kwargs)
     82     index_args = pytree.tree_leaves(args[1])
     83     if all(isinstance(x, torch.Tensor) for x in index_args):
---> 84         return torch.ops.aten.index(args[0], index_args)
     85 return func(*args, **(kwargs or {}))

File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_ops.py:1116, in OpOverloadPacket.__call__(self, *args, **kwargs)
   1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
   1115     return _call_overload_packet_from_python(self, args, kwargs)
-> 1116 return self._op(*args, **(kwargs or {}))

IndexError: index 300 is out of bounds for dimension 0 with size 300

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant