Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention scaling fixes #1349

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False

# T5 does not explicitly rescale the attention logits by 1/sqrt(d). This is
# folded into the initializers of the linear transformations, which is
# equivalent under Adafactor. Under Adam, these are not equivalent, so scaling
# down the query activations may be more stable. Supported options:
#
# init: scale initialization of query weights by 1/sqrt(d)
# query: scale query activations by 1/sqrt(d) (equivalent to scaling qk product)
# both: scale both init and query
# none: don't scale either
query_scaling: init

# MLA parameters
q_lora_rank: 0
kv_lora_rank: 512
Expand Down
25 changes: 18 additions & 7 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import enum
import functools
import math
from typing import Any, Optional, Tuple

from flax import linen as nn
Expand Down Expand Up @@ -506,7 +505,7 @@ def cudnn_flash_attention(
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0 / math.sqrt(head_dim),
scale_factor=1.0, # handled elsewhere; see config.query_scaling
transpose_batch_sequence=False,
window_size=sliding_window_size,
)
Expand Down Expand Up @@ -1347,14 +1346,14 @@ def setup(self):
def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""

# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)

def query_init(*args):
# pylint: disable=no-value-for-parameter
return self.kernel_init(*args) / depth_scaling
if self.config.query_scaling in ["init", "both"]:
return self.kernel_init(*args) / depth_scaling
else:
return self.kernel_init(*args)

query_proj = DenseGeneral(
features=(self.num_query_heads, self.head_dim),
Expand Down Expand Up @@ -1404,10 +1403,19 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
def qkv_projection(self, inputs: Array, proj_name: str):
"""Fused QKV projection"""

depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)

def qkv_init(*args):
# pylint: disable=no-value-for-parameter
if self.config.query_scaling in ["init", "both"]:
return self.kernel_init(*args).at[:, 0, :, :].divide(depth_scaling)
else:
return self.kernel_init(*args)

qkv_proj = DenseGeneral(
features=(3, self.num_query_heads, self.head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_init=qkv_init,
kernel_axes=("embed", "qkv", "heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
Expand Down Expand Up @@ -1546,6 +1554,9 @@ def __call__(

assert not self.config.quantize_kvcache or self.kv_quant

if self.config.query_scaling in ["query", "both"]:
query /= jnp.sqrt(query.shape[-1])

if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN:
unnormalized_out, _, exp_sum = self.paged_attention_op(
query, key, value, decoder_segment_ids, model_mode, previous_chunk, page_state=page_state
Expand Down