Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to load checkpoints with transposed Gating Einsum. #4597

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Shape = Sequence[Union[int, Any]]

K_MASK = -2.3819763e38 # Set to a large negative number.
DEFAULT_ROPE_BASE_FREQUENCY = 10_000


class AttentionType(enum.Enum):
Expand Down Expand Up @@ -80,9 +81,11 @@ def __init__(
num_kv_heads: int,
features: int,
head_dim: int,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -93,6 +96,7 @@ def __init__(
'`sliding_window_size` must be set if `attn_type` is Local Sliding.'
)

self.query_pre_attn_scalar = query_pre_attn_scalar
self.attn_type = attn_type
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
Expand All @@ -101,6 +105,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.rope_base_frequency = rope_base_frequency
self.use_qk_norm = use_qk_norm
self.sow_config = sow_config

Expand Down Expand Up @@ -148,12 +153,14 @@ def __call__(
query_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)
query_scaled = query_proj * self.head_dim**-0.5
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)

# Cache is left aligned.
Expand Down Expand Up @@ -304,9 +311,11 @@ def __init__(
hidden_dim: int,
use_post_attn_norm: bool,
use_post_ffw_norm: bool,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -318,7 +327,9 @@ def __init__(
num_kv_heads=num_kv_heads,
features=embed_dim,
head_dim=head_dim,
query_pre_attn_scalar=query_pre_attn_scalar,
attn_type=attn_type,
rope_base_frequency=rope_base_frequency,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
Expand Down
16 changes: 13 additions & 3 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim):
num_kv_heads=4,
features=5,
head_dim=head_dim,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_use_qkv_einsum(
num_kv_heads=num_kv_heads,
features=5,
head_dim=8,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -144,7 +146,8 @@ def test_attention(
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand Down Expand Up @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand All @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.LOCAL_SLIDING,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.LOCAL_SLIDING,
sliding_window_size=sliding_window_size,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -272,6 +277,7 @@ def test_block(
1,
use_post_attn_norm,
use_post_ffw_norm,
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -315,6 +321,7 @@ def test_post_attention_norm(
1,
True,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -326,6 +333,7 @@ def test_post_attention_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -373,6 +381,7 @@ def test_post_ffw_norm(
1,
True,
True, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -384,6 +393,7 @@ def test_post_ffw_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down
46 changes: 45 additions & 1 deletion examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from collections.abc import Iterable
import dataclasses
import enum
import functools
from typing import Any

from flax import nnx
Expand All @@ -32,6 +34,19 @@
Cache = dict[str, modules.LayerCache]


class QueryPreAttentionNormalisation(enum.Enum):
"""Initialization strategy."""

# Whether to scale the query by 1/sqrt(head_dim)
BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto()

# Whether to scale the query by `embed_dim // num_heads`
BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto()

# Whether to scale the query by `1/sqrt(embed_dim // num_heads)`
BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto()


@dataclasses.dataclass(frozen=True)
class TransformerConfig:
"""Configuration for the gemma transformer."""
Expand All @@ -47,10 +62,26 @@ class TransformerConfig:
use_post_attn_norm: bool
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
query_pre_attn_norm: QueryPreAttentionNormalisation = (
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
attn_logits_soft_cap: float | None = None
transpose_gating_einsum: bool = False
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
use_qk_norm: bool = False
sliding_window_size: int | None = None

def query_pre_attn_scalar(self) -> float:
"""Returns the scalar to multiply the query by before attention."""
match self.query_pre_attn_norm:
case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS:
return self.embed_dim // self.num_heads
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long
return (self.embed_dim // self.num_heads) ** -0.5
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _:
return self.head_dim**-0.5

@classmethod
def from_path(cls, path: str) -> TransformerConfig:
"""Creates a TransformerConfig from loaded parameters."""
Expand Down Expand Up @@ -176,6 +207,7 @@ def gemma_9b(cls):


def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
"""Maps linen variable names to nnx variable names."""
new_key = []
for k in key:
if k.startswith('layer_'):
Expand All @@ -199,8 +231,12 @@ def _assign_linen_params_to_nnx_state(
state: dict[tuple[str, ...], Any],
mapped_path: tuple[str | int, ...],
val: Any,
transpose_gating_einsum: bool,
) -> dict[tuple[str, ...], Any]:
"""Splits and maybe transposes gate_proj."""
if 'gate_proj' in mapped_path:
if transpose_gating_einsum:
val = jnp.swapaxes(val, 1, 2)
state[mapped_path].value = val[0]
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
else:
Expand All @@ -217,11 +253,15 @@ def from_params(
) -> Transformer:
if config is None:
config = TransformerConfig.from_params(params)
assign_val_fn = functools.partial(
_assign_linen_params_to_nnx_state,
transpose_gating_einsum=config.transpose_gating_einsum,
)
return helpers.module_from_linen_variables(
module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)),
variables=params['transformer'],
map_key_fn=_map_linen_var_names,
assign_val_fn=_assign_linen_params_to_nnx_state,
assign_val_fn=assign_val_fn,
)

def __init__(
Expand All @@ -248,7 +288,11 @@ def __init__(
use_post_ffw_norm=config.use_post_ffw_norm,
attn_logits_soft_cap=config.attn_logits_soft_cap,
attn_type=attn_type,
query_pre_attn_scalar=config.query_pre_attn_scalar(),
rngs=rngs,
rope_base_frequency=config.local_base_frequency
if attn_type == modules.AttentionType.LOCAL_SLIDING
else config.global_base_frequency,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
)
Expand Down
Loading