From 2206a5a8a04445f6142cf5294f2bbe2fcd4cb5ba Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Thu, 6 Mar 2025 00:45:21 +0000 Subject: [PATCH] attention scaling fixes --- MaxText/configs/base.yml | 11 +++++++++++ MaxText/layers/attentions.py | 25 ++++++++++++++++++------- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 7a536405d..3017526c6 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -215,6 +215,17 @@ final_logits_soft_cap: 0.0 use_post_attn_norm: False use_post_ffw_norm: False +# T5 does not explicitly rescale the attention logits by 1/sqrt(d). This is +# folded into the initializers of the linear transformations, which is +# equivalent under Adafactor. Under Adam, these are not equivalent, so scaling +# down the query activations may be more stable. Supported options: +# +# init: scale initialization of query weights by 1/sqrt(d) +# query: scale query activations by 1/sqrt(d) (equivalent to scaling qk product) +# both: scale both init and query +# none: don't scale either +query_scaling: init + # MLA parameters q_lora_rank: 0 kv_lora_rank: 512 diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index fee603dfd..e0fd58fca 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -16,7 +16,6 @@ import enum import functools -import math from typing import Any, Optional, Tuple from flax import linen as nn @@ -506,7 +505,7 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=1.0 / math.sqrt(head_dim), + scale_factor=1.0, # handled elsewhere; see config.query_scaling transpose_batch_sequence=False, window_size=sliding_window_size, ) @@ -1347,14 +1346,14 @@ def setup(self): def query_projection(self, inputs_q: Array) -> Array: """Query projection.""" - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) def query_init(*args): # pylint: disable=no-value-for-parameter - return self.kernel_init(*args) / depth_scaling + if self.config.query_scaling in ["init", "both"]: + return self.kernel_init(*args) / depth_scaling + else: + return self.kernel_init(*args) query_proj = DenseGeneral( features=(self.num_query_heads, self.head_dim), @@ -1404,10 +1403,19 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: def qkv_projection(self, inputs: Array, proj_name: str): """Fused QKV projection""" + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + + def qkv_init(*args): + # pylint: disable=no-value-for-parameter + if self.config.query_scaling in ["init", "both"]: + return self.kernel_init(*args).at[:, 0, :, :].divide(depth_scaling) + else: + return self.kernel_init(*args) + qkv_proj = DenseGeneral( features=(3, self.num_query_heads, self.head_dim), axis=-1, - kernel_init=self.kernel_init, + kernel_init=qkv_init, kernel_axes=("embed", "qkv", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, @@ -1546,6 +1554,9 @@ def __call__( assert not self.config.quantize_kvcache or self.kv_quant + if self.config.query_scaling in ["query", "both"]: + query /= jnp.sqrt(query.shape[-1]) + if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN: unnormalized_out, _, exp_sum = self.paged_attention_op( query, key, value, decoder_segment_ids, model_mode, previous_chunk, page_state=page_state