From 146341bf997446d26d59a7a36fb15aeba81435c5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 29 Feb 2024 20:43:35 -0500 Subject: [PATCH] [BugFix] Fix broken gym tests (#1980) --- .github/unittest/linux/scripts/run_all.sh | 4 +- .../linux_distributed/scripts/run_test.sh | 2 +- .../linux_libs/scripts_gym/batch_scripts.sh | 75 +- .../linux_libs/scripts_gym/run_test.sh | 2 +- test/_utils_internal.py | 79 +- test/conftest.py | 26 +- test/smoke_test_deps.py | 2 +- test/test_collector.py | 10 +- test/test_env.py | 199 +++-- test/test_libs.py | 698 ++++++++++-------- test/test_trainer.py | 4 +- test/test_transforms.py | 112 +-- torchrl/_utils.py | 1 + torchrl/data/tensor_specs.py | 4 +- torchrl/envs/batched_envs.py | 12 +- torchrl/envs/libs/gym.py | 34 +- 16 files changed, 711 insertions(+), 553 deletions(-) diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index ea32113462d..cbdcf4ede03 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -194,11 +194,11 @@ pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_contro if [ "${CU_VERSION:-}" != cpu ] ; then python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \ - --timeout=120 + --timeout=120 --mp_fork_if_no_cuda else python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \ - --timeout=120 + --timeout=120 --mp_fork_if_no_cuda fi coverage combine diff --git a/.github/unittest/linux_distributed/scripts/run_test.sh b/.github/unittest/linux_distributed/scripts/run_test.sh index fe7d1ba1ea3..176fefcd73d 100755 --- a/.github/unittest/linux_distributed/scripts/run_test.sh +++ b/.github/unittest/linux_distributed/scripts/run_test.sh @@ -23,6 +23,6 @@ export BATCHED_PIPE_TIMEOUT=60 python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' -python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 --mp_fork_if_no_cuda coverage combine coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index 321da982d2e..06bb33b6ac1 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -19,10 +19,10 @@ apt-get update && apt-get install -y git wget libglew-dev libx11-dev x11proto-de # solves "'extras_require' must be a dictionary" pip install setuptools==65.3.0 -mkdir third_party -cd third_party -git clone https://github.com/vmoens/gym -cd .. +#mkdir -p third_party +#cd third_party +#git clone https://github.com/vmoens/gym +#cd .. # This version is installed initially (see environment.yml) for GYM_VERSION in '0.13' @@ -38,7 +38,7 @@ do # delete the conda copy conda deactivate - conda env remove --prefix ./cloned_env + conda env remove --prefix ./cloned_env -y done # gym[atari]==0.19 is broken, so we install only gym without dependencies. @@ -57,7 +57,7 @@ do # delete the conda copy conda deactivate - conda env remove --prefix ./cloned_env + conda env remove --prefix ./cloned_env -y done # gym[atari]==0.20 installs ale-py==0.8, but this version is not compatible with gym<0.26, so we downgrade it. @@ -76,7 +76,7 @@ do # delete the conda copy conda deactivate - conda env remove --prefix ./cloned_env + conda env remove --prefix ./cloned_env -y done for GYM_VERSION in '0.25' @@ -92,7 +92,7 @@ do # delete the conda copy conda deactivate - conda env remove --prefix ./cloned_env + conda env remove --prefix ./cloned_env -y done # For this version "gym[accept-rom-license]" is required. @@ -104,18 +104,17 @@ do conda activate ./cloned_env echo "Testing gym version: ${GYM_VERSION}" - pip3 install 'gym[accept-rom-license]'==$GYM_VERSION - pip3 install 'gym[atari]'==$GYM_VERSION + pip3 install 'gym[atari,accept-rom-license]'==$GYM_VERSION pip3 install gym-super-mario-bros $DIR/run_test.sh # delete the conda copy conda deactivate - conda env remove --prefix ./cloned_env + conda env remove --prefix ./cloned_env -y done # For this version "gym[accept-rom-license]" is required. -for GYM_VERSION in '0.27' +for GYM_VERSION in '0.27' '0.28' do # Create a copy of the conda env and work with this conda deactivate @@ -123,46 +122,24 @@ do conda activate ./cloned_env echo "Testing gym version: ${GYM_VERSION}" - pip3 install 'gymnasium[accept-rom-license]'==$GYM_VERSION - - - if [[ $OSTYPE != 'darwin'* ]]; then - # install ale-py: manylinux names are broken for CentOS so we need to manually download and - # rename them - PY_VERSION=$(python --version) - if [[ $PY_VERSION == *"3.7"* ]]; then - wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.8"* ]]; then - wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.9"* ]]; then - wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.10"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.11"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - fi - pip install gymnasium[atari] - else - pip install gymnasium[atari] - fi - pip install mo-gymnasium - pip install gymnasium-robotics + pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION $DIR/run_test.sh # delete the conda copy conda deactivate - conda env remove --prefix ./cloned_env + conda env remove --prefix ./cloned_env -y done + +# Latest gymnasium +conda deactivate +conda create --prefix ./cloned_env --clone ./env -y +conda activate ./cloned_env + +pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U + +$DIR/run_test.sh + +# delete the conda copy +conda deactivate +conda env remove --prefix ./cloned_env -y diff --git a/.github/unittest/linux_libs/scripts_gym/run_test.sh b/.github/unittest/linux_libs/scripts_gym/run_test.sh index d59c5ce6213..1183109d497 100755 --- a/.github/unittest/linux_libs/scripts_gym/run_test.sh +++ b/.github/unittest/linux_libs/scripts_gym/run_test.sh @@ -23,6 +23,6 @@ python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_te export DISPLAY=':99.0' Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & -python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips --mp_fork coverage combine coverage xml -i diff --git a/test/_utils_internal.py b/test/_utils_internal.py index ec73812844b..6c267768044 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -24,7 +24,7 @@ from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv from torchrl.envs.libs.envpool import _has_envpool -from torchrl.envs.libs.gym import _has_gym, GymEnv +from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv from torchrl.envs.transforms import ( Compose, RewardClipping, @@ -35,41 +35,72 @@ # Specified for test_utils.py __version__ = "0.3" -# Default versions of the environments. -CARTPOLE_VERSIONED = "CartPole-v1" -HALFCHEETAH_VERSIONED = "HalfCheetah-v4" -PENDULUM_VERSIONED = "Pendulum-v1" -PONG_VERSIONED = "ALE/Pong-v5" + +def CARTPOLE_VERSIONED(): + # load gym + if gym_backend() is not None: + _set_gym_environments() + return _CARTPOLE_VERSIONED + + +def HALFCHEETAH_VERSIONED(): + # load gym + if gym_backend() is not None: + _set_gym_environments() + return _HALFCHEETAH_VERSIONED + + +def PONG_VERSIONED(): + # load gym + if gym_backend() is not None: + _set_gym_environments() + return _PONG_VERSIONED + + +def PENDULUM_VERSIONED(): + # load gym + if gym_backend() is not None: + _set_gym_environments() + return _PENDULUM_VERSIONED + + +def _set_gym_environments(): + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + + _CARTPOLE_VERSIONED = None + _HALFCHEETAH_VERSIONED = None + _PENDULUM_VERSIONED = None + _PONG_VERSIONED = None @implement_for("gym", None, "0.21.0") def _set_gym_environments(): # noqa: F811 - global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED - CARTPOLE_VERSIONED = "CartPole-v0" - HALFCHEETAH_VERSIONED = "HalfCheetah-v2" - PENDULUM_VERSIONED = "Pendulum-v0" - PONG_VERSIONED = "Pong-v4" + _CARTPOLE_VERSIONED = "CartPole-v0" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v2" + _PENDULUM_VERSIONED = "Pendulum-v0" + _PONG_VERSIONED = "Pong-v4" @implement_for("gym", "0.21.0", None) def _set_gym_environments(): # noqa: F811 - global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED - CARTPOLE_VERSIONED = "CartPole-v1" - HALFCHEETAH_VERSIONED = "HalfCheetah-v4" - PENDULUM_VERSIONED = "Pendulum-v1" - PONG_VERSIONED = "ALE/Pong-v5" + _CARTPOLE_VERSIONED = "CartPole-v1" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" + _PENDULUM_VERSIONED = "Pendulum-v1" + _PONG_VERSIONED = "ALE/Pong-v5" @implement_for("gymnasium") def _set_gym_environments(): # noqa: F811 - global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED - CARTPOLE_VERSIONED = "CartPole-v1" - HALFCHEETAH_VERSIONED = "HalfCheetah-v4" - PENDULUM_VERSIONED = "Pendulum-v1" - PONG_VERSIONED = "ALE/Pong-v5" + _CARTPOLE_VERSIONED = "CartPole-v1" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" + _PENDULUM_VERSIONED = "Pendulum-v1" + _PONG_VERSIONED = "ALE/Pong-v5" if _has_gym: @@ -171,7 +202,7 @@ def create_env_fn(): return GymEnv(env_name, frame_skip=frame_skip, device=device) else: - if env_name == PONG_VERSIONED: + if env_name == PONG_VERSIONED(): def create_env_fn(): base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) @@ -250,7 +281,7 @@ def _make_multithreaded_env( torch.manual_seed(0) multithreaded_kwargs = ( - {"frame_skip": frame_skip} if env_name == PONG_VERSIONED else {} + {"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {} ) env_multithread = MultiThreadedEnv( N, @@ -274,7 +305,7 @@ def _make_multithreaded_env( def get_transform_out(env_name, transformed_in, obs_key=None): - if env_name == PONG_VERSIONED: + if env_name == PONG_VERSIONED(): if obs_key is None: obs_key = "pixels" diff --git a/test/conftest.py b/test/conftest.py index 2dcd369003a..bbf61cb84dd 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,9 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import os - +import sys import time import warnings from collections import defaultdict @@ -12,6 +11,7 @@ import pytest CALL_TIMES = defaultdict(lambda: 0.0) +IS_OSX = sys.platform == "darwin" def pytest_sessionfinish(maxprint=50): @@ -97,6 +97,20 @@ def pytest_addoption(parser): "--runslow", action="store_true", default=False, help="run slow tests" ) + parser.addoption( + "--mp_fork", + action="store_true", + default=False, + help="Use 'fork' start method for mp dedicated tests.", + ) + + parser.addoption( + "--mp_fork_if_no_cuda", + action="store_true", + default=False, + help="Use 'fork' start method for mp dedicated tests only if there is no cuda device available.", + ) + def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") @@ -110,3 +124,11 @@ def pytest_collection_modifyitems(config, items): for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) + + +@pytest.fixture +def maybe_fork_ParallelEnv(request): + # Feature available from 0.4 only + from torchrl.envs import ParallelEnv + + return ParallelEnv diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index 1f153958ff5..e0730565709 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -48,7 +48,7 @@ def test_gym(): assert _has_gym from _utils_internal import PONG_VERSIONED - env = GymEnv(PONG_VERSIONED) + env = GymEnv(PONG_VERSIONED()) env.reset() diff --git a/test/test_collector.py b/test/test_collector.py index 09c6ee293c3..9f544dba194 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -598,7 +598,7 @@ def make_env(): # This is currently necessary as the methods in GymWrapper may have mismatching backend # versions. with set_gym_backend(gym_backend()): - return TransformedEnv(GymEnv(PONG_VERSIONED, frame_skip=4), StepCounter()) + return TransformedEnv(GymEnv(PONG_VERSIONED(), frame_skip=4), StepCounter()) if parallel: env = ParallelEnv(2, make_env) @@ -1076,7 +1076,9 @@ def test_collector_vecnorm_envcreator(static_seed): from torchrl.envs.libs.gym import GymEnv num_envs = 4 - env_make = EnvCreator(lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED), VecNorm())) + env_make = EnvCreator( + lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm()) + ) env_make = ParallelEnv(num_envs, env_make) policy = RandomPolicy(env_make.action_spec) @@ -1293,7 +1295,7 @@ def test_collector_output_keys( policy = SafeModule(**policy_kwargs) - env_maker = lambda: GymEnv(PENDULUM_VERSIONED) + env_maker = lambda: GymEnv(PENDULUM_VERSIONED()) policy(env_maker().reset()) @@ -1432,7 +1434,7 @@ class TestAutoWrap: def env_maker(self): from torchrl.envs.libs.gym import GymEnv - return lambda: GymEnv(PENDULUM_VERSIONED) + return lambda: GymEnv(PENDULUM_VERSIONED()) def _create_collector_kwargs(self, env_maker, collector_class, policy): collector_kwargs = { diff --git a/test/test_env.py b/test/test_env.py index e6f5a5dc25f..9bf9ba09d24 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -108,6 +108,7 @@ atari_confs = defaultdict(lambda: "") IS_OSX = platform == "darwin" +IS_WIN = platform == "win32" ## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between ## the serial and parallel batched envs @@ -158,6 +159,7 @@ @pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, CARTPOLE_VERSIONED]) @pytest.mark.parametrize("frame_skip", [1, 4]) def test_env_seed(env_name, frame_skip, seed=0): + env_name = env_name() env = GymEnv(env_name, frame_skip=frame_skip) action = env.action_spec.rand() @@ -190,6 +192,7 @@ def test_env_seed(env_name, frame_skip, seed=0): @pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED]) @pytest.mark.parametrize("frame_skip", [1, 4]) def test_rollout(env_name, frame_skip, seed=0): + env_name = env_name() env = GymEnv(env_name, frame_skip=frame_skip) torch.manual_seed(seed) @@ -278,7 +281,10 @@ def test_rollout_predictability(device): @pytest.mark.parametrize("frame_skip", [1]) @pytest.mark.parametrize("truncated_key", ["truncated", "done"]) @pytest.mark.parametrize("parallel", [False, True]) -def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0): +def test_rollout_reset( + env_name, frame_skip, parallel, truncated_key, maybe_fork_ParallelEnv, seed=0 +): + env_name = env_name() envs = [] for horizon in [20, 30, 40]: envs.append( @@ -288,7 +294,7 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0): ) ) if parallel: - env = ParallelEnv(3, envs) + env = maybe_fork_ParallelEnv(3, envs) else: env = SerialEnv(3, envs) env.set_seed(100) @@ -404,9 +410,11 @@ class TestParallel: @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) @pytest.mark.parametrize("bwad", [True, False]) - def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): + def test_parallel_devices( + self, parallel, hetero, pdevice, edevice, bwad, maybe_fork_ParallelEnv + ): if parallel: - cls = ParallelEnv + cls = maybe_fork_ParallelEnv else: cls = SerialEnv if not hetero: @@ -438,26 +446,32 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): env.shared_tensordict_parent.device.type == torch.device(edevice).type ) - def test_serial_for_single(self): + def test_serial_for_single(self, maybe_fork_ParallelEnv): env = ParallelEnv(1, ContinuousActionVecMockEnv, serial_for_single=True) assert isinstance(env, SerialEnv) - env = ParallelEnv(1, ContinuousActionVecMockEnv) + env = maybe_fork_ParallelEnv(1, ContinuousActionVecMockEnv) assert isinstance(env, ParallelEnv) - env = ParallelEnv(2, ContinuousActionVecMockEnv, serial_for_single=True) + env = maybe_fork_ParallelEnv( + 2, ContinuousActionVecMockEnv, serial_for_single=True + ) assert isinstance(env, ParallelEnv) @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) - def test_env_with_batch_size(self, num_parallel_env, env_batch_size): + def test_env_with_batch_size( + self, num_parallel_env, env_batch_size, maybe_fork_ParallelEnv + ): env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size)) env.set_seed(1) - parallel_env = ParallelEnv(num_parallel_env, lambda: env) + parallel_env = maybe_fork_ParallelEnv(num_parallel_env, lambda: env) assert parallel_env.batch_size == (num_parallel_env, *env_batch_size) @pytest.mark.skipif(not _has_dmc, reason="no dm_control") @pytest.mark.parametrize("env_task", ["stand,stand,stand", "stand,walk,stand"]) @pytest.mark.parametrize("share_individual_td", [True, False]) - def test_multi_task_serial_parallel(self, env_task, share_individual_td): + def test_multi_task_serial_parallel( + self, env_task, share_individual_td, maybe_fork_ParallelEnv + ): tasks = env_task.split(",") if len(tasks) == 1: single_task = True @@ -482,13 +496,17 @@ def env_make(): with pytest.raises( ValueError, match="share_individual_td must be set to None" ): - ParallelEnv(3, env_make, share_individual_td=share_individual_td) + maybe_fork_ParallelEnv( + 3, env_make, share_individual_td=share_individual_td + ) return env_serial = SerialEnv(3, env_make, share_individual_td=share_individual_td) env_serial.start() assert env_serial._single_task is single_task - env_parallel = ParallelEnv(3, env_make, share_individual_td=share_individual_td) + env_parallel = maybe_fork_ParallelEnv( + 3, env_make, share_individual_td=share_individual_td + ) env_parallel.start() assert env_parallel._single_task is single_task @@ -503,7 +521,7 @@ def env_make(): assert_allclose_td(td_serial, td_parallel) @pytest.mark.skipif(not _has_dmc, reason="no dm_control") - def test_multitask(self): + def test_multitask(self, maybe_fork_ParallelEnv): env1 = DMControlEnv("humanoid", "stand") env1_obs_keys = list(env1.observation_spec.keys()) env2 = DMControlEnv("humanoid", "walk") @@ -538,7 +556,7 @@ def env2_maker(): ), ) - env = ParallelEnv(2, [env1_maker, env2_maker]) + env = maybe_fork_ParallelEnv(2, [env1_maker, env2_maker]) # env = SerialEnv(2, [env1_maker, env2_maker]) assert not env._single_task @@ -564,6 +582,7 @@ def env2_maker(): def test_parallel_env( self, env_name, frame_skip, transformed_in, transformed_out, T=10, N=3 ): + env_name = env_name() env_parallel, env_serial, _, env0 = _make_envs( env_name, frame_skip, @@ -610,6 +629,7 @@ def test_parallel_env_with_policy( T=10, N=3, ): + env_name = env_name() env_parallel, env_serial, _, env0 = _make_envs( env_name, frame_skip, @@ -664,7 +684,9 @@ def test_parallel_env_with_policy( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("heterogeneous", [False, True]) - def test_transform_env_transform_no_device(self, heterogeneous): + def test_transform_env_transform_no_device( + self, heterogeneous, maybe_fork_ParallelEnv + ): # Tests non-regression on 1865 def make_env(): return TransformedEnv( @@ -675,7 +697,7 @@ def make_env(): make_envs = [EnvCreator(make_env), EnvCreator(make_env)] else: make_envs = make_env - penv = ParallelEnv(2, make_envs) + penv = maybe_fork_ParallelEnv(2, make_envs) r = penv.rollout(6, break_when_any_done=False) assert r.shape == (2, 6) try: @@ -688,9 +710,7 @@ def make_env(): @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.parametrize( "env_name", - [ - PENDULUM_VERSIONED, - ], + [PENDULUM_VERSIONED], ) # PONG_VERSIONED]) # 1226: efficiency @pytest.mark.parametrize("frame_skip", [4]) @pytest.mark.parametrize( @@ -700,6 +720,7 @@ def make_env(): def test_parallel_env_seed( self, env_name, frame_skip, transformed_in, transformed_out, static_seed ): + env_name = env_name() env_parallel, env_serial, _, _ = _make_envs( env_name, frame_skip, transformed_in, transformed_out, 5 ) @@ -738,9 +759,9 @@ def test_parallel_env_seed( env_serial.close() @pytest.mark.skipif(not _has_gym, reason="no gym") - def test_parallel_env_shutdown(self): - env_make = EnvCreator(lambda: GymEnv(PENDULUM_VERSIONED)) - env = ParallelEnv(4, env_make) + def test_parallel_env_shutdown(self, maybe_fork_ParallelEnv): + env_make = EnvCreator(lambda: GymEnv(PENDULUM_VERSIONED())) + env = maybe_fork_ParallelEnv(4, env_make) env.reset() assert not env.is_closed env.rand_step() @@ -752,11 +773,11 @@ def test_parallel_env_shutdown(self): env.close() @pytest.mark.parametrize("parallel", [True, False]) - def test_parallel_env_custom_method(self, parallel): + def test_parallel_env_custom_method(self, parallel, maybe_fork_ParallelEnv): # define env if parallel: - env = ParallelEnv(2, lambda: DiscreteActionVecMockEnv()) + env = maybe_fork_ParallelEnv(2, lambda: DiscreteActionVecMockEnv()) else: env = SerialEnv(2, lambda: DiscreteActionVecMockEnv()) @@ -792,6 +813,7 @@ def test_parallel_env_cast( open_before, N=3, ): + env_name = env_name() # tests casting to device env_parallel, env_serial, _, env0 = _make_envs( env_name, @@ -884,6 +906,7 @@ def test_parallel_env_cast( def test_parallel_env_device( self, env_name, frame_skip, transformed_in, transformed_out, device ): + env_name = env_name() # tests creation on device torch.manual_seed(0) N = 3 @@ -924,6 +947,7 @@ def test_parallel_env_device( [torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")], ) def test_parallel_env_transform_consistency(self, env_name, frame_skip, device): + env_name = env_name() env_parallel_in, env_serial_in, _, env0_in = _make_envs( env_name, frame_skip, @@ -971,7 +995,7 @@ def test_parallel_env_transform_consistency(self, env_name, frame_skip, device): env0_in.close() @pytest.mark.parametrize("parallel", [True, False]) - def test_parallel_env_kwargs_set(self, parallel): + def test_parallel_env_kwargs_set(self, parallel, maybe_fork_ParallelEnv): num_env = 2 def make_make_env(): @@ -983,7 +1007,7 @@ def make_transformed_env(seed=None): return make_transformed_env - _class = ParallelEnv if parallel else SerialEnv + _class = maybe_fork_ParallelEnv if parallel else SerialEnv def env_fn1(seed): env = _class( @@ -1013,9 +1037,11 @@ def env_fn2(seed): @pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()]) @pytest.mark.parametrize("n_workers", [2, 1]) - def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): + def test_parallel_env_reset_flag( + self, batch_size, n_workers, maybe_fork_ParallelEnv, max_steps=3 + ): torch.manual_seed(1) - env = ParallelEnv( + env = maybe_fork_ParallelEnv( n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) ) env.set_seed(1) @@ -1059,6 +1085,7 @@ def test_parallel_env_nested( nested_done, nested_reward, env_type, + maybe_fork_ParallelEnv, n_envs=2, batch_size=(32,), nested_dim=5, @@ -1072,42 +1099,55 @@ def test_parallel_env_nested( batch_size=batch_size, nested_dim=nested_dim, ) + if env_type == "serial": env = SerialEnv(n_envs, env_fn) else: - env = ParallelEnv(n_envs, env_fn) - env.set_seed(seed) - - batch_size = (n_envs, *batch_size) - - td = env.reset() - assert td.batch_size == batch_size - if nested_done or nested_obs_action: - assert td["data"].batch_size == (*batch_size, nested_dim) - if not nested_done and not nested_reward and not nested_obs_action: - assert "data" not in td.keys() + env = maybe_fork_ParallelEnv(n_envs, env_fn) - policy = CountingEnvCountPolicy(env.action_spec, env.action_key) - td = env.rollout(rollout_length, policy) - assert td.batch_size == (*batch_size, rollout_length) - if nested_done or nested_obs_action: - assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim) - if nested_reward or nested_done or nested_obs_action: - assert td["next", "data"].batch_size == ( - *batch_size, - rollout_length, - nested_dim, - ) - if not nested_done and not nested_reward and not nested_obs_action: - assert "data" not in td.keys() - assert "data" not in td["next"].keys() + try: + env.set_seed(seed) + + batch_size = (n_envs, *batch_size) + + td = env.reset() + assert td.batch_size == batch_size + if nested_done or nested_obs_action: + assert td["data"].batch_size == (*batch_size, nested_dim) + if not nested_done and not nested_reward and not nested_obs_action: + assert "data" not in td.keys() + + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + td = env.rollout(rollout_length, policy) + assert td.batch_size == (*batch_size, rollout_length) + if nested_done or nested_obs_action: + assert td["data"].batch_size == ( + *batch_size, + rollout_length, + nested_dim, + ) + if nested_reward or nested_done or nested_obs_action: + assert td["next", "data"].batch_size == ( + *batch_size, + rollout_length, + nested_dim, + ) + if not nested_done and not nested_reward and not nested_obs_action: + assert "data" not in td.keys() + assert "data" not in td["next"].keys() - if nested_obs_action: - assert "observation" not in td.keys() - assert (td[..., -1]["data", "states"] == 2).all() - else: - assert ("data", "states") not in td.keys(True, True) - assert (td[..., -1]["observation"] == 2).all() + if nested_obs_action: + assert "observation" not in td.keys() + assert (td[..., -1]["data", "states"] == 2).all() + else: + assert ("data", "states") not in td.keys(True, True) + assert (td[..., -1]["observation"] == 2).all() + finally: + try: + env.close() + del env + except Exception: + pass @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) @@ -1146,13 +1186,13 @@ def test_env_base_reset_flag(batch_size, max_steps=3): @pytest.mark.skipif(not _has_gym, reason="no gym") def test_seed(): torch.manual_seed(0) - env1 = GymEnv(PENDULUM_VERSIONED) + env1 = GymEnv(PENDULUM_VERSIONED()) env1.set_seed(0) state0_1 = env1.reset() state1_1 = env1.step(state0_1.set("action", env1.action_spec.rand())) torch.manual_seed(0) - env2 = GymEnv(PENDULUM_VERSIONED) + env2 = GymEnv(PENDULUM_VERSIONED()) env2.set_seed(0) state0_2 = env2.reset() state1_2 = env2.step(state0_2.set("action", env2.action_spec.rand())) @@ -1695,7 +1735,7 @@ def test_info_dict_reader(self, device, seed=0): except ModuleNotFoundError: import gym - env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device) + env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device) env.set_info_dict_reader(default_info_dict_reader(["x_position"])) assert "x_position" in env.observation_spec.keys() @@ -1734,13 +1774,13 @@ def test_info_dict_reader(self, device, seed=0): reason="older versions of half-cheetah do not have 'x_position' info key.", ) @pytest.mark.parametrize("device", get_default_devices()) - def test_auto_register(self, device): + def test_auto_register(self, device, maybe_fork_ParallelEnv): try: import gymnasium as gym except ModuleNotFoundError: import gym - env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device) + env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device) check_env_specs(env) env.set_info_dict_reader() with pytest.raises( @@ -1748,21 +1788,21 @@ def test_auto_register(self, device): ): check_env_specs(env) - env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device) + env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device) env = env.auto_register_info_dict() check_env_specs(env) # check that the env can be executed in parallel - penv = ParallelEnv( + penv = maybe_fork_ParallelEnv( 2, lambda: GymWrapper( - gym.make(HALFCHEETAH_VERSIONED), device=device + gym.make(HALFCHEETAH_VERSIONED()), device=device ).auto_register_info_dict(), ) - senv = ParallelEnv( + senv = maybe_fork_ParallelEnv( 2, lambda: GymWrapper( - gym.make(HALFCHEETAH_VERSIONED), device=device + gym.make(HALFCHEETAH_VERSIONED()), device=device ).auto_register_info_dict(), ) try: @@ -2201,7 +2241,14 @@ def test_rollout(self, batch_size, rollout_steps, max_steps, seed): @pytest.mark.parametrize("env_type", ["serial", "parallel"]) @pytest.mark.parametrize("max_steps", [2, 5]) def test_parallel( - self, batch_size, rollout_steps, env_type, max_steps, seed, n_workers=2 + self, + batch_size, + rollout_steps, + env_type, + max_steps, + seed, + maybe_fork_ParallelEnv, + n_workers=2, ): torch.manual_seed(seed) env_fun = lambda: MultiKeyCountingEnv( @@ -2210,7 +2257,7 @@ def test_parallel( if env_type == "serial": vec_env = SerialEnv(n_workers, env_fun) else: - vec_env = ParallelEnv(n_workers, env_fun) + vec_env = maybe_fork_ParallelEnv(n_workers, env_fun) # check_env_specs(vec_env) policy = MultiKeyCountingEnvPolicy( @@ -2429,16 +2476,16 @@ def test_num_threads(self): IS_OSX, reason="setting different threads across workers can randomly fail on OSX.", ) - def test_auto_num_threads(self): + def test_auto_num_threads(self, maybe_fork_ParallelEnv): init_threads = torch.get_num_threads() try: - env3 = ParallelEnv(3, ContinuousActionVecMockEnv) + env3 = maybe_fork_ParallelEnv(3, ContinuousActionVecMockEnv) env3.rollout(2) assert torch.get_num_threads() == max(1, init_threads - 3) - env2 = ParallelEnv(2, ContinuousActionVecMockEnv) + env2 = maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv) env2.rollout(2) assert torch.get_num_threads() == max(1, init_threads - 5) @@ -2512,7 +2559,7 @@ def test_auto_cast_to_device(break_when_any_done): @pytest.mark.parametrize("device", get_default_devices()) -def test_backprop(device): +def test_backprop(device, maybe_fork_ParallelEnv): # Tests that backprop through a series of single envs and through a serial env are identical # Also tests that no backprop can be achieved with parallel env. class DifferentiableEnv(EnvBase): @@ -2581,7 +2628,7 @@ def make_env(seed, device=device): ) torch.testing.assert_close(g, g_serial) - p_env = ParallelEnv( + p_env = maybe_fork_ParallelEnv( 2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], device=device, diff --git a/test/test_libs.py b/test/test_libs.py index 427eef522d0..a3b466674fa 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import importlib from contextlib import nullcontext @@ -79,7 +78,6 @@ DoubleToFloat, EnvBase, EnvCreator, - ParallelEnv, RemoveEmptySpecs, RenameTransform, ) @@ -92,6 +90,7 @@ _has_gym, _is_from_pixels, _torchrl_to_gym_spec_transform, + gym_backend, GymEnv, GymWrapper, MOGymEnv, @@ -120,25 +119,40 @@ _has_gymnasium = importlib.util.find_spec("gymnasium") is not None _has_gym_regular = importlib.util.find_spec("gym") is not None +if _has_gymnasium: + set_gym_backend("gymnasium").set() + import gymnasium + + assert gym_backend() is gymnasium +elif _has_gym: + set_gym_backend("gym").set() + import gym + + assert gym_backend() is gym + + +def get_gym_pixel_wrapper(): + try: + # works whenever gym_version > version.parse("0.19") + PixelObservationWrapper = gym_backend( + "wrappers.pixel_observation" + ).PixelObservationWrapper + except Exception as err: + from torchrl.envs.libs.utils import ( + GymPixelObservationWrapper as PixelObservationWrapper, + ) + return PixelObservationWrapper + if _has_gym: try: - import gymnasium as gym from gymnasium import __version__ as gym_version gym_version = version.parse(gym_version) - from gymnasium.wrappers.pixel_observation import PixelObservationWrapper except ModuleNotFoundError: - import gym - - gym_version = version.parse(gym.__version__) - if gym_version > version.parse("0.19"): - from gym.wrappers.pixel_observation import PixelObservationWrapper - else: - from torchrl.envs.libs.utils import ( - GymPixelObservationWrapper as PixelObservationWrapper, - ) + from gym import __version__ as gym_version + gym_version = version.parse(gym_version) if _has_dmc: from dm_control import suite @@ -290,162 +304,148 @@ def test_gym_spec_cast(self, categorical): def test_torchrl_to_gym(self, backend, numpy): from torchrl.envs.libs.gym import gym_backend, set_gym_backend - EnvBase.register_gym( - f"Dummy-{numpy}-{backend}-v0", - entry_point=self.DummyEnv, - to_numpy=numpy, - backend=backend, - arg1=1, - arg2=2, - ) - - with set_gym_backend(backend) if backend is not None else nullcontext(): - envgym = gym_backend().make(f"Dummy-{numpy}-{backend}-v0") - envgym.reset() - obs, *_ = envgym.step(envgym.action_space.sample()) - assert "observation" in obs - assert "other" in obs - if numpy: - assert all(isinstance(val, np.ndarray) for val in tree_flatten(obs)[0]) - else: - assert all( - isinstance(val, torch.Tensor) for val in tree_flatten(obs)[0] - ) - - # with a transform - transform = Compose( - CatTensors(["observation", ("other", "another_other")]), - RemoveEmptySpecs(), - ) - envgym = gym_backend().make( + gb = gym_backend() + try: + EnvBase.register_gym( f"Dummy-{numpy}-{backend}-v0", - transform=transform, + entry_point=self.DummyEnv, + to_numpy=numpy, + backend=backend, + arg1=1, + arg2=2, ) - envgym.reset() - obs, *_ = envgym.step(envgym.action_space.sample()) - assert "observation_other" not in obs - assert "observation" not in obs - assert "other" not in obs - if numpy: - assert all(isinstance(val, np.ndarray) for val in tree_flatten(obs)[0]) - else: - assert all( - isinstance(val, torch.Tensor) for val in tree_flatten(obs)[0] - ) - # register with transform - transform = Compose( - CatTensors(["observation", ("other", "another_other")]), RemoveEmptySpecs() - ) - EnvBase.register_gym( - f"Dummy-{numpy}-{backend}-transform-v0", - entry_point=self.DummyEnv, - backend=backend, - to_numpy=numpy, - arg1=1, - arg2=2, - transform=transform, - ) - - with set_gym_backend(backend) if backend is not None else nullcontext(): - envgym = gym_backend().make(f"Dummy-{numpy}-{backend}-transform-v0") - envgym.reset() - obs, *_ = envgym.step(envgym.action_space.sample()) - assert "observation_other" not in obs - assert "observation" not in obs - assert "other" not in obs - if numpy: - assert all(isinstance(val, np.ndarray) for val in tree_flatten(obs)[0]) - else: - assert all( - isinstance(val, torch.Tensor) for val in tree_flatten(obs)[0] - ) + with set_gym_backend(backend) if backend is not None else nullcontext(): + envgym = gym_backend().make(f"Dummy-{numpy}-{backend}-v0") + envgym.reset() + obs, *_ = envgym.step(envgym.action_space.sample()) + assert "observation" in obs + assert "other" in obs + if numpy: + assert all( + isinstance(val, np.ndarray) for val in tree_flatten(obs)[0] + ) + else: + assert all( + isinstance(val, torch.Tensor) for val in tree_flatten(obs)[0] + ) - # register with transform - EnvBase.register_gym( - f"Dummy-{numpy}-{backend}-noarg-v0", - entry_point=self.DummyEnv, - backend=backend, - to_numpy=numpy, - ) - with set_gym_backend(backend) if backend is not None else nullcontext(): - with pytest.raises(AssertionError): + # with a transform + transform = Compose( + CatTensors(["observation", ("other", "another_other")]), + RemoveEmptySpecs(), + ) envgym = gym_backend().make( - f"Dummy-{numpy}-{backend}-noarg-v0", arg1=None, arg2=None + f"Dummy-{numpy}-{backend}-v0", + transform=transform, ) - envgym = gym_backend().make( - f"Dummy-{numpy}-{backend}-noarg-v0", arg1=1, arg2=2 - ) - - # Get info dict - gym_info_at_reset = version.parse(gym_backend().__version__) >= version.parse( - "0.26.0" - ) - with set_gym_backend(backend) if backend is not None else nullcontext(): - envgym = gym_backend().make( - f"Dummy-{numpy}-{backend}-noarg-v0", - arg1=1, - arg2=2, - info_keys=("other",), - ) - if gym_info_at_reset: - out, info = envgym.reset() + envgym.reset() + obs, *_ = envgym.step(envgym.action_space.sample()) + assert "observation_other" not in obs + assert "observation" not in obs + assert "other" not in obs if numpy: assert all( - isinstance(val, np.ndarray) - for val in tree_flatten((obs, info))[0] + isinstance(val, np.ndarray) for val in tree_flatten(obs)[0] ) else: assert all( - isinstance(val, torch.Tensor) - for val in tree_flatten((obs, info))[0] + isinstance(val, torch.Tensor) for val in tree_flatten(obs)[0] ) - else: - out = envgym.reset() - info = {} + + # register with transform + transform = Compose( + CatTensors(["observation", ("other", "another_other")]), + RemoveEmptySpecs(), + ) + EnvBase.register_gym( + f"Dummy-{numpy}-{backend}-transform-v0", + entry_point=self.DummyEnv, + backend=backend, + to_numpy=numpy, + arg1=1, + arg2=2, + transform=transform, + ) + + with set_gym_backend(backend) if backend is not None else nullcontext(): + envgym = gym_backend().make(f"Dummy-{numpy}-{backend}-transform-v0") + envgym.reset() + obs, *_ = envgym.step(envgym.action_space.sample()) + assert "observation_other" not in obs + assert "observation" not in obs + assert "other" not in obs if numpy: assert all( - isinstance(val, np.ndarray) - for val in tree_flatten((obs, info))[0] + isinstance(val, np.ndarray) for val in tree_flatten(obs)[0] ) else: assert all( - isinstance(val, torch.Tensor) - for val in tree_flatten((obs, info))[0] + isinstance(val, torch.Tensor) for val in tree_flatten(obs)[0] ) - assert "observation" in out - assert "other" not in out - if gym_info_at_reset: - assert "other" in info - - out, *_, info = envgym.step(envgym.action_space.sample()) - assert "observation" in out - assert "other" not in out - assert "other" in info - if numpy: - assert all( - isinstance(val, np.ndarray) for val in tree_flatten((obs, info))[0] - ) - else: - assert all( - isinstance(val, torch.Tensor) - for val in tree_flatten((obs, info))[0] + # register with transform + EnvBase.register_gym( + f"Dummy-{numpy}-{backend}-noarg-v0", + entry_point=self.DummyEnv, + backend=backend, + to_numpy=numpy, + ) + with set_gym_backend(backend) if backend is not None else nullcontext(): + with pytest.raises(AssertionError): + envgym = gym_backend().make( + f"Dummy-{numpy}-{backend}-noarg-v0", arg1=None, arg2=None + ) + envgym = gym_backend().make( + f"Dummy-{numpy}-{backend}-noarg-v0", arg1=1, arg2=2 ) - EnvBase.register_gym( - f"Dummy-{numpy}-{backend}-info-v0", - entry_point=self.DummyEnv, - backend=backend, - to_numpy=numpy, - info_keys=("other",), - ) - with set_gym_backend(backend) if backend is not None else nullcontext(): - envgym = gym_backend().make( - f"Dummy-{numpy}-{backend}-info-v0", arg1=1, arg2=2 - ) - if gym_info_at_reset: - out, info = envgym.reset() + # Get info dict + gym_info_at_reset = version.parse( + gym_backend().__version__ + ) >= version.parse("0.26.0") + with set_gym_backend(backend) if backend is not None else nullcontext(): + envgym = gym_backend().make( + f"Dummy-{numpy}-{backend}-noarg-v0", + arg1=1, + arg2=2, + info_keys=("other",), + ) + if gym_info_at_reset: + out, info = envgym.reset() + if numpy: + assert all( + isinstance(val, np.ndarray) + for val in tree_flatten((obs, info))[0] + ) + else: + assert all( + isinstance(val, torch.Tensor) + for val in tree_flatten((obs, info))[0] + ) + else: + out = envgym.reset() + info = {} + if numpy: + assert all( + isinstance(val, np.ndarray) + for val in tree_flatten((obs, info))[0] + ) + else: + assert all( + isinstance(val, torch.Tensor) + for val in tree_flatten((obs, info))[0] + ) + assert "observation" in out + assert "other" not in out + + if gym_info_at_reset: + assert "other" in info + + out, *_, info = envgym.step(envgym.action_space.sample()) + assert "observation" in out + assert "other" not in out + assert "other" in info if numpy: assert all( isinstance(val, np.ndarray) @@ -456,9 +456,53 @@ def test_torchrl_to_gym(self, backend, numpy): isinstance(val, torch.Tensor) for val in tree_flatten((obs, info))[0] ) - else: - out = envgym.reset() - info = {} + + EnvBase.register_gym( + f"Dummy-{numpy}-{backend}-info-v0", + entry_point=self.DummyEnv, + backend=backend, + to_numpy=numpy, + info_keys=("other",), + ) + with set_gym_backend(backend) if backend is not None else nullcontext(): + envgym = gym_backend().make( + f"Dummy-{numpy}-{backend}-info-v0", arg1=1, arg2=2 + ) + if gym_info_at_reset: + out, info = envgym.reset() + if numpy: + assert all( + isinstance(val, np.ndarray) + for val in tree_flatten((obs, info))[0] + ) + else: + assert all( + isinstance(val, torch.Tensor) + for val in tree_flatten((obs, info))[0] + ) + else: + out = envgym.reset() + info = {} + if numpy: + assert all( + isinstance(val, np.ndarray) + for val in tree_flatten((obs, info))[0] + ) + else: + assert all( + isinstance(val, torch.Tensor) + for val in tree_flatten((obs, info))[0] + ) + assert "observation" in out + assert "other" not in out + + if gym_info_at_reset: + assert "other" in info + + out, *_, info = envgym.step(envgym.action_space.sample()) + assert "observation" in out + assert "other" not in out + assert "other" in info if numpy: assert all( isinstance(val, np.ndarray) @@ -469,31 +513,14 @@ def test_torchrl_to_gym(self, backend, numpy): isinstance(val, torch.Tensor) for val in tree_flatten((obs, info))[0] ) - assert "observation" in out - assert "other" not in out - - if gym_info_at_reset: - assert "other" in info - - out, *_, info = envgym.step(envgym.action_space.sample()) - assert "observation" in out - assert "other" not in out - assert "other" in info - if numpy: - assert all( - isinstance(val, np.ndarray) for val in tree_flatten((obs, info))[0] - ) - else: - assert all( - isinstance(val, torch.Tensor) - for val in tree_flatten((obs, info))[0] - ) + finally: + set_gym_backend(gb).set() @pytest.mark.parametrize( "env_name", [ - HALFCHEETAH_VERSIONED, - PONG_VERSIONED, + HALFCHEETAH_VERSIONED(), + PONG_VERSIONED(), # PENDULUM_VERSIONED, ], ) @@ -507,12 +534,15 @@ def test_torchrl_to_gym(self, backend, numpy): ], ) def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): - if env_name == PONG_VERSIONED and not from_pixels: + + if env_name == PONG_VERSIONED() and not from_pixels: # raise pytest.skip("already pixel") # we don't skip because that would raise an exception return elif ( - env_name != PONG_VERSIONED and from_pixels and torch.cuda.device_count() < 1 + env_name != PONG_VERSIONED() + and from_pixels + and torch.cuda.device_count() < 1 ): raise pytest.skip("no cuda device") @@ -567,13 +597,14 @@ def non_null_obs(batched_td): final_seed0, final_seed1 = final_seed assert final_seed0 == final_seed1 - if env_name == PONG_VERSIONED: - base_env = gym.make(env_name, frameskip=frame_skip) + if env_name == PONG_VERSIONED(): + base_env = gym_backend().make(env_name, frameskip=frame_skip) frame_skip = 1 else: base_env = _make_gym_environment(env_name) if from_pixels and not _is_from_pixels(base_env): + PixelObservationWrapper = get_gym_pixel_wrapper() base_env = PixelObservationWrapper(base_env, pixels_only=pixels_only) assert type(base_env) is env_type @@ -606,9 +637,9 @@ def non_null_obs(batched_td): @pytest.mark.parametrize( "env_name", [ - PONG_VERSIONED, + PONG_VERSIONED(), # PENDULUM_VERSIONED, - HALFCHEETAH_VERSIONED, + HALFCHEETAH_VERSIONED(), ], ) @pytest.mark.parametrize("frame_skip", [1, 3]) @@ -621,11 +652,11 @@ def non_null_obs(batched_td): ], ) def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): - if env_name == PONG_VERSIONED and not from_pixels: + if env_name == PONG_VERSIONED() and not from_pixels: # raise pytest.skip("already pixel") return elif ( - env_name != PONG_VERSIONED + env_name != PONG_VERSIONED() and from_pixels and (not torch.has_cuda or not torch.cuda.device_count()) ): @@ -679,36 +710,39 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_info_reader(self): + def test_info_reader_mario(self): try: import gym_super_mario_bros as mario_gym except ImportError as err: try: - import gym + gym = gym_backend() # with 0.26 we must have installed gym_super_mario_bros # Since we capture the skips as errors, we raise a skip in this case # Otherwise, we just return - if ( - version.parse("0.26.0") - <= version.parse(gym.__version__) - < version.parse("0.27.0") - ): + gym_version = version.parse(gym.__version__) + if version.parse( + "0.26.0" + ) <= gym_version and gym_version < version.parse("0.27"): raise pytest.skip(f"no super mario bros: error=\n{err}") except ImportError: pass return - env = mario_gym.make("SuperMarioBros-v0", apply_api_compatibility=True) - env = GymWrapper(env) + gb = gym_backend() + try: + with set_gym_backend("gym"): + env = mario_gym.make("SuperMarioBros-v0") + env = GymWrapper(env) + check_env_specs(env) - def info_reader(info, tensordict): - assert isinstance(info, dict) # failed before bugfix + def info_reader(info, tensordict): + assert isinstance(info, dict) # failed before bugfix - env.info_dict_reader = info_reader - env.reset() - env.rand_step() - env.rollout(3) + env.info_dict_reader = info_reader + check_env_specs(env) + finally: + set_gym_backend(gb).set() @implement_for("gymnasium") def test_one_hot_and_categorical(self): @@ -764,22 +798,26 @@ def test_vecenvs_wrapper(self, envname): ) @pytest.mark.flaky(reruns=5, reruns_delay=1) def test_vecenvs_env(self, envname): - with set_gym_backend("gymnasium"): - env = GymEnv(envname, num_envs=2, from_pixels=False) - env.set_seed(0) - assert env.get_library_name(env._env) == "gymnasium" - # rollouts can be executed without decorator - check_env_specs(env) - rollout = env.rollout(100, break_when_any_done=False) - for obs_key in env.observation_spec.keys(True, True): - rollout_consistency_assertion( - rollout, - done_key="done", - observation_key=obs_key, - done_strict="CartPole" in envname, - ) - env.close() - del env + gb = gym_backend() + try: + with set_gym_backend("gymnasium"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + env.set_seed(0) + assert env.get_library_name(env._env) == "gymnasium" + # rollouts can be executed without decorator + check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, + done_key="done", + observation_key=obs_key, + done_strict="CartPole" in envname, + ) + env.close() + del env + finally: + set_gym_backend(gb).set() @implement_for("gym", "0.18") @pytest.mark.parametrize( @@ -788,8 +826,7 @@ def test_vecenvs_env(self, envname): ) @pytest.mark.flaky(reruns=5, reruns_delay=1) def test_vecenvs_wrapper(self, envname): # noqa: F811 - import gym - + gym = gym_backend() # we can't use parametrize with implement_for for envname in ["CartPole-v1", "HalfCheetah-v4"]: env = GymWrapper( @@ -812,34 +849,42 @@ def test_vecenvs_wrapper(self, envname): # noqa: F811 @implement_for("gym", "0.18") @pytest.mark.parametrize( "envname", - ["CartPole-v1", "HalfCheetah-v4"], + ["cp", "hc"], ) @pytest.mark.flaky(reruns=5, reruns_delay=1) def test_vecenvs_env(self, envname): # noqa: F811 - with set_gym_backend("gym"): - env = GymEnv(envname, num_envs=2, from_pixels=False) - env.set_seed(0) - assert env.get_library_name(env._env) == "gym" - # rollouts can be executed without decorator - check_env_specs(env) - rollout = env.rollout(100, break_when_any_done=False) - for obs_key in env.observation_spec.keys(True, True): - rollout_consistency_assertion( - rollout, - done_key="done", - observation_key=obs_key, - done_strict="CartPole" in envname, - ) - env.close() - del env - if envname != "CartPole-v1": + gb = gym_backend() + try: with set_gym_backend("gym"): - env = GymEnv(envname, num_envs=2, from_pixels=True) + if envname == "hc": + envname = HALFCHEETAH_VERSIONED() + else: + envname = CARTPOLE_VERSIONED() + env = GymEnv(envname, num_envs=2, from_pixels=False) env.set_seed(0) + assert env.get_library_name(env._env) == "gym" # rollouts can be executed without decorator check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, + done_key="done", + observation_key=obs_key, + done_strict="CartPole" in envname, + ) env.close() del env + if envname != "CartPole-v1": + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=True) + env.set_seed(0) + # rollouts can be executed without decorator + check_env_specs(env) + env.close() + del env + finally: + set_gym_backend(gb).set() @implement_for("gym", None, "0.18") @pytest.mark.parametrize( @@ -863,88 +908,101 @@ def test_vecenvs_env(self, envname): # noqa: F811 @pytest.mark.parametrize("wrapper", [True, False]) def test_gym_output_num(self, wrapper): # gym has 4 outputs, no truncation - import gym - - if wrapper: - env = GymWrapper(gym.make(PENDULUM_VERSIONED)) - else: - with set_gym_backend("gym"): - env = GymEnv(PENDULUM_VERSIONED) - # truncated is read from the info - assert "truncated" in env.done_keys - assert "terminated" in env.done_keys - assert "done" in env.done_keys - check_env_specs(env) + gym = gym_backend() + try: + if wrapper: + env = GymWrapper(gym.make(PENDULUM_VERSIONED())) + else: + with set_gym_backend("gym"): + env = GymEnv(PENDULUM_VERSIONED()) + # truncated is read from the info + assert "truncated" in env.done_keys + assert "terminated" in env.done_keys + assert "done" in env.done_keys + check_env_specs(env) + finally: + set_gym_backend(gym).set() @implement_for("gym", "0.26") @pytest.mark.parametrize("wrapper", [True, False]) def test_gym_output_num(self, wrapper): # noqa: F811 # gym has 5 outputs, with truncation - import gym - - if wrapper: - env = GymWrapper(gym.make(PENDULUM_VERSIONED)) - else: - with set_gym_backend("gym"): - env = GymEnv(PENDULUM_VERSIONED) - assert "truncated" in env.done_keys - assert "terminated" in env.done_keys - assert "done" in env.done_keys - check_env_specs(env) + gym = gym_backend() + try: + if wrapper: + env = GymWrapper(gym.make(PENDULUM_VERSIONED())) + else: + with set_gym_backend("gym"): + env = GymEnv(PENDULUM_VERSIONED()) + assert "truncated" in env.done_keys + assert "terminated" in env.done_keys + assert "done" in env.done_keys + check_env_specs(env) - if wrapper: - # let's further test with a wrapper that exposes the env with old API - from gym.wrappers.compatibility import EnvCompatibility + if wrapper: + # let's further test with a wrapper that exposes the env with old API + from gym.wrappers.compatibility import EnvCompatibility - with pytest.raises( - ValueError, - match="GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility", - ): - GymWrapper(EnvCompatibility(gym.make("CartPole-v1"))) + with pytest.raises( + ValueError, + match="GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility", + ): + GymWrapper(EnvCompatibility(gym.make("CartPole-v1"))) + finally: + set_gym_backend(gym).set() @implement_for("gymnasium") @pytest.mark.parametrize("wrapper", [True, False]) def test_gym_output_num(self, wrapper): # noqa: F811 # gym has 5 outputs, with truncation - import gymnasium as gym - - if wrapper: - env = GymWrapper(gym.make(PENDULUM_VERSIONED)) - else: - with set_gym_backend("gymnasium"): - env = GymEnv(PENDULUM_VERSIONED) - assert "truncated" in env.done_keys - assert "terminated" in env.done_keys - assert "done" in env.done_keys - check_env_specs(env) + gym = gym_backend() + try: + if wrapper: + env = GymWrapper(gym.make(PENDULUM_VERSIONED())) + else: + with set_gym_backend("gymnasium"): + env = GymEnv(PENDULUM_VERSIONED()) + assert "truncated" in env.done_keys + assert "terminated" in env.done_keys + assert "done" in env.done_keys + check_env_specs(env) + finally: + set_gym_backend(gym).set() - def test_gym_gymnasium_parallel(self): + def test_gym_gymnasium_parallel(self, maybe_fork_ParallelEnv): # tests that both gym and gymnasium work with wrappers without # decorating with set_gym_backend during execution - if importlib.util.find_spec("gym") is not None: - import gym + gym = gym_backend() + try: + if importlib.util.find_spec("gym") is not None: + with set_gym_backend("gym"): + gym = gym_backend() - old_api = version.parse(gym.__version__) < version.parse("0.26") - make_fun = EnvCreator(lambda: GymWrapper(gym.make(PENDULUM_VERSIONED))) - elif importlib.util.find_spec("gymnasium") is not None: - import gymnasium + old_api = version.parse(gym.__version__) < version.parse("0.26") + make_fun = EnvCreator( + lambda: GymWrapper(gym.make(PENDULUM_VERSIONED())) + ) + elif importlib.util.find_spec("gymnasium") is not None: + import gymnasium - old_api = False - make_fun = EnvCreator( - lambda: GymWrapper(gymnasium.make(PENDULUM_VERSIONED)) - ) - else: - raise ImportError # unreachable under pytest.skipif - penv = ParallelEnv(2, make_fun) - rollout = penv.rollout(2) - if old_api: - assert "terminated" in rollout.keys() - # truncated is read from info - assert "truncated" in rollout.keys() - else: - assert "terminated" in rollout.keys() - assert "truncated" in rollout.keys() - check_env_specs(penv) + old_api = False + make_fun = EnvCreator( + lambda: GymWrapper(gymnasium.make(PENDULUM_VERSIONED())) + ) + else: + raise ImportError # unreachable under pytest.skipif + penv = maybe_fork_ParallelEnv(2, make_fun) + rollout = penv.rollout(2) + if old_api: + assert "terminated" in rollout.keys() + # truncated is read from info + assert "truncated" in rollout.keys() + else: + assert "terminated" in rollout.keys() + assert "truncated" in rollout.keys() + check_env_specs(penv) + finally: + set_gym_backend(gym).set() @implement_for("gym", None, "0.22.0") def test_vecenvs_nan(self): # noqa: F811 @@ -1000,7 +1058,7 @@ def test_vecenvs_nan(self): # noqa: F811 def test_vecenvs_nan(self): # noqa: F811 # new versions of gym must never return nan for next values when there is a done state torch.manual_seed(0) - env = GymEnv("CartPole-v0", num_envs=2) + env = GymEnv("CartPole-v1", num_envs=2) env.set_seed(0) rollout = env.rollout(200) assert torch.isfinite(rollout.get("observation")).all() @@ -1009,7 +1067,7 @@ def test_vecenvs_nan(self): # noqa: F811 del env # same with collector - env = GymEnv("CartPole-v0", num_envs=2) + env = GymEnv("CartPole-v1", num_envs=2) env.set_seed(0) c = SyncDataCollector( env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200 @@ -1024,16 +1082,19 @@ def test_vecenvs_nan(self): # noqa: F811 @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 + gym = gym_backend() return gym.make(env_name) @implement_for("gym", "0.26", None) def _make_gym_environment(env_name): # noqa: F811 + gym = gym_backend() return gym.make(env_name, render_mode="rgb_array") @implement_for("gymnasium") def _make_gym_environment(env_name): # noqa: F811 + gym = gym_backend() return gym.make(env_name, render_mode="rgb_array") @@ -1138,8 +1199,8 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): if _has_gym: params += [ # [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}], - [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}], - [GymEnv, (PONG_VERSIONED,), {}], + [GymEnv, (HALFCHEETAH_VERSIONED(),), {"from_pixels": False}], + [GymEnv, (PONG_VERSIONED(),), {}], ] @@ -1195,6 +1256,7 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): ) class TestCollectorLib: def test_collector_run(self, env_lib, env_args, env_kwargs, device): + env_args = tuple(arg() if callable(arg) else arg for arg in env_args) if not _has_dmc and env_lib is DMControlEnv: raise pytest.skip("no dmc") if not _has_gym and env_lib is GymEnv: @@ -1338,11 +1400,11 @@ def test_jumanji_consistency(self, envname, batch_size): ENVPOOL_CLASSIC_CONTROL_ENVS = [ - PENDULUM_VERSIONED, + PENDULUM_VERSIONED(), "MountainCar-v0", "MountainCarContinuous-v0", "Acrobot-v1", - CARTPOLE_VERSIONED, + CARTPOLE_VERSIONED(), ] ENVPOOL_ATARI_ENVS = [] # PONG_VERSIONED] ENVPOOL_GYM_ENVS = ENVPOOL_CLASSIC_CONTROL_ENVS + ENVPOOL_ATARI_ENVS @@ -1569,7 +1631,7 @@ def test_multithreaded_env_seed( # Check that results are different if seed is different # Skip Pong, since there different actions can lead to the same result - if env_name != PONG_VERSIONED: + if env_name != PONG_VERSIONED(): env.set_seed( seed=seed + 10, ) @@ -1584,7 +1646,7 @@ def test_multithreaded_env_seed( @pytest.mark.skipif(not _has_gym, reason="no gym") def test_multithread_env_shutdown(self): env = _make_multithreaded_env( - PENDULUM_VERSIONED, + PENDULUM_VERSIONED(), 1, transformed_out=False, N=3, @@ -1753,14 +1815,16 @@ def test_brax_grad(self, envname, batch_size): @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) @pytest.mark.parametrize("parallel", [False, True]) - def test_brax_parallel(self, envname, batch_size, parallel, n=1): + def test_brax_parallel( + self, envname, batch_size, parallel, maybe_fork_ParallelEnv, n=1 + ): def make_brax(): env = BraxEnv(envname, batch_size=batch_size, requires_grad=False) env.set_seed(1) return env if parallel: - env = ParallelEnv(n, make_brax) + env = maybe_fork_ParallelEnv(n, make_brax) else: env = SerialEnv(n, make_brax) check_env_specs(env) @@ -1951,6 +2015,7 @@ def test_vmas_parallel( num_envs, n_workers, continuous_actions, + maybe_fork_ParallelEnv, n_agents=5, n_rollout_samples=3, ): @@ -1966,7 +2031,7 @@ def make_vmas(): env.set_seed(0) return env - env = ParallelEnv(n_workers, make_vmas) + env = maybe_fork_ParallelEnv(n_workers, make_vmas) tensordict = env.rollout(max_steps=n_rollout_samples) assert tensordict.shape == torch.Size( @@ -1984,6 +2049,7 @@ def test_vmas_reset( scenario_name, num_envs, n_workers, + maybe_fork_ParallelEnv, n_agents=5, n_rollout_samples=3, max_steps=3, @@ -2000,7 +2066,7 @@ def make_vmas(): env.set_seed(0) return env - env = ParallelEnv(n_workers, make_vmas) + env = maybe_fork_ParallelEnv(n_workers, make_vmas) tensordict = env.rollout(max_steps=n_rollout_samples) assert ( @@ -2056,13 +2122,15 @@ def make_vmas(): @pytest.mark.parametrize("n_envs", [1, 4]) @pytest.mark.parametrize("n_workers", [1, 2]) @pytest.mark.parametrize("n_agents", [1, 3]) - def test_collector(self, n_envs, n_workers, n_agents, frames_per_batch=80): + def test_collector( + self, n_envs, n_workers, n_agents, maybe_fork_ParallelEnv, frames_per_batch=80 + ): torch.manual_seed(1) env_fun = lambda: VmasEnv( scenario="flocking", num_envs=n_envs, n_agents=n_agents, max_steps=7 ) - env = ParallelEnv(n_workers, env_fun) + env = maybe_fork_ParallelEnv(n_workers, env_fun) n_actions_per_agent = env.action_spec.shape[-1] n_observations_per_agent = env.observation_spec["agents", "observation"].shape[ @@ -2957,14 +3025,14 @@ def test_envs_more_groups_aec(self, task): @pytest.mark.parametrize("task", ["knights_archers_zombies_v10", "pistonball_v6"]) @pytest.mark.parametrize("parallel", [True, False]) - def test_vec_env(self, task, parallel): + def test_vec_env(self, task, parallel, maybe_fork_ParallelEnv): env_fun = lambda: PettingZooEnv( task=task, parallel=parallel, seed=0, use_mask=not parallel, ) - vec_env = ParallelEnv(2, create_env_fn=env_fun) + vec_env = maybe_fork_ParallelEnv(2, create_env_fn=env_fun) vec_env.rollout(100, break_when_any_done=False) @pytest.mark.parametrize("task", ["knights_archers_zombies_v10", "pistonball_v6"]) @@ -3065,9 +3133,9 @@ def test_env(self, map: str, categorical_actions): check_env_specs(env, seed=None) env.close() - def test_parallel_env(self): + def test_parallel_env(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv( + maybe_fork_ParallelEnv( num_workers=2, create_env_fn=lambda: SMACv2Env( map_name="3s_vs_5z", diff --git a/test/test_trainer.py b/test/test_trainer.py index aeb6c971dd8..a799449aba7 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -837,7 +837,7 @@ def test_subsampler_state_dict(self): class TestRecorder: def _get_args(self): args = Namespace() - args.env_name = PONG_VERSIONED + args.env_name = PONG_VERSIONED() args.env_task = "" args.grayscale = True args.env_library = "gym" @@ -895,7 +895,7 @@ def test_recorder(self, N=8): }, ) ea.Reload() - img = ea.Images(f"tmp_{PONG_VERSIONED}_video") + img = ea.Images(f"tmp_{PONG_VERSIONED()}_video") try: assert len(img) == N // args.record_interval break diff --git a/test/test_transforms.py b/test/test_transforms.py index c2fb7fca41c..07e24d36a54 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -225,8 +225,8 @@ def test_serial_trans_env_check(self): check_env_specs(env) env.close() - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv(ContinuousActionVecMockEnv(), BinarizeReward()) ) try: @@ -243,9 +243,10 @@ def test_trans_serial_env_check(self): finally: env.close() - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), BinarizeReward() + maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + BinarizeReward(), ) try: check_env_specs(env) @@ -538,7 +539,7 @@ def test_transform_no_env(self): assert data["reward"] == 2 assert data["reward_clip"] == 0.1 - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = ContinuousActionVecMockEnv() return TransformedEnv( @@ -551,7 +552,7 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -573,10 +574,10 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = ContinuousActionVecMockEnv() env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), ClipTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -622,8 +623,8 @@ def test_serial_trans_env_check(self): ) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv( ContinuousActionVecMockEnv(), @@ -649,9 +650,9 @@ def test_trans_serial_env_check(self): ), ) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), CatFrames(dim=-1, N=3, in_keys=["observation"]), ) try: @@ -662,11 +663,16 @@ def test_trans_parallel_env_check(self): @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @pytest.mark.parametrize("break_when_any_done", [True, False]) - def test_catframes_batching(self, batched_class, break_when_any_done): + def test_catframes_batching( + self, batched_class, break_when_any_done, maybe_fork_ParallelEnv + ): from _utils_internal import CARTPOLE_VERSIONED + if batched_class is ParallelEnv: + batched_class = maybe_fork_ParallelEnv + env = TransformedEnv( - batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED)), + batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), CatFrames( dim=-1, N=3, in_keys=["observation"], out_keys=["observation_cat"] ), @@ -678,7 +684,7 @@ def test_catframes_batching(self, batched_class, break_when_any_done): env = batched_class( 2, lambda: TransformedEnv( - GymEnv(CARTPOLE_VERSIONED), + GymEnv(CARTPOLE_VERSIONED()), CatFrames( dim=-1, N=3, in_keys=["observation"], out_keys=["observation_cat"] ), @@ -725,7 +731,7 @@ def test_nested(self, nested_dim=3, batch_size=(32, 1), rollout_length=6, cat_N= @pytest.mark.skipif(not _has_gym, reason="Gym not available") def test_transform_env(self): env = TransformedEnv( - GymEnv(PENDULUM_VERSIONED, frame_skip=4), + GymEnv(PENDULUM_VERSIONED(), frame_skip=4), CatFrames(dim=-1, N=3, in_keys=["observation"]), ) td = env.reset() @@ -745,7 +751,7 @@ def test_transform_env(self): @pytest.mark.skipif(not _has_gym, reason="Gym not available") def test_transform_env_clone(self): env = TransformedEnv( - GymEnv(PENDULUM_VERSIONED, frame_skip=4), + GymEnv(PENDULUM_VERSIONED(), frame_skip=4), CatFrames(dim=-1, N=3, in_keys=["observation"]), ) td = env.reset() @@ -1526,7 +1532,7 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device): class TestStepCounter(TransformBase): @pytest.mark.skipif(not _has_gym, reason="no gym detected") def test_step_count_gym(self): - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), StepCounter(max_steps=30)) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), StepCounter(max_steps=30)) env.rollout(1000) check_env_specs(env) @@ -1534,7 +1540,7 @@ def test_step_count_gym(self): def test_step_count_gym_doublecount(self): # tests that 2 truncations can be used together env = TransformedEnv( - GymEnv(PENDULUM_VERSIONED), + GymEnv(PENDULUM_VERSIONED()), Compose( StepCounter(max_steps=2), StepCounter(max_steps=3), # this one will be ignored @@ -1559,7 +1565,7 @@ def test_stepcount_batching(self, batched_class, break_when_any_done): from _utils_internal import CARTPOLE_VERSIONED env = TransformedEnv( - batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED)), + batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), StepCounter(max_steps=15), ) torch.manual_seed(0) @@ -1569,7 +1575,7 @@ def test_stepcount_batching(self, batched_class, break_when_any_done): env = batched_class( 2, lambda: TransformedEnv( - GymEnv(CARTPOLE_VERSIONED), StepCounter(max_steps=15) + GymEnv(CARTPOLE_VERSIONED()), StepCounter(max_steps=15) ), ) torch.manual_seed(0) @@ -1626,7 +1632,7 @@ def test_trans_serial_env_check(self): @pytest.mark.skipif(not _has_gym, reason="Gym not found") def test_transform_env(self): - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), StepCounter(10)) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), StepCounter(10)) td = env.rollout(100, break_when_any_done=False) assert td["step_count"].max() == 9 assert td.shape[-1] == 100 @@ -2029,7 +2035,7 @@ def test_transform_env(self, del_keys, out_key): dim=-1, del_keys=del_keys, ) - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), ct) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), ct) assert env.observation_spec[out_key] if del_keys: assert "observation" not in env.observation_spec @@ -2303,7 +2309,7 @@ def test_transform_env(self, out_key): ct = Compose( ToTensorImage(), CenterCrop(out_keys=out_key, w=20, h=20, in_keys=keys) ) - env = TransformedEnv(GymEnv(PONG_VERSIONED), ct) + env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct) td = env.reset() if out_key is None: assert td["pixels"].shape == torch.Size([3, 20, 20]) @@ -3307,7 +3313,7 @@ def test_transform_compose(self, keys, size, nchannels, batch, device): ) def test_transform_env(self, out_keys): env = TransformedEnv( - GymEnv(PONG_VERSIONED), FlattenObservation(-3, -1, out_keys=out_keys) + GymEnv(PONG_VERSIONED()), FlattenObservation(-3, -1, out_keys=out_keys) ) check_env_specs(env) if out_keys: @@ -3415,9 +3421,9 @@ def test_transform_env(self, skip): return else: fs = FrameSkipTransform(skip) - base_env = GymEnv(PENDULUM_VERSIONED, frame_skip=skip) + base_env = GymEnv(PENDULUM_VERSIONED(), frame_skip=skip) tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), fs) base_env.set_seed(0) env.base_env.set_seed(0) td1 = base_env.reset() @@ -3476,9 +3482,9 @@ def test_frame_skip_transform_unroll(self, skip): return else: fs = FrameSkipTransform(skip) - base_env = GymEnv(PENDULUM_VERSIONED) + base_env = GymEnv(PENDULUM_VERSIONED()) tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), fs) base_env.set_seed(0) env.base_env.set_seed(0) td1 = base_env.reset() @@ -4108,7 +4114,7 @@ def test_transform_inverse(self, out_key, out_key_inv): standard_normal=standard_normal, ) ) - base_env = GymEnv(PENDULUM_VERSIONED) + base_env = GymEnv(PENDULUM_VERSIONED()) env = TransformedEnv(base_env, t) td = env.rollout(3) check_env_specs(env) @@ -4483,7 +4489,7 @@ def test_trans_parallel_env_check(self): @pytest.mark.parametrize("out_key", ["pixels", ("agents", "pixels")]) def test_transform_env(self, out_key): env = TransformedEnv( - GymEnv(PONG_VERSIONED), + GymEnv(PONG_VERSIONED()), Compose( ToTensorImage(), Resize(20, 21, in_keys=["pixels"], out_keys=[out_key]) ), @@ -4571,7 +4577,7 @@ def test_transform_compose(self): @pytest.mark.skipif(not _has_gym, reason="No Gym") def test_transform_env(self): t = Compose(RewardClipping(-0.1, 0.1)) - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), t) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) td = env.rollout(3) assert (td["next", "reward"] <= 0.1).all() assert (td["next", "reward"] >= -0.1).all() @@ -4721,7 +4727,7 @@ def test_transform_env(self, standard_normal): loc = 0.5 scale = 1.5 t = Compose(RewardScaling(0.5, 1.5, standard_normal=standard_normal)) - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), t) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) torch.manual_seed(0) env.set_seed(0) td = env.rollout(3) @@ -4907,7 +4913,7 @@ def test_transform_compose( @pytest.mark.parametrize("out_key", ["reward_sum", ("some", "nested")]) def test_transform_env(self, out_key): t = Compose(RewardSum(in_keys=["reward"], out_keys=[out_key])) - env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), t) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) env.set_seed(0) torch.manual_seed(0) td = env.rollout(3) @@ -4926,14 +4932,14 @@ def test_rewardsum_batching(self, batched_class, break_when_any_done): from _utils_internal import CARTPOLE_VERSIONED env = TransformedEnv( - batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED)), RewardSum() + batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), RewardSum() ) torch.manual_seed(0) env.set_seed(0) r0 = env.rollout(100, break_when_any_done=break_when_any_done) env = batched_class( - 2, lambda: TransformedEnv(GymEnv(CARTPOLE_VERSIONED), RewardSum()) + 2, lambda: TransformedEnv(GymEnv(CARTPOLE_VERSIONED()), RewardSum()) ) torch.manual_seed(0) env.set_seed(0) @@ -5619,7 +5625,7 @@ def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): @pytest.mark.skipif(not _has_gym, reason="No gym") def test_transform_inverse(self): env = TransformedEnv( - GymEnv(HALFCHEETAH_VERSIONED), + GymEnv(HALFCHEETAH_VERSIONED()), # the order is inverted Compose( UnsqueezeTransform( @@ -5890,11 +5896,11 @@ def test_transform_rb(self, out_keys, rbclass): @pytest.mark.skipif(not _has_gym, reason="No Gym") def test_transform_inverse(self): env = TransformedEnv( - GymEnv(HALFCHEETAH_VERSIONED), self._inv_circular_transform + GymEnv(HALFCHEETAH_VERSIONED()), self._inv_circular_transform ) check_env_specs(env) r = env.rollout(3) - r2 = GymEnv(HALFCHEETAH_VERSIONED).rollout(3) + r2 = GymEnv(HALFCHEETAH_VERSIONED()).rollout(3) assert (r.zero_() == r2.zero_()).all() @@ -6019,7 +6025,7 @@ def test_targetreturn_batching(self, batched_class, break_when_any_done): from _utils_internal import CARTPOLE_VERSIONED env = TransformedEnv( - batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED)), + batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), TargetReturn(target_return=10.0, mode="reduce"), ) torch.manual_seed(0) @@ -6029,7 +6035,7 @@ def test_targetreturn_batching(self, batched_class, break_when_any_done): env = batched_class( 2, lambda: TransformedEnv( - GymEnv(CARTPOLE_VERSIONED), + GymEnv(CARTPOLE_VERSIONED()), TargetReturn(target_return=10.0, mode="reduce"), ), ) @@ -6491,7 +6497,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): from _utils_internal import CARTPOLE_VERSIONED env = TransformedEnv( - batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED)), + batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), ) torch.manual_seed(0) @@ -6501,7 +6507,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): env = batched_class( 2, lambda: TransformedEnv( - GymEnv(CARTPOLE_VERSIONED), + GymEnv(CARTPOLE_VERSIONED()), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), ), ) @@ -6633,7 +6639,7 @@ def test_timemax_batching(self, batched_class, break_when_any_done): from _utils_internal import CARTPOLE_VERSIONED env = TransformedEnv( - batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED)), + batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), TimeMaxPool( in_keys=["observation"], out_keys=["observation_max"], @@ -6647,7 +6653,7 @@ def test_timemax_batching(self, batched_class, break_when_any_done): env = batched_class( 2, lambda: TransformedEnv( - GymEnv(CARTPOLE_VERSIONED), + GymEnv(CARTPOLE_VERSIONED()), TimeMaxPool( in_keys=["observation"], out_keys=["observation_max"], @@ -6664,7 +6670,7 @@ def test_timemax_batching(self, batched_class, break_when_any_done): @pytest.mark.parametrize("out_keys", [None, ["obs2"], [("some", "other")]]) def test_transform_env(self, out_keys): env = TransformedEnv( - GymEnv(PENDULUM_VERSIONED, frame_skip=4), + GymEnv(PENDULUM_VERSIONED(), frame_skip=4), TimeMaxPool( in_keys=["observation"], out_keys=out_keys, @@ -6878,7 +6884,7 @@ def test_transform_compose(self): @pytest.mark.skipif(not _has_gym, reason="no gym") def test_transform_env(self): env = TransformedEnv( - GymEnv(PENDULUM_VERSIONED), gSDENoise(state_dim=3, action_dim=1) + GymEnv(PENDULUM_VERSIONED()), gSDENoise(state_dim=3, action_dim=1) ) check_env_specs(env) assert (env.reset()["_eps_gSDE"] != 0.0).all() @@ -7655,7 +7661,7 @@ def test_vecnorm_parallel_auto(self, nprc): prcs = [] if _has_gym: make_env = EnvCreator( - lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED), VecNorm(decay=1.0)) + lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm(decay=1.0)) ) else: make_env = EnvCreator( @@ -7756,7 +7762,7 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): def test_parallelenv_vecnorm(self): if _has_gym: make_env = EnvCreator( - lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED), VecNorm()) + lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm()) ) else: make_env = EnvCreator( @@ -7808,14 +7814,14 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): torch.manual_seed(self.SEED) if parallel is None: - env = GymEnv(PENDULUM_VERSIONED) + env = GymEnv(PENDULUM_VERSIONED()) elif parallel: env = ParallelEnv( - num_workers=5, create_env_fn=lambda: GymEnv(PENDULUM_VERSIONED) + num_workers=5, create_env_fn=lambda: GymEnv(PENDULUM_VERSIONED()) ) else: env = SerialEnv( - num_workers=5, create_env_fn=lambda: GymEnv(PENDULUM_VERSIONED) + num_workers=5, create_env_fn=lambda: GymEnv(PENDULUM_VERSIONED()) ) env.set_seed(self.SEED) @@ -8762,7 +8768,7 @@ def test_init_gym( self, ): env = TransformedEnv( - GymEnv(PENDULUM_VERSIONED), + GymEnv(PENDULUM_VERSIONED()), Compose(StepCounter(max_steps=30), InitTracker()), ) env.rollout(1000) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index ae01556f0e6..d582462b7ed 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -363,6 +363,7 @@ def import_module(cls, module_name: Union[Callable, str]) -> str: _lazy_impl = collections.defaultdict(list) def _delazify(self, func_name): + out = None for local_call in implement_for._lazy_impl[func_name]: out = local_call() return out diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index efe928856b9..69f3c7b3efd 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -576,9 +576,9 @@ def encode( ): val = val.copy() if not ignore_device: - val = torch.tensor(val, device=self.device, dtype=self.dtype) + val = torch.as_tensor(val, device=self.device, dtype=self.dtype) else: - val = torch.tensor(val, dtype=self.dtype) + val = torch.as_tensor(val, dtype=self.dtype) if val.shape != self.shape: # if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f472533b9f3..9c4340f8f25 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1052,6 +1052,8 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """ def _start_workers(self) -> None: + self._timeout = 10.0 + from torchrl.envs.env_creator import EnvCreator if self.num_threads is None: @@ -1149,7 +1151,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: for i, channel in enumerate(self.parent_channels): channel.send(("load_state_dict", state_dict[f"worker{i}"])) for event in self._events: - event.wait() + event.wait(self._timeout) event.clear() @torch.no_grad() @@ -1183,7 +1185,7 @@ def step_and_maybe_reset( for i in range(self.num_workers): event = self._events[i] - event.wait() + event.wait(self._timeout) event.clear() # We must pass a clone of the tensordict, as the values of this tensordict @@ -1246,7 +1248,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in range(self.num_workers): event = self._events[i] - event.wait() + event.wait(self._timeout) event.clear() # We must pass a clone of the tensordict, as the values of this tensordict @@ -1333,7 +1335,7 @@ def tentative_update(val, other): for i in workers: event = self._events[i] - event.wait() + event.wait(self._timeout) event.clear() selected_output_keys = self._selected_reset_keys_filt @@ -1367,7 +1369,7 @@ def _shutdown_workers(self) -> None: if self._verbose: torchrl_logger.info(f"closing {i}") channel.send(("close", None)) - self._events[i].wait() + self._events[i].wait(self._timeout) self._events[i].clear() del self.shared_tensordicts, self.shared_tensordict_parent diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 59730c6df8c..a419b013722 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -494,12 +494,11 @@ def _get_gym_envs(): # noqa: F811 def _is_from_pixels(env): - gym = gym_backend() observation_spec = env.observation_space try: PixelObservationWrapper = gym_backend( - "wrappers.pixel_observation.PixelObservationWrapper" - ) + "wrappers.pixel_observation" + ).PixelObservationWrapper except ModuleNotFoundError: class PixelObservationWrapper: @@ -509,22 +508,33 @@ class PixelObservationWrapper: GymPixelObservationWrapper as LegacyPixelObservationWrapper, ) + gDict = gym_backend("spaces").dict.Dict + Box = gym_backend("spaces").Box + if isinstance(observation_spec, (Dict,)): if "pixels" in set(observation_spec.keys()): return True - if isinstance(observation_spec, (gym.spaces.dict.Dict,)): + if isinstance(observation_spec, (gDict,)): if "pixels" in set(observation_spec.spaces.keys()): return True elif ( - isinstance(observation_spec, gym.spaces.Box) + isinstance(observation_spec, Box) and (observation_spec.low == 0).all() and (observation_spec.high == 255).all() and observation_spec.low.shape[-1] == 3 and observation_spec.low.ndim == 3 ): return True - elif isinstance(env, (LegacyPixelObservationWrapper, PixelObservationWrapper)): - return True + else: + while True: + if isinstance( + env, (LegacyPixelObservationWrapper, PixelObservationWrapper) + ): + return True + if hasattr(env, "env"): + env = env.env + else: + break return False @@ -1114,14 +1124,6 @@ def rebuild_with_kwargs(self, **new_kwargs): self._env = self._build_env(**self._constructor_kwargs) self._make_specs(self._env) - @property - def info_dict_reader(self): - return self._info_dict_reader - - @info_dict_reader.setter - def info_dict_reader(self, value: callable): - self._info_dict_reader = value - def _reset( self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: @@ -1265,7 +1267,7 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) - @implement_for("gymnasium", "0.27.0", None) + @implement_for("gymnasium") def _set_gym_args( # noqa: F811 self, kwargs,