Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention scaling fixes #1349

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

philip-essential
Copy link
Contributor

Description

I believe we've identified a couple issues with how attention is calculated, related to the 1/sqrt(d) scaling factor on the qk product.

First, when we invoke the cudnn_flash_te kernel for GPUs, we supply a scaling factor of 1/sqrt(d).

scale_factor=1.0 / math.sqrt(head_dim),

This would normally be correct, but for the naive dot_product and tpu flash kernels we do not supply this scaling factor. This is because of these lines, which initialize the Q weights to be scaled down by that factor:

def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
def query_init(*args):
# pylint: disable=no-value-for-parameter
return self.kernel_init(*args) / depth_scaling

This would mean the QK product for GPU flash attention is doubly scaled down, which can hurt performance. In any case, it differs from the TPU and dot_product implementations.

Second, investigating this led us to think harder about that comment, and there may be another issue there, which applies also to TPUs. If you use the adafactor optimizer, then gradients are implicitly scaled by the size of the weights, so you can scale just the initializer and this is equivalent to scaling the qk product. However, this does not hold for other optimizers, including maxtext's default of Adam. This means that under Adam, you would expect training to be less stable than if you used the more typical formulation of dividing the QK product by sqrt(d).

To be precise:

  • That comment claims that initializing the Q weights to be scaled down is equivalent to dividing the QK product by the same factor.

  • I claim that's not true under Adam, because the update does not depend on the magnitude of the weights.

  • We can test this empirically. If the initialization strategy is sufficient, then we should see both produce exactly the same loss curves (up to floating point error). If not, then we should see significant differences.

  • Here are two runs (using Adam) which differ only in this way.
    Screenshot 2025-03-05 at 4 57 02 PM

  • Since the loss curves differ significantly, these are not equivalent.

  • Traditionally we would scale the qk product instead of the the initializer, so we could default to that behavior.

  • However, in most of the experiments I've done the actual loss curve of the init strategy is actually slightly better than scaling the qk product. Pending further testing, we should probably continue to support this as an option, and maybe default to it.

This PR has one possible resolution of these issues:

  • Remove the extra 1/sqrt(d) scaling factor from the cudnn_flash_te kernel to bring it in line with the dot_product and tpu flash implementations.
  • Add a config option query_scaling, which lets you choose between scaling the initializer and scaling the query. Default to scaling the initializer, which matches the previous behavior.
  • Fix the fused_qkv implementation, which didn't scale the query anyway. I'm not sure if anyone uses fused_qkv regularly, but we got very bad results with it, which this PR fixes.

Note that scaling the qk product is equal to scaling the query, so since the TPU flash attention kernel doesn't support scaling the qk product, I opted to scale the query for all attention implementations. This is actually how the new MLA attention module does it, so I'm just doing the same everywhere else:

# Query projection is scaled by 1 / self.softmax_scale to be consistent MaxText implementation.
# DeepSeek v3 was doing it in attention score computation.
return jnp.concatenate([q_nope, q_pe], axis=-1) / self.softmax_scale

Tests

In addition to the graph shown above, I ran a number of other variations, such as turning off gradient clipping (which I thought could be a confounder), and not scaling at all. I also ran some GPU tests using cudnn_flash_te, and some with fused_qkv. I can provide loss curves from any of those if needed.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant