diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 22cdad1b479..6c9c1c43cf5 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -80,7 +80,7 @@ export DISPLAY=:0 export SDL_VIDEODRIVER=dummy # legacy from bash scripts: remove? -conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG +conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG TOKENIZERS_PARALLELISM=true pip3 install pip --upgrade pip install virtualenv diff --git a/.github/unittest/linux_distributed/scripts/setup_env.sh b/.github/unittest/linux_distributed/scripts/setup_env.sh index 2a48ab21459..4344c136994 100755 --- a/.github/unittest/linux_distributed/scripts/setup_env.sh +++ b/.github/unittest/linux_distributed/scripts/setup_env.sh @@ -69,7 +69,8 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=$PRIVATE_MUJOCO_GL \ - PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL + PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL \ + TOKENIZERS_PARALLELISM=true # Software rendering requires GLX and OSMesa. if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then diff --git a/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh b/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh index 58ec8becf2e..f1775a0375a 100755 --- a/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh @@ -92,6 +92,7 @@ conda env config vars set \ MUJOCO_PY_MJKEY_PATH=$root_dir/.mujoco/mjkey.txt \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=$PRIVATE_MUJOCO_GL \ - PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL + PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL \ + TOKENIZERS_PARALLELISM=true conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/unittest/linux_libs/scripts_gym/setup_env.sh b/.github/unittest/linux_libs/scripts_gym/setup_env.sh index 8804370aa6d..163a26fbdf8 100755 --- a/.github/unittest/linux_libs/scripts_gym/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_gym/setup_env.sh @@ -80,6 +80,7 @@ conda env config vars set \ MUJOCO_PY_MJKEY_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/mjkey.txt \ MUJOCO_PY_MUJOCO_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/linux/mujoco210 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/pytorch/rl/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin + TOKENIZERS_PARALLELISM=true # LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin # make env variables apparent diff --git a/.github/unittest/linux_libs/scripts_habitat/run_test.sh b/.github/unittest/linux_libs/scripts_habitat/run_test.sh index a60fffd8f45..b03fd0823a9 100755 --- a/.github/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.github/unittest/linux_libs/scripts_habitat/run_test.sh @@ -10,7 +10,7 @@ conda activate ./env # https://stackoverflow.com/questions/72540359/glibcxx-3-4-30-not-found-for-librosa-in-conda-virtual-environment-after-tryin #conda install -y -c conda-forge gcc=12.1.0 conda install -y -c conda-forge libstdcxx-ng=12 -conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC +conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC TOKENIZERS_PARALLELISM=true ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) @@ -36,7 +36,7 @@ export MKL_THREADING_LAYER=GNU #wget https://github.com/openai/mujoco-py/blob/master/vendor/10_nvidia.json #mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json -conda env config vars set MAGNUM_LOG=quiet HABITAT_SIM_LOG=quiet +conda env config vars set MAGNUM_LOG=quiet HABITAT_SIM_LOG=quiet TOKENIZERS_PARALLELISM=true conda deactivate && conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh index 6ad970c3f47..e436a0c9bf0 100755 --- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -41,7 +41,7 @@ fi conda activate "${env_dir}" # set debug variables -conda env config vars set MAGNUM_LOG=debug HABITAT_SIM_LOG=debug +conda env config vars set MAGNUM_LOG=debug HABITAT_SIM_LOG=debug TOKENIZERS_PARALLELISM=true conda deactivate && conda activate "${env_dir}" pip3 install "cython<3" diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh index 38e6d350354..4e3bc93bf03 100755 --- a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -67,7 +67,8 @@ conda env config vars set \ PYOPENGL_PLATFORM=egl \ NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ sim_backend=MUJOCO \ - LAZY_LEGACY_OP=False + LAZY_LEGACY_OP=False \ + TOKENIZERS_PARALLELISM=true # make env variables apparent conda deactivate && conda activate "${env_dir}" diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh index 9e360c4b9c4..bba6b9d8ecf 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh @@ -79,7 +79,8 @@ conda env config vars set \ NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ MUJOCO_PY_MJKEY_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/mjkey.txt \ MUJOCO_PY_MUJOCO_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/linux/mujoco210 \ - LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin \ + TOKENIZERS_PARALLELISM=true # make env variables apparent conda deactivate && conda activate "${env_dir}" diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index 5da5256de99..e8a7423c9d3 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -83,7 +83,8 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=egl \ PYOPENGL_PLATFORM=egl \ - BATCHED_PIPE_TIMEOUT=60 + BATCHED_PIPE_TIMEOUT=60 \ + TOKENIZERS_PARALLELISM=true pip install pip --upgrade @@ -100,7 +101,7 @@ pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl conda install -y -c conda-forge libstdcxx-ng=12 ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) -conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC +conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC TOKENIZERS_PARALLELISM=true # compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index d75a0e67c54..20b2802591c 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -41,19 +41,35 @@ Each env will have the following attributes: the done-flag spec. See the section on trajectory termination below. - :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). - It is locked and should not be modified directly. - :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). - It is locked and should not be modified directly. -If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensorSpec` +If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor` instance can be used. +Env specs: locks and batch size +------------------------------- + +.. _Environment-lock: + +Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor). +Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite` +instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`. +The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the +likes. +Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should +be avoided). +Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase +the cache, unlock the specs, make the modification and relock the specs if the env was previously locked. + Importantly, the environment spec shapes should contain the batch size, e.g. an environment with :obj:`env.batch_size == torch.Size([4])` should have an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`. This is helpful when preallocation tensors, checking shape consistency etc. +Env methods +----------- + With these, the following methods are implemented: - :meth:`env.reset`: a reset method that may (but not necessarily requires to) take diff --git a/test/test_env.py b/test/test_env.py index f3b71910ce7..1197f5b1d02 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -9,11 +9,13 @@ import gc import importlib import os.path +import pickle import random import re from collections import defaultdict from functools import partial from sys import platform +from typing import Optional import numpy as np import pytest @@ -246,6 +248,41 @@ def test_run_type_checks(self): with pytest.raises(TypeError): check_env_specs(env) + class MyEnv(EnvBase): + def __init__(self): + super().__init__() + self.observation_spec = Unbounded(()) + self.action_spec = Unbounded(()) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + ... + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + ... + + def _set_seed(self, seed: Optional[int]): + ... + + def test_env_lock(self): + + env = self.MyEnv() + for _ in range(2): + assert env.is_spec_locked + assert env.output_spec.is_locked + assert env.input_spec.is_locked + with pytest.raises(RuntimeError, match="lock"): + env.input_spec["full_action_spec", "action"] = Unbounded(()) + env = pickle.loads(pickle.dumps(env)) + + env = self.MyEnv(spec_locked=False) + assert not env.is_spec_locked + assert not env.output_spec.is_locked + assert not env.input_spec.is_locked + env.input_spec["full_action_spec", "action"] = Unbounded(()) + def test_single_env_spec(self): env = NestedCountingEnv(batch_size=[3, 1, 7]) assert not env.full_action_spec_unbatched.shape @@ -2294,8 +2331,9 @@ def test_multi_purpose_env(self, serial): env = SerialEnv(2, ContinuousActionVecMockEnv) else: env = ContinuousActionVecMockEnv() + env.set_spec_lock_() env.rollout(10) - assert env._step_mdp.validate(None) + assert env._step_mdp.validated c = SyncDataCollector( env, env.rand_action, frames_per_batch=10, total_frames=20 ) @@ -3387,12 +3425,15 @@ def policy(td): class TestEnvWithDynamicSpec: def test_dynamic_rollout(self): env = EnvWithDynamicSpec() + rollout = env.rollout(4) + assert isinstance(rollout, LazyStackedTensorDict) + rollout = env.rollout(4, return_contiguous=False) + assert isinstance(rollout, LazyStackedTensorDict) with pytest.raises( RuntimeError, match="The environment specs are dynamic. Call rollout with return_contiguous=False", ): - rollout = env.rollout(4) - rollout = env.rollout(4, return_contiguous=False) + rollout = env.rollout(4, return_contiguous=True) check_env_specs(env, return_contiguous=False) @pytest.mark.skipif(not _has_gym, reason="requires gym to be installed") diff --git a/test/test_transforms.py b/test/test_transforms.py index aba41ba614f..f57bc58221d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1100,7 +1100,7 @@ def test_catframes_transform_observation_spec(self): } ) - result = cat_frames.transform_observation_spec(observation_spec) + result = cat_frames.transform_observation_spec(observation_spec.clone()) observation_spec = Composite( { key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double) @@ -1665,7 +1665,9 @@ def test_r3mnet_transform_observation_spec( {key: Unbounded(r3m_net.outdim, device) for key in out_keys} ) - observation_spec_out = r3m_net.transform_observation_spec(observation_spec) + observation_spec_out = r3m_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys: assert key not in observation_spec_out @@ -1681,7 +1683,9 @@ def test_r3mnet_transform_observation_spec( ts_dict[key] = Unbounded(r3m_net.outdim, device) exp_ts = Composite(ts_dict) - observation_spec_out = r3m_net.transform_observation_spec(observation_spec) + observation_spec_out = r3m_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys + out_keys: assert observation_spec_out[key].shape == exp_ts[key].shape @@ -2059,7 +2063,7 @@ class TestTrajCounter(TransformBase): def test_single_trans_env_check(self): torch.manual_seed(0) env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) check_env_specs(env) @pytest.mark.parametrize("predefined", [True, False]) @@ -2073,7 +2077,9 @@ def make_env(max_steps=4, t=t): if t is None: t = TrajCounter() env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec( + env.base_env.observation_spec.clone() + ) return env if predefined: @@ -2109,7 +2115,9 @@ def make_env(max_steps=4, t=t): else: t = t.clone() env = TransformedEnv(CountingEnv(max_steps=max_steps), t) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec( + env.base_env.observation_spec.clone() + ) return env if predefined: @@ -2137,7 +2145,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): ), TrajCounter(), ) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout( 100, lambda td: td.set("action", torch.ones(env.shape + (1,))), @@ -2153,7 +2161,7 @@ def test_trans_serial_env_check(self): ), TrajCounter(), ) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout( 100, lambda td: td.set("action", torch.ones(env.shape + (1,))), @@ -2165,7 +2173,7 @@ def test_trans_serial_env_check(self): def test_transform_env(self): torch.manual_seed(0) env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False) assert r["traj_count"].max() == 19 @@ -2178,7 +2186,7 @@ def test_nested(self): TrajCounter(out_key=(("nested"), (("traj_count",),))), ), ) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False) assert r["nested", "traj_count"].max() == 19 @@ -2210,7 +2218,9 @@ def test_collector_match(self): def make_env(max_steps=4): env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec( + env.base_env.observation_spec.clone() + ) return env collector = MultiSyncDataCollector( @@ -3283,13 +3293,17 @@ def test_transform_no_env(self, keys, device, out_key): if len(keys) == 1: observation_spec = Bounded(0, 1, (1, 4, 32)) - observation_spec = cattensors.transform_observation_spec(observation_spec) + observation_spec = cattensors.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32]) else: observation_spec = Composite( {key: Bounded(0, 1, (1, 4, 32)) for key in keys} ) - observation_spec = cattensors.transform_observation_spec(observation_spec) + observation_spec = cattensors.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec[out_key].shape == torch.Size([1, len(keys) * 4, 32]) @pytest.mark.parametrize("device", get_default_devices()) @@ -3429,13 +3443,13 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = crop.transform_observation_spec(observation_spec) + observation_spec = crop.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = crop.transform_observation_spec(observation_spec) + observation_spec = crop.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @@ -3636,13 +3650,13 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = cc.transform_observation_spec(observation_spec) + observation_spec = cc.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = cc.transform_observation_spec(observation_spec) + observation_spec = cc.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @@ -3994,7 +4008,9 @@ def test_double2float(self, keys, keys_inv, device): observation_spec = Composite( {key: Bounded(0, 1, (1, 3, 3), dtype=torch.double) for key in keys} ) - observation_spec = double2float.transform_observation_spec(observation_spec) + observation_spec = double2float.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].dtype == torch.float, key @@ -4041,7 +4057,7 @@ def test_double2float_auto(self, keys, keys_inv, device): assert td_modif.get(key).dtype == torch.double def test_single_env_no_inkeys(self): - base_env = ContinuousActionVecMockEnv() + base_env = ContinuousActionVecMockEnv(spec_locked=False) for key, spec in list(base_env.observation_spec.items(True, True)): base_env.observation_spec[key] = spec.to(torch.float64) for key, spec in list(base_env.state_spec.items(True, True)): @@ -4052,6 +4068,7 @@ def test_single_env_no_inkeys(self): env = TransformedEnv( base_env, DoubleToFloat(), + spec_locked=False, ) for spec in env.observation_spec.values(True, True): assert spec.dtype == torch.float32 @@ -4773,13 +4790,17 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape[-3] == expected_size else: observation_spec = Composite( {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape[-3] == expected_size @@ -4813,13 +4834,17 @@ def test_transform_compose(self, keys, size, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape[-3] == expected_size else: observation_spec = Composite( {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape[-3] == expected_size @@ -5055,13 +5080,13 @@ def test_transform_no_env(self, keys, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([1, 16, 16]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([1, 16, 16]) @@ -5092,13 +5117,13 @@ def test_transform_compose(self, keys, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([1, 16, 16]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([1, 16, 16]) @@ -5702,7 +5727,7 @@ def test_observationnorm( if len(keys) == 1: observation_spec = Bounded(0, 1, (nchannels, 16, 16), device=device) - observation_spec = on.transform_observation_spec(observation_spec) + observation_spec = on.transform_observation_spec(observation_spec.clone()) if standard_normal: assert (observation_spec.space.low == -loc / scale).all() assert (observation_spec.space.high == (1 - loc) / scale).all() @@ -5714,7 +5739,7 @@ def test_observationnorm( observation_spec = Composite( {key: Bounded(0, 1, (nchannels, 16, 16), device=device) for key in keys} ) - observation_spec = on.transform_observation_spec(observation_spec) + observation_spec = on.transform_observation_spec(observation_spec.clone()) for key in keys: if standard_normal: assert (observation_spec[key].space.low == -loc / scale).all() @@ -5919,13 +5944,17 @@ def test_transform_no_env(self, interpolation, keys, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, 21]) @@ -5956,13 +5985,17 @@ def test_transform_compose(self, interpolation, keys, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, 21]) @@ -6951,7 +6984,9 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == expected_size else: observation_spec = Composite( @@ -6960,7 +6995,9 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): for key in keys } ) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == expected_size @@ -7107,7 +7144,9 @@ def test_transform_compose(self, keys, size, nchannels, batch, device, dim): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == expected_size else: observation_spec = Composite( @@ -7116,7 +7155,9 @@ def test_transform_compose(self, keys, size, nchannels, batch, device, dim): for key in keys } ) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == expected_size @@ -7713,7 +7754,7 @@ def test_transform_no_env(self, keys, batch, device): if len(keys) == 1: observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) assert observation_spec.shape == torch.Size([3, 16, 16]) assert (observation_spec.space.low == 0).all() @@ -7723,7 +7764,7 @@ def test_transform_no_env(self, keys, batch, device): {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) for key in keys: assert observation_spec[key].shape == torch.Size([3, 16, 16]) @@ -7759,7 +7800,7 @@ def test_transform_compose(self, keys, batch, device): if len(keys) == 1: observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) assert observation_spec.shape == torch.Size([3, 16, 16]) assert (observation_spec.space.low == 0).all() @@ -7769,7 +7810,7 @@ def test_transform_compose(self, keys, batch, device): {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) for key in keys: assert observation_spec[key].shape == torch.Size([3, 16, 16]) @@ -9048,7 +9089,9 @@ def test_vipnet_transform_observation_spec( if del_keys: exp_ts = Composite({key: Unbounded(1024, device) for key in out_keys}) - observation_spec_out = vip_net.transform_observation_spec(observation_spec) + observation_spec_out = vip_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys: assert key not in observation_spec_out @@ -9064,7 +9107,9 @@ def test_vipnet_transform_observation_spec( ts_dict[key] = Unbounded(1024, device) exp_ts = Composite(ts_dict) - observation_spec_out = vip_net.transform_observation_spec(observation_spec) + observation_spec_out = vip_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys + out_keys: assert observation_spec_out[key].shape == exp_ts[key].shape @@ -9528,7 +9573,8 @@ def test_parallelenv_vecnorm(self): lambda: TransformedEnv( GymEnv(PENDULUM_VERSIONED()), Compose( - self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"]) + self.rename_t, + VecNorm(in_keys=[("some", "obs"), "reward"]), ), ) ) @@ -9537,7 +9583,8 @@ def test_parallelenv_vecnorm(self): lambda: TransformedEnv( ContinuousActionVecMockEnv(), Compose( - self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"]) + self.rename_t, + VecNorm(in_keys=[("some", "obs"), "reward"]), ), ) ) @@ -9934,13 +9981,17 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): if len(keys) == 1: observation_spec = Bounded(0, 255, (nchannels, 16, 16)) # StepCounter does not want non composite specs - observation_spec = compose[:2].transform_observation_spec(observation_spec) + observation_spec = compose[:2].transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: observation_spec = Composite( {key: Bounded(0, 255, (nchannels, 16, 16)) for key in keys} ) - observation_spec = compose.transform_observation_spec(observation_spec) + observation_spec = compose.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == torch.Size( [nchannels * N, 16, 16] diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ba1c171fda6..0296b55f972 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4961,6 +4961,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite: return self.__class__(**kwargs, device=_device, shape=self.shape) def clone(self) -> Composite: + """Clones the Composite spec. + + Locked specs will not produce locked clones. + """ try: device = self.device except RuntimeError: @@ -5184,6 +5188,21 @@ def is_locked(self, value: bool) -> None: else: self.unlock_() + def __getstate__(self): + result = self.__dict__.copy() + __lock_parents_weakrefs = result.pop("__lock_parents_weakrefs", None) + if __lock_parents_weakrefs is not None: + result["_lock_recurse"] = True + return result + + def __setstate__(self, state): + _lock_recurse = state.pop("_lock_recurse", False) + for key, value in state.items(): + setattr(self, key, value) + if self._is_locked: + self._is_locked = False + self.lock_(recurse=_lock_recurse) + def _propagate_lock( self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling ): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 51331a86346..9db6949cb37 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -12,7 +12,7 @@ import os import weakref from collections import OrderedDict -from copy import copy, deepcopy +from copy import deepcopy from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock @@ -371,6 +371,8 @@ def __init__( ) self._mp_start_method = mp_start_method + is_spec_locked = EnvBase.is_spec_locked + @property def non_blocking(self): nb = self._non_blocking @@ -471,7 +473,7 @@ def find_all_worker_devices(item): return _do_nothing, _do_nothing def __getstate__(self): - out = copy(self.__dict__) + out = self.__dict__.copy() out["_sync_m2w_value"] = None out["_sync_w2m_value"] = None return out @@ -933,8 +935,9 @@ def _start_workers(self) -> None: "environments!" ) weakref_set.add(wr) - self._envs.append(env) + self._envs.append(env.set_spec_lock_()) self.is_closed = False + self.set_spec_lock_() @_check_start def state_dict(self) -> OrderedDict: @@ -1458,6 +1461,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): for channel in self.parent_channels: channel.send(("init", None)) self.is_closed = False + self.set_spec_lock_() @_check_start def state_dict(self) -> OrderedDict: @@ -2164,6 +2168,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): ) env = env_fun del env_fun + env.set_spec_lock_() i = -1 import torchrl diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f4913d05a67..2df9d433486 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -20,7 +20,7 @@ TensorDictBase, unravel_key, ) -from tensordict.base import _is_leaf_nontensor +from tensordict.base import _is_leaf_nontensor, NO_DEFAULT from tensordict.utils import is_non_tensor, NestedKey from torchrl._utils import ( _ends_with, @@ -63,24 +63,34 @@ def _tensor_to_np(t): } +def _maybe_unlock(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + is_locked = self.is_spec_locked + try: + if is_locked: + self.set_spec_lock_(False) + result = func(self, *args, **kwargs) + finally: + if is_locked: + self.set_spec_lock_(True) + return result + + return wrapper + + def _cache_value(func): """Caches the result of the decorated function in env._cache dictionary.""" - # func_name = func.__name__ + func_name = func.__name__ @wraps(func) def wrapper(self, *args, **kwargs): - # result = self._cache.get(func_name, NO_DEFAULT) - # if result is NO_DEFAULT: - result = func(self, *args, **kwargs) - # Ideally we'd like to cache all the `_keys` attributes but there's a catch: one can modify the specs at - # any time so this will not run as expected. - # The solution should be: - # - optionally lock the specs in the env, like we do with tensordict. - # - Locked specs will behave like locked tensordict: we lock the root spec, meaning that all the sub-specs - # will be locked, and no __setattr__ will be allowed within the env unless it's unlocked. - # We cannot just guard spec.__setattr__ because `spec[key0][key1] = smth` will not call a setattr - # on the root spec so there's a chance we miss it. - # self._cache[func_name] = result + if not self.is_spec_locked: + return func(self, *args, **kwargs) + result = self._cache.get(func_name, NO_DEFAULT) + if result is NO_DEFAULT: + result = func(self, *args, **kwargs) + self._cache[func_name] = result return result return wrapper @@ -215,9 +225,16 @@ def to(self, device: DEVICE_TYPING) -> EnvMetaData: class _EnvPostInit(abc.ABCMeta): def __call__(cls, *args, **kwargs): + spec_locked = kwargs.pop("spec_locked", True) auto_reset = kwargs.pop("auto_reset", False) auto_reset_replace = kwargs.pop("auto_reset_replace", True) instance: EnvBase = super().__call__(*args, **kwargs) + + if spec_locked: + instance.input_spec.lock_(recurse=True) + instance.output_spec.lock_(recurse=True) + instance._is_spec_locked = spec_locked + # we create the done spec by adding a done/terminated entry if one is missing instance._create_done_specs() # we access lazy attributed to make sure they're built properly. @@ -276,6 +293,20 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): at every reset and every step. Defaults to ``False``. allow_done_after_reset (bool, optional): if ``True``, an environment can be done after a call to :meth:`~.reset` is made. Defaults to ``False``. + spec_locked (bool, optional): if ``True``, the specs are locked and can only be + modified if :meth:`~torchrl.envs.EnvBase.set_spec_lock_` is called. + + .. note:: The locking is achieved by the `EnvBase` metaclass. It does not appear in the + `__init__` method and is included in the keyword arguments strictly for type-hinting purpose. + + .. seealso:: :ref:`Locking environment specs `. + + Defaults to ``True``. + auto_reset (bool, optional): if ``True``, the env is assumed to reset automatically + when done. Defaults to ``False``. + + .. note:: The auto-resetting is achieved by the `EnvBase` metaclass. It does not appear in the + `__init__` method and is included in the keyword arguments strictly for type-hinting purpose. Attributes: done_spec (Composite): equivalent to ``full_done_spec`` as all @@ -306,6 +337,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): batch_size (torch.Size): The batch-size of the environment. device (torch.device): the device where the input/outputs of the environment are to be expected. Can be ``None``. + is_spec_locked (bool): returns ``True`` if the specs are locked. See the :attr:`spec_locked` + argument above. Methods: step (TensorDictBase -> TensorDictBase): step in the environment @@ -406,6 +439,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): _batch_size: torch.Size | None _device: torch.device | None + _is_spec_locked: bool = False def __init__( self, @@ -414,6 +448,8 @@ def __init__( batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, + spec_locked: bool = True, + auto_reset: bool = False, ): super().__init__() @@ -436,7 +472,7 @@ def __init__( if output_spec is None: output_spec = self.__dict__["_output_spec"] = Composite( shape=batch_size, device=device - ).lock_() + ) elif self._output_spec.device != device and device is not None: self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to( self.device @@ -445,12 +481,12 @@ def __init__( if input_spec is None: input_spec = self.__dict__["_input_spec"] = Composite( shape=batch_size, device=device - ).lock_() + ) elif self._input_spec.device != device and device is not None: self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(self.device) - output_spec.unlock_() - input_spec.unlock_() + output_spec.unlock_(recurse=True) + input_spec.unlock_(recurse=True) if "full_observation_spec" not in output_spec: output_spec["full_observation_spec"] = Composite() if "full_done_spec" not in output_spec: @@ -461,8 +497,6 @@ def __init__( input_spec["full_state_spec"] = Composite() if "full_action_spec" not in input_spec: input_spec["full_action_spec"] = Composite() - output_spec.lock_() - input_spec.lock_() if "is_closed" not in self.__dir__(): self.is_closed = True @@ -470,6 +504,52 @@ def __init__( self._allow_done_after_reset = allow_done_after_reset self._cache = {} + def set_spec_lock_(self, mode: bool = True) -> EnvBase: + """Locks or unlocks the environment's specs. + + Args: + mode (bool): Whether to lock (`True`) or unlock (`False`) the specs. Defaults to `True`. + + Returns: + EnvBase: The environment instance itself. + + .. seealso:: :ref:`Locking environment specs `. + + """ + output_spec = self.__dict__.get("_output_spec") + input_spec = self.__dict__.get("_input_spec") + if mode: + if output_spec is not None: + output_spec.lock_(recurse=True) + if input_spec is not None: + input_spec.lock_(recurse=True) + else: + self._cache.clear() + if output_spec is not None: + output_spec.unlock_(recurse=True) + if input_spec is not None: + input_spec.unlock_(recurse=False) + self.__dict__["_is_spec_locked"] = mode + return self + + @property + def is_spec_locked(self): + """Gets whether the environment's specs are locked. + + This property can be modified directly. + + Returns: + bool: True if the specs are locked, False otherwise. + + .. seealso:: :ref:`Locking environment specs `. + + """ + return self.__dict__.get("_is_spec_locked", False) + + @is_spec_locked.setter + def is_spec_locked(self, value: bool): + self.set_spec_lock_(value) + def auto_specs_( self, policy: Callable[[TensorDictBase], TensorDictBase], @@ -700,19 +780,16 @@ def batch_size(self) -> torch.Size: return _batch_size @batch_size.setter + @_maybe_unlock def batch_size(self, value: torch.Size) -> None: self._batch_size = torch.Size(value) if ( hasattr(self, "output_spec") and self.output_spec.shape[: len(value)] != value ): - self.output_spec.unlock_() self.output_spec.shape = value - self.output_spec.lock_() if hasattr(self, "input_spec") and self.input_spec.shape[: len(value)] != value: - self.input_spec.unlock_() self.input_spec.shape = value - self.input_spec.lock_() @property def shape(self): @@ -803,12 +880,17 @@ def input_spec(self) -> TensorSpec: """ input_spec = self.__dict__.get("_input_spec") if input_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) input_spec = Composite( full_state_spec=None, shape=self.batch_size, device=self.device, - ).lock_() + ) self.__dict__["_input_spec"] = input_spec + if is_locked: + self.set_spec_lock_(True) return input_spec @input_spec.setter @@ -863,11 +945,16 @@ def output_spec(self) -> TensorSpec: """ output_spec = self.__dict__.get("_output_spec") if output_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) output_spec = Composite( shape=self.batch_size, device=self.device, - ).lock_() + ) self.__dict__["_output_spec"] = output_spec + if is_locked: + self.set_spec_lock_(True) return output_spec @output_spec.setter @@ -1024,28 +1111,25 @@ def action_spec(self) -> TensorSpec: return out @action_spec.setter + @_maybe_unlock def action_spec(self, value: TensorSpec) -> None: - try: - self.input_spec.unlock_() - device = self.input_spec._device - if not hasattr(value, "shape"): - raise TypeError( - f"action_spec of type {type(value)} do not have a shape attribute." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " - "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." - ) + device = self.input_spec._device + if not hasattr(value, "shape"): + raise TypeError( + f"action_spec of type {type(value)} do not have a shape attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " + "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." + ) - if not isinstance(value, Composite): - value = Composite( - action=value.to(device), shape=self.batch_size, device=device - ) + if not isinstance(value, Composite): + value = Composite( + action=value.to(device), shape=self.batch_size, device=device + ) - self.input_spec["full_action_spec"] = value.to(device) - finally: - self.input_spec.lock_() + self.input_spec["full_action_spec"] = value.to(device) @property def full_action_spec(self) -> Composite: @@ -1073,10 +1157,13 @@ def full_action_spec(self) -> Composite: """ full_action_spec = self.input_spec.get("full_action_spec", None) if full_action_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) full_action_spec = Composite(shape=self.batch_size, device=self.device) - self.input_spec.unlock_() self.input_spec["full_action_spec"] = full_action_spec - self.input_spec.lock_() + if is_locked: + self.set_spec_lock_(True) return full_action_spec @full_action_spec.setter @@ -1211,36 +1298,31 @@ def reward_spec(self) -> TensorSpec: return reward_spec[self.reward_keys[0]] @reward_spec.setter - @_clear_cache_when_set + @_maybe_unlock def reward_spec(self, value: TensorSpec) -> None: - try: - self.output_spec.unlock_() - device = self.output_spec._device - if not hasattr(value, "shape"): - raise TypeError( - f"reward_spec of type {type(value)} do not have a shape " - f"attribute." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " - "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." - ) - if not isinstance(value, Composite): - value = Composite( - reward=value.to(device), shape=self.batch_size, device=device + device = self.output_spec._device + if not hasattr(value, "shape"): + raise TypeError( + f"reward_spec of type {type(value)} do not have a shape " f"attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " + "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." + ) + if not isinstance(value, Composite): + value = Composite( + reward=value.to(device), shape=self.batch_size, device=device + ) + for leaf in value.values(True, True): + if len(leaf.shape) == 0: + raise RuntimeError( + "the reward_spec's leafs shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." ) - for leaf in value.values(True, True): - if len(leaf.shape) == 0: - raise RuntimeError( - "the reward_spec's leafs shape cannot be empty (this error" - " usually comes from trying to set a reward_spec" - " with a null number of dimensions. Try using a multidimensional" - " spec instead, for instance with a singleton dimension at the tail)." - ) - self.output_spec["full_reward_spec"] = value.to(device) - finally: - self.output_spec.lock_() + self.output_spec["full_reward_spec"] = value.to(device) @property def full_reward_spec(self) -> Composite: @@ -1281,7 +1363,7 @@ def full_reward_spec(self) -> Composite: return self.output_spec["full_reward_spec"] @full_reward_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_reward_spec(self, spec: Composite) -> None: self.reward_spec = spec.to(self.device) if self.device is not None else spec @@ -1344,7 +1426,7 @@ def full_done_spec(self) -> Composite: return self.output_spec["full_done_spec"] @full_done_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_done_spec(self, spec: Composite) -> None: self.done_spec = spec.to(self.device) if self.device is not None else spec @@ -1418,6 +1500,7 @@ def done_spec(self) -> TensorSpec: done_spec = self.output_spec["full_done_spec"] return done_spec + @_maybe_unlock def _create_done_specs(self): """Reads through the done specs and makes it so that it's complete. @@ -1446,9 +1529,7 @@ def _create_done_specs(self): dtype=torch.bool, device=self.device, ) - self.output_spec.unlock_() self.output_spec["full_done_spec"] = full_done_spec - self.output_spec.lock_() return def check_local_done(spec): @@ -1482,46 +1563,44 @@ def check_local_done(spec): n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) - self.output_spec.unlock_() + if_locked = self.is_spec_locked + if if_locked: + self.is_spec_locked = False check_local_done(full_done_spec) self.output_spec["full_done_spec"] = full_done_spec - self.output_spec.lock_() + if if_locked: + self.is_spec_locked = True return @done_spec.setter - @_clear_cache_when_set + @_maybe_unlock def done_spec(self, value: TensorSpec) -> None: - try: - self.output_spec.unlock_() - device = self.output_spec.device - if not hasattr(value, "shape"): - raise TypeError( - f"done_spec of type {type(value)} do not have a shape " - f"attribute." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - if not isinstance(value, Composite): - value = Composite( - done=value.to(device), - terminated=value.to(device), - shape=self.batch_size, - device=device, + device = self.output_spec.device + if not hasattr(value, "shape"): + raise TypeError( + f"done_spec of type {type(value)} do not have a shape " f"attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) + if not isinstance(value, Composite): + value = Composite( + done=value.to(device), + terminated=value.to(device), + shape=self.batch_size, + device=device, + ) + for leaf in value.values(True, True): + if len(leaf.shape) == 0: + raise RuntimeError( + "the done_spec's leafs shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." ) - for leaf in value.values(True, True): - if len(leaf.shape) == 0: - raise RuntimeError( - "the done_spec's leafs shape cannot be empty (this error" - " usually comes from trying to set a reward_spec" - " with a null number of dimensions. Try using a multidimensional" - " spec instead, for instance with a singleton dimension at the tail)." - ) - self.output_spec["full_done_spec"] = value.to(device) - self._create_done_specs() - finally: - self.output_spec.lock_() + self.output_spec["full_done_spec"] = value.to(device) + self._create_done_specs() # observation spec: observation specs belong to output_spec @property @@ -1555,40 +1634,44 @@ def observation_spec(self) -> Composite: """ observation_spec = self.output_spec.get("full_observation_spec", default=None) if observation_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) observation_spec = Composite(shape=self.batch_size, device=self.device) - self.output_spec.unlock_() self.output_spec["full_observation_spec"] = observation_spec - self.output_spec.lock_() + if is_locked: + self.set_spec_lock_(True) + return observation_spec @observation_spec.setter - @_clear_cache_when_set + @_maybe_unlock def observation_spec(self, value: TensorSpec) -> None: - try: - self.output_spec.unlock_() - if not isinstance(value, Composite): - raise TypeError("The type of an observation_spec must be Composite.") - elif value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - device = self.output_spec._device - self.output_spec["full_observation_spec"] = ( - value.to(device) if device is not None else value + if not isinstance(value, Composite): + value = Composite( + observation=value, + device=self.device, + batch_size=self.output_spec.batch_size, ) - finally: - self.output_spec.lock_() + elif value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) + device = self.output_spec._device + self.output_spec["full_observation_spec"] = ( + value.to(device) if device is not None else value + ) @property def full_observation_spec(self) -> Composite: return self.observation_spec @full_observation_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_observation_spec(self, spec: Composite): self.observation_spec = spec @@ -1628,38 +1711,37 @@ def state_spec(self) -> Composite: """ state_spec = self.input_spec["full_state_spec"] if state_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) state_spec = Composite(shape=self.batch_size, device=self.device) - self.input_spec.unlock_() self.input_spec["full_state_spec"] = state_spec - self.input_spec.lock_() + if is_locked: + self.set_spec_lock_(True) return state_spec @state_spec.setter - @_clear_cache_when_set + @_maybe_unlock def state_spec(self, value: Composite) -> None: - try: - self.input_spec.unlock_() - if value is None: - self.input_spec["full_state_spec"] = Composite( - device=self.device, shape=self.batch_size + if value is None: + self.input_spec["full_state_spec"] = Composite( + device=self.device, shape=self.batch_size + ) + else: + device = self.input_spec.device + if not isinstance(value, Composite): + raise TypeError("The type of an state_spec must be Composite.") + elif value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - else: - device = self.input_spec.device - if not isinstance(value, Composite): - raise TypeError("The type of an state_spec must be Composite.") - elif value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - self.input_spec["full_state_spec"] = ( - value.to(device) if device is not None else value + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - finally: - self.input_spec.lock_() + self.input_spec["full_state_spec"] = ( + value.to(device) if device is not None else value + ) @property def full_state_spec(self) -> Composite: @@ -1689,6 +1771,7 @@ def full_state_spec(self) -> Composite: return self.state_spec @full_state_spec.setter + @_maybe_unlock def full_state_spec(self, spec: Composite) -> None: self.state_spec = spec @@ -1710,6 +1793,7 @@ def full_action_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_action_spec) @full_action_spec_unbatched.setter + @_maybe_unlock def full_action_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_action_spec = spec @@ -1720,6 +1804,7 @@ def action_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.action_spec) @action_spec_unbatched.setter + @_maybe_unlock def action_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.action_spec = spec @@ -1730,6 +1815,7 @@ def full_observation_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_observation_spec) @full_observation_spec_unbatched.setter + @_maybe_unlock def full_observation_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_observation_spec = spec @@ -1740,6 +1826,7 @@ def observation_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.observation_spec) @observation_spec_unbatched.setter + @_maybe_unlock def observation_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.observation_spec = spec @@ -1750,6 +1837,7 @@ def full_reward_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_reward_spec) @full_reward_spec_unbatched.setter + @_maybe_unlock def full_reward_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_reward_spec = spec @@ -1760,6 +1848,7 @@ def reward_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.reward_spec) @reward_spec_unbatched.setter + @_maybe_unlock def reward_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.reward_spec = spec @@ -1770,6 +1859,7 @@ def full_done_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_done_spec) @full_done_spec_unbatched.setter + @_maybe_unlock def full_done_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_done_spec = spec @@ -1780,6 +1870,7 @@ def done_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.done_spec) @done_spec_unbatched.setter + @_maybe_unlock def done_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.done_spec = spec @@ -1790,6 +1881,7 @@ def output_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.output_spec) @output_spec_unbatched.setter + @_maybe_unlock def output_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.output_spec = spec @@ -1800,6 +1892,7 @@ def input_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.input_spec) @input_spec_unbatched.setter + @_maybe_unlock def input_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.input_spec = spec @@ -1810,6 +1903,7 @@ def full_state_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_state_spec) @full_state_spec_unbatched.setter + @_maybe_unlock def full_state_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_state_spec = spec @@ -1820,6 +1914,7 @@ def state_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.state_spec) @state_spec_unbatched.setter + @_maybe_unlock def state_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.state_spec = spec @@ -2715,7 +2810,7 @@ def specs(self) -> Composite: output_spec=self.output_spec, input_spec=self.input_spec, shape=self.batch_size, - ).lock_() + ) @property @_cache_value @@ -3063,7 +3158,7 @@ def rollout( out_td.refine_names(..., "time") return out_td - @_clear_cache_when_set + @_maybe_unlock def add_truncated_keys(self) -> EnvBase: """Adds truncated keys to the environment.""" i = 0 @@ -3450,12 +3545,13 @@ def __del__(self): # __del__ will not affect the program. pass + @_maybe_unlock def to(self, device: DEVICE_TYPING) -> EnvBase: device = _make_ordinal_device(torch.device(device)) if device == self.device: return self - self.__dict__["_input_spec"] = self.input_spec.to(device).lock_() - self.__dict__["_output_spec"] = self.output_spec.to(device).lock_() + self.__dict__["_input_spec"] = self.input_spec.to(device) + self.__dict__["_output_spec"] = self.output_spec.to(device) self._device = device return super().to(device) @@ -3526,12 +3622,14 @@ def __init__( device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, + spec_locked: bool = True, **kwargs, ): super().__init__( device=device, batch_size=batch_size, allow_done_after_reset=allow_done_after_reset, + spec_locked=spec_locked, ) if len(args): raise ValueError( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index bb849847f3a..5e6dd55be8b 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -16,7 +16,7 @@ from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded -from torchrl.envs.common import _EnvWrapper, EnvBase +from torchrl.envs.common import _EnvWrapper, _maybe_unlock, EnvBase class BaseInfoDictReader(metaclass=abc.ABCMeta): @@ -434,6 +434,7 @@ def _output_transform( def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: ... + @_maybe_unlock def set_info_dict_reader( self, info_dict_reader: BaseInfoDictReader | None = None, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 602b7f8c1f9..33e7d463611 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -60,7 +60,6 @@ _ends_with, _make_ordinal_device, _replace_last, - implement_for, logger as torchrl_logger, ) @@ -78,7 +77,13 @@ Unbounded, UnboundedContinuous, ) -from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict +from torchrl.envs.common import ( + _do_nothing, + _EnvPostInit, + _maybe_unlock, + EnvBase, + make_tensordict, +) from torchrl.envs.transforms import functional as F from torchrl.envs.transforms.utils import ( _get_reset, @@ -571,6 +576,21 @@ def container(self): container = container_weakref return container + def __getstate__(self): + state = self.__dict__.copy() + container_weakref = state.pop("_container", None) + if container_weakref is not None: + container = container_weakref() + else: + container = container_weakref + state["_container"] = container + return state + + def __setstate__(self, state): + container = state.pop("_container", None) + state["_container"] = weakref.ref(container) if container is not None else None + self.__dict__.update(state) + @property def parent(self) -> Optional[EnvBase]: """Returns the parent env of the transform. @@ -832,36 +852,43 @@ def _inplace_update(self): @property def output_spec(self) -> TensorSpec: """Observation spec of the transformed environment.""" - if not self.cache_specs or self.__dict__.get("_output_spec", None) is None: - output_spec = self.base_env.output_spec.clone() - - # remove cached key values, but not _input_spec - super().empty_cache() - output_spec = output_spec.unlock_() - output_spec = self.transform.transform_output_spec(output_spec) - output_spec.lock_() - if self.cache_specs: - self.__dict__["_output_spec"] = output_spec - else: - output_spec = self.__dict__.get("_output_spec", None) + if self.cache_specs: + output_spec = self.__dict__.get("_output_spec") + if output_spec is not None: + return output_spec + output_spec = self._make_output_spec() + return output_spec + + @_maybe_unlock + def _make_output_spec(self): + output_spec = self.base_env.output_spec.clone() + + # remove cached key values, but not _input_spec + super().empty_cache() + output_spec = self.transform.transform_output_spec(output_spec) + if self.cache_specs: + self.__dict__["_output_spec"] = output_spec return output_spec @property def input_spec(self) -> TensorSpec: - """Action spec of the transformed environment.""" - if self.__dict__.get("_input_spec", None) is None or not self.cache_specs: - input_spec = self.base_env.input_spec.clone() - - # remove cached key values but not _output_spec - super().empty_cache() - - input_spec.unlock_() - input_spec = self.transform.transform_input_spec(input_spec) - input_spec.lock_() - if self.cache_specs: - self.__dict__["_input_spec"] = input_spec - else: - input_spec = self.__dict__.get("_input_spec", None) + """Observation spec of the transformed environment.""" + if self.cache_specs: + input_spec = self.__dict__.get("_input_spec") + if input_spec is not None: + return input_spec + input_spec = self._make_input_spec() + return input_spec + + @_maybe_unlock + def _make_input_spec(self): + input_spec = self.base_env.input_spec.clone() + + # remove cached key values, but not _input_spec + super().empty_cache() + input_spec = self.transform.transform_input_spec(input_spec) + if self.cache_specs: + self.__dict__["_input_spec"] = input_spec return input_spec def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict: @@ -1718,7 +1745,7 @@ def reset_key(self): f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor." ) reset_key = reset_keys[0] - self._reset_key = reset_keys + self._reset_key = reset_key return reset_key @reset_key.setter @@ -6489,7 +6516,7 @@ def __repr__(self) -> str: ) def __getstate__(self) -> Dict[str, Any]: - state = self.__dict__.copy() + state = super().__getstate__() _lock = state.pop("lock", None) if _lock is not None: state["lock_placeholder"] = None @@ -6500,7 +6527,7 @@ def __setstate__(self, state: Dict[str, Any]): state.pop("lock_placeholder") _lock = mp.Lock() state["lock"] = _lock - self.__dict__.update(state) + super().__setstate__(state) @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: @@ -9932,14 +9959,7 @@ def __init__(self, out_key: NestedKey = "traj_count"): def _make_shared_value(self): self._traj_count = mp.Value("i", 0) - @implement_for("torch", None, "2.1") def __getstate__(self): - state = self.__dict__.copy() - state["_traj_count"] = None - return state - - @implement_for("torch", "2.1") - def __getstate__(self): # noqa: F811 state = super().__getstate__() state["_traj_count"] = None return state