From b4d39a6296748c64f4f59c2f5f30ccf9e8d21982 Mon Sep 17 00:00:00 2001 From: Sascha Rothe Date: Tue, 4 Mar 2025 04:37:20 -0800 Subject: [PATCH] Add option to load checkpoints with transposed Gating Einsum. PiperOrigin-RevId: 733277427 --- examples/gemma/modules.py | 13 +++++++++- examples/gemma/modules_test.py | 16 +++++++++--- examples/gemma/transformer.py | 46 +++++++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index c7572e055..be152eda6 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -33,6 +33,7 @@ Shape = Sequence[Union[int, Any]] K_MASK = -2.3819763e38 # Set to a large negative number. +DEFAULT_ROPE_BASE_FREQUENCY = 10_000 class AttentionType(enum.Enum): @@ -80,9 +81,11 @@ def __init__( num_kv_heads: int, features: int, head_dim: int, + query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, + rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, use_qk_norm: bool = False, @@ -93,6 +96,7 @@ def __init__( '`sliding_window_size` must be set if `attn_type` is Local Sliding.' ) + self.query_pre_attn_scalar = query_pre_attn_scalar self.attn_type = attn_type self.sliding_window_size = sliding_window_size self.attn_logits_soft_cap = attn_logits_soft_cap @@ -101,6 +105,7 @@ def __init__( shape=(num_heads, head_dim, features), rngs=rngs, ) + self.rope_base_frequency = rope_base_frequency self.use_qk_norm = use_qk_norm self.sow_config = sow_config @@ -148,12 +153,14 @@ def __call__( query_proj, segment_pos, head_dim=self.head_dim, + max_wavelength=self.rope_base_frequency, ) - query_scaled = query_proj * self.head_dim**-0.5 + query_scaled = query_proj * self.query_pre_attn_scalar key_proj = positional_embeddings.apply_rope( key_proj, segment_pos, head_dim=self.head_dim, + max_wavelength=self.rope_base_frequency, ) # Cache is left aligned. @@ -304,9 +311,11 @@ def __init__( hidden_dim: int, use_post_attn_norm: bool, use_post_ffw_norm: bool, + query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, + rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, use_qk_norm: bool = False, @@ -318,7 +327,9 @@ def __init__( num_kv_heads=num_kv_heads, features=embed_dim, head_dim=head_dim, + query_pre_attn_scalar=query_pre_attn_scalar, attn_type=attn_type, + rope_base_frequency=rope_base_frequency, attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, rngs=rngs, diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index 7439585cc..0f94a5e4f 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim): num_kv_heads=4, features=5, head_dim=head_dim, + query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -107,6 +108,7 @@ def test_use_qkv_einsum( num_kv_heads=num_kv_heads, features=5, head_dim=8, + query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -144,7 +146,8 @@ def test_attention( num_heads, features, head_dim, - modules.AttentionType.GLOBAL, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size): num_heads, features, head_dim, - modules.AttentionType.GLOBAL, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size): num_heads, features, head_dim, - modules.AttentionType.LOCAL_SLIDING, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.LOCAL_SLIDING, sliding_window_size=sliding_window_size, rngs=nnx.Rngs(params=0), ) @@ -272,6 +277,7 @@ def test_block( 1, use_post_attn_norm, use_post_ffw_norm, + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -315,6 +321,7 @@ def test_post_attention_norm( 1, True, False, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -326,6 +333,7 @@ def test_post_attention_norm( 1, False, False, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -373,6 +381,7 @@ def test_post_ffw_norm( 1, True, True, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -384,6 +393,7 @@ def test_post_ffw_norm( 1, False, False, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index e253c4082..ce9eed319 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -18,6 +18,8 @@ from collections.abc import Iterable import dataclasses +import enum +import functools from typing import Any from flax import nnx @@ -32,6 +34,19 @@ Cache = dict[str, modules.LayerCache] +class QueryPreAttentionNormalisation(enum.Enum): + """Initialization strategy.""" + + # Whether to scale the query by 1/sqrt(head_dim) + BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto() + + # Whether to scale the query by `embed_dim // num_heads` + BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto() + + # Whether to scale the query by `1/sqrt(embed_dim // num_heads)` + BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto() + + @dataclasses.dataclass(frozen=True) class TransformerConfig: """Configuration for the gemma transformer.""" @@ -47,10 +62,26 @@ class TransformerConfig: use_post_attn_norm: bool use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] + query_pre_attn_norm: QueryPreAttentionNormalisation = ( + QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM + ) attn_logits_soft_cap: float | None = None + transpose_gating_einsum: bool = False + local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY + global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY use_qk_norm: bool = False sliding_window_size: int | None = None + def query_pre_attn_scalar(self) -> float: + """Returns the scalar to multiply the query by before attention.""" + match self.query_pre_attn_norm: + case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS: + return self.embed_dim // self.num_heads + case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long + return (self.embed_dim // self.num_heads) ** -0.5 + case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _: + return self.head_dim**-0.5 + @classmethod def from_path(cls, path: str) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters.""" @@ -176,6 +207,7 @@ def gemma_9b(cls): def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: + """Maps linen variable names to nnx variable names.""" new_key = [] for k in key: if k.startswith('layer_'): @@ -199,8 +231,12 @@ def _assign_linen_params_to_nnx_state( state: dict[tuple[str, ...], Any], mapped_path: tuple[str | int, ...], val: Any, + transpose_gating_einsum: bool, ) -> dict[tuple[str, ...], Any]: + """Splits and maybe transposes gate_proj.""" if 'gate_proj' in mapped_path: + if transpose_gating_einsum: + val = jnp.swapaxes(val, 1, 2) state[mapped_path].value = val[0] state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] else: @@ -217,11 +253,15 @@ def from_params( ) -> Transformer: if config is None: config = TransformerConfig.from_params(params) + assign_val_fn = functools.partial( + _assign_linen_params_to_nnx_state, + transpose_gating_einsum=config.transpose_gating_einsum, + ) return helpers.module_from_linen_variables( module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), variables=params['transformer'], map_key_fn=_map_linen_var_names, - assign_val_fn=_assign_linen_params_to_nnx_state, + assign_val_fn=assign_val_fn, ) def __init__( @@ -248,7 +288,11 @@ def __init__( use_post_ffw_norm=config.use_post_ffw_norm, attn_logits_soft_cap=config.attn_logits_soft_cap, attn_type=attn_type, + query_pre_attn_scalar=config.query_pre_attn_scalar(), rngs=rngs, + rope_base_frequency=config.local_base_frequency + if attn_type == modules.AttentionType.LOCAL_SLIDING + else config.global_base_frequency, use_qk_norm=config.use_qk_norm, sow_config=sow_config, )