From 441df4b407d9805eab2ae17f5d47670ad0acc2ff Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Oct 2023 20:30:27 -0700 Subject: [PATCH] Add gym_kwargs option to suite_atari --- tf_agents/environments/suite_atari.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tf_agents/environments/suite_atari.py b/tf_agents/environments/suite_atari.py index 8b02a5bb4..4415703f4 100644 --- a/tf_agents/environments/suite_atari.py +++ b/tf_agents/environments/suite_atari.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from typing import Dict, Optional, Sequence, Text +from typing import Dict, Optional, Sequence, Text, Any import ale_py # pylint: disable=unused-import import gin @@ -84,13 +84,15 @@ def load( ] = DEFAULT_ATARI_GYM_WRAPPERS, env_wrappers: Sequence[types.PyEnvWrapper] = (), spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None, + gym_kwargs: Optional[Dict[str, Any]] = None, ) -> py_environment.PyEnvironment: """Loads the selected environment and wraps it with the specified wrappers.""" if spec_dtype_map is None: spec_dtype_map = {gym.spaces.Box: np.uint8} + gym_kwargs = gym_kwargs if gym_kwargs else {} gym_spec = gym.spec(environment_name) - gym_env = gym_spec.make() + gym_env = gym_spec.make(**gym_kwargs) if max_episode_steps is None and gym_spec.max_episode_steps is not None: max_episode_steps = gym_spec.max_episode_steps