Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] lock_ / unlock_ graphs #2729

Open
wants to merge 5 commits into
base: gh/vmoens/82/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 88 additions & 24 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,36 +697,100 @@ def test_create_composite_nested(shape, device):
assert c["a"].device == device


@pytest.mark.parametrize("recurse", [True, False])
def test_lock(recurse):
shape = [3, 4, 5]
spec = Composite(
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
shape=shape[:1],
)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
assert not spec.locked
spec.lock_(recurse=recurse)
assert spec.locked
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
class TestLock:
@pytest.mark.parametrize("recurse", [None, True, False])
def test_lock(self, recurse):
catch_warn = (
pytest.warns(DeprecationWarning, match="recurse")
if recurse is None
else contextlib.nullcontext()
)

shape = [3, 4, 5]
spec = Composite(
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
shape=shape[:1],
)
spec["a"] = spec["a"].clone()
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
spec["a", "b"] = spec["a", "b"].clone()
assert not spec.locked
with catch_warn:
spec.lock_(recurse=recurse)
assert spec.locked
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"].set("b", spec["a", "b"].clone())
spec["a"] = spec["a"].clone()
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"].set("b", spec["a", "b"].clone())
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
spec["a"].set("b", spec["a", "b"].clone())
with catch_warn:
spec.unlock_(recurse=recurse)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
spec["a"].set("b", spec["a", "b"].clone())
spec.unlock_(recurse=recurse)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
spec["a"].set("b", spec["a", "b"].clone())

def test_edge_cases(self):
level3 = Composite()
level2 = Composite(level3=level3)
level1 = Composite(level2=level2)
level0 = Composite(level1=level1)
# locking level0 locks them all
level0.lock_(recurse=True)
assert level3.is_locked
# We cannot unlock level3
with pytest.raises(
RuntimeError,
match="Cannot unlock a Composite that is part of a locked graph",
):
level3.unlock_(recurse=True)
assert level3.is_locked
# Adding level2 to a new spec and locking it makes it hard to unlock the level0 root
new_spec = Composite(level2=level2)
new_spec.lock_(recurse=True)
with pytest.raises(
RuntimeError,
match="Cannot unlock a Composite that is part of a locked graph",
):
level0.unlock_(recurse=True)
assert level0.is_locked

def test_lock_mix_recurse_nonrecurse(self):
# lock with recurse
level3 = Composite()
level2 = Composite(level3=level3)
level1 = Composite(level2=level2)
level0 = Composite(level1=level1)
# locking level0 locks them all
level0.lock_(recurse=True)
new_spec = Composite(level2=level2)
new_spec.lock_(recurse=True)

# Unlock with recurse=False
with pytest.raises(RuntimeError, match="Cannot unlock"):
level3.unlock_(recurse=False)
assert level3.is_locked
assert level2.is_locked
assert new_spec.is_locked
with pytest.raises(RuntimeError, match="Cannot unlock"):
level2.unlock_(recurse=False)
with pytest.raises(RuntimeError, match="Cannot unlock"):
level1.unlock_(recurse=False)
level0.unlock_(recurse=False)
assert level3.is_locked
assert level2.is_locked
assert level1.is_locked
new_spec.unlock_(recurse=False)
assert level3.is_locked
assert level2.is_locked
assert level1.is_locked


def test_keys_to_empty_composite_spec():
Expand Down
177 changes: 160 additions & 17 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import abc
import enum
import gc
import math
import warnings
import weakref
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -4428,7 +4430,7 @@ class Composite(TensorSpec):
@classmethod
def __new__(cls, *args, **kwargs):
cls._device = None
cls._locked = False
cls._is_locked = False
return super().__new__(cls)

@property
Expand Down Expand Up @@ -4959,6 +4961,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite:
return self.__class__(**kwargs, device=_device, shape=self.shape)

def clone(self) -> Composite:
"""Clones the Composite spec.

Locked specs will not produce locked clones.
"""
try:
device = self.device
except RuntimeError:
Expand Down Expand Up @@ -5170,14 +5176,82 @@ def unbind(self, dim: int = 0):
for i in range(self.shape[dim])
)

def lock_(self, recurse=False):
"""Locks the Composite and prevents modification of its content.
# Locking functionality
@property
def is_locked(self) -> bool:
return self._is_locked

@is_locked.setter
def is_locked(self, value: bool) -> None:
if value:
self.lock_()
else:
self.unlock_()

def __getstate__(self):
result = self.__dict__.copy()
__lock_parents_weakrefs = result.pop("__lock_parents_weakrefs", None)
if __lock_parents_weakrefs is not None:
result["_lock_recurse"] = True
return result

def __setstate__(self, state):
_lock_recurse = state.pop("_lock_recurse", False)
for key, value in state.items():
setattr(self, key, value)
if self._is_locked:
self._is_locked = False
self.lock_(recurse=_lock_recurse)

def _propagate_lock(
self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling
):
"""Registers the parent composite that handles the lock."""
self._is_locked = True
if lock_parents_weakrefs is not None:
lock_parents_weakrefs = [
ref
for ref in lock_parents_weakrefs
if not any(refref is ref for refref in self._lock_parents_weakrefs)
]
if not is_compiling:
is_root = lock_parents_weakrefs is None
if is_root:
lock_parents_weakrefs = []
else:
self._lock_parents_weakrefs = (
self._lock_parents_weakrefs + lock_parents_weakrefs
)
lock_parents_weakrefs = list(lock_parents_weakrefs)
lock_parents_weakrefs.append(weakref.ref(self))

This is only a first-level lock, unless specified otherwise through the
``recurse`` arg.
if recurse:
for value in self.values():
if isinstance(value, Composite):
value._propagate_lock(
recurse=True,
lock_parents_weakrefs=lock_parents_weakrefs,
is_compiling=is_compiling,
)

Leaf specs can always be modified in place, but they cannot be replaced
in their Composite parent.
@property
def _lock_parents_weakrefs(self):
_lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs")
if _lock_parents_weakrefs is None:
self.__dict__["__lock_parents_weakrefs"] = []
_lock_parents_weakrefs = self.__dict__["__lock_parents_weakrefs"]
return _lock_parents_weakrefs

@_lock_parents_weakrefs.setter
def _lock_parents_weakrefs(self, value: list):
self.__dict__["__lock_parents_weakrefs"] = value

def lock_(self, recurse: bool | None = None) -> T:
"""Locks the Composite and prevents modification of its content.

The recurse argument control whether the lock will be propagated to sub-specs.
The current default is ``False`` but it will be turned to ``True`` for consistency
with the TensorDict API in v0.8.

Examples:
>>> shape = [3, 4, 5]
Expand Down Expand Up @@ -5211,30 +5285,99 @@ def lock_(self, recurse=False):
failed!

"""
self._locked = True
if self.is_locked:
return self
is_comp = is_compiling()
if is_comp:
# TODO: See what to do when compiling
pass
if recurse is None:
warnings.warn(
"You have not specified a value for recurse when calling CompositeSpec.lock_(). "
"The current default is False but it will be turned to True in v0.8. To adapt to these changes "
"and silence this warning, pass the value of recurse explicitly.",
category=DeprecationWarning,
)
recurse = False
self._propagate_lock(recurse=recurse, is_compiling=is_comp)
return self

def _propagate_unlock(self, recurse: bool):
# if we end up here, we can clear the graph associated with this td
self._is_locked = False

self._is_shared = False
self._is_memmap = False

if recurse:
sub_specs = []
for value in self.values():
if isinstance(value, Composite):
value.lock_(recurse)
return self
sub_specs.extend(value._propagate_unlock(recurse=recurse))
sub_specs.append(value)
return sub_specs
return []

def _check_unlock(self, first_attempt=True):
if not first_attempt:
gc.collect()
obj = None
for ref in self._lock_parents_weakrefs:
obj = ref()
# check if the locked parent exists and if it's locked
# we check _is_locked because it can be False or None in the case of Lazy stacks,
# but if we check obj.is_locked it will be True for this class.
if obj is not None and obj._is_locked:
break

def unlock_(self, recurse=False):
else:
try:
self._lock_parents_weakrefs = []
except AttributeError:
# Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref
pass
return

if first_attempt:
del obj
return self._check_unlock(False)
raise RuntimeError(
"Cannot unlock a Composite that is part of a locked graph. "
"Graphs are locked when a Composite is locked with recurse=True. "
"Unlock the root Composite first. If the Composite is part of multiple graphs, "
"group the graphs under a common Composite an unlock this root. "
f"self: {self}, obj: {obj}"
)

def unlock_(self, recurse: bool | None = None) -> T:
"""Unlocks the Composite and allows modification of its content.

This is only a first-level lock modification, unless specified
otherwise through the ``recurse`` arg.

"""
self._locked = False
if recurse:
for value in self.values():
if isinstance(value, Composite):
value.unlock_(recurse)
try:
if recurse is None:
warnings.warn(
"You have not specified a value for recurse when calling CompositeSpec.unlock_(). "
"The current default is False but it will be turned to True in v0.8. To adapt to these changes "
"and silence this warning, pass the value of recurse explicitly.",
category=DeprecationWarning,
)
recurse = False
sub_specs = self._propagate_unlock(recurse=recurse)
if recurse:
for sub_spec in sub_specs:
sub_spec._check_unlock()
self._check_unlock()
except RuntimeError as err:
self.lock_()
raise err
return self

@property
def locked(self):
return self._locked
return self._is_locked


class StackedComposite(_LazyStackedMixin[Composite], Composite):
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
specs = make_composite_from_td(data)

obs_spec = self.observation_spec
obs_spec.unlock_()
obs_spec.unlock_(recurse=True)
obs_spec.update(specs)
obs_spec.lock_()
obs_spec.lock_(recurse=True)

def _output_transform(self, output):
obs, reward, done, info = output
Expand Down
Loading
Loading