Skip to content

Commit

Permalink
Merge pull request #4286 from 8bitmp3:update-nnx-jax-transforms
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705649786
  • Loading branch information
Flax Authors committed Dec 12, 2024
2 parents fcc2e0e + c8bfd2a commit 207966e
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions docs_nnx/guides/jax_and_nnx_transforms.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
Flax NNX vs JAX Transformations
==========================
Flax NNX vs JAX transformations
===============================

.. attention::
This page relates to the new Flax NNX API.

In this guide, you will learn the differences using Flax NNX and JAX transformations, and how to
seamlessly switch between them or use them together. We will be focusing on the ``jit`` and
``grad`` function transformations in this guide.
This guide describes the differences between
`Flax NNX transformations <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__
and `JAX transformations <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__,
and how to seamlessly switch between them or use them side-by-side. The examples here will focus on
``nnx.jit``, ``jax.jit``, ``nnx.grad`` and ``jax.grad`` function transformations (transforms).

First, let's set up imports and generate some dummy data:

Expand All @@ -18,27 +17,34 @@ First, let's set up imports and generate some dummy data:
x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))

Differences between NNX and JAX transformations
***********************************************
Differences
***********

Flax NNX transformations can transform functions that are not pure and make mutations and
side-effects:
- Flax NNX transforms enable you to transform functions that take in Flax NNX graph objects as
arguments - such as ``nnx.Module``, ``nnx.Rngs``, ``nnx.Optimizer``, and so on - even those whose state
will be mutated.
- In comparison, these kinds of objects aren't recognized in JAX transformations.

The Flax NNX `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`_
provides a way to convert graph structures to `pytrees <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__
and back. By doing this at every function boundary you can effectively use graph structures with any
JAX transforms and propagate state updates in a way consistent with functional purity.

The primary difference between Flax NNX and JAX transformations is that Flax NNX transformations allow you to
transform functions that take in Flax NNX graph objects as arguments (`Module`, `Rngs`, `Optimizer`, etc),
even those whose state will be mutated, whereas they aren't recognized in JAX transformations.
Therefore Flax NNX transformations can transform functions that are not pure and make mutations and
side-effects.
Flax NNX custom transforms, such as ``nnx.jit`` and ``nnx.grad``, simply remove the boilerplate, and
as a result the code looks stateful.

Flax NNX's `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`_
provides a way to convert graph structures to pytrees and back. By doing this at every function
boundary you can effectively use graph structures with any JAX transform and propagate state updates
in a way consistent with functional purity. Flax NNX custom transforms such as ``nnx.jit`` and ``nnx.grad``
simply remove the boilerplate, as a result the code looks stateful.
Below is an example of using the ``nnx.jit`` and ``nnx.grad`` transforms compared to the
the code that uses ``jax.jit`` and ``jax.grad`` transforms.

Below is an example of using the ``nnx.jit`` and ``nnx.grad`` transformations compared to using the
``jax.jit`` and ``jax.grad`` transformations. Notice the function signature of Flax NNX-transformed
functions can accept the ``nnx.Linear`` module directly and can make stateful updates to the module,
whereas the function signature of JAX-transformed functions can only accept the pytree-registered
``State`` and ``GraphDef`` objects and must return an updated copy of them to maintain the purity of
the transformed function.
Notice that:

- The function signature of Flax NNX-transformed functions can accept the ``nnx.Linear``
``nnx.Module`` instances directly and make stateful updates to the ``Module``.
- The function signature of JAX-transformed functions can only accept the pytree-registered
``nnx.State`` and ``nnx.GraphDef`` objects, and must return an updated copy of them to maintain the
purity of the transformed function.

.. codediff::
:title: Flax NNX transforms, JAX transforms
Expand Down Expand Up @@ -79,11 +85,11 @@ the transformed function.
graphdef, state = train_step(graphdef, state, x, y) #!


Mixing Flax NNX and JAX transformations
Mixing Flax NNX and JAX transforms
**********************************

Flax NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is
pure and has valid argument types that are recognized by JAX.
Both Flax NNX transforms and JAX transforms can be mixed together, so long as the JAX-transformed function
in your code is pure and has valid argument types that are recognized by JAX.

.. codediff::
:title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad``
Expand Down Expand Up @@ -121,5 +127,3 @@ pure and has valid argument types that are recognized by JAX.

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)


0 comments on commit 207966e

Please sign in to comment.