diff --git a/examples/tf2/main.py b/examples/tf2/main.py index 8dc44d90..1f504181 100644 --- a/examples/tf2/main.py +++ b/examples/tf2/main.py @@ -22,10 +22,10 @@ set_growing_gpu_memory() FLAGS = flags.FLAGS -flags.DEFINE_string("env", "voltage_control", "Environment name.") -flags.DEFINE_string("scenario", "case33_3min_final", "Environment scenario name.") -flags.DEFINE_string("dataset", "Replay", "Dataset type.: 'Good', 'Medium', 'Poor' or '' for combined. ") -flags.DEFINE_string("system", "iddpg", "System name.") +flags.DEFINE_string("env", "mamujoco", "Environment name.") +flags.DEFINE_string("scenario", "2halfcheetah", "Environment scenario name.") +flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or '' for combined. ") +flags.DEFINE_string("system", "iddpg+cql", "System name.") flags.DEFINE_integer("seed", 42, "Seed.") flags.DEFINE_float("trainer_steps", 1e5, "Number of training steps.") flags.DEFINE_integer("batch_size", 32, "Number of training steps.") diff --git a/examples/tf2/online/iddpg_mamujoco.py b/examples/tf2/online/iddpg_mamujoco.py index 453f54ea..5aa53def 100644 --- a/examples/tf2/online/iddpg_mamujoco.py +++ b/examples/tf2/online/iddpg_mamujoco.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from og_marl.environments.wrappers import PadObsandActs, Dtype +from og_marl.environments.wrappers import PadObsandActs, Dtype, ExperienceRecorder from og_marl.loggers import WandbLogger from og_marl.tf2.systems.iddpg import IDDPGSystem from og_marl.environments.gymnasium_mamujoco import MAMuJoCo @@ -24,6 +24,8 @@ env = Dtype(env, "float32") +env = ExperienceRecorder(env, "mamujoco") + logger = WandbLogger() system = IDDPGSystem(env, logger) diff --git a/og_marl/environments/gymnasium_mamujoco.py b/og_marl/environments/gymnasium_mamujoco.py index 9f63e850..75aed458 100644 --- a/og_marl/environments/gymnasium_mamujoco.py +++ b/og_marl/environments/gymnasium_mamujoco.py @@ -60,16 +60,16 @@ def __init__(self, scenario): def reset(self): - observations, info = self._environment.reset() + observations, _ = self._environment.reset() - info["state"] = self._environment.state().astype("float32") + info = {"state": self._environment.state().astype("float32")} return observations, info def step(self, actions): - observations, rewards, terminals, trunctations, info = self._environment.step(actions) + observations, rewards, terminals, trunctations, _ = self._environment.step(actions) - info["state"] = self._environment.state().astype("float32") + info = {"state": self._environment.state().astype("float32")} return observations, rewards, terminals, trunctations, info diff --git a/og_marl/environments/wrappers.py b/og_marl/environments/wrappers.py index 2baea203..aa1f340a 100644 --- a/og_marl/environments/wrappers.py +++ b/og_marl/environments/wrappers.py @@ -12,7 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. +import jax import numpy as np +import flashbax as fbx +from flashbax.vault import Vault + +BUFFER_TIME_AXIS_LEN = 100_000 + +class ExperienceRecorder: + + def __init__(self, environment, vault_name: str, write_to_vault_every=10_000): + + self._environment = environment + # self._buffer = fbx.make_trajectory_buffer( + # add_batch_size=1, + # max_length_time_axis=BUFFER_TIME_AXIS_LEN, + # min_length_time_axis=1, + # # Unused, as we are not sampling + # sample_batch_size=1, + # sample_sequence_length=1, + # period=1, + # ) + self._buffer = fbx.make_flat_buffer( + max_length=2*10_000, + min_length=1, + # Unused: + sample_batch_size=1, + ) + self._buffer_state = None + self._add_to_buffer = jax.jit(self._buffer.add, donate_argnums=0) + + self._vault = None + self.vault_name = vault_name + self._has_initialised = False + + self._write_to_vault_every = write_to_vault_every + self._step_count = 0 + + + def _pack_timestep(self, observations, actions, rewards, terminals, truncations, infos): + packed_timestep = { + "observations": observations, + "actions": actions, + "rewards": rewards, + "terminals": terminals, + "truncations": truncations, + "infos": infos, + } + packed_timestep = jax.tree_map(lambda x: np.array(x), packed_timestep) + return packed_timestep + + def reset(self): + observations, infos = self._environment.reset() + + self._observations = observations + self._infos = infos + + return observations, infos + + def step(self, actions): + observations, rewards, terminals, truncations, infos = self._environment.step(actions) + + packed_timestep = self._pack_timestep( + observations=self._observations, + actions=actions, + rewards=rewards, + terminals=terminals, + truncations=truncations, + infos=self._infos, + ) + + # Log stuff to vault/flashbax + if not self._has_initialised: + self._buffer_state = self._buffer.init(packed_timestep) + self._vault = Vault( + vault_name=self.vault_name, + init_fbx_state=self._buffer_state, + ) + self._has_initialised = True + + self._buffer_state = self._add_to_buffer( + self._buffer_state, + packed_timestep, + # jax.tree_map(lambda x: np.expand_dims(np.expand_dims(np.array(x), axis=0), axis=0), packed_timestep), # NOTE add time dimension and batch dimension. should we use flat buffer? + ) + + # Store new observations and infos + self._observations = observations + self._info = infos + + self._step_count += 1 + if self._step_count % self._write_to_vault_every == 0: + self._vault.write(self._buffer_state) + + return observations, rewards, terminals, truncations, infos + + def __getattr__(self, name: str): + """Expose any other attributes of the underlying environment.""" + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) + + class Dtype: diff --git a/og_marl/tf2/systems/base.py b/og_marl/tf2/systems/base.py index 4255fa2b..ef11595c 100644 --- a/og_marl/tf2/systems/base.py +++ b/og_marl/tf2/systems/base.py @@ -134,7 +134,7 @@ def train_online(self, replay_buffer, max_env_steps=1e6, train_period=20): break episodes += 1 - if episodes % 20 == 0: # TODO: make variable + if episodes % 1 == 0: # TODO: make variable self._logger.write({"Episodes": episodes, "Episode Return": episode_return, "Environment Steps": self._env_step_ctr}, force=True) if self._env_step_ctr > max_env_steps: