Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimal ordering with block mask #56

Open
francois-rozet opened this issue Oct 19, 2024 · 9 comments
Open

Optimal ordering with block mask #56

francois-rozet opened this issue Oct 19, 2024 · 9 comments

Comments

@francois-rozet
Copy link

francois-rozet commented Oct 19, 2024

From my understanding, flex attention (using block_mask) gets faster when the number of empty blocks is larger. If the inputs (Q, K, V) do not represent sequences, but graphs with local connectivity (e.g. pixels in an image) the ordering of the elements has a huge impact on the number of empty blocks.

It would be very useful to add helpers to find optimal, or simply better, orderings given a mask. For example, for images, it is likely better to order the pixels by small patch (close to the attention window size), rather than the standard row-by-row order.

Note that this is related to the minimum degree algorithm.

@Chillee
Copy link
Contributor

Chillee commented Oct 21, 2024

Yeah, this is a pretty fun idea :) I had previously played around with an idea like this using a permute_transform like so

def permute_mod(mod, permutation):
  def new_mod(b, h, q, kv):
    q_idx = permutation[q]
    kv_idx = permutation[kv]
    return mod(b, h, q_idx, kv_idx)
  return new_mod

and so, this allows you to transform any existing mask_mod to one that operates on a permuted input. Unfortunately, this does require you to do a bunch of additional memory accesses, so might not be worth it unless you get way more sparsity. But I had some good successes in certain cases with a hilbert curve.

For 2d images, uwu (on Discord) suggested trying a Morton curve, which could be a good alternative, since it's cheap to "compute" :)

@francois-rozet
Copy link
Author

I think it is worth it if you can do the permutation once before a series of attention operations. That is pretty much the case in vision transformers with local windows.

I also tried the Hilbert and Moore curves, but I haven't conducted a proper benchmark.

@Chillee
Copy link
Contributor

Chillee commented Oct 21, 2024

The issue isn't necessarily that permuting the tokens itself is expensive, but rather that after the permutation you need to load the permutation index into the "inner loop" of the attention, which does offset some of the sparsity gains you can get.

Why is why Morton curves were an interesting suggestion to me, since I think they're fairly cheaply computable "within" the kernel itself.

@francois-rozet
Copy link
Author

francois-rozet commented Oct 21, 2024

I don't think you need to load the permutation index if you compute the BlockMask once.

Basically my idea was to find a permutation to minimize the number of (non-empty) blocks in a BlockMask. Then you can reuse the same block mask again and again.

@Chillee
Copy link
Contributor

Chillee commented Oct 21, 2024

If you can guarantee that all of your non-empty blocks are "full" (i.e. non-masked at all), then you don't need to load the permutation index for those blocks.

However, for the partially-masked blocks, you still need to load permutation index to compute the mask for those blocks. For example, this is NATTEN with a hilbert curve.

image

@francois-rozet
Copy link
Author

francois-rozet commented Oct 22, 2024

I don't see where the permutation indices appear anymore after the BlockMask has been created with the permuted mask_mod. If both the sequence and the block mask are permuted, there is no permutation happening anymore. There is indexing with respect to column indices in the block mask, but no "permutation".

@Chillee
Copy link
Contributor

Chillee commented Oct 22, 2024

Yes, that's what I mean. You must load from your column indices (which represent a permutation) in your inner loop.

@francois-rozet
Copy link
Author

I think we are speaking of the same think in different terms, but I don't see how the column indices represent permutations. They are ordered (which allows faster access than random indexing) and target a subset of the full block. I agree that the subset is determined by the original permutation, but the indexing operation does not involve a permutation.

Anyway, your NATTEN + Hilbert curve seems much more efficient than NATTEN alone! Do you still have the code to generate the permutation? I used a random Python library previously.

@cat-state
Copy link

cat-state commented Nov 12, 2024

as long as the positions of your latents aren't changing, there shouldnt be any case where you need to apply a permutation at score/mask mod time, unless you are relying on some non permutation equivariant function of q_idx and kv_idx.
If your latents dont move, then all you should need to do it reorder them (& any position associated data - rope coeffs, e.g) once

then:

  1. Your score_mod is nothing or doesnt depend on any input data - then the ordering doesnt matter & the block_mask encodes the sparsity pattern
  2. Your score_mod depends on q_idx/kv_idx to access position associated data - then the upfront reordering of this data should cancel out

There is one case where you would need to permute/unpermute
3. Your score_mod depends on q_idx/kv_idx in a way that isnt permutation equivariant. e.g, if you had score * torch.exp(q_idx - kv_idx)). Then, the reorder changes, and you would need to unpermute

the morton code in pytorch (for 3d, but you can also use it for 2d by passing in 3d coords with the last dim zerod) if anyone finds this helpful. sorting by this is decent but wont be as good an ordering as hilbert curve probably:

def quantize_coords(coords: torch.Tensor, bits: int = 21):
    max_int = (1 << bits) - 1
    coords = coords.clamp(0, 1)
    return (coords * max_int).long()

def split_by_3_bits_21(x: torch.Tensor):
    x = (x | (x << 32)) & 0x1f00000000ffff
    x = (x | (x << 16)) & 0x1f0000ff0000ff
    x = (x | (x << 8)) & 0x100f00f00f00f00f
    x = (x | (x << 4)) & 0x10c30c30c30c30c3
    x = (x | (x << 2)) & 0x1249249249249249
    return x

@torch.compile
def morton_encode(coords: torch.Tensor, bits: int = 21):
    coords = quantize_coords(coords, bits)
    x = split_by_3_bits_21(coords[..., 0])
    y = split_by_3_bits_21(coords[..., 1]) << 1
    z = split_by_3_bits_21(coords[..., 2]) << 2
    morton_code = x | y | z

    return morton_code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants