From 10d8e5ca8a7a6849e02e5f40bd4630fd895d9194 Mon Sep 17 00:00:00 2001
From: Cristian Garcia <cgarcia.e88@gmail.com>
Date: Tue, 18 Feb 2025 16:55:58 -0500
Subject: [PATCH] [nnx] add shard_map

---
 .../api_reference/flax.nnx/transforms.rst     |   1 +
 flax/nnx/__init__.py                          |   1 +
 flax/nnx/transforms/compilation.py            | 297 +++++++++++++++++-
 tests/nnx/bridge/wrappers_test.py             |  10 +
 tests/nnx/transforms_test.py                  | 123 ++++++++
 uv.lock                                       |   6 +-
 6 files changed, 433 insertions(+), 5 deletions(-)

diff --git a/docs_nnx/api_reference/flax.nnx/transforms.rst b/docs_nnx/api_reference/flax.nnx/transforms.rst
index 54ba3399a..5b4440ed3 100644
--- a/docs_nnx/api_reference/flax.nnx/transforms.rst
+++ b/docs_nnx/api_reference/flax.nnx/transforms.rst
@@ -15,6 +15,7 @@ transforms
 
 .. autofunction:: grad
 .. autofunction:: jit
+.. autofunction:: shard_map
 .. autofunction:: remat
 .. autofunction:: scan
 .. autofunction:: value_and_grad
diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py
index f059358ba..20bd940e4 100644
--- a/flax/nnx/__init__.py
+++ b/flax/nnx/__init__.py
@@ -139,6 +139,7 @@
 from .transforms.autodiff import custom_vjp as custom_vjp
 from .transforms.autodiff import remat as remat
 from .transforms.compilation import jit as jit
+from .transforms.compilation import shard_map as shard_map
 from .transforms.compilation import StateSharding as StateSharding
 from .transforms.iteration import Carry as Carry
 from .transforms.iteration import scan as scan
diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py
index 99d757ae2..c064427cc 100644
--- a/flax/nnx/transforms/compilation.py
+++ b/flax/nnx/transforms/compilation.py
@@ -17,6 +17,10 @@
 import functools
 import typing as tp
 
+import jax.experimental
+import jax.experimental.shard_map
+from jax.sharding import PartitionSpec
+
 from flax.nnx import (
   extract,
   filterlib,
@@ -27,11 +31,13 @@
 import jax
 import jax.core
 import jax.stages
+from jax._src.mesh import Mesh, AbstractMesh
 
 from flax.typing import Missing
 
 F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
-
+Specs = tp.Any
+AxisName = tp.Hashable
 
 # -------------------------------
 # jit
@@ -341,7 +347,6 @@ def jit_wrapper(*args, **kwargs):
         check_aliasing=in_shardings is not None or kwarg_shardings is not None,
         ctxtag=jit_wrapper,
       )
-      jax_in_shardings, kwarg_shardings, jax_out_shardings
       pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
         *pure_args, **pure_kwargs
       )
@@ -371,3 +376,291 @@ def jit_wrapper(*args, **kwargs):
   jit_wrapper.inner = jitted_fn  # type: ignore
 
   return jit_wrapper  # type: ignore
+
+# -------------------------------
+# shard_map
+# -------------------------------
+
+# TODO: create StateSpec and consider enabling a mode that does
+# not use filters during split for performance. Overall there might
+# be performance limitations for using shard_map at a top-level
+
+@dataclasses.dataclass(eq=False)
+class ShardMapFn:
+  f: tp.Callable[..., tp.Any]
+  in_specs: tp.Any
+  out_specs: tp.Any
+  kwarg_specs: tp.Any
+  ctxtag: tp.Hashable
+
+  def __post_init__(self):
+    functools.update_wrapper(self, self.f)
+
+  def __call__(self, *pure_args, **pure_kwargs):
+    args, kwargs = extract.from_tree(
+      (pure_args, pure_kwargs),
+      merge_fn=_jit_merge_fn,
+      ctxtag=self.ctxtag,
+      is_inner=True,
+    )
+
+    out = self.f(*args, **kwargs)
+
+    args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
+    pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
+      (args_out, kwargs_out, out),
+      prefix=(self.in_specs, self.kwarg_specs, self.out_specs),
+      ctxtag=self.ctxtag,
+      split_fn=_jit_split_fn,
+    )
+
+    return pure_args_out, pure_kwargs_out, pure_out
+
+
+@tp.overload
+def shard_map(
+  f: F,
+  *,
+  mesh: Mesh | AbstractMesh,
+  in_specs: Specs,
+  out_specs: Specs,
+  check_rep: bool = True,
+  auto: frozenset[AxisName] = frozenset(),
+) -> F: ...
+@tp.overload
+def shard_map(
+  *,
+  mesh: Mesh | AbstractMesh,
+  in_specs: Specs,
+  out_specs: Specs,
+  check_rep: bool = True,
+  auto: frozenset[AxisName] = frozenset(),
+) -> tp.Callable[[F], F]: ...
+def shard_map(
+  f: F | type[Missing] = Missing,
+  *,
+  mesh: Mesh | AbstractMesh,
+  in_specs: Specs,
+  out_specs: Specs,
+  check_rep: bool = True,
+  auto: frozenset[AxisName] = frozenset(),
+) -> F | tp.Callable[[F], F]:
+  """
+  Lifted version of
+  `jax.experimental.shard_map.shard_map <https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html>`_
+  that can handle Modules / graph nodes as arguments.
+
+  Simple data parallel example::
+
+    import jax
+    import jax.numpy as jnp
+    from flax import nnx
+    from jax.sharding import PartitionSpec as P
+
+    mesh = jax.sharding.Mesh(jax.local_devices(), ('data',))
+
+    m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
+    x = jnp.ones((32, 2))
+
+    @nnx.shard_map(
+      mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data')
+    )
+    def f(m, x):
+      return m(x)
+
+    y = f(m, x)
+
+    jax.debug.visualize_array_sharding(y)
+
+  Notice that here we simply used some ``PartitionSpec`` to define the spec
+  the the whole model and data. This works for simple cases but if we need
+  to assign different ``PartitionSpec`` to different parts of the model we
+  need to use ``StateSharding`` and create some filters that allow us to target
+  specific parts of the model. Here's an example of how to do tensor parallelism
+  for a simple MLP block using ``StateSharding`` and filters::
+
+    mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
+
+    class MLP(nnx.Module):
+      def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
+        self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
+        self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
+
+      def __call__(self, x):
+        return self.linear2(jax.nn.relu(self.linear1(x)))
+
+    m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
+    x = jnp.ones((32, 2))
+
+    def path_ends_with(*path_suffix): # custom filter
+      return lambda path, value: path[-len(path_suffix):] == path_suffix
+
+    model_spec = nnx.StateSharding({
+      path_ends_with('linear1', 'kernel'): P(None, 'model'),
+      path_ends_with('linear2', 'kernel'): P('model', None),
+    })
+
+    @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None))
+    def f(m, x):
+      y = m(x)
+      return jax.lax.psum(y, 'model')
+
+    y = f(m, x)
+
+    jax.debug.visualize_array_sharding(m.linear1.kernel.value)
+    jax.debug.visualize_array_sharding(m.linear2.kernel.value)
+
+
+  Alternatively, a ``State`` object with the exact PartitionSpec for each
+  state then you can be passed to ``StateSharding``::
+
+    mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
+
+    class MLP(nnx.Module):
+      def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
+        self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
+        self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
+
+      def __call__(self, x):
+        return self.linear2(jax.nn.relu(self.linear1(x)))
+
+    m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
+    x = jnp.ones((32, 2))
+
+    model_spec = nnx.State(
+      {
+        'linear1': {'kernel': P(None, 'model')},
+        'linear2': {'kernel': P('model', None)},
+      }
+    )
+
+    @nnx.shard_map(
+      mesh=mesh,
+      in_specs=(nnx.StateSharding(model_spec), P(None)),
+      out_specs=P(None),
+    )
+    def f(m, x):
+      y = m(x)
+      return jax.lax.psum(y, 'model')
+
+    y = f(m, x)
+
+    jax.debug.visualize_array_sharding(m.linear1.kernel.value)
+    jax.debug.visualize_array_sharding(m.linear2.kernel.value)
+
+  Here ``model_spec`` was created manually but you can also automate
+  this process by using ``nnx.get_partition_spec`` to automatically
+  create it for you (see
+  `Scale up on multiple devices <https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html>`_
+  ).
+
+  Args:
+    f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
+      takes as input a shard of the mapped-over arguments and produces a shard
+      of the output.
+    mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
+      to shard the data and on which to execute instances of ``f``. The names of
+      the ``Mesh`` can be used in collective communication operations in ``f``.
+      This is typically created by a utility function like
+      :func:`jax.experimental.mesh_utils.create_device_mesh`.
+    in_specs: a pytree with ``jax.sharding.PartitionSpec``or ``nnx.StateSharding``
+      (mapping substates to ``PartitionSpec``s) instances as leaves,
+      with a tree structure that is a tree prefix of the
+      args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``,
+      each ``PartitionSpec`` represents how the corresponding argument (or subtree
+      of arguments) should be sharded along the named axes of ``mesh``. In each
+      ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding
+      the corresponding argument array axis along that positional axis; not
+      mentioning an axis name expresses replication. If an argument, or argument
+      subtree, has a corresponding spec of None, that argument is not sharded.
+    out_specs: a pytree with ``jax.sharding.PartitionSpec`` or ``nnx.StateSharding``
+      (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure
+      that is a tree prefix of the output of ``f``.
+      Each ``PartitionSpec`` represents how the corresponding output shards should be
+      concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at
+      a position expresses concatenation of that mesh axis's shards along the
+      corresponding positional axis. Not mentioning a ``mesh`` axis name
+      expresses a promise that the output values are equal along that mesh axis,
+      and that rather than concatenating only a single value should be produced.
+    check_rep: If True (default) enable additional validity checks and automatic
+      differentiation optimizations. The validity checks concern whether any mesh
+      axis names not mentioned in ``out_specs`` are consistent with how the outputs
+      of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
+    auto: (experimental) an optional set of axis names from ``mesh`` over which we
+      do not shard the data or map the function, but rather we allow the
+      compiler to control sharding. These names cannot be used in ``in_specs``,
+      ``out_specs``, or in communication collectives in ``f``.
+
+  Returns:
+    A callable that applies the input function ``f`` across data sharded according to
+    the ``mesh`` and ``in_specs``.
+  """
+  if f is Missing:
+    return functools.partial(
+      shard_map,
+      mesh=mesh,
+      in_specs=in_specs,
+      out_specs=out_specs,
+      check_rep=check_rep,
+      auto=auto,
+    )  # type: ignore[return-value]
+  assert not isinstance(f, type)
+
+  kwarg_specs = PartitionSpec()
+  jax_in_specs = jax.tree.map(
+    lambda x: extract.NodeStates(
+      _graphdef=PartitionSpec(),  # type: ignore[arg-type]
+      states=x.shardings,
+      metadata=x,
+    )
+    if isinstance(x, StateSharding)
+    else x,
+    in_specs,
+  )
+  jax_out_specs = jax.tree.map(
+    lambda x: extract.NodeStates(
+      _graphdef=PartitionSpec(),  # type: ignore[arg-type]
+      states=x.shardings,
+      metadata=x,
+    )
+    if isinstance(x, StateSharding)
+    else x,
+    out_specs,
+  )
+
+  @functools.wraps(f)
+  def shard_map_wrapper(*args, **kwargs):
+    # run dynamic_cache_context before update_context
+    with graph.update_context(shard_map_wrapper):
+      pure_args, pure_kwargs = extract.to_tree(
+        (args, kwargs),
+        prefix=(in_specs, kwarg_specs)
+        if in_specs is not None or kwarg_specs is not None
+        else None,
+        split_fn=_jit_split_fn,
+        check_aliasing=in_specs is not None or kwarg_specs is not None,
+        ctxtag=shard_map_wrapper,
+      )
+      pure_args_out, pure_kwargs_out, pure_out = shard_map_fn(
+        *pure_args, **pure_kwargs
+      )
+      _args_out, _kwargs_out, out = extract.from_tree(
+        (pure_args_out, pure_kwargs_out, pure_out),
+        merge_fn=_jit_merge_fn,
+        is_inner=False,
+        ctxtag=shard_map_wrapper,
+      )
+    return out
+
+  shard_map_fn = jax.experimental.shard_map.shard_map(
+    ShardMapFn(f, in_specs, out_specs, kwarg_specs, shard_map_wrapper),
+    mesh=mesh,
+    in_specs=jax_in_specs,
+    out_specs=(jax_in_specs, kwarg_specs, jax_out_specs),  # type: ignore
+    check_rep=check_rep,
+    auto=auto,
+  )
+
+  shard_map_wrapper.inner = shard_map_fn  # type: ignore
+
+  return shard_map_wrapper  # type: ignore
\ No newline at end of file
diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py
index e1ef5a775..591442dd3 100644
--- a/tests/nnx/bridge/wrappers_test.py
+++ b/tests/nnx/bridge/wrappers_test.py
@@ -500,6 +500,7 @@ class Foo(bridge.Module):
     nnx.update(foo, state)
 
   def test_compact_basic(self):
+    test = self
     class Linear(bridge.Module):
       dout: int
 
@@ -519,11 +520,20 @@ def __call__(self, x):
         din = x.shape[-1]
         self.linear = Linear(self.dout)
         x = self.linear(x)
+
+        # NNX
+        graphdef, state = nnx.split(self)
+        test.assertIn('Linear_0', state)
+        test.assertIn('w', state['Linear_0'])
+        test.assertIn('b', state['Linear_0'])
+
         return x
 
     foo = Foo(5)
     x = jnp.ones((3, 2))
 
+    self.assertIsInstance(foo, nnx.Module)
+
     variables = foo.init(0, x)
     params = variables['params']
 
diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py
index 41208e882..d166e986d 100644
--- a/tests/nnx/transforms_test.py
+++ b/tests/nnx/transforms_test.py
@@ -416,6 +416,129 @@ def f(cached_m: nnx.Linear, m: nnx.Linear):
     cached_m2 = cached_f(m)
     self.assertIs(cached_m, cached_m2)
 
+class TestShardMap(absltest.TestCase):
+  def test_basic_shardmap(self):
+    n_devices = jax.local_device_count()
+    devices = mesh_utils.create_device_mesh((n_devices,))
+    mesh = jax.sharding.Mesh(devices, ('a',))
+    PS = jax.sharding.PartitionSpec
+
+    state_sharding = nnx.StateSharding(
+      {
+        nnx.PathContains('kernel'): PS(None, 'a'),
+        nnx.PathContains('bias'): PS(),
+      }
+    )
+
+    m = nnx.Linear(16, 32, rngs=nnx.Rngs(0))
+
+    self.assertNotIsInstance(
+      m.kernel.value.sharding, jax.sharding.NamedSharding
+    )
+
+    @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None)
+    def f(m: nnx.Linear):
+      self.assertEqual(
+        m.kernel.value.shape, (m.in_features, m.out_features // n_devices)
+      )
+      self.assertEqual(m.bias.shape, (m.out_features,))
+
+    f(m)
+
+    self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding)
+
+  def test_from_state(self):
+    n_devices = jax.local_device_count()
+    devices = mesh_utils.create_device_mesh((n_devices,))
+    mesh = jax.sharding.Mesh(devices, ('a',))
+    PS = jax.sharding.PartitionSpec
+
+    state_spec = nnx.State(
+      {
+        'kernel': PS(None, 'a'),
+        'bias': PS(),
+      }
+    )
+    state_sharding = nnx.StateSharding(state_spec)
+
+    m = nnx.Linear(16, 32, rngs=nnx.Rngs(0))
+
+    self.assertNotIsInstance(
+      m.kernel.value.sharding, jax.sharding.NamedSharding
+    )
+
+    @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None)
+    def f(m: nnx.Linear):
+      self.assertEqual(
+        m.kernel.value.shape, (m.in_features, m.out_features // n_devices)
+      )
+      self.assertEqual(m.bias.shape, (m.out_features,))
+
+    f(m)
+
+    self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding)
+    self.assertIsInstance(m.bias.value.sharding, jax.sharding.NamedSharding)
+
+  def test_simple_data_parallel(self):
+    P = jax.sharding.PartitionSpec
+    n_devices = jax.local_device_count()
+
+    mesh = jax.sharding.Mesh(jax.local_devices(), ('data',))
+
+    m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
+    x = jnp.ones((32, 2))
+
+    @nnx.shard_map(
+      mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data')
+    )
+    def f(m, x):
+      self.assertEqual(x.shape, (32 // n_devices, 2))
+      return m(x)
+
+    y = f(m, x)
+
+    self.assertEqual(y.shape, (32, 3))
+    self.assertIsInstance(y.sharding, jax.sharding.NamedSharding)
+    self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding)
+    self.assertIsInstance(m.bias.value.sharding, jax.sharding.NamedSharding)
+
+  def test_simple_tensor_parallel(self):
+    P = jax.sharding.PartitionSpec
+
+    mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
+
+    class MLP(nnx.Module):
+      def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
+        self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
+        self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
+
+      def __call__(self, x):
+        return self.linear2(jax.nn.relu(self.linear1(x)))
+
+    m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
+    x = jnp.ones((32, 2))
+
+    def path_ends_with(path_suffix):
+      return lambda path, value: path[-len(path_suffix) :] == path_suffix
+
+    model_sharding = nnx.StateSharding(
+      {
+        path_ends_with(('linear1', 'kernel')): P(None, 'model'),
+        path_ends_with(('linear2', 'kernel')): P('model', None),
+      }
+    )
+
+    @nnx.shard_map(
+      mesh=mesh, in_specs=(model_sharding, P(None)), out_specs=P(None)
+    )
+    def f(m, x):
+      y = m(x)
+      return jax.lax.psum(y, 'model')
+
+    y = f(m, x)
+
+    jax.debug.visualize_array_sharding(m.linear1.kernel.value)
+
 
 class TestGrad(parameterized.TestCase):
   def test_grad(self):
diff --git a/uv.lock b/uv.lock
index 21c497104..3aaa836df 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2258,7 +2258,7 @@ wheels = [
 
 [[package]]
 name = "orbax-checkpoint"
-version = "0.11.5"
+version = "0.11.6"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "absl-py" },
@@ -2276,9 +2276,9 @@ dependencies = [
     { name = "tensorstore" },
     { name = "typing-extensions" },
 ]
-sdist = { url = "https://files.pythonhosted.org/packages/7e/b4/262439a3fe00064b53a5182c0149447304f5a2d5a328a92d64390c18d189/orbax_checkpoint-0.11.5.tar.gz", hash = "sha256:8331ff594980a241ba43eb59dd683e5b590b339cff32a7b72d78cb5a350030b4", size = 249258 }
+sdist = { url = "https://files.pythonhosted.org/packages/a8/66/1499f770885b6b42ed4ac839ac4482e681b3a60f96dcd935a32221c165b3/orbax_checkpoint-0.11.6.tar.gz", hash = "sha256:e16a8bbabe7bc0c94f611d115b2b7790183e6847152804a261048160b81b9628", size = 258729 }
 wheels = [
-    { url = "https://files.pythonhosted.org/packages/66/80/e659696b5b1c2ced427efedd2d9d29c1bc31d841ac8a031215aa38f6b2ae/orbax_checkpoint-0.11.5-py3-none-any.whl", hash = "sha256:b55a7a254ea0ab18237e8234a6ca8bf5522f589fcc2ac698cf6893d5e7ae3500", size = 342800 },
+    { url = "https://files.pythonhosted.org/packages/82/f5/6707aabc6d4928bdac3eadee00972616f41a83362a5b0441d63a93d81b75/orbax_checkpoint-0.11.6-py3-none-any.whl", hash = "sha256:fb208012e5d3601ee37b1100fe4331f9982b814df89f572749be9094fa499e1f", size = 361764 },
 ]
 
 [[package]]