Skip to content

Commit

Permalink
[nnx] add flaxlib
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 3, 2025
1 parent cf6db71 commit 78be9e9
Show file tree
Hide file tree
Showing 22 changed files with 1,340 additions and 463 deletions.
4 changes: 3 additions & 1 deletion benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ def main(argv):
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
pass

cached_step_nnx = nnx.cache_args(step_nnx, model, optimizer)

t0 = time()
for _ in range(total_steps):
step_nnx(model, optimizer)
cached_step_nnx()

total_time = time() - t0
time_per_step = total_time / total_steps
Expand Down
235 changes: 235 additions & 0 deletions benchmarks/nnx_mlpmixer_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# %%
from functools import partial
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np
from einop import einop
from time import time
from tqdm import tqdm

from flax import nnx

from absl import flags
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 4, 'Depth of the model')


class MlpBlock(nnx.Module):
def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs):
self.din, self.mlp_dim = din, mlp_dim
self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs)
self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs)

def __call__(self, x):
return self.linear_out(nnx.gelu(self.linear_in(x)))


class MixerBlock(nnx.Module):
def __init__(
self,
tokens_mlp_dim: int,
channels_mlp_dim: int,
hidden_dim: int,
rngs: nnx.Rngs,
):
self.tokens_mlp_dim = tokens_mlp_dim
self.channels_mlp_dim = channels_mlp_dim
self.hidden_dim = hidden_dim
self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs)
self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs)
self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)

def __call__(self, x):
y = self.ln1(x)
y = y.swapaxes(1, 2)
y = self.token_mixing(y)
y = y.swapaxes(1, 2)
x = x + y
y = self.ln2(x)
return x + self.channel_mixing(y)


class MlpMixer(nnx.Module):
def __init__(
self,
din: int,
kernel_size: tuple[int, int],
strides: tuple[int, int],
num_blocks: int,
hidden_dim: int,
tokens_mlp_dim: int,
channels_mlp_dim: int,
rngs: nnx.Rngs,
):
self.din = din
self.kernel_size = kernel_size
self.num_blocks = num_blocks
self.hidden_dim = hidden_dim
self.tokens_mlp_dim = tokens_mlp_dim
self.channels_mlp_dim = channels_mlp_dim
self.stem = nnx.Conv(
din + 1,
channels_mlp_dim,
kernel_size=kernel_size,
strides=strides,
rngs=rngs,
)
self.blocks = [
MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs)
for _ in range(num_blocks)
]
self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
self.conv_t = nnx.ConvTranspose(
channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs
)

def __call__(self, *, x, t):
# add time feature to input
t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1)
x = jnp.concatenate([x, t], axis=-1)
# create patches
x = self.stem(x)
h, w = x.shape[1], x.shape[2]
x = einop(x, 'n h w c -> n (h w) c')
# apply blocks
for block in self.blocks:
x = block(x)
x = self.pre_head_layer_norm(x)
# recreate image
x = einop(x, 'n (h w) c -> n h w c', h=h, w=w)
x = self.conv_t(x)
return x


def main(argv):
print(argv)
mode: str = FLAGS.mode
total_steps: int = FLAGS.total_steps
batch_size: int = FLAGS.batch_size
width: int = FLAGS.width
depth: int = FLAGS.depth

print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')

X = np.random.uniform(size=(batch_size, 28, 28, 1))

if mode == 'nnx' or mode == 'all':
rngs = nnx.Rngs(0)
flow = MlpMixer(
din=1,
kernel_size=(2, 2),
strides=(2, 2),
num_blocks=4,
hidden_dim=512,
tokens_mlp_dim=196,
channels_mlp_dim=512,
rngs=rngs,
)
optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
t0 = time()

mse = lambda a, b: jnp.mean((a - b) ** 2)

@nnx.jit(donate_argnums=(0, 1, 2))
def train_step_nnx(flow, optimizer, rngs, x_1):
print('JITTING NNX')
x_0 = jax.random.normal(rngs(), x_1.shape)
t = jax.random.uniform(rngs(), (len(x_1),))

x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
dx_t = x_1 - x_0

loss, grads = nnx.value_and_grad(
lambda flow: mse(flow(x=x_t, t=t), dx_t)
)(flow)
optimizer.update(grads)
return loss

losses = []
t0 = time()
for step in tqdm(range(total_steps), desc='NNX'):
loss = train_step_nnx(flow, optimizer, rngs, X)
losses.append(loss)

total_time = time() - t0
print('### NNX ###')
print(f'final loss: {losses[-1]}')
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')

if mode == 'jax' or mode == 'all':
rngs = nnx.Rngs(0)
flow = MlpMixer(
din=1,
kernel_size=(2, 2),
strides=(2, 2),
num_blocks=depth,
hidden_dim=width,
tokens_mlp_dim=196,
channels_mlp_dim=width,
rngs=rngs,
)
optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
graphdef, state = nnx.split((flow, optimizer, rngs))
t0 = time()

mse = lambda a, b: jnp.mean((a - b) ** 2)

@partial(nnx.jit, donate_argnums=0)
def train_step_jax(state, x_1):
print('JITTING JAX')
flow, optimizer, rngs = nnx.merge(graphdef, state)
x_0 = jax.random.normal(rngs(), x_1.shape)
t = jax.random.uniform(rngs(), (len(x_1),))

x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
dx_t = x_1 - x_0

loss, grads = nnx.value_and_grad(
lambda flow: mse(flow(x=x_t, t=t), dx_t)
)(flow)
optimizer.update(grads)
state = nnx.state((flow, optimizer, rngs))
return loss, state

losses = []
t0 = time()
for step in tqdm(range(total_steps), desc='JAX'):
loss, state = train_step_jax(state, X)
losses.append(loss)

nnx.update((flow, optimizer, rngs), state)
total_time = time() - t0
print('### JAX ###')
print(f'final loss: {losses[-1]}')
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')


if __name__ == '__main__':
app.run(main)
24 changes: 14 additions & 10 deletions benchmarks/nnx_simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# %%
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -97,7 +98,7 @@ def main(argv):
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
@nnx.jit(donate_argnums=(0, 1))
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch

Expand All @@ -108,18 +109,21 @@ def loss_fn(model: MLP):
grads: nnx.State = nnx.grad(loss_fn)(model)
optimizer.update(grads)

@nnx.jit
@nnx.jit(donate_argnums=0)
def test_step_nnx(model: MLP, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

cached_train_step_nnx = nnx.cache_args(train_step_nnx, model, optimizer)
cached_test_step_nnx = nnx.cache_args(test_step_nnx, model)

for step, batch in enumerate(dataset(X, Y, batch_size)):
train_step_nnx(model, optimizer, batch)
cached_train_step_nnx(batch)

if step % 1000 == 0:
logs = test_step_nnx(model, (X, Y))
logs = cached_test_step_nnx((X, Y))

if step >= total_steps - 1:
break
Expand All @@ -137,8 +141,8 @@ def test_step_nnx(model: MLP, batch):
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def train_step_jax(graphdef, state, batch):
@partial(jax.jit, donate_argnums=0)
def train_step_jax(state, batch):
model, optimizer = nnx.merge(graphdef, state)
x, y = batch

Expand All @@ -151,8 +155,8 @@ def loss_fn(model: MLP):

return nnx.state((model, optimizer))

@jax.jit
def test_step_jax(graphdef, state, batch):
@partial(jax.jit, donate_argnums=0)
def test_step_jax(state, batch):
model, optimizer = nnx.merge(graphdef, state)
x, y = batch
y_pred = model(x)
Expand All @@ -163,10 +167,10 @@ def test_step_jax(graphdef, state, batch):
graphdef, state = nnx.split((model, optimizer))

for step, batch in enumerate(dataset(X, Y, batch_size)):
state = train_step_jax(graphdef, state, batch)
state = train_step_jax(state, batch)

if step % 1000 == 0:
state, logs = test_step_jax(graphdef, state, (X, Y))
state, logs = test_step_jax(state, (X, Y))

if step >= total_steps - 1:
break
Expand Down
6 changes: 4 additions & 2 deletions examples/nnx_toy_examples/02_lifted_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def test_step(model: MLP, batch):
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

cached_train_step = nnx.cache_args(train_step, model, optimizer)
cached_test_step = nnx.cache_args(test_step, model)

total_steps = 10_000
for step, batch in enumerate(dataset(32)):
train_step(model, optimizer, batch)
cached_train_step(batch)

if step % 1000 == 0:
logs = test_step(model, (X, Y))
logs = cached_test_step((X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
Expand Down
11 changes: 11 additions & 0 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


class Config:
flax_use_flaxlib: bool
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True

Expand Down Expand Up @@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /):
raise LookupError(f'Unrecognized config option: {name}')
self._values[name] = value

def __repr__(self):
values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items())
return f'Config({values_repr}\n)'


config = Config()

Expand Down Expand Up @@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool):
' PRNG keys.'
),
)

flax_use_flaxlib = bool_flag(
name='flax_use_flaxlib',
default=False,
help='Whether to use flaxlib for C++ acceleration.',
)
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .graph import MergeContext as MergeContext
from .graph import merge_context as merge_context
from .graph import variables as variables
from .graph import cache_args as cache_args
from .nn import initializers as initializers
from .nn.activations import celu as celu
from .nn.activations import elu as elu
Expand Down
Loading

0 comments on commit 78be9e9

Please sign in to comment.