Skip to content

Commit

Permalink
Merge pull request #3876 from google:nnx-v0.1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628712571
  • Loading branch information
Flax Authors committed Apr 27, 2024
2 parents ea3bcab + 08b1336 commit 2c7d7cd
Show file tree
Hide file tree
Showing 71 changed files with 3,718 additions and 4,645 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.9"
python: "3.10"

# Build documentation in the docs/ directory with Sphinx
sphinx:
Expand Down
1 change: 1 addition & 0 deletions docs/api_reference/flax.experimental.nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/exper
transforms
variables
helpers
visualization

3 changes: 2 additions & 1 deletion docs/api_reference/flax.experimental.nnx/module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ module
.. autoclass:: Module
:members:

.. autofunction:: merge
.. automethod:: sow
.. automethod:: iter_modules
2 changes: 0 additions & 2 deletions docs/api_reference/flax.experimental.nnx/variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ variables
:members:
.. autoclass:: Param
:members:
.. autoclass:: Rng
:members:
.. autoclass:: Variable
:members:
.. autoclass:: VariableMetadata
Expand Down
7 changes: 7 additions & 0 deletions docs/api_reference/flax.experimental.nnx/visualization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
visualization
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx

.. autofunction:: display
55 changes: 33 additions & 22 deletions docs/experimental/nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,6 @@ Features
NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen
to provide a streamlined experience.


Installation
^^^^^^^^^^^^
NNX is under active development, we recommend using the latest version from Flax's GitHub repository:

.. code-block:: bash
pip install git+https://github.com/google/flax.git
Basic usage
^^^^^^^^^^^^

Expand All @@ -89,22 +79,43 @@ Basic usage
.. testcode::

from flax.experimental import nnx
import optax


class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs() # get a unique random key
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,))) # initialize parameters
self.din, self.dout = din, dout
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)

def __call__(self, x: jax.Array):
return x @ self.w.value + self.b.value
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing

rngs = nnx.Rngs(0) # explicit RNG handling
model = Linear(din=2, dout=3, rngs=rngs) # initialize the model
@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()

loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # inplace updates

return loss


Installation
^^^^^^^^^^^^
NNX is under active development, we recommend using the latest version from Flax's GitHub repository:

.. code-block:: bash
pip install git+https://github.com/google/flax.git
x = jnp.empty((1, 2)) # generate random data
y = model(x) # forward pass
----

Expand Down
380 changes: 225 additions & 155 deletions docs/experimental/nnx/mnist_tutorial.ipynb

Large diffs are not rendered by default.

136 changes: 50 additions & 86 deletions docs/experimental/nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ Since NNX is under active development, we recommend using the latest version fro
```{code-cell} ipython3
:tags: [skip-execution]
# TODO: Fix text descriptions in this tutorial
!pip install git+https://github.com/google/flax.git
# !pip install git+https://github.com/google/flax.git
```

## 2. Load the MNIST Dataset

We'll use TensorFlow Datasets (TFDS) for loading and preparing the MNIST dataset:
First, the MNIST dataset is loaded and prepared for training and testing using
Tensorflow Datasets. Image values are normalized, the data is shuffled and divided
into batches, and samples are prefetched to enhance performance.

```{code-cell} ipython3
import tensorflow_datasets as tfds # TFDS for MNIST
Expand Down Expand Up @@ -77,40 +78,28 @@ Create a convolutional neural network with NNX by subclassing `nnx.Module`.

```{code-cell} ipython3
from flax.experimental import nnx # NNX API
from functools import partial
class CNN(nnx.Module):
"""A simple CNN model."""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(
in_features=1, out_features=32, kernel_size=(3, 3), rngs=rngs
)
self.conv2 = nnx.Conv(
in_features=32, out_features=64, kernel_size=(3, 3), rngs=rngs
)
self.linear1 = nnx.Linear(in_features=3136, out_features=256, rngs=rngs)
self.linear2 = nnx.Linear(in_features=256, out_features=10, rngs=rngs)
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.conv1(x)
x = nnx.relu(x)
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = self.conv2(x)
x = nnx.relu(x)
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = self.linear1(x)
x = nnx.relu(x)
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
model = CNN(rngs=nnx.Rngs(0))
print(f'model = {model}'[:500] + '\n...\n') # print a part of the model
print(
f'{model.conv1.kernel.value.shape = }'
) # inspect the shape of the kernel of the first convolutional layer
nnx.display(model)
```

### Run model
Expand All @@ -123,84 +112,71 @@ Let's put our model to the test! We'll perform a forward pass with arbitrary da
import jax.numpy as jnp # JAX NumPy
y = model(jnp.ones((1, 28, 28, 1)))
y
nnx.display(y)
```

## 4. Create the `TrainState`

In Flax, a common practice is to use a dataclass to encapsulate the entire training state, which would allow you to simply pass only two arguments (the train state and batched data) to functions like `train_step`. The training state would typically contain an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/optimizer.html#flax.experimental.nnx.optimizer.Optimizer) (which contains the step number, model and optimizer state) and an `nnx.Module` (for easier access to the model from the top-level of the train state). The training state can also be easily extended to add training and test metrics, as you will see in this tutorial (see [`nnx.metrics`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/metrics.html#module-flax.experimental.nnx.metrics) for more detail on NNX's metric classes).

```{code-cell} ipython3
import dataclasses
@dataclasses.dataclass
class TrainState(nnx.GraphNode):
optimizer: nnx.Optimizer
model: CNN
metrics: nnx.MultiMetric
```
## 4. Create Optimizer and Metrics

We use `optax` to create an optimizer ([`adamw`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw)) and initialize the `nnx.Optimizer`. We use `nnx.MultiMetric` to keep track of both the accuracy and average loss for both training and test batches.
In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.

```{code-cell} ipython3
import optax
learning_rate = 0.005
momentum = 0.9
tx = optax.adamw(learning_rate, momentum)
state = TrainState(
optimizer=nnx.Optimizer(model=model, tx=tx),
model=model,
metrics=nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
),
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average('loss'),
)
nnx.display(optimizer)
```

## 5. Training step

We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing.

```{code-cell} ipython3
def loss_fn(model, batch):
def loss_fn(model: CNN, batch):
logits = model(batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
).mean()
return loss, logits
```

Next, we create the training step function. This function takes the `state` and a data `batch` and does the following:
Next, we create the training step function. This function takes the `model` and a data `batch` and does the following:

* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`.
* Updates the training loss using the loss and updates the training accuracy using the logits and batch labels
* Updates model parameters and optimizer state by applying the gradient pytree to the optimizer.
* Updates training accuracy using the loss, logits, and batch labels.
* Updates model parameters via the optimizer by applying the gradient updates.

```{code-cell} ipython3
@nnx.jit
def train_step(state: TrainState, batch):
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
"""Train for a single step."""
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.model, batch)
state.metrics.update(values=loss, logits=logits, labels=batch['label'])
state.optimizer.update(grads=grads)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
optimizer.update(grads)
```

The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with
[XLA](https://www.tensorflow.org/xla), optimizing performance on
hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),
except it can decorate functions that make stateful updates to NNX classes.
except it can transforms functions that contain NNX objects as inputs and outputs.

## 6. Metric Computation
## 6. Evaluation step

Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier.

```{code-cell} ipython3
@nnx.jit
def compute_test_metrics(*, state: TrainState, batch):
loss, logits = loss_fn(state.model, batch)
state.metrics.update(values=loss, logits=logits, labels=batch['label'])
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
loss, logits = loss_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
```

## 7. Seed randomness
Expand All @@ -213,20 +189,9 @@ tf.random.set_seed(0)

## 8. Train and Evaluate

**Dataset Preparation:** create a "shuffled" dataset
- Repeat the dataset for the desired number of training epochs.
- Establish a 1024-sample buffer (holding the dataset's initial 1024 samples).
Randomly draw batches from this buffer.
- As samples are drawn, replenish the buffer with subsequent dataset samples.

**Training Loop:** Iterate through epochs
- Sample batches randomly from the dataset.
- Execute an optimization step for each training batch.
- Calculate mean training metrics across batches within the epoch.
- With updated parameters, compute metrics on the test set.
- Log train and test metrics for visualization.

After 10 training and testing epochs, your model should reach approximately 99% accuracy.
Now we train a model using batches of data for 10 epochs, evaluate its performance
on the test set after each epoch, and log the training and testing metrics (loss and
accuracy) throughout the process. Typically this leads to a model with around 99% accuracy.

```{code-cell} ipython3
:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87
Expand All @@ -245,22 +210,22 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
# - the train state's model parameters
# - the optimizer state
# - the training loss and accuracy batch metrics
train_step(state, batch)
train_step(model, optimizer, metrics, batch)
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
# Log training metrics
for metric, value in state.metrics.compute().items(): # compute metrics
for metric, value in metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state.metrics.reset() # reset metrics for test set
metrics.reset() # reset metrics for test set
# Compute metrics on the test set after each training epoch
for test_batch in test_ds.as_numpy_iterator():
compute_test_metrics(state=state, batch=test_batch)
eval_step(model, metrics, test_batch)
# Log test metrics
for metric, value in state.metrics.compute().items():
for metric, value in metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
state.metrics.reset() # reset metrics for next training epoch
metrics.reset() # reset metrics for next training epoch
print(
f"train epoch: {(step+1) // num_steps_per_epoch}, "
Expand Down Expand Up @@ -293,7 +258,6 @@ for dataset in ('train', 'test'):
ax1.legend()
ax2.legend()
plt.show()
plt.clf()
```

## 10. Perform inference on test set
Expand All @@ -302,16 +266,16 @@ Define a jitted inference function, `pred_step`, to generate predictions on the

```{code-cell} ipython3
@nnx.jit
def pred_step(state: TrainState, batch):
logits = state.model(batch['image'])
def pred_step(model: CNN, batch):
logits = model(batch['image'])
return logits.argmax(axis=1)
```

```{code-cell} ipython3
:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e
test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(state, test_batch)
pred = pred_step(model, test_batch)
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
Expand Down
Loading

0 comments on commit 2c7d7cd

Please sign in to comment.