From 7b50ffec23f10b71532220937481149f9b65b258 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 27 Nov 2024 16:28:56 +0000 Subject: [PATCH] [nnx] optimize Variable --- docs_nnx/nnx_basics.ipynb | 116 ++++++++++-- flax/nnx/graph.py | 118 ++++++------ flax/nnx/reprlib.py | 16 +- flax/nnx/transforms/iteration.py | 15 +- flax/nnx/variablelib.py | 298 ++++++++++++------------------- tests/nnx/containers_test.py | 13 +- tests/nnx/graph_utils_test.py | 2 +- tests/nnx/module_test.py | 8 +- tests/nnx/spmd_test.py | 27 +-- uv.lock | 2 +- 10 files changed, 335 insertions(+), 280 deletions(-) diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 30c2dcb4..351ae8b6 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [ "skip-execution" @@ -88,7 +88,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -180,7 +192,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -235,7 +259,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -359,7 +395,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -418,7 +466,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -467,7 +527,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -479,7 +551,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -580,7 +652,31 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -592,7 +688,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index fec21add..2339f5c1 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -290,6 +290,24 @@ def __treescope_repr__(self, path, subtree_renderer): jax.tree_util.register_static(VariableDef) +@dataclasses.dataclass(frozen=True, slots=True) +class SubGraphAttribute: + key: Key + value: NodeDef[tp.Any] | NodeRef[tp.Any] + + +@dataclasses.dataclass(frozen=True, slots=True) +class StaticAttribute: + key: Key + value: tp.Any + + +@dataclasses.dataclass(frozen=True, slots=True) +class LeafAttribute: + key: Key + value: VariableDef | NodeRef[tp.Any] + + @dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a @@ -298,10 +316,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable): type: tp.Type[Node] index: int - attributes: tuple[Key, ...] - subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] - static_fields: HashableMapping[Key, tp.Any] - leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]] + attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] metadata: tp.Any index_mapping: HashableMapping[Index, Index] | None @@ -310,10 +325,7 @@ def create( cls, type: tp.Type[Node], index: int, - attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]], - static_fields: tp.Iterable[tuple[Key, tp.Any]], - leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]], + attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], metadata: tp.Any, index_mapping: tp.Mapping[Index, Index] | None, ): @@ -321,9 +333,6 @@ def create( type=type, index=index, attributes=attributes, - subgraphs=HashableMapping(subgraphs), - static_fields=HashableMapping(static_fields), - leaves=HashableMapping(leaves), metadata=metadata, index_mapping=HashableMapping(index_mapping) if index_mapping is not None @@ -335,12 +344,7 @@ def __nnx_repr__(self): yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) - yield reprlib.Attr('attributes', self.attributes) - yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs)) - yield reprlib.Attr( - 'static_fields', reprlib.PrettyMapping(self.static_fields) - ) - yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves)) + yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes)) yield reprlib.Attr('metadata', self.metadata) yield reprlib.Attr( 'index_mapping', @@ -352,18 +356,15 @@ def __nnx_repr__(self): def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] return treescope.repr_lib.render_object_constructor( - object_type=type(self), - attributes={ - 'type': self.type, - 'index': self.index, - 'attributes': self.attributes, - 'subgraphs': dict(self.subgraphs), - 'static_fields': dict(self.static_fields), - 'leaves': dict(self.leaves), - 'metadata': self.metadata, - }, - path=path, - subtree_renderer=subtree_renderer, + object_type=type(self), + attributes={ + 'type': self.type, + 'index': self.index, + 'attributes': self.attributes, + 'metadata': self.metadata, + }, + path=path, + subtree_renderer=subtree_renderer, ) def apply( @@ -426,40 +427,39 @@ def _graph_flatten( else: index = -1 - subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = [] - static_fields: list[tuple[Key, tp.Any]] = [] - leaves: list[tuple[Key, VariableDef | NodeRef]] = [] + attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) - subgraphs.append((key, nodedef)) + # subgraphs.append((key, nodedef)) + attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): if value in ref_index: - leaves.append((key, NodeRef(type(value), ref_index[value]))) + attributes.append( + LeafAttribute(key, NodeRef(type(value), ref_index[value])) + ) else: flat_state[(*path, key)] = value.to_state() variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value.get_metadata()) ) - leaves.append((key, variabledef)) + attributes.append(LeafAttribute(key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' ) - static_fields.append((key, value)) + # static_fields.append((key, value)) + attributes.append(StaticAttribute(key, value)) nodedef = NodeDef.create( type=node_impl.type, index=index, - attributes=tuple(key for key, _ in values), - subgraphs=subgraphs, - static_fields=static_fields, - leaves=leaves, + attributes=tuple(attributes), metadata=metadata, index_mapping=None, ) @@ -529,22 +529,20 @@ def _graph_unflatten( def _get_children(): children: dict[Key, NodeLeaf | Node] = {} - - # NOTE: we could allw adding new StateLeafs here - if unkown_keys := set(state) - set(nodedef.attributes): - raise ValueError(f'Unknown keys: {unkown_keys}') + state_keys: set = set(state.keys()) # for every key in attributes there are 6 possible cases: # - (2) the key can either be present in the state or not # - (3) the key can be a subgraph, a leaf, or a static attribute - for key in nodedef.attributes: + for attribute in nodedef.attributes: + key = attribute.key if key not in state: # if key is not present create an empty types - if key in nodedef.static_fields: - children[key] = nodedef.static_fields[key] - elif key in nodedef.subgraphs: + if type(attribute) is StaticAttribute: + children[key] = attribute.value + elif type(attribute) is SubGraphAttribute: # if the key is a subgraph we create an empty node - subgraphdef = nodedef.subgraphs[key] + subgraphdef = attribute.value assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache @@ -558,8 +556,8 @@ def _get_children(): children[key] = _graph_unflatten( subgraphdef, substate, index_ref, index_ref_cache ) - elif key in nodedef.leaves: - variabledef = nodedef.leaves[key] + elif type(attribute) is LeafAttribute: + variabledef = attribute.value if variabledef.index in index_ref: # variable exists, take it from the cache children[key] = index_ref[variabledef.index] @@ -572,19 +570,21 @@ def _get_children(): else: raise RuntimeError(f'Unknown static field: {key!r}') else: + state_keys.remove(key) value = state[key] - if key in nodedef.static_fields: + # if key in nodedef.static_fields: + if type(attribute) is StaticAttribute: raise ValueError( f'Got state for static field {key!r}, this is not supported.' ) - if key in nodedef.subgraphs: + elif type(attribute) is SubGraphAttribute: if is_state_leaf(value): raise ValueError( - f'Expected value of type {nodedef.subgraphs[key]} for ' + f'Expected value of type {attribute.value} for ' f'{key!r}, but got {value!r}' ) assert isinstance(value, dict) - subgraphdef = nodedef.subgraphs[key] + subgraphdef = attribute.value if isinstance(subgraphdef, NodeRef): children[key] = index_ref[subgraphdef.index] @@ -593,8 +593,8 @@ def _get_children(): subgraphdef, value, index_ref, index_ref_cache ) - elif key in nodedef.leaves: - variabledef = nodedef.leaves[key] + elif type(attribute) is LeafAttribute: + variabledef = attribute.value if variabledef.index in index_ref: # add an existing variable @@ -631,6 +631,10 @@ def _get_children(): else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') + # NOTE: we could allw adding new StateLeafs here + if state_keys: + raise ValueError(f'Unknown keys: {state_keys}') + return children if isinstance(node_impl, GraphNodeImpl): diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py index 855a3049..6ed7660c 100644 --- a/flax/nnx/reprlib.py +++ b/flax/nnx/reprlib.py @@ -16,7 +16,6 @@ import dataclasses import threading import typing as tp -from abc import ABC, abstractmethod A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -48,10 +47,9 @@ class Attr: end: str = '' -class Representable(ABC): +class Representable: __slots__ = () - @abstractmethod def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: raise NotImplementedError @@ -121,4 +119,14 @@ def __nnx_repr__(self): yield Object(type='', value_sep=': ', start='{', end='}') for key, value in self.mapping.items(): - yield Attr(repr(key), value) \ No newline at end of file + yield Attr(repr(key), value) + +@dataclasses.dataclass(repr=False) +class PrettySequence(Representable): + list: tp.Sequence + + def __nnx_repr__(self): + yield Object(type='', value_sep='', start='[', end=']') + + for value in self.list: + yield Attr('', value) \ No newline at end of file diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index c9a3c1c4..994e5828 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1341,11 +1341,16 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef): global_index_mapping[nd.index] = nd.index if isinstance(nd, graph.NodeRef): return - for sub_nd in nd.subgraphs.values(): - per_node_def(sub_nd) - for l in nd.leaves.values(): - if isinstance(l, (graph.VariableDef, graph.NodeRef)) and l.index >= 0: - global_index_mapping[l.index] = l.index + + for attribute in nd.attributes: + if type(attribute) is graph.SubGraphAttribute: + per_node_def(attribute.value) + elif ( + type(attribute) is graph.LeafAttribute + and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef)) + and attribute.value.index >= 0 + ): + global_index_mapping[attribute.value.index] = attribute.value.index return per_node_def(ns._graphdef) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 91d6c861..7af20cdb 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -120,136 +120,86 @@ class Variable(tp.Generic[A], reprlib.Representable): """ raw_value: A - set_value_hooks: tuple[SetValueHook[A], ...] - get_value_hooks: tuple[GetValueHook[A], ...] - create_value_hooks: tuple[CreateValueHook[A], ...] - add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] - remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] _trace_state: tracers.TraceState + _var_metadata: dict[str, tp.Any] def __init__( self, value: tp.Union[A, VariableMetadata[A]], - *, - set_value_hooks: tp.Union[ - SetValueHook[A], tp.Sequence[SetValueHook[A]] - ] = (), - get_value_hooks: tp.Union[ - GetValueHook[A], tp.Sequence[GetValueHook[A]] - ] = (), - create_value_hooks: tp.Union[ - CreateValueHook[A], tp.Sequence[CreateValueHook[A]] - ] = (), - add_axis_hooks: tp.Union[ - AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] - ] = (), - remove_axis_hooks: tp.Union[ - RemoveAxisHook[Variable[A]], - tp.Sequence[RemoveAxisHook[Variable[A]]], - ] = (), **metadata: tp.Any, ): - vars(self)['_trace_state'] = tracers.TraceState() - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) - else: - set_value_hooks = tuple(set_value_hooks) - - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) - else: - get_value_hooks = tuple(get_value_hooks) - - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) - else: - create_value_hooks = tuple(create_value_hooks) - - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) - else: - add_axis_hooks = tuple(add_axis_hooks) - - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) - else: - remove_axis_hooks = tuple(remove_axis_hooks) + type_vars = vars(type(self)) + vars_self = vars(self) + vars_self['_trace_state'] = tracers.TraceState() if isinstance(value, VariableMetadata): - value_metadata = dict(value.metadata) - if value.set_value_hooks: - set_value_hooks = set_value_hooks + value.set_value_hooks - if value.get_value_hooks: - get_value_hooks = get_value_hooks + value.get_value_hooks - if value.create_value_hooks: - create_value_hooks = create_value_hooks + value.create_value_hooks - if value.add_axis_hooks: - add_axis_hooks = add_axis_hooks + value.add_axis_hooks - if value.remove_axis_hooks: - remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks - - metadata.update(value_metadata) + metadata.update(value.metadata) value = tp.cast(A, value.raw_value) - self.raw_value = value + object.__setattr__(self, 'raw_value', value) - if 'on_get_value' in vars(type(self)): - on_get_value = getattr(type(self), 'on_get_value') - if on_get_value not in get_value_hooks: - get_value_hooks = (on_get_value, *get_value_hooks) + if 'on_get_value' in type_vars and 'on_get_value' not in metadata: + metadata['get_value'] = getattr(type(self), 'on_get_value') - if 'on_set_value' in vars(type(self)): - on_set_value = getattr(type(self), 'on_set_value') - if on_set_value not in set_value_hooks: - set_value_hooks = (on_set_value, *set_value_hooks) + if 'on_set_value' in type_vars and 'on_set_value' not in metadata: + metadata['set_value'] = getattr(type(self), 'on_set_value') - if 'on_create_value' in vars(type(self)): - on_create_value = getattr(type(self), 'on_create_value') - if on_create_value not in create_value_hooks: - create_value_hooks = (on_create_value, *create_value_hooks) + if 'on_create_value' in type_vars and 'on_create_value' not in metadata: + metadata['create_value'] = getattr(type(self), 'on_create_value') - if 'on_add_axis' in vars(type(self)): - on_add_axis = getattr(type(self), 'on_add_axis') - if on_add_axis not in add_axis_hooks: - add_axis_hooks = (on_add_axis, *add_axis_hooks) + if 'on_add_axis' in type_vars and 'on_add_axis' not in metadata: + metadata['add_axis'] = getattr(type(self), 'on_add_axis') - if 'on_remove_axis' in vars(type(self)): - on_remove_axis = getattr(type(self), 'on_remove_axis') - if on_remove_axis not in remove_axis_hooks: - remove_axis_hooks = (on_remove_axis, *remove_axis_hooks) - - self.get_value_hooks = get_value_hooks - self.set_value_hooks = set_value_hooks - self.create_value_hooks = create_value_hooks - self.add_axis_hooks = add_axis_hooks - self.remove_axis_hooks = remove_axis_hooks - vars(self).update(metadata) + if 'on_remove_axis' in type_vars and 'on_remove_axis' not in metadata: + metadata['remove_axis'] = getattr(type(self), 'on_remove_axis') + vars_self['_var_metadata'] = metadata # run create_value hooks - self.raw_value = self.create_value(self.raw_value) + vars_self['raw_value'] = self.create_value(self.raw_value) + + def __getattr__(self, name: str) -> tp.Any: + if name in vars(self)['_var_metadata']: + return self._var_metadata[name] + return getattr(self.value, name) - if not tp.TYPE_CHECKING: + def __setattr__(self, name: str, value: tp.Any): + if not self._trace_state.is_valid(): + raise errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' + ) - def __setattr__(self, name: str, value: Any) -> None: - return self._setattr(name, value) + if ( + name == 'value' + or name == 'raw_value' + or name == '_var_metadata' + or name == '_trace_state' + ): + object.__setattr__(self, name, value) + else: + self._var_metadata[name] = value - def _setattr(self, name: str, value: tp.Any): + def __delattr__(self, name: str): if not self._trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) - object.__setattr__(self, name, value) + if ( + name == 'value' + or name == 'raw_value' + or name == '_var_metadata' + or name == '_trace_state' + ): + object.__delattr__(self, name) + else: + del self._var_metadata[name] @classmethod def state(cls, value: A, **metadata) -> VariableState[A]: return cls(value, **metadata).to_state() def get_metadata(self): - metadata = vars(self).copy() - del metadata['raw_value'] - del metadata['_trace_state'] - return metadata + return self._var_metadata def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): @@ -259,29 +209,20 @@ def copy_from(self, other: Variable[A]) -> None: ) if self is other: return - trace_state = self._trace_state - vars_dict = vars(self) - other_vars = vars(other).copy() - del other_vars['_trace_state'] - vars_dict.clear() - vars_dict.update(other_vars, _trace_state=trace_state) + self.raw_value = other.raw_value + self._var_metadata.clear() + self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: VariableState[A]): - trace_state = self._trace_state - variable_vars = vars(self) - variable_vars.clear() - variable_vars.update( - variable_state.get_metadata(), - raw_value=variable_state.value, - _trace_state=trace_state, - ) + vars_self = vars(self) + vars_self['raw_value'] = variable_state.value + vars_self['_var_metadata'] = variable_state.get_metadata().copy() @property def value(self) -> A: value = self.raw_value - if self.get_value_hooks: - for hook in self.get_value_hooks: - value = hook(self, value) + if 'on_get_value' in self._var_metadata: + value = self._var_metadata['on_get_value'](self, value) return value @value.setter @@ -290,23 +231,22 @@ def value(self, value: A): raise ValueError( 'Cannot set value to a Variable, ' 'use `copy_from` method instead' ) - if self.set_value_hooks: - for hook in self.set_value_hooks: - value = hook(self, value) - self.raw_value = value + if 'on_set_value' in self._var_metadata: + value = self._var_metadata['on_set_value'](self, value) + vars(self)['raw_value'] = value def create_value(self, value: A): - for hook in self.create_value_hooks: - value = hook(self, value) + if 'on_create_value' in self._var_metadata: + value = self._var_metadata['on_create_value'](self, value) return value def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_add_axis' in self._var_metadata: + self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_remove_axis' in self._var_metadata: + self._var_metadata['on_remove_axis'](self, axis_index, axis_name) def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @@ -344,26 +284,27 @@ def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: return value # get and update attributes - attributes = vars(self).copy() - attributes.update(**kwargs) # return new instance with updated attributes obj = object.__new__(type(self)) - vars(obj).update(attributes) + object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value')) + object.__setattr__(obj, '_var_metadata', self.get_metadata()) + obj._var_metadata.update(kwargs) return obj @classmethod def from_metadata(cls, value: A, attributes: tp.Mapping[str, tp.Any]): obj = object.__new__(cls) - vars(obj).update( - attributes, raw_value=value, _trace_state=tracers.TraceState() - ) + object.__setattr__(obj, '_trace_state', tracers.TraceState()) + object.__setattr__(obj, 'raw_value', value) + object.__setattr__(obj, '_var_metadata', attributes) return obj def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) - attributes = vars(self).copy() - attributes['_trace_state'] = tracers.TraceState() - vars(obj).update(attributes) + object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, 'raw_value', self.raw_value) + object.__setattr__(obj, '_var_metadata', self.get_metadata().copy()) return obj def to_state(self: Variable[A]) -> VariableState[A]: @@ -372,23 +313,14 @@ def to_state(self: Variable[A]) -> VariableState[A]: def __nnx_repr__(self): yield reprlib.Object(type=type(self)) - for name, value in vars(self).items(): - if name == 'raw_value': - name = 'value' - if name.endswith('_hooks') or name == '_trace_state': - continue + yield reprlib.Attr('value', self.raw_value) + for name, value in self._var_metadata.items(): yield reprlib.Attr(name, repr(value)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] - children = {} - for name, value in vars(self).items(): - if name == 'raw_value': - name = 'value' - if name.endswith('_hooks') or name == '_trace_state': - continue - children[name] = value + children = {'value': self.raw_value, **self._var_metadata} return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes=children, @@ -426,10 +358,6 @@ def __setstate__(self, state): # -------------------------------------------- # proxy methods # -------------------------------------------- - # NOTE: we dont override __setattr__ to avoid cases where - # you need to set an attribute on the variable instance - def __getattr__(self, name: str) -> tp.Any: - return getattr(self.value, name) def __getitem__(self, key) -> tp.Any: return self.value[key] # type: ignore @@ -803,39 +731,51 @@ class Intermediate(Variable[A]): class VariableState(tp.Generic[A], reprlib.Representable): + __slots__ = ('type', 'value', '_var_metadata') + type: type[Variable[A]] + value: A + _var_metadata: dict[str, tp.Any] + def __init__( self, - type: type[Variable[tp.Any]], + type: type[Variable[A]], # type: ignore [valid-type] value: A, **metadata, ): - self.type = type - self.value = value - vars(self).update(metadata) - - if tp.TYPE_CHECKING: + object.__setattr__(self, 'type', type) + object.__setattr__(self, 'value', value) + object.__setattr__(self, '_var_metadata', metadata) + + def __getattr__(self, name: str) -> None: + var_metadata = object.__getattribute__(self, '_var_metadata') + if name not in var_metadata: + raise AttributeError(f"'VariableState' object has no attribute '{name}'") + return var_metadata[name] + + def __setattr__(self, name: str, value: Any) -> None: + if name == 'type' or name == 'value' or name == '_var_metadata': + object.__setattr__(self, name, value) + else: + self._var_metadata[name] = value - def __getattr__(self, name: str) -> None: ... - def __setattr__(self, name: str, value: Any) -> None: ... - def __delattr__(self, name: str) -> None: ... + def __delattr__(self, name: str) -> None: + if name == 'type' or name == 'value' or name == '_var_metadata': + object.__delattr__(self, name) + else: + del self._var_metadata[name] def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('value', self.value) - for name, value in vars(self).items(): - if name == 'type' or name.endswith('_hooks'): - continue + for name, value in self._var_metadata.items(): yield reprlib.Attr(name, repr(value)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] - children = {'type': self.type} - for name, value in vars(self).items(): - if name == 'type' or name.endswith('_hooks'): - continue - children[name] = value + children = {'type': self.type, 'value': self.value, **self._var_metadata} return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes=children, @@ -849,29 +789,25 @@ def replace(self, value: B) -> VariableState[B]: def to_variable(self) -> Variable[A]: # we use object.__new__ to avoid calling __init__ and bypass the # __init__ logic which should not be called twice - metadata = self.get_metadata() - variables = object.__new__(self.type) - vars(variables).update( - metadata, raw_value=self.value, _trace_state=tracers.TraceState() - ) - return variables + variable = object.__new__(self.type) + object.__setattr__(variable, '_trace_state', tracers.TraceState()) + object.__setattr__(variable, 'raw_value', self.value) + object.__setattr__(variable, '_var_metadata', self.get_metadata().copy()) + return variable def copy(self: VariableState[A]) -> VariableState[A]: return jax.tree.map(lambda x: x, self) def get_metadata(self) -> dict[str, tp.Any]: - metadata = vars(self).copy() - del metadata['type'] - del metadata['value'] - return metadata + return self._var_metadata def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_add_axis' in self._var_metadata: + self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_remove_axis' in self._var_metadata: + self._var_metadata['on_remove_axis'](self, axis_index, axis_name) def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): diff --git a/tests/nnx/containers_test.py b/tests/nnx/containers_test.py index 97785e76..92345abc 100644 --- a/tests/nnx/containers_test.py +++ b/tests/nnx/containers_test.py @@ -21,15 +21,15 @@ class TestContainers(absltest.TestCase): def test_unbox(self): x = nnx.Param( 1, - get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + on_get_value=lambda c, x: x + 3, # type: ignore ) assert x.value == 4 - def test_box(self): + def test_on_set_value(self): x: nnx.Param[int] = nnx.Param( 1, # type: ignore - set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + on_set_value=lambda c, x: x + 7, # type: ignore ) x.value = 5 @@ -38,9 +38,7 @@ def test_box(self): def test_module_unbox(self): class Foo(nnx.Module): def __init__(self) -> None: - self.x = nnx.Param( - 1, get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] - ) + self.x = nnx.Param(1, on_get_value=lambda c, x: x + 3) module = Foo() @@ -51,7 +49,8 @@ def test_module_box(self): class Foo(nnx.Module): def __init__(self) -> None: self.x = nnx.Param( - 1, set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] + 1, + on_set_value=lambda c, x: x + 7, # type: ignore ) module = Foo() diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index fb0496e0..a7bbf178 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -303,7 +303,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree + assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state) diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 498ce3de..ce65186d 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -40,7 +40,7 @@ def __setitem__(self, idx, value): class Dict(nnx.Module): def __init__(self, *args, **kwargs): - self.items = dict(*args, **kwargs) + vars(self)['items'] = dict(*args, **kwargs) def __getitem__(self, key): return vars(self)['items'][key] @@ -48,6 +48,12 @@ def __getitem__(self, key): def __setitem__(self, key, value): vars(self)['items'][key] = value + def __setattr__(self, key, value): + if key == 'items': + object.__setattr__(self, key, value) + else: + vars(self)['items'][key] = value + def __getattr__(self, key): attrs = vars(self) if 'items' not in attrs: diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 828ee568..2372fbad 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -112,19 +112,20 @@ class MLP(nnx.Module): ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( - 3, - 3, - kernel_init=nnx.with_metadata( - nnx.initializers.lecun_normal(), sharding=('din', 'dout'), - add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)), - ), - bias_init=nnx.with_metadata( - nnx.initializers.zeros_init(), # no sharding annotation here! - add_axis_hooks=lambda _, idx, name: badds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)), - ), - rngs=rngs, + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), + sharding=('din', 'dout'), + on_add_axis=lambda _, idx, name: kadds.append((idx, name)), + on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), # no sharding annotation here! + on_add_axis=lambda _, idx, name: badds.append((idx, name)), + on_remove_axis=lambda _, idx, name: bremoves.append((idx, name)), + ), + rngs=rngs, ) @nnx.scan( diff --git a/uv.lock b/uv.lock index 0d68a86b..a3015511 100644 --- a/uv.lock +++ b/uv.lock @@ -773,7 +773,7 @@ wheels = [ [[package]] name = "flax" -version = "0.10.1" +version = "0.10.2" source = { editable = "." } dependencies = [ { name = "jax" },