diff --git a/tf_agents/environments/batched_py_environment.py b/tf_agents/environments/batched_py_environment.py index 99fbc2b39..b50413747 100644 --- a/tf_agents/environments/batched_py_environment.py +++ b/tf_agents/environments/batched_py_environment.py @@ -26,7 +26,7 @@ from multiprocessing import dummy as mp_threads from multiprocessing import pool # pylint: enable=line-too-long -from typing import Sequence, Optional +from typing import Any, Optional, Sequence import gin import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import @@ -182,6 +182,21 @@ def _step(self, actions): ) return nest_utils.stack_nested_arrays(time_steps) + def seed(self, seed: types.Seed) -> Any: + """Seeds the environment.""" + return self._execute(lambda env: env.seed(seed), self._envs) + + def get_state(self) -> Any: + """Returns the `state` of the environment.""" + return self._execute(lambda env: env.get_state(), self._envs) + + def set_state(self, state: Sequence[Any]) -> None: + """Restores the environment to a given `state`.""" + self._execute( + lambda env_state: env_state[0].set_state(env_state[1]), + zip(self._envs, state) + ) + def render(self, mode="rgb_array") -> Optional[types.NestedArray]: if self._num_envs == 1: img = self._envs[0].render(mode) diff --git a/tf_agents/environments/batched_py_environment_test.py b/tf_agents/environments/batched_py_environment_test.py index 9cdf96378..3fc6e4a4d 100644 --- a/tf_agents/environments/batched_py_environment_test.py +++ b/tf_agents/environments/batched_py_environment_test.py @@ -38,10 +38,21 @@ class GymWrapperEnvironmentMock(random_py_environment.RandomPyEnvironment): def __init__(self, *args, **kwargs): super(GymWrapperEnvironmentMock, self).__init__(*args, **kwargs) self._info = {} + self._state = {'seed': 0} def get_info(self): return self._info + def seed(self, seed): + self._state['seed'] = seed + return super(GymWrapperEnvironmentMock, self).seed(seed) + + def get_state(self): + return self._state + + def set_state(self, state): + self._state = state + def _step(self, action): self._info['last_action'] = action return super(GymWrapperEnvironmentMock, self)._step(action) @@ -116,6 +127,32 @@ def test_get_info_gym_env(self, multithreading): self.assertAllEqual(info['last_action'], action) gym_env.close() + @parameterized.parameters(*COMMON_PARAMETERS) + def test_seed_gym_env(self, multithreading): + num_envs = 5 + gym_env = self._make_batched_mock_gym_py_environment( + multithreading, num_envs=num_envs + ) + + gym_env.seed(42) + + actual_seeds = [state['seed'] for state in gym_env.get_state()] + self.assertEqual(actual_seeds, [42] * num_envs) + gym_env.close() + + @parameterized.parameters(*COMMON_PARAMETERS) + def test_state_gym_env(self, multithreading): + num_envs = 5 + gym_env = self._make_batched_mock_gym_py_environment( + multithreading, num_envs=num_envs + ) + state = [{'value': i * 10} for i in range(num_envs)] + + gym_env.set_state(state) + + self.assertEqual(gym_env.get_state(), state) + gym_env.close() + @parameterized.parameters(*COMMON_PARAMETERS) def test_step(self, multithreading): num_envs = 5