From 9e5f23e9a58832e924eb9f89fb52f57989163809 Mon Sep 17 00:00:00 2001 From: Benjamin Marks Date: Sun, 21 Jan 2024 09:40:24 -0500 Subject: [PATCH 1/2] Add seed, get_state, and set_state to BatchedPyEnvironment. --- .../environments/batched_py_environment.py | 21 +++++++++- .../batched_py_environment_test.py | 39 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/tf_agents/environments/batched_py_environment.py b/tf_agents/environments/batched_py_environment.py index 99fbc2b39..33a3be1b3 100644 --- a/tf_agents/environments/batched_py_environment.py +++ b/tf_agents/environments/batched_py_environment.py @@ -26,7 +26,7 @@ from multiprocessing import dummy as mp_threads from multiprocessing import pool # pylint: enable=line-too-long -from typing import Sequence, Optional +from typing import Any, Optional, Sequence import gin import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import @@ -182,6 +182,25 @@ def _step(self, actions): ) return nest_utils.stack_nested_arrays(time_steps) + def seed(self, seed: types.Seed) -> Any: + """Seeds the environment. + + Args: + seed: Value to use as seed for the environment. + """ + return self._execute(lambda env: env.seed(seed), self._envs) + + def get_state(self) -> Any: + """Returns the `state` of the environment.""" + return self._execute(lambda env: env.get_state(), self._envs) + + def set_state(self, state: Sequence[Any]) -> None: + """Restores the environment to a given `state`.""" + self._execute( + lambda env_and_state: env_and_state[0].set_state(env_and_state[1]), + zip(self._envs, state) + ) + def render(self, mode="rgb_array") -> Optional[types.NestedArray]: if self._num_envs == 1: img = self._envs[0].render(mode) diff --git a/tf_agents/environments/batched_py_environment_test.py b/tf_agents/environments/batched_py_environment_test.py index 9cdf96378..617ba0356 100644 --- a/tf_agents/environments/batched_py_environment_test.py +++ b/tf_agents/environments/batched_py_environment_test.py @@ -38,10 +38,21 @@ class GymWrapperEnvironmentMock(random_py_environment.RandomPyEnvironment): def __init__(self, *args, **kwargs): super(GymWrapperEnvironmentMock, self).__init__(*args, **kwargs) self._info = {} + self._state = {'seed': 0} def get_info(self): return self._info + def seed(self, seed): + self._state['seed'] = seed + return super(GymWrapperEnvironmentMock, self).seed(seed) + + def get_state(self): + return self._state + + def set_state(self, state): + self._state = state + def _step(self, action): self._info['last_action'] = action return super(GymWrapperEnvironmentMock, self)._step(action) @@ -116,6 +127,34 @@ def test_get_info_gym_env(self, multithreading): self.assertAllEqual(info['last_action'], action) gym_env.close() + @parameterized.parameters(*COMMON_PARAMETERS) + def test_seed_gym_env(self, multithreading): + num_envs = 5 + rng = np.random.RandomState() + gym_env = self._make_batched_mock_gym_py_environment( + multithreading, num_envs=num_envs + ) + + gym_env.seed(42) + + actual_seeds = [state['seed'] for state in gym_env.get_state()] + self.assertEqual(actual_seeds, [42] * num_envs) + gym_env.close() + + @parameterized.parameters(*COMMON_PARAMETERS) + def test_state_gym_env(self, multithreading): + num_envs = 5 + rng = np.random.RandomState() + gym_env = self._make_batched_mock_gym_py_environment( + multithreading, num_envs=num_envs + ) + state = [{'value': i * 10} for i in range(num_envs)] + + gym_env.set_state(state) + + self.assertEqual(gym_env.get_state(), state) + gym_env.close() + @parameterized.parameters(*COMMON_PARAMETERS) def test_step(self, multithreading): num_envs = 5 From 97de036c3155c5318dc503e1d70b751068802b71 Mon Sep 17 00:00:00 2001 From: Benjamin Marks Date: Sun, 21 Jan 2024 10:23:37 -0500 Subject: [PATCH 2/2] Fix lint errors. --- tf_agents/environments/batched_py_environment.py | 10 +++------- tf_agents/environments/batched_py_environment_test.py | 2 -- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tf_agents/environments/batched_py_environment.py b/tf_agents/environments/batched_py_environment.py index 33a3be1b3..b50413747 100644 --- a/tf_agents/environments/batched_py_environment.py +++ b/tf_agents/environments/batched_py_environment.py @@ -183,11 +183,7 @@ def _step(self, actions): return nest_utils.stack_nested_arrays(time_steps) def seed(self, seed: types.Seed) -> Any: - """Seeds the environment. - - Args: - seed: Value to use as seed for the environment. - """ + """Seeds the environment.""" return self._execute(lambda env: env.seed(seed), self._envs) def get_state(self) -> Any: @@ -197,8 +193,8 @@ def get_state(self) -> Any: def set_state(self, state: Sequence[Any]) -> None: """Restores the environment to a given `state`.""" self._execute( - lambda env_and_state: env_and_state[0].set_state(env_and_state[1]), - zip(self._envs, state) + lambda env_state: env_state[0].set_state(env_state[1]), + zip(self._envs, state) ) def render(self, mode="rgb_array") -> Optional[types.NestedArray]: diff --git a/tf_agents/environments/batched_py_environment_test.py b/tf_agents/environments/batched_py_environment_test.py index 617ba0356..3fc6e4a4d 100644 --- a/tf_agents/environments/batched_py_environment_test.py +++ b/tf_agents/environments/batched_py_environment_test.py @@ -130,7 +130,6 @@ def test_get_info_gym_env(self, multithreading): @parameterized.parameters(*COMMON_PARAMETERS) def test_seed_gym_env(self, multithreading): num_envs = 5 - rng = np.random.RandomState() gym_env = self._make_batched_mock_gym_py_environment( multithreading, num_envs=num_envs ) @@ -144,7 +143,6 @@ def test_seed_gym_env(self, multithreading): @parameterized.parameters(*COMMON_PARAMETERS) def test_state_gym_env(self, multithreading): num_envs = 5 - rng = np.random.RandomState() gym_env = self._make_batched_mock_gym_py_environment( multithreading, num_envs=num_envs )