Skip to content

Commit

Permalink
Merge pull request #2313 from levskaya:idfix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 463175098
  • Loading branch information
Flax Authors committed Jul 25, 2022
2 parents 922e1ee + 963fa4a commit db4a74b
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 24 deletions.
5 changes: 3 additions & 2 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from typing import (Any, Callable, Dict, Generic, Iterable, Mapping, Optional,
Sequence, Set, Tuple, TypeVar, Union)

from . import tracers
from flax.ids import uuid
from flax import config
from flax import errors
from flax import struct
from flax import traceback_util
from .frozen_dict import freeze
from .frozen_dict import FrozenDict
from .frozen_dict import unfreeze
from . import tracers
import jax
from jax import config as jax_config
from jax import numpy as jnp
Expand All @@ -51,7 +52,6 @@
# When conditioning on filters we require explicit boolean comparisons.
# pylint: disable=g-bool-id-comparison


@dataclasses.dataclass(frozen=True, eq=True)
class DenyList:
"""DenyList represents an opt-out based mutability filter.
Expand Down Expand Up @@ -343,6 +343,7 @@ def __init__(self, scope: 'Scope', collection: str, name: str):
collection: The collection of the variable (e.g., "params").
name: The name of the variable (e.g., "dense").
"""
self._id = uuid()
self.scope = scope
self.collection = collection
self.name = name
Expand Down
57 changes: 57 additions & 0 deletions flax/ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""UUIDs for Flax internals."""

import threading


class UUIDManager:
"""Globally unique counter-based id manager.
We need globally unique key ids for Module and Variable object instances
to preserve and recreate sharing-by-reference relationship when lifting
transforms and adopting outside Modules.
- Use of id() is unacceptable because these identifiers are literally
pointers which can be recycled, so we rely on a globally unique counter id
instead.
- We need to handle copy/deepcopy uniqueness via a wrapped type.
"""
def __init__(self):
self._lock = threading.Lock()
self._id = 0

def __call__(self):
with self._lock:
self._id += 1
return FlaxId(self._id)

uuid = UUIDManager()


class FlaxId:
"""Hashable wrapper for ids that handles uniqueness of copies."""
def __init__(self, rawid):
self.id = rawid
def __eq__(self, other):
return isinstance(other, FlaxId) and other.id == self.id
def __hash__(self):
return hash(self.id)
def __repr__(self):
return f"FlaxId({self.id})"
def __deepcopy__(self, memo):
del memo
return uuid()
def __copy__(self):
return uuid()
14 changes: 7 additions & 7 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from flax.core.scope import ( # pylint: disable=g-multiple-import
CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict,
union_filters)
from flax.ids import uuid
from flax.linen import summary


Expand Down Expand Up @@ -730,6 +731,7 @@ def __post_init__(self) -> None:
# initialization, attach this Module as a submodule of a parent, or bind
# this Module at the top-level to variables and rngs.

object.__setattr__(self, '_id', uuid())
object.__setattr__(self, '_state', _ModuleInternalState())

# Typically we set the parent based on the dynamic module context.
Expand Down Expand Up @@ -827,14 +829,12 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
# Module was passed from outside. It needs to be cloned.
# Outside modules are named by attachment, not an outer name.
object.__setattr__(subvalue, 'name', None)
key = id(subvalue)
# Preserve sharing-by-reference relationships during adoption
# via cache keyed on unique instance ids.
key = subvalue._id
if key not in cache:
# since we use id() as key, we need to keep a reference to original
# subvalue to ensure it's lifetime is long enough for the entire
# model setup and the id() is not recycled.
# TODO(levskaya): consider switching to per-module UUIDs
cache[key] = (subvalue.clone(), subvalue)
subvalue = cache[key][0]
cache[key] = subvalue.clone()
subvalue = cache[key]
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
object.__setattr__(subvalue, 'name', f'{name}{suffix}')
Expand Down
31 changes: 16 additions & 15 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

traceback_util.register_exclusion(__file__)

# pylint: disable=protected-access

# Utils
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -79,12 +80,12 @@ def wrapped_fn(x):
nonlocal refs
if isinstance(x, (VariablePlaceholder, InstancePlaceholder)):
x_id = x.id
elif isinstance(x, (Variable, Module)):
x_id = x._id
else:
x_id = id(x)
return fn(x)
if x_id not in refs:
refs[x_id] = fn(x)
else:
pass
return refs[x_id]
return wrapped_fn

Expand Down Expand Up @@ -124,25 +125,25 @@ def get_arg_scope(x):
nonlocal scopes
if isinstance(x, Variable) and isinstance(x.scope, Scope):
scopes.append(x.scope)
return VariablePlaceholder(x.collection, x.name, id(x))
return VariablePlaceholder(x.collection, x.name, x._id)
elif isinstance(x, Module) and isinstance(x.scope, Scope):
x._try_setup(shallow=True) # pylint: disable=protected-access
x._try_setup(shallow=True)
scopes.append(x.scope)
attrs = {
f.name: getattr(x, f.name)
for f in dataclasses.fields(x)
if f.name != 'parent' and f.init
}
attrs = jax.tree_util.tree_map(get_arg_scope, attrs)
return InstancePlaceholder(x.__class__, attrs, id(x))
return InstancePlaceholder(x.__class__, attrs, x._id)
return x
new_args, new_kwargs = jax.tree_util.tree_map(get_arg_scope, (args, kwargs))

# Gather scopes in Variables and Submodules passed as Module attributes.
@functools.partial(_memoize_by_id, refs=refs)
def get_scopes(module):
nonlocal scopes
module._try_setup(shallow=True) # pylint: disable=protected-access
module._try_setup(shallow=True)
def get_scopes_inner(x):
nonlocal scopes
if isinstance(x, Module) and isinstance(x.scope, Scope):
Expand Down Expand Up @@ -303,9 +304,9 @@ def core_fn(scopes, *args, **kwargs):
# we reference module_class, not self.__class__ to avoid infinite loop
cloned = module_class(parent=None, **attrs)
cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes)
object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access
object.__setattr__(cloned, '_state', state.export())
res = fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state)
_test_transformed_return_values(res, fn_name)
return res
# here we apply the given lifting transform to the scope-ingesting fn
Expand Down Expand Up @@ -351,9 +352,9 @@ def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs):
if not multi_scope:
scopes = [scopes]
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access
object.__setattr__(cloned, '_state', state.export())
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state)
_test_transformed_return_values(res, getattr(class_fn, '__name__', None))
return res
core_fns = [functools.partial(core_fn, prewrapped_fn, class_fn)
Expand Down Expand Up @@ -1325,8 +1326,8 @@ def wrapped_fn(self, *args, **kwargs):
prewrapped_fn = wrap_method_once(class_fn)
@functools.wraps(prewrapped_fn)
def wrapped_fn(self, *args, **kwargs):
if ((not force and not linen_module._use_named_call) # pylint: disable=protected-access
or self._state.in_setup): # pylint: disable=protected-access
if ((not force and not linen_module._use_named_call)
or self._state.in_setup):
return prewrapped_fn(self, *args, **kwargs)
fn_name = class_fn.__name__
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
Expand All @@ -1335,9 +1336,9 @@ def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
object.__setattr__(cloned, '_state', self._state.export())
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state)
_test_transformed_return_values(res, fn_name)
return res
# here we apply the given lifting transform to the scope-ingesting fn
Expand Down
16 changes: 16 additions & 0 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

"""Tests for flax.linen."""

import copy
from absl.testing import absltest, parameterized

from flax import ids
from flax import linen as nn

import jax
Expand Down Expand Up @@ -361,5 +363,19 @@ def test_optimized_lstm_cell_matches_regular(self):
jtu.check_eq(lstm_params, lstm_opt_params)


class IdsTest(absltest.TestCase):

def test_hashable(self):
id1 = ids.uuid()
id2 = ids.uuid()
self.assertEqual(id1, id1)
self.assertNotEqual(id1, id2)
self.assertNotEqual(hash(id1), hash(id2))
id1c = copy.copy(id1)
id1dc = copy.deepcopy(id1)
self.assertNotEqual(hash(id1), hash(id1c))
self.assertNotEqual(hash(id1), hash(id1dc))


if __name__ == '__main__':
absltest.main()

0 comments on commit db4a74b

Please sign in to comment.