From 10d8e5ca8a7a6849e02e5f40bd4630fd895d9194 Mon Sep 17 00:00:00 2001 From: Cristian Garcia <cgarcia.e88@gmail.com> Date: Tue, 18 Feb 2025 16:55:58 -0500 Subject: [PATCH] [nnx] add shard_map --- .../api_reference/flax.nnx/transforms.rst | 1 + flax/nnx/__init__.py | 1 + flax/nnx/transforms/compilation.py | 297 +++++++++++++++++- tests/nnx/bridge/wrappers_test.py | 10 + tests/nnx/transforms_test.py | 123 ++++++++ uv.lock | 6 +- 6 files changed, 433 insertions(+), 5 deletions(-) diff --git a/docs_nnx/api_reference/flax.nnx/transforms.rst b/docs_nnx/api_reference/flax.nnx/transforms.rst index 54ba3399a..5b4440ed3 100644 --- a/docs_nnx/api_reference/flax.nnx/transforms.rst +++ b/docs_nnx/api_reference/flax.nnx/transforms.rst @@ -15,6 +15,7 @@ transforms .. autofunction:: grad .. autofunction:: jit +.. autofunction:: shard_map .. autofunction:: remat .. autofunction:: scan .. autofunction:: value_and_grad diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index f059358ba..20bd940e4 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -139,6 +139,7 @@ from .transforms.autodiff import custom_vjp as custom_vjp from .transforms.autodiff import remat as remat from .transforms.compilation import jit as jit +from .transforms.compilation import shard_map as shard_map from .transforms.compilation import StateSharding as StateSharding from .transforms.iteration import Carry as Carry from .transforms.iteration import scan as scan diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 99d757ae2..c064427cc 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -17,6 +17,10 @@ import functools import typing as tp +import jax.experimental +import jax.experimental.shard_map +from jax.sharding import PartitionSpec + from flax.nnx import ( extract, filterlib, @@ -27,11 +31,13 @@ import jax import jax.core import jax.stages +from jax._src.mesh import Mesh, AbstractMesh from flax.typing import Missing F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) - +Specs = tp.Any +AxisName = tp.Hashable # ------------------------------- # jit @@ -341,7 +347,6 @@ def jit_wrapper(*args, **kwargs): check_aliasing=in_shardings is not None or kwarg_shardings is not None, ctxtag=jit_wrapper, ) - jax_in_shardings, kwarg_shardings, jax_out_shardings pure_args_out, pure_kwargs_out, pure_out = jitted_fn( *pure_args, **pure_kwargs ) @@ -371,3 +376,291 @@ def jit_wrapper(*args, **kwargs): jit_wrapper.inner = jitted_fn # type: ignore return jit_wrapper # type: ignore + +# ------------------------------- +# shard_map +# ------------------------------- + +# TODO: create StateSpec and consider enabling a mode that does +# not use filters during split for performance. Overall there might +# be performance limitations for using shard_map at a top-level + +@dataclasses.dataclass(eq=False) +class ShardMapFn: + f: tp.Callable[..., tp.Any] + in_specs: tp.Any + out_specs: tp.Any + kwarg_specs: tp.Any + ctxtag: tp.Hashable + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args, **pure_kwargs): + args, kwargs = extract.from_tree( + (pure_args, pure_kwargs), + merge_fn=_jit_merge_fn, + ctxtag=self.ctxtag, + is_inner=True, + ) + + out = self.f(*args, **kwargs) + + args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) + pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( + (args_out, kwargs_out, out), + prefix=(self.in_specs, self.kwarg_specs, self.out_specs), + ctxtag=self.ctxtag, + split_fn=_jit_split_fn, + ) + + return pure_args_out, pure_kwargs_out, pure_out + + +@tp.overload +def shard_map( + f: F, + *, + mesh: Mesh | AbstractMesh, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset(), +) -> F: ... +@tp.overload +def shard_map( + *, + mesh: Mesh | AbstractMesh, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset(), +) -> tp.Callable[[F], F]: ... +def shard_map( + f: F | type[Missing] = Missing, + *, + mesh: Mesh | AbstractMesh, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset(), +) -> F | tp.Callable[[F], F]: + """ + Lifted version of + `jax.experimental.shard_map.shard_map <https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html>`_ + that can handle Modules / graph nodes as arguments. + + Simple data parallel example:: + + import jax + import jax.numpy as jnp + from flax import nnx + from jax.sharding import PartitionSpec as P + + mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) + + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((32, 2)) + + @nnx.shard_map( + mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data') + ) + def f(m, x): + return m(x) + + y = f(m, x) + + jax.debug.visualize_array_sharding(y) + + Notice that here we simply used some ``PartitionSpec`` to define the spec + the the whole model and data. This works for simple cases but if we need + to assign different ``PartitionSpec`` to different parts of the model we + need to use ``StateSharding`` and create some filters that allow us to target + specific parts of the model. Here's an example of how to do tensor parallelism + for a simple MLP block using ``StateSharding`` and filters:: + + mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) + + class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) + self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) + + def __call__(self, x): + return self.linear2(jax.nn.relu(self.linear1(x))) + + m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((32, 2)) + + def path_ends_with(*path_suffix): # custom filter + return lambda path, value: path[-len(path_suffix):] == path_suffix + + model_spec = nnx.StateSharding({ + path_ends_with('linear1', 'kernel'): P(None, 'model'), + path_ends_with('linear2', 'kernel'): P('model', None), + }) + + @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None)) + def f(m, x): + y = m(x) + return jax.lax.psum(y, 'model') + + y = f(m, x) + + jax.debug.visualize_array_sharding(m.linear1.kernel.value) + jax.debug.visualize_array_sharding(m.linear2.kernel.value) + + + Alternatively, a ``State`` object with the exact PartitionSpec for each + state then you can be passed to ``StateSharding``:: + + mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) + + class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) + self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) + + def __call__(self, x): + return self.linear2(jax.nn.relu(self.linear1(x))) + + m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((32, 2)) + + model_spec = nnx.State( + { + 'linear1': {'kernel': P(None, 'model')}, + 'linear2': {'kernel': P('model', None)}, + } + ) + + @nnx.shard_map( + mesh=mesh, + in_specs=(nnx.StateSharding(model_spec), P(None)), + out_specs=P(None), + ) + def f(m, x): + y = m(x) + return jax.lax.psum(y, 'model') + + y = f(m, x) + + jax.debug.visualize_array_sharding(m.linear1.kernel.value) + jax.debug.visualize_array_sharding(m.linear2.kernel.value) + + Here ``model_spec`` was created manually but you can also automate + this process by using ``nnx.get_partition_spec`` to automatically + create it for you (see + `Scale up on multiple devices <https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html>`_ + ). + + Args: + f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: a ``jax.sharding.Mesh`` representing the array of devices over which + to shard the data and on which to execute instances of ``f``. The names of + the ``Mesh`` can be used in collective communication operations in ``f``. + This is typically created by a utility function like + :func:`jax.experimental.mesh_utils.create_device_mesh`. + in_specs: a pytree with ``jax.sharding.PartitionSpec``or ``nnx.StateSharding`` + (mapping substates to ``PartitionSpec``s) instances as leaves, + with a tree structure that is a tree prefix of the + args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``, + each ``PartitionSpec`` represents how the corresponding argument (or subtree + of arguments) should be sharded along the named axes of ``mesh``. In each + ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding + the corresponding argument array axis along that positional axis; not + mentioning an axis name expresses replication. If an argument, or argument + subtree, has a corresponding spec of None, that argument is not sharded. + out_specs: a pytree with ``jax.sharding.PartitionSpec`` or ``nnx.StateSharding`` + (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure + that is a tree prefix of the output of ``f``. + Each ``PartitionSpec`` represents how the corresponding output shards should be + concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at + a position expresses concatenation of that mesh axis's shards along the + corresponding positional axis. Not mentioning a ``mesh`` axis name + expresses a promise that the output values are equal along that mesh axis, + and that rather than concatenating only a single value should be produced. + check_rep: If True (default) enable additional validity checks and automatic + differentiation optimizations. The validity checks concern whether any mesh + axis names not mentioned in ``out_specs`` are consistent with how the outputs + of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``. + auto: (experimental) an optional set of axis names from ``mesh`` over which we + do not shard the data or map the function, but rather we allow the + compiler to control sharding. These names cannot be used in ``in_specs``, + ``out_specs``, or in communication collectives in ``f``. + + Returns: + A callable that applies the input function ``f`` across data sharded according to + the ``mesh`` and ``in_specs``. + """ + if f is Missing: + return functools.partial( + shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=check_rep, + auto=auto, + ) # type: ignore[return-value] + assert not isinstance(f, type) + + kwarg_specs = PartitionSpec() + jax_in_specs = jax.tree.map( + lambda x: extract.NodeStates( + _graphdef=PartitionSpec(), # type: ignore[arg-type] + states=x.shardings, + metadata=x, + ) + if isinstance(x, StateSharding) + else x, + in_specs, + ) + jax_out_specs = jax.tree.map( + lambda x: extract.NodeStates( + _graphdef=PartitionSpec(), # type: ignore[arg-type] + states=x.shardings, + metadata=x, + ) + if isinstance(x, StateSharding) + else x, + out_specs, + ) + + @functools.wraps(f) + def shard_map_wrapper(*args, **kwargs): + # run dynamic_cache_context before update_context + with graph.update_context(shard_map_wrapper): + pure_args, pure_kwargs = extract.to_tree( + (args, kwargs), + prefix=(in_specs, kwarg_specs) + if in_specs is not None or kwarg_specs is not None + else None, + split_fn=_jit_split_fn, + check_aliasing=in_specs is not None or kwarg_specs is not None, + ctxtag=shard_map_wrapper, + ) + pure_args_out, pure_kwargs_out, pure_out = shard_map_fn( + *pure_args, **pure_kwargs + ) + _args_out, _kwargs_out, out = extract.from_tree( + (pure_args_out, pure_kwargs_out, pure_out), + merge_fn=_jit_merge_fn, + is_inner=False, + ctxtag=shard_map_wrapper, + ) + return out + + shard_map_fn = jax.experimental.shard_map.shard_map( + ShardMapFn(f, in_specs, out_specs, kwarg_specs, shard_map_wrapper), + mesh=mesh, + in_specs=jax_in_specs, + out_specs=(jax_in_specs, kwarg_specs, jax_out_specs), # type: ignore + check_rep=check_rep, + auto=auto, + ) + + shard_map_wrapper.inner = shard_map_fn # type: ignore + + return shard_map_wrapper # type: ignore \ No newline at end of file diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index e1ef5a775..591442dd3 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -500,6 +500,7 @@ class Foo(bridge.Module): nnx.update(foo, state) def test_compact_basic(self): + test = self class Linear(bridge.Module): dout: int @@ -519,11 +520,20 @@ def __call__(self, x): din = x.shape[-1] self.linear = Linear(self.dout) x = self.linear(x) + + # NNX + graphdef, state = nnx.split(self) + test.assertIn('Linear_0', state) + test.assertIn('w', state['Linear_0']) + test.assertIn('b', state['Linear_0']) + return x foo = Foo(5) x = jnp.ones((3, 2)) + self.assertIsInstance(foo, nnx.Module) + variables = foo.init(0, x) params = variables['params'] diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 41208e882..d166e986d 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -416,6 +416,129 @@ def f(cached_m: nnx.Linear, m: nnx.Linear): cached_m2 = cached_f(m) self.assertIs(cached_m, cached_m2) +class TestShardMap(absltest.TestCase): + def test_basic_shardmap(self): + n_devices = jax.local_device_count() + devices = mesh_utils.create_device_mesh((n_devices,)) + mesh = jax.sharding.Mesh(devices, ('a',)) + PS = jax.sharding.PartitionSpec + + state_sharding = nnx.StateSharding( + { + nnx.PathContains('kernel'): PS(None, 'a'), + nnx.PathContains('bias'): PS(), + } + ) + + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + + self.assertNotIsInstance( + m.kernel.value.sharding, jax.sharding.NamedSharding + ) + + @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None) + def f(m: nnx.Linear): + self.assertEqual( + m.kernel.value.shape, (m.in_features, m.out_features // n_devices) + ) + self.assertEqual(m.bias.shape, (m.out_features,)) + + f(m) + + self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding) + + def test_from_state(self): + n_devices = jax.local_device_count() + devices = mesh_utils.create_device_mesh((n_devices,)) + mesh = jax.sharding.Mesh(devices, ('a',)) + PS = jax.sharding.PartitionSpec + + state_spec = nnx.State( + { + 'kernel': PS(None, 'a'), + 'bias': PS(), + } + ) + state_sharding = nnx.StateSharding(state_spec) + + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + + self.assertNotIsInstance( + m.kernel.value.sharding, jax.sharding.NamedSharding + ) + + @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None) + def f(m: nnx.Linear): + self.assertEqual( + m.kernel.value.shape, (m.in_features, m.out_features // n_devices) + ) + self.assertEqual(m.bias.shape, (m.out_features,)) + + f(m) + + self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding) + self.assertIsInstance(m.bias.value.sharding, jax.sharding.NamedSharding) + + def test_simple_data_parallel(self): + P = jax.sharding.PartitionSpec + n_devices = jax.local_device_count() + + mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) + + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((32, 2)) + + @nnx.shard_map( + mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data') + ) + def f(m, x): + self.assertEqual(x.shape, (32 // n_devices, 2)) + return m(x) + + y = f(m, x) + + self.assertEqual(y.shape, (32, 3)) + self.assertIsInstance(y.sharding, jax.sharding.NamedSharding) + self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding) + self.assertIsInstance(m.bias.value.sharding, jax.sharding.NamedSharding) + + def test_simple_tensor_parallel(self): + P = jax.sharding.PartitionSpec + + mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) + + class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) + self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) + + def __call__(self, x): + return self.linear2(jax.nn.relu(self.linear1(x))) + + m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((32, 2)) + + def path_ends_with(path_suffix): + return lambda path, value: path[-len(path_suffix) :] == path_suffix + + model_sharding = nnx.StateSharding( + { + path_ends_with(('linear1', 'kernel')): P(None, 'model'), + path_ends_with(('linear2', 'kernel')): P('model', None), + } + ) + + @nnx.shard_map( + mesh=mesh, in_specs=(model_sharding, P(None)), out_specs=P(None) + ) + def f(m, x): + y = m(x) + return jax.lax.psum(y, 'model') + + y = f(m, x) + + jax.debug.visualize_array_sharding(m.linear1.kernel.value) + class TestGrad(parameterized.TestCase): def test_grad(self): diff --git a/uv.lock b/uv.lock index 21c497104..3aaa836df 100644 --- a/uv.lock +++ b/uv.lock @@ -2258,7 +2258,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.11.5" +version = "0.11.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2276,9 +2276,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/b4/262439a3fe00064b53a5182c0149447304f5a2d5a328a92d64390c18d189/orbax_checkpoint-0.11.5.tar.gz", hash = "sha256:8331ff594980a241ba43eb59dd683e5b590b339cff32a7b72d78cb5a350030b4", size = 249258 } +sdist = { url = "https://files.pythonhosted.org/packages/a8/66/1499f770885b6b42ed4ac839ac4482e681b3a60f96dcd935a32221c165b3/orbax_checkpoint-0.11.6.tar.gz", hash = "sha256:e16a8bbabe7bc0c94f611d115b2b7790183e6847152804a261048160b81b9628", size = 258729 } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/80/e659696b5b1c2ced427efedd2d9d29c1bc31d841ac8a031215aa38f6b2ae/orbax_checkpoint-0.11.5-py3-none-any.whl", hash = "sha256:b55a7a254ea0ab18237e8234a6ca8bf5522f589fcc2ac698cf6893d5e7ae3500", size = 342800 }, + { url = "https://files.pythonhosted.org/packages/82/f5/6707aabc6d4928bdac3eadee00972616f41a83362a5b0441d63a93d81b75/orbax_checkpoint-0.11.6-py3-none-any.whl", hash = "sha256:fb208012e5d3601ee37b1100fe4331f9982b814df89f572749be9094fa499e1f", size = 361764 }, ] [[package]]