You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Flash decoding divides the sequence blocks into a series of splits, with each split assigned to a thread block. However, in the masking step loop (code), every split undergoes the same masking process, even though only the final split might actually require it. Is this the intended behavior? Should there be a control logic to only let the final split go through this "masking steps" loop? Thanks!
The text was updated successfully, but these errors were encountered:
The outputs are still correct when we have extra masking iterations since the mask takes in m_block and n_block, so if they don't go out of bound the masking code will not change the elements.
Having separate masking iterations is just a speed optimization. You can add the check that only the final split should mask but that seems more complicated (and likely isn't faster during decoding where the bottleneck is loading KV, not computation).
Flash decoding divides the sequence blocks into a series of splits, with each split assigned to a thread block. However, in the masking step loop (code), every split undergoes the same masking process, even though only the final split might actually require it. Is this the intended behavior? Should there be a control logic to only let the final split go through this "masking steps" loop? Thanks!
The text was updated successfully, but these errors were encountered: