Skip to content

Commit

Permalink
Replace use of id() with global counter-based id.
Browse files Browse the repository at this point in the history
Historically we key'd on id() to record sharing relationships during
lifting and outer module adoption.  This was dumb, and after recently
fixing one bad bug arising from id reuse, we should use something
sound instead.
  • Loading branch information
levskaya committed Jul 22, 2022
1 parent e7742de commit 963fa4a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 24 deletions.
5 changes: 3 additions & 2 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from typing import (Any, Callable, Dict, Generic, Iterable, Mapping, Optional,
Sequence, Set, Tuple, TypeVar, Union)

from . import tracers
from flax.ids import uuid
from flax import config
from flax import errors
from flax import struct
from flax import traceback_util
from .frozen_dict import freeze
from .frozen_dict import FrozenDict
from .frozen_dict import unfreeze
from . import tracers
import jax
from jax import config as jax_config
from jax import numpy as jnp
Expand All @@ -51,7 +52,6 @@
# When conditioning on filters we require explicit boolean comparisons.
# pylint: disable=g-bool-id-comparison


@dataclasses.dataclass(frozen=True, eq=True)
class DenyList:
"""DenyList represents an opt-out based mutability filter.
Expand Down Expand Up @@ -343,6 +343,7 @@ def __init__(self, scope: 'Scope', collection: str, name: str):
collection: The collection of the variable (e.g., "params").
name: The name of the variable (e.g., "dense").
"""
self._id = uuid()
self.scope = scope
self.collection = collection
self.name = name
Expand Down
57 changes: 57 additions & 0 deletions flax/ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""UUIDs for Flax internals."""

import threading


class UUIDManager:
"""Globally unique counter-based id manager.
We need globally unique key ids for Module and Variable object instances
to preserve and recreate sharing-by-reference relationship when lifting
transforms and adopting outside Modules.
- Use of id() is unacceptable because these identifiers are literally
pointers which can be recycled, so we rely on a globally unique counter id
instead.
- We need to handle copy/deepcopy uniqueness via a wrapped type.
"""
def __init__(self):
self._lock = threading.Lock()
self._id = 0

def __call__(self):
with self._lock:
self._id += 1
return FlaxId(self._id)

uuid = UUIDManager()


class FlaxId:
"""Hashable wrapper for ids that handles uniqueness of copies."""
def __init__(self, rawid):
self.id = rawid
def __eq__(self, other):
return isinstance(other, FlaxId) and other.id == self.id
def __hash__(self):
return hash(self.id)
def __repr__(self):
return f"FlaxId({self.id})"
def __deepcopy__(self, memo):
del memo
return uuid()
def __copy__(self):
return uuid()
14 changes: 7 additions & 7 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from flax.core.scope import ( # pylint: disable=g-multiple-import
CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict,
union_filters)
from flax.ids import uuid
from flax.linen import summary


Expand Down Expand Up @@ -730,6 +731,7 @@ def __post_init__(self) -> None:
# initialization, attach this Module as a submodule of a parent, or bind
# this Module at the top-level to variables and rngs.

object.__setattr__(self, '_id', uuid())
object.__setattr__(self, '_state', _ModuleInternalState())

# Typically we set the parent based on the dynamic module context.
Expand Down Expand Up @@ -827,14 +829,12 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
# Module was passed from outside. It needs to be cloned.
# Outside modules are named by attachment, not an outer name.
object.__setattr__(subvalue, 'name', None)
key = id(subvalue)
# Preserve sharing-by-reference relationships during adoption
# via cache keyed on unique instance ids.
key = subvalue._id
if key not in cache:
# since we use id() as key, we need to keep a reference to original
# subvalue to ensure it's lifetime is long enough for the entire
# model setup and the id() is not recycled.
# TODO(levskaya): consider switching to per-module UUIDs
cache[key] = (subvalue.clone(), subvalue)
subvalue = cache[key][0]
cache[key] = subvalue.clone()
subvalue = cache[key]
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
object.__setattr__(subvalue, 'name', f'{name}{suffix}')
Expand Down
31 changes: 16 additions & 15 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

traceback_util.register_exclusion(__file__)

# pylint: disable=protected-access

# Utils
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -79,12 +80,12 @@ def wrapped_fn(x):
nonlocal refs
if isinstance(x, (VariablePlaceholder, InstancePlaceholder)):
x_id = x.id
elif isinstance(x, (Variable, Module)):
x_id = x._id
else:
x_id = id(x)
return fn(x)
if x_id not in refs:
refs[x_id] = fn(x)
else:
pass
return refs[x_id]
return wrapped_fn

Expand Down Expand Up @@ -124,25 +125,25 @@ def get_arg_scope(x):
nonlocal scopes
if isinstance(x, Variable) and isinstance(x.scope, Scope):
scopes.append(x.scope)
return VariablePlaceholder(x.collection, x.name, id(x))
return VariablePlaceholder(x.collection, x.name, x._id)
elif isinstance(x, Module) and isinstance(x.scope, Scope):
x._try_setup(shallow=True) # pylint: disable=protected-access
x._try_setup(shallow=True)
scopes.append(x.scope)
attrs = {
f.name: getattr(x, f.name)
for f in dataclasses.fields(x)
if f.name != 'parent' and f.init
}
attrs = jax.tree_util.tree_map(get_arg_scope, attrs)
return InstancePlaceholder(x.__class__, attrs, id(x))
return InstancePlaceholder(x.__class__, attrs, x._id)
return x
new_args, new_kwargs = jax.tree_util.tree_map(get_arg_scope, (args, kwargs))

# Gather scopes in Variables and Submodules passed as Module attributes.
@functools.partial(_memoize_by_id, refs=refs)
def get_scopes(module):
nonlocal scopes
module._try_setup(shallow=True) # pylint: disable=protected-access
module._try_setup(shallow=True)
def get_scopes_inner(x):
nonlocal scopes
if isinstance(x, Module) and isinstance(x.scope, Scope):
Expand Down Expand Up @@ -303,9 +304,9 @@ def core_fn(scopes, *args, **kwargs):
# we reference module_class, not self.__class__ to avoid infinite loop
cloned = module_class(parent=None, **attrs)
cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes)
object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access
object.__setattr__(cloned, '_state', state.export())
res = fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state)
_test_transformed_return_values(res, fn_name)
return res
# here we apply the given lifting transform to the scope-ingesting fn
Expand Down Expand Up @@ -351,9 +352,9 @@ def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs):
if not multi_scope:
scopes = [scopes]
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access
object.__setattr__(cloned, '_state', state.export())
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state)
_test_transformed_return_values(res, getattr(class_fn, '__name__', None))
return res
core_fns = [functools.partial(core_fn, prewrapped_fn, class_fn)
Expand Down Expand Up @@ -1325,8 +1326,8 @@ def wrapped_fn(self, *args, **kwargs):
prewrapped_fn = wrap_method_once(class_fn)
@functools.wraps(prewrapped_fn)
def wrapped_fn(self, *args, **kwargs):
if ((not force and not linen_module._use_named_call) # pylint: disable=protected-access
or self._state.in_setup): # pylint: disable=protected-access
if ((not force and not linen_module._use_named_call)
or self._state.in_setup):
return prewrapped_fn(self, *args, **kwargs)
fn_name = class_fn.__name__
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
Expand All @@ -1335,9 +1336,9 @@ def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
object.__setattr__(cloned, '_state', self._state.export())
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state)
_test_transformed_return_values(res, fn_name)
return res
# here we apply the given lifting transform to the scope-ingesting fn
Expand Down
16 changes: 16 additions & 0 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

"""Tests for flax.linen."""

import copy
from absl.testing import absltest, parameterized

from flax import ids
from flax import linen as nn

import jax
Expand Down Expand Up @@ -361,5 +363,19 @@ def test_optimized_lstm_cell_matches_regular(self):
jtu.check_eq(lstm_params, lstm_opt_params)


class IdsTest(absltest.TestCase):

def test_hashable(self):
id1 = ids.uuid()
id2 = ids.uuid()
self.assertEqual(id1, id1)
self.assertNotEqual(id1, id2)
self.assertNotEqual(hash(id1), hash(id2))
id1c = copy.copy(id1)
id1dc = copy.deepcopy(id1)
self.assertNotEqual(hash(id1), hash(id1c))
self.assertNotEqual(hash(id1), hash(id1dc))


if __name__ == '__main__':
absltest.main()

0 comments on commit 963fa4a

Please sign in to comment.