Skip to content

Commit

Permalink
add AxisMetadata docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jheek committed Sep 30, 2022
1 parent a0267ef commit f17e89e
Showing 1 changed file with 111 additions and 13 deletions.
124 changes: 111 additions & 13 deletions docs/flip/2434-general-metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ Users can extend the base class to keep track of per-axis metadata in a way that

## Motivation

Generally, there is no way in Flax to track metadata for variables across lifted transformations. Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs.
For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware.
Generally, there is no way in Flax to track metadata for variables across lifted transformations.
Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs.
For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs
in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware.

Currently, there is an experimental [API](https://github.com/google/flax/blob/main/flax/linen/partitioning.py)
supporting partitioning annotations with wrappers around lifted transforms that change axes (``nn.scan_with_axes``, ``nn.vmap_with_axes``)
and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``).
The experimental partitioning API stores the metadata in a separate collection named "[collection]_axes".

Currently, there is experimental support for partitioning annotations which requires using dedicated wrapper around lifted transforms that change axes (``nn.scan``, ``nn.vmap``) and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``).
The experimental partitioning API stores the metadata in a seperate collection named "[collection]_axes".

The experimental API has a number of shortcomings that we like to solve:
1. The current API works for tracking PartitionSpecs but not for other types of metadata like optimizer annotations.
Expand All @@ -34,32 +39,95 @@ The experimental API has a number of shortcomings that we like to solve:
To generalize metadata tracking and keep the specific metadata out of core Flax we propose the following abstract base class:

```python
TAxisMetadata = TypeVar("TAxisMetadata", bound="AxisMetadata")

class AxisMetadata(metaclass=abc.ABCMeta):
"""Abstract base class for boxed Metadata.
``AxisMetadata`` enables arbitrary, per axis metadata for variables.
By using ``unbox`` the metadata is stripped away to obtain the original
variables. By using unboxing, most code handling variables does not need
to handle ``AxisMetadata`` specifically, but can directly operate on the JAX
arrays that they wrap.
Additionally, ``AxisMetadata`` supports updating metadata whenever an axis
is added or removed by a functional transformation
(e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis``
methods.
By extending ``AxisMetadata``, custom metadata can be stored. See
``Partitioned`` for a specific implementation.
"""

@abc.abstractmethod
def unbox(self) -> Any:
"""Returns the content of the AxisMetadata box.
Note that unlike ``meta.unbox`` the unbox call should recursively unbox
metadata. It should simply return value that it wraps directly even
if that value itself is an instance of AxisMetadata.
In practise, AxisMetadata subclasses should be registred as PyTree nodes to
support passing instances to JAX and Flax APIs. The leaves returned for this
note should correspond to the value returned by unbox.
Returns:
The unboxed value.
"""
pass

@abc.abstractmethod
def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
def add_axis(self: TAxisMetadata, index: int,
params: Dict[Any, Any]) -> TAxisMetadata:
"""Adds a new axis to the axis metadata.
Note that add_axis and remove_axis should act as each other's inverse
(meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``)
Args:
index: The position at which the new axis will be inserted
params: An arbitrary dictionary of parameters passed by the transformation
that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The
user passes this dictionary as the `metadata_param` argument to the
transformation.
Returns:
A new instance of the same type as self and with the same ``unbox``
content with updated axis metadata.
"""
pass

@abc.abstractmethod
def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
def remove_axis(self: TAxisMetadata, index: int,
params: Dict[Any, Any]) -> TAxisMetadata:
"""Removes an axis from the axis metadata.
Note that add_axis and remove_axis should act as each other's inverse
(meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``)
Args:
index: The position of the axis that is to be removed
params: An arbitrary dictionary of parameters passed by the transformation
that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The
user passes this dictionary as the `metadata_param` argument to the
transformation.
Returns:
A new instance of the same type as self and with the same ``unbox``
content with updated axis metadata.
"""
pass
```

We call this type of class wrapping a value and keeping track of some additional data a **box**.
By defining an abstract base class for this box, the API does not need to be aware of the specifics of the metadata that is tracked.
This should make the API future proof and modular.

The ``add_axis`` and ``remove_axis`` callback return an instance of their own type instead of mutating in-place.
The ``add_axis`` and ``remove_axis`` method return an instance of their own type instead of mutating in-place.
Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API.
Calling ``jax.tree_map`` on a boxed value will simply map over the value in the box.
The lifted transforms that need to handle metadata will call ``jax.tree_map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree.

Advantages of the boxing approach:
1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the compiler state will
1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the optimizer state will
have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree_map`` over the boxed parameters.
2. Boxes are composable.
3. Boxing avoids string manipulation and generally avoids having to handle additional auxilary collections like "param_axes" in the current
Expand All @@ -68,26 +136,35 @@ Advantages of the boxing approach:


Disadvantages:
1. Handling boxed values requires the relatively new ``is_leaf=`` syntax which users might not be familiar with. Although users will
probably call Flax provided utils that handle the low-level tree_map calls in most cases.
1. Adding the boxes changes the PyTree hierarchy and introduces dataclasses within the otherwise plain, nested dict of variables.
3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async.


### Init syntax


Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers.
The main advantage of this is that we can decouple metadata handling compltely from the Module definition. Also, most Modules already overwrite
The main advantage of this is that we can decouple metadata handling completely from the Module definition. Also, most Modules already overwrite
attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes.

To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``:

```python
class Partitioned(flax.struct.PyTreeNode, AxisMetadata):
value: Any
named_axes: Tuple[Optional[str]]
names: Tuple[Optional[str], ...] = flax.struct.field(pytree_node=False)

...
def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
axis_name = self._get_partition_name(params)
names = list(self.names)
names.insert(index, axis_name)
return self.replace(names=tuple(names))

def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
axis_name = self._get_partition_name(params)
names = list(self.names)
assert names.pop(index) == axis_name
return self.replace(names=tuple(names))

def with_partitioning(init_fn, names):
def wrapper(*args, **kwargs):
Expand All @@ -103,6 +180,27 @@ Here we also defined a small utility called ``with_partitioning`` that we can us
partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data")))
```

Initializing a model that creates partitioned weights would result in the following variable structure:

```python
variables = partitioned_dense.init(rng, jnp.ones((4,)))
jax.tree_map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}}
```

The variable tree with metadata can be used to integrate with other libaries and APIs.
For example, we can turn the ``Partitioned`` metadata into ``jax.pjit`` sharding annotations:

```python
def to_sharding_spec(x):
if isinstance(x, Partitioned):
return PartitionSpec(*x.names)
else:
# fully replicated
return PartitionSpec()

# Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}}
variables_pspec = jax.tree_map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned))
```

### Unbox syntax

Expand Down

0 comments on commit f17e89e

Please sign in to comment.