From 9b7c0f1ce8338015fbaab5d59e580d4c0fdec0b0 Mon Sep 17 00:00:00 2001 From: jheek Date: Tue, 23 Aug 2022 13:13:30 +0200 Subject: [PATCH 1/3] General metadata FLIP --- docs/flip/2434-general-metadata.md | 132 +++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 docs/flip/2434-general-metadata.md diff --git a/docs/flip/2434-general-metadata.md b/docs/flip/2434-general-metadata.md new file mode 100644 index 00000000..f5840c75 --- /dev/null +++ b/docs/flip/2434-general-metadata.md @@ -0,0 +1,132 @@ +# FLIP: Axis Metadata + + +- Start Date: 2022-08-08 +- FLIP Issue: [#2434](https://github.com/google/flax/issues/2434) +- FLIP PR: [#2435](https://github.com/google/flax/pull/2435) +- Status: Proposal + + +## Summary + +This FLIP proposes to extend Flax's variable collections with a generic axis metadata API. +The core of the API is a abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan). +Users can extend the base class to keep track of per-axis metadata in a way that works with lifted transformations. + + +## Motivation + +Generally, there is no way in Flax to track metadata for variables across lifted transformations. Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs. +For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware. + +Currently, there is experimental support for partitioning annotations which requires using dedicated wrapper around lifted transforms that change axes (``nn.scan``, ``nn.vmap``) and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``). +The experimental partitioning API stores the metadata in a seperate collection named "[collection]_axes". + +The experimental API has a number of shortcomings that we like to solve: +1. The current API works for tracking PartitionSpecs but not for other types of metadata like optimizer annotations. +2. The implementation using an "xxx_axes" collection requires error-prone and non-composable string manipulation. +3. Special, partioning-aware variable creators and lifted transforms are required +4. The partioning API is hard to use with pre-existing Modules that aren't partioning aware. + + +## Proposal + +To generalize metadata tracking and keep the specific metadata out of core Flax we propose the following abstract base class: + +```python +class AxisMetadata(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def unbox(self) -> Any: + pass + + @abc.abstractmethod + def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + pass + + @abc.abstractmethod + def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + pass +``` + +We call this type of class wrapping a value and keeping track of some additional data a **box**. +By defining an abstract base class for this box, the API does not need to be aware of the specifics of the metadata that is tracked. +This should make the API future proof and modular. + +The ``add_axis`` and ``remove_axis`` callback return an instance of their own type instead of mutating in-place. +Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API. +Calling ``jax.tree_map`` on a boxed value will simply map over the value in the box. +The lifted transforms that need to handle metadata will call ``jax.tree_map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. + +Advantages of the boxing approach: +1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the compiler state will + have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree_map`` over the boxed parameters. +2. Boxes are composable. +3. Boxing avoids string manipulation and generally avoids having to handle additional auxilary collections like "param_axes" in the current + partitioning API. +4. No need to lift metadata collections seperately. + + +Disadvantages: +1. Handling boxed values requires the relatively new ``is_leaf=`` syntax which users might not be familiar with. Although users will + probably call Flax provided utils that handle teh low-level tree_map calls in most cases. +3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async. + + +### Init syntax + + +Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers. +The main advantage of this is that we can decouple metadata handling completly from the Module definition. Also, most Modules already over +attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes. + +Too illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: + +```python +class Partitioned(flax.struct.PyTreeNode, AxisMetadata): + value: Any + named_axes: Tuple[Optional[str]] + + ... + +def with_partitioning(init_fn, names): + def wrapper(*args, **kwargs): + return Partitioned(init_fn(*args, **kwargs), names) + return wrapper +``` + +Here we also defined a small utility called ``with_partitioning`` that we can use to wrap existing initialzers to add metadata: + + +```python +# init kernel with lecun normal and split the output features over the data axis +partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, [None, "data"])) +``` + + +### Unbox syntax + + +Metadata typically doesn't need to be handled by Modules directly. Therefore, we prosose to make Modules agnostic to Metadata boxes by default. +The ``unbox`` method can be used to unpack a variable such that only the original JAX arrays remain. Users can manually call unbox but to make +sure Module classes don't have to call it everywhere we add an unbox keyword arg to variable returning APIs (e.g.: ``.param``, ``.variable``, ``.get_variable``). +The keyword arg ``unbox`` will default to ``True`` such that a Modules are metadata agnostic by default. This also means existing Modules will be backward compatible +with the new API. + +```python +kernel = self.param("kernel", self.kernel_init, shape) # No AxisMetadata instances +kernel_box = self.get_variable("param", "kernel", unbox=False) # AxisMetadata boxes are preserved +``` + + +### Lift syntax + +When calling a lifted transformation that adds an axis you will now be able to pass a dictionary with arguments. +These params will be passed to ``AxisMetadata`` add_axis/remove_axis callbacks: + +```python +nn.scan(..., variable_axes={"params": 0}, metadata_params={nn.Partitioned.AXIS_NAME: "layers"}) +``` + +A dict is used such that users can add their own arguments to custom AxisMetadata classes. + From a0267efc2787f4c481d7106ffd9816f277e0f4bb Mon Sep 17 00:00:00 2001 From: jheek Date: Fri, 16 Sep 2022 15:09:56 +0200 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Marc van Zee --- docs/flip/2434-general-metadata.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/flip/2434-general-metadata.md b/docs/flip/2434-general-metadata.md index f5840c75..e53649c8 100644 --- a/docs/flip/2434-general-metadata.md +++ b/docs/flip/2434-general-metadata.md @@ -10,7 +10,7 @@ ## Summary This FLIP proposes to extend Flax's variable collections with a generic axis metadata API. -The core of the API is a abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan). +The core of the API is an abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan). Users can extend the base class to keep track of per-axis metadata in a way that works with lifted transformations. @@ -69,7 +69,7 @@ Advantages of the boxing approach: Disadvantages: 1. Handling boxed values requires the relatively new ``is_leaf=`` syntax which users might not be familiar with. Although users will - probably call Flax provided utils that handle teh low-level tree_map calls in most cases. + probably call Flax provided utils that handle the low-level tree_map calls in most cases. 3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async. @@ -77,10 +77,10 @@ Disadvantages: Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers. -The main advantage of this is that we can decouple metadata handling completly from the Module definition. Also, most Modules already over +The main advantage of this is that we can decouple metadata handling compltely from the Module definition. Also, most Modules already overwrite attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes. -Too illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: +To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: ```python class Partitioned(flax.struct.PyTreeNode, AxisMetadata): @@ -100,14 +100,14 @@ Here we also defined a small utility called ``with_partitioning`` that we can us ```python # init kernel with lecun normal and split the output features over the data axis -partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, [None, "data"])) +partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data"))) ``` ### Unbox syntax -Metadata typically doesn't need to be handled by Modules directly. Therefore, we prosose to make Modules agnostic to Metadata boxes by default. +Metadata typically doesn't need to be handled by Modules directly. Therefore, we propose to make Modules agnostic to Metadata boxes by default. The ``unbox`` method can be used to unpack a variable such that only the original JAX arrays remain. Users can manually call unbox but to make sure Module classes don't have to call it everywhere we add an unbox keyword arg to variable returning APIs (e.g.: ``.param``, ``.variable``, ``.get_variable``). The keyword arg ``unbox`` will default to ``True`` such that a Modules are metadata agnostic by default. This also means existing Modules will be backward compatible From f17e89e6ae49d63a2e79087d8ac1610f6072c3f4 Mon Sep 17 00:00:00 2001 From: jheek Date: Thu, 22 Sep 2022 09:50:53 +0200 Subject: [PATCH 3/3] add AxisMetadata docstring --- docs/flip/2434-general-metadata.md | 124 ++++++++++++++++++++++++++--- 1 file changed, 111 insertions(+), 13 deletions(-) diff --git a/docs/flip/2434-general-metadata.md b/docs/flip/2434-general-metadata.md index e53649c8..d432eaa3 100644 --- a/docs/flip/2434-general-metadata.md +++ b/docs/flip/2434-general-metadata.md @@ -16,11 +16,16 @@ Users can extend the base class to keep track of per-axis metadata in a way that ## Motivation -Generally, there is no way in Flax to track metadata for variables across lifted transformations. Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs. -For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware. +Generally, there is no way in Flax to track metadata for variables across lifted transformations. +Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs. +For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs +in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware. + +Currently, there is an experimental [API](https://github.com/google/flax/blob/main/flax/linen/partitioning.py) +supporting partitioning annotations with wrappers around lifted transforms that change axes (``nn.scan_with_axes``, ``nn.vmap_with_axes``) +and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``). +The experimental partitioning API stores the metadata in a separate collection named "[collection]_axes". -Currently, there is experimental support for partitioning annotations which requires using dedicated wrapper around lifted transforms that change axes (``nn.scan``, ``nn.vmap``) and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``). -The experimental partitioning API stores the metadata in a seperate collection named "[collection]_axes". The experimental API has a number of shortcomings that we like to solve: 1. The current API works for tracking PartitionSpecs but not for other types of metadata like optimizer annotations. @@ -34,18 +39,81 @@ The experimental API has a number of shortcomings that we like to solve: To generalize metadata tracking and keep the specific metadata out of core Flax we propose the following abstract base class: ```python +TAxisMetadata = TypeVar("TAxisMetadata", bound="AxisMetadata") + class AxisMetadata(metaclass=abc.ABCMeta): + """Abstract base class for boxed Metadata. + + ``AxisMetadata`` enables arbitrary, per axis metadata for variables. + By using ``unbox`` the metadata is stripped away to obtain the original + variables. By using unboxing, most code handling variables does not need + to handle ``AxisMetadata`` specifically, but can directly operate on the JAX + arrays that they wrap. + + Additionally, ``AxisMetadata`` supports updating metadata whenever an axis + is added or removed by a functional transformation + (e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis`` + methods. + + By extending ``AxisMetadata``, custom metadata can be stored. See + ``Partitioned`` for a specific implementation. + """ @abc.abstractmethod def unbox(self) -> Any: + """Returns the content of the AxisMetadata box. + + Note that unlike ``meta.unbox`` the unbox call should recursively unbox + metadata. It should simply return value that it wraps directly even + if that value itself is an instance of AxisMetadata. + + In practise, AxisMetadata subclasses should be registred as PyTree nodes to + support passing instances to JAX and Flax APIs. The leaves returned for this + note should correspond to the value returned by unbox. + + Returns: + The unboxed value. + """ pass @abc.abstractmethod - def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + def add_axis(self: TAxisMetadata, index: int, + params: Dict[Any, Any]) -> TAxisMetadata: + """Adds a new axis to the axis metadata. + + Note that add_axis and remove_axis should act as each other's inverse + (meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``) + + Args: + index: The position at which the new axis will be inserted + params: An arbitrary dictionary of parameters passed by the transformation + that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The + user passes this dictionary as the `metadata_param` argument to the + transformation. + Returns: + A new instance of the same type as self and with the same ``unbox`` + content with updated axis metadata. + """ pass @abc.abstractmethod - def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + def remove_axis(self: TAxisMetadata, index: int, + params: Dict[Any, Any]) -> TAxisMetadata: + """Removes an axis from the axis metadata. + + Note that add_axis and remove_axis should act as each other's inverse + (meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``) + + Args: + index: The position of the axis that is to be removed + params: An arbitrary dictionary of parameters passed by the transformation + that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The + user passes this dictionary as the `metadata_param` argument to the + transformation. + Returns: + A new instance of the same type as self and with the same ``unbox`` + content with updated axis metadata. + """ pass ``` @@ -53,13 +121,13 @@ We call this type of class wrapping a value and keeping track of some additional By defining an abstract base class for this box, the API does not need to be aware of the specifics of the metadata that is tracked. This should make the API future proof and modular. -The ``add_axis`` and ``remove_axis`` callback return an instance of their own type instead of mutating in-place. +The ``add_axis`` and ``remove_axis`` method return an instance of their own type instead of mutating in-place. Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API. Calling ``jax.tree_map`` on a boxed value will simply map over the value in the box. The lifted transforms that need to handle metadata will call ``jax.tree_map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. Advantages of the boxing approach: -1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the compiler state will +1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the optimizer state will have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree_map`` over the boxed parameters. 2. Boxes are composable. 3. Boxing avoids string manipulation and generally avoids having to handle additional auxilary collections like "param_axes" in the current @@ -68,8 +136,7 @@ Advantages of the boxing approach: Disadvantages: -1. Handling boxed values requires the relatively new ``is_leaf=`` syntax which users might not be familiar with. Although users will - probably call Flax provided utils that handle the low-level tree_map calls in most cases. +1. Adding the boxes changes the PyTree hierarchy and introduces dataclasses within the otherwise plain, nested dict of variables. 3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async. @@ -77,7 +144,7 @@ Disadvantages: Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers. -The main advantage of this is that we can decouple metadata handling compltely from the Module definition. Also, most Modules already overwrite +The main advantage of this is that we can decouple metadata handling completely from the Module definition. Also, most Modules already overwrite attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes. To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: @@ -85,9 +152,19 @@ To illustrate this, let's consider a metadata class that keeps track of Partitio ```python class Partitioned(flax.struct.PyTreeNode, AxisMetadata): value: Any - named_axes: Tuple[Optional[str]] + names: Tuple[Optional[str], ...] = flax.struct.field(pytree_node=False) - ... + def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + axis_name = self._get_partition_name(params) + names = list(self.names) + names.insert(index, axis_name) + return self.replace(names=tuple(names)) + + def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + axis_name = self._get_partition_name(params) + names = list(self.names) + assert names.pop(index) == axis_name + return self.replace(names=tuple(names)) def with_partitioning(init_fn, names): def wrapper(*args, **kwargs): @@ -103,6 +180,27 @@ Here we also defined a small utility called ``with_partitioning`` that we can us partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data"))) ``` +Initializing a model that creates partitioned weights would result in the following variable structure: + +```python +variables = partitioned_dense.init(rng, jnp.ones((4,))) +jax.tree_map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} +``` + +The variable tree with metadata can be used to integrate with other libaries and APIs. +For example, we can turn the ``Partitioned`` metadata into ``jax.pjit`` sharding annotations: + +```python +def to_sharding_spec(x): + if isinstance(x, Partitioned): + return PartitionSpec(*x.names) + else: + # fully replicated + return PartitionSpec() + +# Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}} +variables_pspec = jax.tree_map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned)) +``` ### Unbox syntax