Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Marc van Zee <[email protected]>
  • Loading branch information
jheek and marcvanzee authored Sep 16, 2022
1 parent 9b7c0f1 commit a0267ef
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions docs/flip/2434-general-metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
## Summary

This FLIP proposes to extend Flax's variable collections with a generic axis metadata API.
The core of the API is a abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan).
The core of the API is an abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan).
Users can extend the base class to keep track of per-axis metadata in a way that works with lifted transformations.


Expand Down Expand Up @@ -69,18 +69,18 @@ Advantages of the boxing approach:

Disadvantages:
1. Handling boxed values requires the relatively new ``is_leaf=`` syntax which users might not be familiar with. Although users will
probably call Flax provided utils that handle teh low-level tree_map calls in most cases.
probably call Flax provided utils that handle the low-level tree_map calls in most cases.
3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async.


### Init syntax


Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers.
The main advantage of this is that we can decouple metadata handling completly from the Module definition. Also, most Modules already over
The main advantage of this is that we can decouple metadata handling compltely from the Module definition. Also, most Modules already overwrite
attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes.

Too illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``:
To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``:

```python
class Partitioned(flax.struct.PyTreeNode, AxisMetadata):
Expand All @@ -100,14 +100,14 @@ Here we also defined a small utility called ``with_partitioning`` that we can us

```python
# init kernel with lecun normal and split the output features over the data axis
partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, [None, "data"]))
partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data")))
```


### Unbox syntax


Metadata typically doesn't need to be handled by Modules directly. Therefore, we prosose to make Modules agnostic to Metadata boxes by default.
Metadata typically doesn't need to be handled by Modules directly. Therefore, we propose to make Modules agnostic to Metadata boxes by default.
The ``unbox`` method can be used to unpack a variable such that only the original JAX arrays remain. Users can manually call unbox but to make
sure Module classes don't have to call it everywhere we add an unbox keyword arg to variable returning APIs (e.g.: ``.param``, ``.variable``, ``.get_variable``).
The keyword arg ``unbox`` will default to ``True`` such that a Modules are metadata agnostic by default. This also means existing Modules will be backward compatible
Expand Down

0 comments on commit a0267ef

Please sign in to comment.