You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The fact that it's possible to create arbitrary score mod / mask mod patterns is really powerful!
I'm wondering if there is any way to reason about the efficiency of different masking patterns (if this is a relevant consideration)?
For example, is a 'full' score_mod e.g. returning bias[b, h, i, j], where bias is some explicitly materialised attention bias tensor going to yield any efficiency gains over manually adding the bias to the attention logits? What are the relative efficiencies of e.g. structured and random sparsity patterns in mask_mod?
Thanks
The text was updated successfully, but these errors were encountered:
alex-hh
changed the title
How to reason about efficiency
How to reason about efficiency of different score/mask mod functions
Oct 22, 2024
@alex-hh Generally speaking, the less memory you have to access from outside the kernel, the better. So loading from a full bias (i.e. size S^2) is going to be slower than loading from a 1d bias (i.e. size S), which is going to be slower than loading from.
For sparsity, FlexAttention is fundamentally block-sparse. So pure random sparsity is unlikely to help much.
Regarding block sparsity - does this mean that given a particular mask_mod pattern, there is potentially an optimal way of permuting the inputs before applying flex attention?
Hi,
The fact that it's possible to create arbitrary score mod / mask mod patterns is really powerful!
I'm wondering if there is any way to reason about the efficiency of different masking patterns (if this is a relevant consideration)?
For example, is a 'full' score_mod e.g. returning bias[b, h, i, j], where bias is some explicitly materialised attention bias tensor going to yield any efficiency gains over manually adding the bias to the attention logits? What are the relative efficiencies of e.g. structured and random sparsity patterns in mask_mod?
Thanks
The text was updated successfully, but these errors were encountered: