Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Add NNX WeightNorm. #4568

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from .nn.normalization import LayerNorm as LayerNorm
from .nn.normalization import RMSNorm as RMSNorm
from .nn.normalization import GroupNorm as GroupNorm
from .nn.normalization import WeightNorm as WeightNorm
from .nn.stochastic import Dropout as Dropout
from .rnglib import Rngs as Rngs
from .rnglib import RngStream as RngStream
Expand Down
154 changes: 153 additions & 1 deletion flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@ def _normalize(
return jnp.asarray(y, dtype)


def _l2_normalize(x, axis=None, eps=1e-12):
"""Normalizes along dimension `axis` using an L2 norm.

This specialized function exists for numerical stability reasons.

Args:
x: An input ndarray.
axis: Dimension along which to normalize, e.g. `1` to separately normalize
vectors in a batch. Passing `None` views `t` as a flattened vector when
calculating the norm (equivalent to Frobenius norm).
eps: Epsilon to avoid dividing by zero.

Returns:
An array of the same shape as 'x' L2-normalized along 'axis'.
"""
return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)


class BatchNorm(Module):
"""BatchNorm Module.

Expand Down Expand Up @@ -835,4 +853,138 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
(self.feature_axis,),
self.dtype,
self.epsilon,
)
)


class WeightNorm(nnx.Module):
"""L2 weight normalization (https://arxiv.org/abs/1602.07868).

Weight normalization normalizes the weight params so that the l2-norm of
the matrix is equal to 1. This is implemented as a layer wrapper where
each wrapped layer will have its params l2-normalized before computing
its ``__call__`` output.

Example usage::

>>> import jax
>>> import numpy as np
>>> from flax import nnx

>>> class Foo(nnx.Module):
... def __init__(self, rngs: nnx.Rngs):
... self.normed_linear = nnx.WeightNorm(
... nnx.Linear(8, 4, rngs=rngs),
... variable_filter=nnx.PathContains('kernel'),
... rngs=rngs,
... )
...
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.normed_linear(x)

>>> rng = jax.random.PRNGKey(42)
>>> model = Foo(rngs=nnx.Rngs(rng))

>>> x = jax.random.normal(rng, (5, 8))
>>> y = model(x)
>>> y.shape
(5, 4)

>>> w = model.normed_linear.layer_instance.kernel.value
>>> col_norms = np.linalg.norm(np.array(w), axis=0)
>>> np.testing.assert_allclose(col_norms, np.ones(4))

Args:
layer_instance: The layer instance to wrap.
feature_axes: The axes to normalize.
use_scale: Whether to use a scale parameter.
scale_init: The initializer for the scale parameter, by default ones.
epsilon: The epsilon value for the normalization, by default 1e-12.
dtype: The dtype of the result, by default infer from input and params.
param_dtype: The dtype of the parameters, by default float32.
variable_filter: The variable filter, by default ``nnx.PathContains('kernel')``.
rngs: The rng key.
"""
def __init__(
self,
layer_instance: nnx.Module,
*,
feature_axes: Axes | None = -1,
use_scale: bool = True,
scale_init: Initializer = initializers.ones,
epsilon: float = 1e-12,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
variable_filter: nnx.filterlib.Filter = nnx.PathContains('kernel'),
rngs: rnglib.Rngs,
):
self.layer_instance = layer_instance
self.feature_axes = feature_axes
self.use_scale = use_scale
self.scale_init = scale_init
self.epsilon = epsilon
self.dtype = dtype
self.param_dtype = param_dtype
self.variable_filter = variable_filter
self.rngs = rngs

def __call__(self, x: Array, *args, **kwargs) -> Array:
"""Compute the l2-norm of the weights in ``self.layer_instance``
and normalize the weights using this value before computing the
``__call__`` output.

Args:
*args: positional arguments to be passed into the call method of the
underlying layer instance in ``self.layer_instance``.
**kwargs: keyword arguments to be passed into the call method of the
underlying layer instance in ``self.layer_instance``.

Returns:
Output of the layer using l2-normalized weights.
"""
state = nnx.state(self.layer_instance)

def apply_weightnorm(path, var_state):
if not self.variable_filter(path, var_state):
return var_state

param_val = jnp.asarray(var_state.value)
if self.feature_axes is None:
feature_axes = ()
reduction_axes = tuple(range(param_val.ndim))
else:
feature_axes = _canonicalize_axes(param_val.ndim, self.feature_axes)
reduction_axes = tuple(i for i in range(param_val.ndim) if i not in feature_axes)

value_bar = _l2_normalize(param_val, axis=reduction_axes, eps=self.epsilon)

if self.use_scale:
scale_shape = tuple(param_val.shape[ax] for ax in feature_axes)
scale_path = path + ("scale",)
try:
scale_state = state[scale_path]
scale_value = scale_state.value
except KeyError:
key = self.rngs.params()
scale_value = self.scale_init(key, scale_shape, self.param_dtype)
state[scale_path] = nnx.Param(scale_value)

if len(feature_axes) < param_val.ndim:
broadcast_shape = [1] * param_val.ndim
for ax in feature_axes:
broadcast_shape[ax] = param_val.shape[ax]
scale_value = scale_value.reshape(broadcast_shape)
value_bar = value_bar * scale_value

cast_args = [param_val]
if self.use_scale:
cast_args.append(scale_value)

final_dtype = dtypes.canonicalize_dtype(*cast_args, dtype=self.dtype)
new_val = jnp.asarray(value_bar, final_dtype)

return nnx.Param(new_val)

state = nnx.map_state(apply_weightnorm, state)
nnx.update(self.layer_instance, state)

return self.layer_instance(x, *args, **kwargs) # type: ignore
65 changes: 64 additions & 1 deletion tests/nnx/nn/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,69 @@ def __call__(self, x, *, mask=None):
assert isinstance(linen_out, jax.Array)
np.testing.assert_array_equal(linen_out, nnx_out)

@parameterized.product(
dtype=[jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
scale_init=[
nnx.initializers.ones,
nnx.initializers.constant(10.0),
nnx.initializers.constant(0.5),
],
)
def test_nnx_linen_weightnorm_equivalence(
self,
dtype: tp.Optional[Dtype],
param_dtype: Dtype,
scale_init: nnx.Initializer,
):
class NNXModel(nnx.Module):
def __init__(self, dtype, param_dtype, rngs):
self.dense = nnx.Linear(
8, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs
)
self.normed = nnx.WeightNorm(
self.dense,
use_scale=True,
scale_init=scale_init,
feature_axes=-1,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)

def __call__(self, x, *, mask=None):
return self.normed(x)

class LinenModel(linen.Module):
dtype: tp.Optional[Dtype] = None
param_dtype: Dtype = jnp.float32

def setup(self):
self.dense = linen.Dense(
4, dtype=self.dtype, param_dtype=self.param_dtype
)
self.weight_norm = linen.WeightNorm(
self.dense, variable_filter={'kernel'}, scale_init=scale_init
)

def __call__(self, x, *, mask=None):
return self.weight_norm(x)

rngs = nnx.Rngs(42)

x = jax.random.normal(jax.random.key(0), (10, 8))

linen_model = LinenModel(dtype=dtype, param_dtype=param_dtype)
variables = linen_model.init(jax.random.key(1), x)

nnx_model = NNXModel(dtype=dtype, param_dtype=param_dtype, rngs=rngs)
nnx_model.dense.kernel.value = variables['params']['dense']['kernel']
nnx_model.dense.bias.value = variables['params']['dense']['bias']

linen_out = linen_model.apply(variables, x)

nnx_out = nnx_model(x)
np.testing.assert_array_equal(linen_out, nnx_out)

if __name__ == '__main__':
absltest.main()
absltest.main()