You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was playing around with Flex Attention, specifically Document Masking / Jagged Sequences, and I notice that when doing cross attention with Q_LEN being different from KV_LEN, trying to run create_block_mask with a block size that is not a common divisor of both Q_LEN and KV_LEN fails. Is this expected behavior?
device="cpu"# happens on "cuda" as wellseq_len1a=150seq_len2a=150q_len=seq_len1a+seq_len2adoc_idsa=torch.tensor([0]*seq_len1a+ [1]*seq_len2a)
seq_len1b=70seq_len2b=110kv_len=seq_len1b+seq_len2bdoc_idsb=torch.tensor([0]*seq_len1b+ [1]*seq_len2b)
doc_idsa=doc_idsa.to(device)
doc_idsb=doc_idsb.to(device)
defx_masking(b, h, q_idx, kv_idx):
returndoc_idsa[q_idx] ==doc_idsb[kv_idx]
q=torch.rand(1, num_heads, q_len, embed_dim).to(device)
kv=torch.rand(1, num_heads, kv_len, embed_dim).to(device)
# Greatest common divisor with q_len = 300 and kv_len = 180 is 60# any divisor of the two will workx_mask=create_block_mask(x_masking, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=60, device=device) # works# using a block size that is not divisor of q_len and kv_len failsx_mask=create_block_mask(x_masking, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=8, device=device) # fails
The error when using a block size that is not a divisor is
IndexError Traceback (most recent call last)
Cell In[118], line 1
----> 1 x_mask = create_block_mask(x_masking, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=8, device=device)
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:850, in create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE, _compile)
848 inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
849 with TransformGetItemToIndex():
--> 850 partial_block_mask, full_block_mask = inner_func(
851 mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE
852 )
853 block_mask = _create_sparse_block_from_block_mask(
854 (partial_block_mask, full_block_mask), mask_mod
855 )
856 return block_mask
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:775, in _create_block_mask_inner(mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE)
761 def _create_block_mask_inner(
762 mask_mod: Callable,
763 B: int,
(...)
769 Q_BLOCK_SIZE: int,
770 ):
771 r"""Work around for being unable to instantiate __torch_function__ mode under compile.
772 `create_block_mask` will compile this inner function and wrap the call to this
773 with the __torch_function__ mode.
774 """
--> 775 mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
776 partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
777 mask_tensor,
778 KV_BLOCK_SIZE=KV_BLOCK_SIZE,
779 Q_BLOCK_SIZE=Q_BLOCK_SIZE,
780 separate_full_blocks=True,
781 )
782 return partial_block_mask, full_block_mask
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:755, in create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device, _compile)
753 mask_mod = mod_fn
754 mask_mod = _vmap_for_bhqkv(mask_mod, prefix=())
--> 755 mask = mask_mod(b, h, m, n)
756 return mask
757 else:
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
202 def wrapped(*args, **kwargs):
--> 203 return vmap_impl(
204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205 )
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
320 return _chunked_vmap(
321 func,
322 flat_in_dims,
(...)
327 **kwargs,
328 )
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
334 flat_in_dims,
335 flat_args,
336 args_spec,
337 out_dims,
338 randomness,
339 **kwargs,
340 )
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476 batched_inputs = _create_batched_inputs(
477 flat_in_dims, flat_args, vmap_level, args_spec
478 )
--> 479 batched_outputs = func(*batched_inputs, **kwargs)
480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
202 def wrapped(*args, **kwargs):
--> 203 return vmap_impl(
204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205 )
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
320 return _chunked_vmap(
321 func,
322 flat_in_dims,
(...)
327 **kwargs,
328 )
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
334 flat_in_dims,
335 flat_args,
336 args_spec,
337 out_dims,
338 randomness,
339 **kwargs,
340 )
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476 batched_inputs = _create_batched_inputs(
477 flat_in_dims, flat_args, vmap_level, args_spec
478 )
--> 479 batched_outputs = func(*batched_inputs, **kwargs)
480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
[... skipping similar frames: _flat_vmap at line 479 (1 times), vmap_impl at line 331 (1 times), vmap.<locals>.wrapped at line 203 (1 times)]
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
202 def wrapped(*args, **kwargs):
--> 203 return vmap_impl(
204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205 )
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:331, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
320 return _chunked_vmap(
321 func,
322 flat_in_dims,
(...)
327 **kwargs,
328 )
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
334 flat_in_dims,
335 flat_args,
336 args_spec,
337 out_dims,
338 randomness,
339 **kwargs,
340 )
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_functorch/vmap.py:479, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
475 with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476 batched_inputs = _create_batched_inputs(
477 flat_in_dims, flat_args, vmap_level, args_spec
478 )
--> 479 batched_outputs = func(*batched_inputs, **kwargs)
480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
Cell In[111], line 2
1 def x_masking(b, h, q_idx, kv_idx):
----> 2 return doc_idsa[q_idx] == doc_idsb[kv_idx]
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_higher_order_ops/flex_attention.py:84, in TransformGetItemToIndex.__torch_function__(self, func, types, args, kwargs)
82 index_args = pytree.tree_leaves(args[1])
83 if all(isinstance(x, torch.Tensor) for x in index_args):
---> 84 return torch.ops.aten.index(args[0], index_args)
85 return func(*args, **(kwargs or {}))
File /run/media/Secondary/Projects/ML/PreDATFlex/.venv/lib64/python3.12/site-packages/torch/_ops.py:1116, in OpOverloadPacket.__call__(self, *args, **kwargs)
1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1115 return _call_overload_packet_from_python(self, args, kwargs)
-> 1116 return self._op(*args, **(kwargs or {}))
IndexError: index 300 is out of bounds for dimension 0 with size 300
Thank you
The text was updated successfully, but these errors were encountered:
I was playing around with Flex Attention, specifically Document Masking / Jagged Sequences, and I notice that when doing cross attention with Q_LEN being different from KV_LEN, trying to run
create_block_mask
with a block size that is not a common divisor of both Q_LEN and KV_LEN fails. Is this expected behavior?The error when using a block size that is not a divisor is
Thank you
The text was updated successfully, but these errors were encountered: