diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 443fc752..8b747af4 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -891,7 +891,7 @@ class WeightNorm(nnx.Module): >>> w = model.normed_linear.layer_instance.kernel.value >>> col_norms = np.linalg.norm(np.array(w), axis=0) - >>> np.testing.assert_allclose(col_norms, np.ones(4), rtol=1e-5, atol=1e-5) + >>> np.testing.assert_allclose(col_norms, np.ones(4)) Args: layer_instance: The layer instance to wrap.