diff --git a/test/test_specs.py b/test/test_specs.py index 07762c7ad30..340afaa449a 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -697,36 +697,100 @@ def test_create_composite_nested(shape, device): assert c["a"].device == device -@pytest.mark.parametrize("recurse", [True, False]) -def test_lock(recurse): - shape = [3, 4, 5] - spec = Composite( - a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]), - shape=shape[:1], - ) - spec["a"] = spec["a"].clone() - spec["a", "b"] = spec["a", "b"].clone() - assert not spec.locked - spec.lock_(recurse=recurse) - assert spec.locked - with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): +class TestLock: + @pytest.mark.parametrize("recurse", [None, True, False]) + def test_lock(self, recurse): + catch_warn = ( + pytest.warns(DeprecationWarning, match="recurse") + if recurse is None + else contextlib.nullcontext() + ) + + shape = [3, 4, 5] + spec = Composite( + a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]), + shape=shape[:1], + ) spec["a"] = spec["a"].clone() - with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): - spec.set("a", spec["a"].clone()) - if recurse: - assert spec["a"].locked + spec["a", "b"] = spec["a", "b"].clone() + assert not spec.locked + with catch_warn: + spec.lock_(recurse=recurse) + assert spec.locked with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): - spec["a"].set("b", spec["a", "b"].clone()) + spec["a"] = spec["a"].clone() with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): + spec.set("a", spec["a"].clone()) + if recurse: + assert spec["a"].locked + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): + spec["a"].set("b", spec["a", "b"].clone()) + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): + spec["a", "b"] = spec["a", "b"].clone() + else: + assert not spec["a"].locked spec["a", "b"] = spec["a", "b"].clone() - else: - assert not spec["a"].locked + spec["a"].set("b", spec["a", "b"].clone()) + with catch_warn: + spec.unlock_(recurse=recurse) + spec["a"] = spec["a"].clone() spec["a", "b"] = spec["a", "b"].clone() spec["a"].set("b", spec["a", "b"].clone()) - spec.unlock_(recurse=recurse) - spec["a"] = spec["a"].clone() - spec["a", "b"] = spec["a", "b"].clone() - spec["a"].set("b", spec["a", "b"].clone()) + + def test_edge_cases(self): + level3 = Composite() + level2 = Composite(level3=level3) + level1 = Composite(level2=level2) + level0 = Composite(level1=level1) + # locking level0 locks them all + level0.lock_(recurse=True) + assert level3.is_locked + # We cannot unlock level3 + with pytest.raises( + RuntimeError, + match="Cannot unlock a Composite that is part of a locked graph", + ): + level3.unlock_(recurse=True) + assert level3.is_locked + # Adding level2 to a new spec and locking it makes it hard to unlock the level0 root + new_spec = Composite(level2=level2) + new_spec.lock_(recurse=True) + with pytest.raises( + RuntimeError, + match="Cannot unlock a Composite that is part of a locked graph", + ): + level0.unlock_(recurse=True) + assert level0.is_locked + + def test_lock_mix_recurse_nonrecurse(self): + # lock with recurse + level3 = Composite() + level2 = Composite(level3=level3) + level1 = Composite(level2=level2) + level0 = Composite(level1=level1) + # locking level0 locks them all + level0.lock_(recurse=True) + new_spec = Composite(level2=level2) + new_spec.lock_(recurse=True) + + # Unlock with recurse=False + with pytest.raises(RuntimeError, match="Cannot unlock"): + level3.unlock_(recurse=False) + assert level3.is_locked + assert level2.is_locked + assert new_spec.is_locked + with pytest.raises(RuntimeError, match="Cannot unlock"): + level2.unlock_(recurse=False) + with pytest.raises(RuntimeError, match="Cannot unlock"): + level1.unlock_(recurse=False) + level0.unlock_(recurse=False) + assert level3.is_locked + assert level2.is_locked + assert level1.is_locked + new_spec.unlock_(recurse=False) + assert level3.is_locked + assert level2.is_locked + assert level1.is_locked def test_keys_to_empty_composite_spec(): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a7914f4f1d7..0296b55f972 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -7,8 +7,10 @@ import abc import enum +import gc import math import warnings +import weakref from collections.abc import Iterable from copy import deepcopy from dataclasses import dataclass @@ -4428,7 +4430,7 @@ class Composite(TensorSpec): @classmethod def __new__(cls, *args, **kwargs): cls._device = None - cls._locked = False + cls._is_locked = False return super().__new__(cls) @property @@ -4959,6 +4961,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite: return self.__class__(**kwargs, device=_device, shape=self.shape) def clone(self) -> Composite: + """Clones the Composite spec. + + Locked specs will not produce locked clones. + """ try: device = self.device except RuntimeError: @@ -5170,14 +5176,82 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def lock_(self, recurse=False): - """Locks the Composite and prevents modification of its content. + # Locking functionality + @property + def is_locked(self) -> bool: + return self._is_locked + + @is_locked.setter + def is_locked(self, value: bool) -> None: + if value: + self.lock_() + else: + self.unlock_() + + def __getstate__(self): + result = self.__dict__.copy() + __lock_parents_weakrefs = result.pop("__lock_parents_weakrefs", None) + if __lock_parents_weakrefs is not None: + result["_lock_recurse"] = True + return result + + def __setstate__(self, state): + _lock_recurse = state.pop("_lock_recurse", False) + for key, value in state.items(): + setattr(self, key, value) + if self._is_locked: + self._is_locked = False + self.lock_(recurse=_lock_recurse) + + def _propagate_lock( + self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling + ): + """Registers the parent composite that handles the lock.""" + self._is_locked = True + if lock_parents_weakrefs is not None: + lock_parents_weakrefs = [ + ref + for ref in lock_parents_weakrefs + if not any(refref is ref for refref in self._lock_parents_weakrefs) + ] + if not is_compiling: + is_root = lock_parents_weakrefs is None + if is_root: + lock_parents_weakrefs = [] + else: + self._lock_parents_weakrefs = ( + self._lock_parents_weakrefs + lock_parents_weakrefs + ) + lock_parents_weakrefs = list(lock_parents_weakrefs) + lock_parents_weakrefs.append(weakref.ref(self)) - This is only a first-level lock, unless specified otherwise through the - ``recurse`` arg. + if recurse: + for value in self.values(): + if isinstance(value, Composite): + value._propagate_lock( + recurse=True, + lock_parents_weakrefs=lock_parents_weakrefs, + is_compiling=is_compiling, + ) - Leaf specs can always be modified in place, but they cannot be replaced - in their Composite parent. + @property + def _lock_parents_weakrefs(self): + _lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs") + if _lock_parents_weakrefs is None: + self.__dict__["__lock_parents_weakrefs"] = [] + _lock_parents_weakrefs = self.__dict__["__lock_parents_weakrefs"] + return _lock_parents_weakrefs + + @_lock_parents_weakrefs.setter + def _lock_parents_weakrefs(self, value: list): + self.__dict__["__lock_parents_weakrefs"] = value + + def lock_(self, recurse: bool | None = None) -> T: + """Locks the Composite and prevents modification of its content. + + The recurse argument control whether the lock will be propagated to sub-specs. + The current default is ``False`` but it will be turned to ``True`` for consistency + with the TensorDict API in v0.8. Examples: >>> shape = [3, 4, 5] @@ -5211,30 +5285,99 @@ def lock_(self, recurse=False): failed! """ - self._locked = True + if self.is_locked: + return self + is_comp = is_compiling() + if is_comp: + # TODO: See what to do when compiling + pass + if recurse is None: + warnings.warn( + "You have not specified a value for recurse when calling CompositeSpec.lock_(). " + "The current default is False but it will be turned to True in v0.8. To adapt to these changes " + "and silence this warning, pass the value of recurse explicitly.", + category=DeprecationWarning, + ) + recurse = False + self._propagate_lock(recurse=recurse, is_compiling=is_comp) + return self + + def _propagate_unlock(self, recurse: bool): + # if we end up here, we can clear the graph associated with this td + self._is_locked = False + + self._is_shared = False + self._is_memmap = False + if recurse: + sub_specs = [] for value in self.values(): if isinstance(value, Composite): - value.lock_(recurse) - return self + sub_specs.extend(value._propagate_unlock(recurse=recurse)) + sub_specs.append(value) + return sub_specs + return [] + + def _check_unlock(self, first_attempt=True): + if not first_attempt: + gc.collect() + obj = None + for ref in self._lock_parents_weakrefs: + obj = ref() + # check if the locked parent exists and if it's locked + # we check _is_locked because it can be False or None in the case of Lazy stacks, + # but if we check obj.is_locked it will be True for this class. + if obj is not None and obj._is_locked: + break - def unlock_(self, recurse=False): + else: + try: + self._lock_parents_weakrefs = [] + except AttributeError: + # Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref + pass + return + + if first_attempt: + del obj + return self._check_unlock(False) + raise RuntimeError( + "Cannot unlock a Composite that is part of a locked graph. " + "Graphs are locked when a Composite is locked with recurse=True. " + "Unlock the root Composite first. If the Composite is part of multiple graphs, " + "group the graphs under a common Composite an unlock this root. " + f"self: {self}, obj: {obj}" + ) + + def unlock_(self, recurse: bool | None = None) -> T: """Unlocks the Composite and allows modification of its content. This is only a first-level lock modification, unless specified otherwise through the ``recurse`` arg. """ - self._locked = False - if recurse: - for value in self.values(): - if isinstance(value, Composite): - value.unlock_(recurse) + try: + if recurse is None: + warnings.warn( + "You have not specified a value for recurse when calling CompositeSpec.unlock_(). " + "The current default is False but it will be turned to True in v0.8. To adapt to these changes " + "and silence this warning, pass the value of recurse explicitly.", + category=DeprecationWarning, + ) + recurse = False + sub_specs = self._propagate_unlock(recurse=recurse) + if recurse: + for sub_spec in sub_specs: + sub_spec._check_unlock() + self._check_unlock() + except RuntimeError as err: + self.lock_() + raise err return self @property def locked(self): - return self._locked + return self._is_locked class StackedComposite(_LazyStackedMixin[Composite], Composite): diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index fb37639ad37..0a64c395126 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -80,9 +80,9 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 specs = make_composite_from_td(data) obs_spec = self.observation_spec - obs_spec.unlock_() + obs_spec.unlock_(recurse=True) obs_spec.update(specs) - obs_spec.lock_() + obs_spec.lock_(recurse=True) def _output_transform(self, output): obs, reward, done, info = output diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bb92042bf0d..e8b0a744d5c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -530,22 +530,6 @@ def reset_parent(self) -> None: self.__dict__["_container"] = None self.__dict__["_parent"] = None - def __getstate__(self): - result = self.__dict__.copy() - container = result["_container"] - if container is not None: - container = container() - result["_container"] = container - return result - - def __setstate__(self, state): - state["_container"] = ( - weakref.ref(state["_container"]) - if state["_container"] is not None - else None - ) - self.__dict__.update(state) - def clone(self) -> T: self_copy = copy(self) state = copy(self.__dict__) @@ -589,6 +573,22 @@ def container(self): container = container_weakref return container + def __getstate__(self): + result = self.__dict__.copy() + container = result["_container"] + if container is not None: + container = container() + result["_container"] = container + return result + + def __setstate__(self, state): + state["_container"] = ( + weakref.ref(state["_container"]) + if state["_container"] is not None + else None + ) + self.__dict__.update(state) + @property def parent(self) -> Optional[EnvBase]: """Returns the parent env of the transform. @@ -1736,7 +1736,7 @@ def reset_key(self): f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor." ) reset_key = reset_keys[0] - self._reset_key = reset_keys + self._reset_key = reset_key return reset_key @reset_key.setter