Skip to content

Commit

Permalink
[Feature] lock_ / unlock_ graphs
Browse files Browse the repository at this point in the history
ghstack-source-id: a418d41b561118498199695d479176ce99dd6b80
Pull Request resolved: #2729
  • Loading branch information
vmoens committed Jan 30, 2025
1 parent b20239e commit 03b0261
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 49 deletions.
112 changes: 88 additions & 24 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,36 +697,100 @@ def test_create_composite_nested(shape, device):
assert c["a"].device == device


@pytest.mark.parametrize("recurse", [True, False])
def test_lock(recurse):
shape = [3, 4, 5]
spec = Composite(
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
shape=shape[:1],
)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
assert not spec.locked
spec.lock_(recurse=recurse)
assert spec.locked
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
class TestLock:
@pytest.mark.parametrize("recurse", [None, True, False])
def test_lock(self, recurse):
catch_warn = (
pytest.warns(DeprecationWarning, match="recurse")
if recurse is None
else contextlib.nullcontext()
)

shape = [3, 4, 5]
spec = Composite(
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
shape=shape[:1],
)
spec["a"] = spec["a"].clone()
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
spec["a", "b"] = spec["a", "b"].clone()
assert not spec.locked
with catch_warn:
spec.lock_(recurse=recurse)
assert spec.locked
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"].set("b", spec["a", "b"].clone())
spec["a"] = spec["a"].clone()
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"].set("b", spec["a", "b"].clone())
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
spec["a"].set("b", spec["a", "b"].clone())
with catch_warn:
spec.unlock_(recurse=recurse)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
spec["a"].set("b", spec["a", "b"].clone())
spec.unlock_(recurse=recurse)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
spec["a"].set("b", spec["a", "b"].clone())

def test_edge_cases(self):
level3 = Composite()
level2 = Composite(level3=level3)
level1 = Composite(level2=level2)
level0 = Composite(level1=level1)
# locking level0 locks them all
level0.lock_(recurse=True)
assert level3.is_locked
# We cannot unlock level3
with pytest.raises(
RuntimeError,
match="Cannot unlock a Composite that is part of a locked graph",
):
level3.unlock_(recurse=True)
assert level3.is_locked
# Adding level2 to a new spec and locking it makes it hard to unlock the level0 root
new_spec = Composite(level2=level2)
new_spec.lock_(recurse=True)
with pytest.raises(
RuntimeError,
match="Cannot unlock a Composite that is part of a locked graph",
):
level0.unlock_(recurse=True)
assert level0.is_locked

def test_lock_mix_recurse_nonrecurse(self):
# lock with recurse
level3 = Composite()
level2 = Composite(level3=level3)
level1 = Composite(level2=level2)
level0 = Composite(level1=level1)
# locking level0 locks them all
level0.lock_(recurse=True)
new_spec = Composite(level2=level2)
new_spec.lock_(recurse=True)

# Unlock with recurse=False
with pytest.raises(RuntimeError, match="Cannot unlock"):
level3.unlock_(recurse=False)
assert level3.is_locked
assert level2.is_locked
assert new_spec.is_locked
with pytest.raises(RuntimeError, match="Cannot unlock"):
level2.unlock_(recurse=False)
with pytest.raises(RuntimeError, match="Cannot unlock"):
level1.unlock_(recurse=False)
level0.unlock_(recurse=False)
assert level3.is_locked
assert level2.is_locked
assert level1.is_locked
new_spec.unlock_(recurse=False)
assert level3.is_locked
assert level2.is_locked
assert level1.is_locked


def test_keys_to_empty_composite_spec():
Expand Down
158 changes: 141 additions & 17 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import abc
import enum
import gc
import math
import warnings
import weakref
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -4428,7 +4430,7 @@ class Composite(TensorSpec):
@classmethod
def __new__(cls, *args, **kwargs):
cls._device = None
cls._locked = False
cls._is_locked = False
return super().__new__(cls)

@property
Expand Down Expand Up @@ -5170,14 +5172,67 @@ def unbind(self, dim: int = 0):
for i in range(self.shape[dim])
)

def lock_(self, recurse=False):
"""Locks the Composite and prevents modification of its content.
# Locking functionality
@property
def is_locked(self) -> bool:
return self._is_locked

This is only a first-level lock, unless specified otherwise through the
``recurse`` arg.
@is_locked.setter
def is_locked(self, value: bool) -> None:
if value:
self.lock_()
else:
self.unlock_()

Leaf specs can always be modified in place, but they cannot be replaced
in their Composite parent.
def _propagate_lock(
self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling
):
"""Registers the parent composite that handles the lock."""
self._is_locked = True
if lock_parents_weakrefs is not None:
lock_parents_weakrefs = [
ref
for ref in lock_parents_weakrefs
if not any(refref is ref for refref in self._lock_parents_weakrefs)
]
if not is_compiling:
is_root = lock_parents_weakrefs is None
if is_root:
lock_parents_weakrefs = []
else:
self._lock_parents_weakrefs = (
self._lock_parents_weakrefs + lock_parents_weakrefs
)
lock_parents_weakrefs = list(lock_parents_weakrefs)
lock_parents_weakrefs.append(weakref.ref(self))

if recurse:
for value in self.values():
if isinstance(value, Composite):
value._propagate_lock(
recurse=True,
lock_parents_weakrefs=lock_parents_weakrefs,
is_compiling=is_compiling,
)

@property
def _lock_parents_weakrefs(self):
_lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs")
if _lock_parents_weakrefs is None:
self.__dict__["__lock_parents_weakrefs"] = []
_lock_parents_weakrefs = self.__dict__["__lock_parents_weakrefs"]
return _lock_parents_weakrefs

@_lock_parents_weakrefs.setter
def _lock_parents_weakrefs(self, value: list):
self.__dict__["__lock_parents_weakrefs"] = value

def lock_(self, recurse: bool | None = None) -> T:
"""Locks the Composite and prevents modification of its content.
The recurse argument control whether the lock will be propagated to sub-specs.
The current default is ``False`` but it will be turned to ``True`` for consistency
with the TensorDict API in v0.8.
Examples:
>>> shape = [3, 4, 5]
Expand Down Expand Up @@ -5211,30 +5266,99 @@ def lock_(self, recurse=False):
failed!
"""
self._locked = True
if self.is_locked:
return self
is_comp = is_compiling()
if is_comp:
# TODO: See what to do when compiling
pass
if recurse is None:
warnings.warn(
"You have not specified a value for recurse when calling CompositeSpec.lock_(). "
"The current default is False but it will be turned to True in v0.8. To adapt to these changes "
"and silence this warning, pass the value of recurse explicitly.",
category=DeprecationWarning,
)
recurse = False
self._propagate_lock(recurse=recurse, is_compiling=is_comp)
return self

def _propagate_unlock(self, recurse: bool):
# if we end up here, we can clear the graph associated with this td
self._is_locked = False

self._is_shared = False
self._is_memmap = False

if recurse:
sub_specs = []
for value in self.values():
if isinstance(value, Composite):
value.lock_(recurse)
return self
sub_specs.extend(value._propagate_unlock(recurse=recurse))
sub_specs.append(value)
return sub_specs
return []

def _check_unlock(self, first_attempt=True):
if not first_attempt:
gc.collect()
obj = None
for ref in self._lock_parents_weakrefs:
obj = ref()
# check if the locked parent exists and if it's locked
# we check _is_locked because it can be False or None in the case of Lazy stacks,
# but if we check obj.is_locked it will be True for this class.
if obj is not None and obj._is_locked:
break

else:
try:
self._lock_parents_weakrefs = []
except AttributeError:
# Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref
pass
return

def unlock_(self, recurse=False):
if first_attempt:
del obj
return self._check_unlock(False)
raise RuntimeError(
"Cannot unlock a Composite that is part of a locked graph. "
"Graphs are locked when a Composite is locked with recurse=True. "
"Unlock the root Composite first. If the Composite is part of multiple graphs, "
"group the graphs under a common Composite an unlock this root. "
f"self: {self}, obj: {obj}"
)

def unlock_(self, recurse: bool | None = None) -> T:
"""Unlocks the Composite and allows modification of its content.
This is only a first-level lock modification, unless specified
otherwise through the ``recurse`` arg.
"""
self._locked = False
if recurse:
for value in self.values():
if isinstance(value, Composite):
value.unlock_(recurse)
try:
if recurse is None:
warnings.warn(
"You have not specified a value for recurse when calling CompositeSpec.unlock_(). "
"The current default is False but it will be turned to True in v0.8. To adapt to these changes "
"and silence this warning, pass the value of recurse explicitly.",
category=DeprecationWarning,
)
recurse = False
sub_specs = self._propagate_unlock(recurse=recurse)
if recurse:
for sub_spec in sub_specs:
sub_spec._check_unlock()
self._check_unlock()
except RuntimeError as err:
self.lock_()
raise err
return self

@property
def locked(self):
return self._locked
return self._is_locked


class StackedComposite(_LazyStackedMixin[Composite], Composite):
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
specs = make_composite_from_td(data)

obs_spec = self.observation_spec
obs_spec.unlock_()
obs_spec.unlock_(recurse=True)
obs_spec.update(specs)
obs_spec.lock_()
obs_spec.lock_(recurse=True)

def _output_transform(self, output):
obs, reward, done, info = output
Expand Down
7 changes: 1 addition & 6 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@
Unbounded,
UnboundedContinuous,
)
from torchrl.envs.common import (
_do_nothing,
_EnvPostInit,
EnvBase,
make_tensordict,
)
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import (
_get_reset,
Expand Down

0 comments on commit 03b0261

Please sign in to comment.