Skip to content

Commit

Permalink
Add option to load checkpoints with transposed Gating Einsum.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733277427
  • Loading branch information
casaro authored and Flax Authors committed Mar 5, 2025
1 parent 0769411 commit b4d39a6
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
13 changes: 12 additions & 1 deletion examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Shape = Sequence[Union[int, Any]]

K_MASK = -2.3819763e38 # Set to a large negative number.
DEFAULT_ROPE_BASE_FREQUENCY = 10_000


class AttentionType(enum.Enum):
Expand Down Expand Up @@ -80,9 +81,11 @@ def __init__(
num_kv_heads: int,
features: int,
head_dim: int,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -93,6 +96,7 @@ def __init__(
'`sliding_window_size` must be set if `attn_type` is Local Sliding.'
)

self.query_pre_attn_scalar = query_pre_attn_scalar
self.attn_type = attn_type
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
Expand All @@ -101,6 +105,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.rope_base_frequency = rope_base_frequency
self.use_qk_norm = use_qk_norm
self.sow_config = sow_config

Expand Down Expand Up @@ -148,12 +153,14 @@ def __call__(
query_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)
query_scaled = query_proj * self.head_dim**-0.5
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)

# Cache is left aligned.
Expand Down Expand Up @@ -304,9 +311,11 @@ def __init__(
hidden_dim: int,
use_post_attn_norm: bool,
use_post_ffw_norm: bool,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -318,7 +327,9 @@ def __init__(
num_kv_heads=num_kv_heads,
features=embed_dim,
head_dim=head_dim,
query_pre_attn_scalar=query_pre_attn_scalar,
attn_type=attn_type,
rope_base_frequency=rope_base_frequency,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
Expand Down
16 changes: 13 additions & 3 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim):
num_kv_heads=4,
features=5,
head_dim=head_dim,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_use_qkv_einsum(
num_kv_heads=num_kv_heads,
features=5,
head_dim=8,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -144,7 +146,8 @@ def test_attention(
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand Down Expand Up @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand All @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.LOCAL_SLIDING,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.LOCAL_SLIDING,
sliding_window_size=sliding_window_size,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -272,6 +277,7 @@ def test_block(
1,
use_post_attn_norm,
use_post_ffw_norm,
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -315,6 +321,7 @@ def test_post_attention_norm(
1,
True,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -326,6 +333,7 @@ def test_post_attention_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -373,6 +381,7 @@ def test_post_ffw_norm(
1,
True,
True, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -384,6 +393,7 @@ def test_post_ffw_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down
46 changes: 45 additions & 1 deletion examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from collections.abc import Iterable
import dataclasses
import enum
import functools
from typing import Any

from flax import nnx
Expand All @@ -32,6 +34,19 @@
Cache = dict[str, modules.LayerCache]


class QueryPreAttentionNormalisation(enum.Enum):
"""Initialization strategy."""

# Whether to scale the query by 1/sqrt(head_dim)
BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto()

# Whether to scale the query by `embed_dim // num_heads`
BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto()

# Whether to scale the query by `1/sqrt(embed_dim // num_heads)`
BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto()


@dataclasses.dataclass(frozen=True)
class TransformerConfig:
"""Configuration for the gemma transformer."""
Expand All @@ -47,10 +62,26 @@ class TransformerConfig:
use_post_attn_norm: bool
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
query_pre_attn_norm: QueryPreAttentionNormalisation = (
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
attn_logits_soft_cap: float | None = None
transpose_gating_einsum: bool = False
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
use_qk_norm: bool = False
sliding_window_size: int | None = None

def query_pre_attn_scalar(self) -> float:
"""Returns the scalar to multiply the query by before attention."""
match self.query_pre_attn_norm:
case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS:
return self.embed_dim // self.num_heads
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long
return (self.embed_dim // self.num_heads) ** -0.5
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _:
return self.head_dim**-0.5

@classmethod
def from_path(cls, path: str) -> TransformerConfig:
"""Creates a TransformerConfig from loaded parameters."""
Expand Down Expand Up @@ -176,6 +207,7 @@ def gemma_9b(cls):


def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
"""Maps linen variable names to nnx variable names."""
new_key = []
for k in key:
if k.startswith('layer_'):
Expand All @@ -199,8 +231,12 @@ def _assign_linen_params_to_nnx_state(
state: dict[tuple[str, ...], Any],
mapped_path: tuple[str | int, ...],
val: Any,
transpose_gating_einsum: bool,
) -> dict[tuple[str, ...], Any]:
"""Splits and maybe transposes gate_proj."""
if 'gate_proj' in mapped_path:
if transpose_gating_einsum:
val = jnp.swapaxes(val, 1, 2)
state[mapped_path].value = val[0]
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
else:
Expand All @@ -217,11 +253,15 @@ def from_params(
) -> Transformer:
if config is None:
config = TransformerConfig.from_params(params)
assign_val_fn = functools.partial(
_assign_linen_params_to_nnx_state,
transpose_gating_einsum=config.transpose_gating_einsum,
)
return helpers.module_from_linen_variables(
module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)),
variables=params['transformer'],
map_key_fn=_map_linen_var_names,
assign_val_fn=_assign_linen_params_to_nnx_state,
assign_val_fn=assign_val_fn,
)

def __init__(
Expand All @@ -248,7 +288,11 @@ def __init__(
use_post_ffw_norm=config.use_post_ffw_norm,
attn_logits_soft_cap=config.attn_logits_soft_cap,
attn_type=attn_type,
query_pre_attn_scalar=config.query_pre_attn_scalar(),
rngs=rngs,
rope_base_frequency=config.local_base_frequency
if attn_type == modules.AttentionType.LOCAL_SLIDING
else config.global_base_frequency,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
)
Expand Down

0 comments on commit b4d39a6

Please sign in to comment.