Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
I believe we've identified a couple issues with how attention is calculated, related to the
1/sqrt(d)
scaling factor on the qk product.First, when we invoke the cudnn_flash_te kernel for GPUs, we supply a scaling factor of
1/sqrt(d)
.maxtext/MaxText/layers/attentions.py
Line 470 in ee14ae6
This would normally be correct, but for the naive dot_product and tpu flash kernels we do not supply this scaling factor. This is because of these lines, which initialize the Q weights to be scaled down by that factor:
maxtext/MaxText/layers/attentions.py
Lines 1203 to 1213 in ee14ae6
This would mean the QK product for GPU flash attention is doubly scaled down, which can hurt performance. In any case, it differs from the TPU and dot_product implementations.
Second, investigating this led us to think harder about that comment, and there may be another issue there, which applies also to TPUs. If you use the adafactor optimizer, then gradients are implicitly scaled by the size of the weights, so you can scale just the initializer and this is equivalent to scaling the qk product. However, this does not hold for other optimizers, including maxtext's default of Adam. This means that under Adam, you would expect training to be less stable than if you used the more typical formulation of dividing the QK product by
sqrt(d)
.To be precise:
That comment claims that initializing the Q weights to be scaled down is equivalent to dividing the QK product by the same factor.
I claim that's not true under Adam, because the update does not depend on the magnitude of the weights.
We can test this empirically. If the initialization strategy is sufficient, then we should see both produce exactly the same loss curves (up to floating point error). If not, then we should see significant differences.
Here are two runs (using Adam) which differ only in this way.

Since the loss curves differ significantly, these are not equivalent.
Traditionally we would scale the qk product instead of the the initializer, so we could default to that behavior.
However, in most of the experiments I've done the actual loss curve of the init strategy is actually slightly better than scaling the qk product. Pending further testing, we should probably continue to support this as an option, and maybe default to it.
This PR has one possible resolution of these issues:
1/sqrt(d)
scaling factor from the cudnn_flash_te kernel to bring it in line with the dot_product and tpu flash implementations.query_scaling
, which lets you choose between scaling the initializer and scaling the query. Default to scaling the initializer, which matches the previous behavior.Note that scaling the qk product is equal to scaling the query, so since the TPU flash attention kernel doesn't support scaling the qk product, I opted to scale the query for all attention implementations. This is actually how the new MLA attention module does it, so I'm just doing the same everywhere else:
maxtext/MaxText/layers/attentions.py
Lines 1540 to 1542 in ee14ae6
Tests
In addition to the graph shown above, I ran a number of other variations, such as turning off gradient clipping (which I thought could be a confounder), and not scaling at all. I also ran some GPU tests using cudnn_flash_te, and some with fused_qkv. I can provide loss curves from any of those if needed.
Checklist
Before submitting this PR, please make sure (put X in square brackets):