From 55b7a5aa564833235ffe819823d09d5883316afd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 29 Jan 2025 12:44:22 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_specs.py | 112 ++++++++++++++---- torchrl/data/tensor_specs.py | 158 +++++++++++++++++++++++--- torchrl/envs/libs/isaacgym.py | 4 +- torchrl/envs/transforms/transforms.py | 7 +- 4 files changed, 232 insertions(+), 49 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 07762c7ad30..340afaa449a 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -697,36 +697,100 @@ def test_create_composite_nested(shape, device): assert c["a"].device == device -@pytest.mark.parametrize("recurse", [True, False]) -def test_lock(recurse): - shape = [3, 4, 5] - spec = Composite( - a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]), - shape=shape[:1], - ) - spec["a"] = spec["a"].clone() - spec["a", "b"] = spec["a", "b"].clone() - assert not spec.locked - spec.lock_(recurse=recurse) - assert spec.locked - with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): +class TestLock: + @pytest.mark.parametrize("recurse", [None, True, False]) + def test_lock(self, recurse): + catch_warn = ( + pytest.warns(DeprecationWarning, match="recurse") + if recurse is None + else contextlib.nullcontext() + ) + + shape = [3, 4, 5] + spec = Composite( + a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]), + shape=shape[:1], + ) spec["a"] = spec["a"].clone() - with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): - spec.set("a", spec["a"].clone()) - if recurse: - assert spec["a"].locked + spec["a", "b"] = spec["a", "b"].clone() + assert not spec.locked + with catch_warn: + spec.lock_(recurse=recurse) + assert spec.locked with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): - spec["a"].set("b", spec["a", "b"].clone()) + spec["a"] = spec["a"].clone() with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): + spec.set("a", spec["a"].clone()) + if recurse: + assert spec["a"].locked + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): + spec["a"].set("b", spec["a", "b"].clone()) + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): + spec["a", "b"] = spec["a", "b"].clone() + else: + assert not spec["a"].locked spec["a", "b"] = spec["a", "b"].clone() - else: - assert not spec["a"].locked + spec["a"].set("b", spec["a", "b"].clone()) + with catch_warn: + spec.unlock_(recurse=recurse) + spec["a"] = spec["a"].clone() spec["a", "b"] = spec["a", "b"].clone() spec["a"].set("b", spec["a", "b"].clone()) - spec.unlock_(recurse=recurse) - spec["a"] = spec["a"].clone() - spec["a", "b"] = spec["a", "b"].clone() - spec["a"].set("b", spec["a", "b"].clone()) + + def test_edge_cases(self): + level3 = Composite() + level2 = Composite(level3=level3) + level1 = Composite(level2=level2) + level0 = Composite(level1=level1) + # locking level0 locks them all + level0.lock_(recurse=True) + assert level3.is_locked + # We cannot unlock level3 + with pytest.raises( + RuntimeError, + match="Cannot unlock a Composite that is part of a locked graph", + ): + level3.unlock_(recurse=True) + assert level3.is_locked + # Adding level2 to a new spec and locking it makes it hard to unlock the level0 root + new_spec = Composite(level2=level2) + new_spec.lock_(recurse=True) + with pytest.raises( + RuntimeError, + match="Cannot unlock a Composite that is part of a locked graph", + ): + level0.unlock_(recurse=True) + assert level0.is_locked + + def test_lock_mix_recurse_nonrecurse(self): + # lock with recurse + level3 = Composite() + level2 = Composite(level3=level3) + level1 = Composite(level2=level2) + level0 = Composite(level1=level1) + # locking level0 locks them all + level0.lock_(recurse=True) + new_spec = Composite(level2=level2) + new_spec.lock_(recurse=True) + + # Unlock with recurse=False + with pytest.raises(RuntimeError, match="Cannot unlock"): + level3.unlock_(recurse=False) + assert level3.is_locked + assert level2.is_locked + assert new_spec.is_locked + with pytest.raises(RuntimeError, match="Cannot unlock"): + level2.unlock_(recurse=False) + with pytest.raises(RuntimeError, match="Cannot unlock"): + level1.unlock_(recurse=False) + level0.unlock_(recurse=False) + assert level3.is_locked + assert level2.is_locked + assert level1.is_locked + new_spec.unlock_(recurse=False) + assert level3.is_locked + assert level2.is_locked + assert level1.is_locked def test_keys_to_empty_composite_spec(): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a7914f4f1d7..ba1c171fda6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -7,8 +7,10 @@ import abc import enum +import gc import math import warnings +import weakref from collections.abc import Iterable from copy import deepcopy from dataclasses import dataclass @@ -4428,7 +4430,7 @@ class Composite(TensorSpec): @classmethod def __new__(cls, *args, **kwargs): cls._device = None - cls._locked = False + cls._is_locked = False return super().__new__(cls) @property @@ -5170,14 +5172,67 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def lock_(self, recurse=False): - """Locks the Composite and prevents modification of its content. + # Locking functionality + @property + def is_locked(self) -> bool: + return self._is_locked - This is only a first-level lock, unless specified otherwise through the - ``recurse`` arg. + @is_locked.setter + def is_locked(self, value: bool) -> None: + if value: + self.lock_() + else: + self.unlock_() - Leaf specs can always be modified in place, but they cannot be replaced - in their Composite parent. + def _propagate_lock( + self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling + ): + """Registers the parent composite that handles the lock.""" + self._is_locked = True + if lock_parents_weakrefs is not None: + lock_parents_weakrefs = [ + ref + for ref in lock_parents_weakrefs + if not any(refref is ref for refref in self._lock_parents_weakrefs) + ] + if not is_compiling: + is_root = lock_parents_weakrefs is None + if is_root: + lock_parents_weakrefs = [] + else: + self._lock_parents_weakrefs = ( + self._lock_parents_weakrefs + lock_parents_weakrefs + ) + lock_parents_weakrefs = list(lock_parents_weakrefs) + lock_parents_weakrefs.append(weakref.ref(self)) + + if recurse: + for value in self.values(): + if isinstance(value, Composite): + value._propagate_lock( + recurse=True, + lock_parents_weakrefs=lock_parents_weakrefs, + is_compiling=is_compiling, + ) + + @property + def _lock_parents_weakrefs(self): + _lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs") + if _lock_parents_weakrefs is None: + self.__dict__["__lock_parents_weakrefs"] = [] + _lock_parents_weakrefs = self.__dict__["__lock_parents_weakrefs"] + return _lock_parents_weakrefs + + @_lock_parents_weakrefs.setter + def _lock_parents_weakrefs(self, value: list): + self.__dict__["__lock_parents_weakrefs"] = value + + def lock_(self, recurse: bool | None = None) -> T: + """Locks the Composite and prevents modification of its content. + + The recurse argument control whether the lock will be propagated to sub-specs. + The current default is ``False`` but it will be turned to ``True`` for consistency + with the TensorDict API in v0.8. Examples: >>> shape = [3, 4, 5] @@ -5211,30 +5266,99 @@ def lock_(self, recurse=False): failed! """ - self._locked = True + if self.is_locked: + return self + is_comp = is_compiling() + if is_comp: + # TODO: See what to do when compiling + pass + if recurse is None: + warnings.warn( + "You have not specified a value for recurse when calling CompositeSpec.lock_(). " + "The current default is False but it will be turned to True in v0.8. To adapt to these changes " + "and silence this warning, pass the value of recurse explicitly.", + category=DeprecationWarning, + ) + recurse = False + self._propagate_lock(recurse=recurse, is_compiling=is_comp) + return self + + def _propagate_unlock(self, recurse: bool): + # if we end up here, we can clear the graph associated with this td + self._is_locked = False + + self._is_shared = False + self._is_memmap = False + if recurse: + sub_specs = [] for value in self.values(): if isinstance(value, Composite): - value.lock_(recurse) - return self + sub_specs.extend(value._propagate_unlock(recurse=recurse)) + sub_specs.append(value) + return sub_specs + return [] + + def _check_unlock(self, first_attempt=True): + if not first_attempt: + gc.collect() + obj = None + for ref in self._lock_parents_weakrefs: + obj = ref() + # check if the locked parent exists and if it's locked + # we check _is_locked because it can be False or None in the case of Lazy stacks, + # but if we check obj.is_locked it will be True for this class. + if obj is not None and obj._is_locked: + break + + else: + try: + self._lock_parents_weakrefs = [] + except AttributeError: + # Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref + pass + return - def unlock_(self, recurse=False): + if first_attempt: + del obj + return self._check_unlock(False) + raise RuntimeError( + "Cannot unlock a Composite that is part of a locked graph. " + "Graphs are locked when a Composite is locked with recurse=True. " + "Unlock the root Composite first. If the Composite is part of multiple graphs, " + "group the graphs under a common Composite an unlock this root. " + f"self: {self}, obj: {obj}" + ) + + def unlock_(self, recurse: bool | None = None) -> T: """Unlocks the Composite and allows modification of its content. This is only a first-level lock modification, unless specified otherwise through the ``recurse`` arg. """ - self._locked = False - if recurse: - for value in self.values(): - if isinstance(value, Composite): - value.unlock_(recurse) + try: + if recurse is None: + warnings.warn( + "You have not specified a value for recurse when calling CompositeSpec.unlock_(). " + "The current default is False but it will be turned to True in v0.8. To adapt to these changes " + "and silence this warning, pass the value of recurse explicitly.", + category=DeprecationWarning, + ) + recurse = False + sub_specs = self._propagate_unlock(recurse=recurse) + if recurse: + for sub_spec in sub_specs: + sub_spec._check_unlock() + self._check_unlock() + except RuntimeError as err: + self.lock_() + raise err return self @property def locked(self): - return self._locked + return self._is_locked class StackedComposite(_LazyStackedMixin[Composite], Composite): diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index fb37639ad37..0a64c395126 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -80,9 +80,9 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 specs = make_composite_from_td(data) obs_spec = self.observation_spec - obs_spec.unlock_() + obs_spec.unlock_(recurse=True) obs_spec.update(specs) - obs_spec.lock_() + obs_spec.lock_(recurse=True) def _output_transform(self, output): obs, reward, done, info = output diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b10c139a9b1..602b7f8c1f9 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -78,12 +78,7 @@ Unbounded, UnboundedContinuous, ) -from torchrl.envs.common import ( - _do_nothing, - _EnvPostInit, - EnvBase, - make_tensordict, -) +from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict from torchrl.envs.transforms import functional as F from torchrl.envs.transforms.utils import ( _get_reset,