Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] masking steps in flash decoding #1449

Open
aws-jiadingg opened this issue Jan 17, 2025 · 1 comment
Open

[QST] masking steps in flash decoding #1449

aws-jiadingg opened this issue Jan 17, 2025 · 1 comment

Comments

@aws-jiadingg
Copy link

aws-jiadingg commented Jan 17, 2025

Flash decoding divides the sequence blocks into a series of splits, with each split assigned to a thread block. However, in the masking step loop (code), every split undergoes the same masking process, even though only the final split might actually require it. Is this the intended behavior? Should there be a control logic to only let the final split go through this "masking steps" loop? Thanks!

@tridao
Copy link
Member

tridao commented Jan 18, 2025

The outputs are still correct when we have extra masking iterations since the mask takes in m_block and n_block, so if they don't go out of bound the masking code will not change the elements.
Having separate masking iterations is just a speed optimization. You can add the check that only the final split should mask but that seems more complicated (and likely isn't faster during decoding where the bottleneck is loading KV, not computation).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants