Skip to content

Commit

Permalink
feat: experience recording wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Jan 15, 2024
1 parent e752264 commit 0ef4cc2
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 10 deletions.
8 changes: 4 additions & 4 deletions examples/tf2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
set_growing_gpu_memory()

FLAGS = flags.FLAGS
flags.DEFINE_string("env", "voltage_control", "Environment name.")
flags.DEFINE_string("scenario", "case33_3min_final", "Environment scenario name.")
flags.DEFINE_string("dataset", "Replay", "Dataset type.: 'Good', 'Medium', 'Poor' or '' for combined. ")
flags.DEFINE_string("system", "iddpg", "System name.")
flags.DEFINE_string("env", "mamujoco", "Environment name.")
flags.DEFINE_string("scenario", "2halfcheetah", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or '' for combined. ")
flags.DEFINE_string("system", "iddpg+cql", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 1e5, "Number of training steps.")
flags.DEFINE_integer("batch_size", 32, "Number of training steps.")
Expand Down
4 changes: 3 additions & 1 deletion examples/tf2/online/iddpg_mamujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from og_marl.environments.wrappers import PadObsandActs, Dtype
from og_marl.environments.wrappers import PadObsandActs, Dtype, ExperienceRecorder
from og_marl.loggers import WandbLogger
from og_marl.tf2.systems.iddpg import IDDPGSystem
from og_marl.environments.gymnasium_mamujoco import MAMuJoCo
Expand All @@ -24,6 +24,8 @@

env = Dtype(env, "float32")

env = ExperienceRecorder(env, "mamujoco")

logger = WandbLogger()

system = IDDPGSystem(env, logger)
Expand Down
8 changes: 4 additions & 4 deletions og_marl/environments/gymnasium_mamujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ def __init__(self, scenario):

def reset(self):

observations, info = self._environment.reset()
observations, _ = self._environment.reset()

info["state"] = self._environment.state().astype("float32")
info = {"state": self._environment.state().astype("float32")}

return observations, info

def step(self, actions):
observations, rewards, terminals, trunctations, info = self._environment.step(actions)
observations, rewards, terminals, trunctations, _ = self._environment.step(actions)

info["state"] = self._environment.state().astype("float32")
info = {"state": self._environment.state().astype("float32")}

return observations, rewards, terminals, trunctations, info

Expand Down
102 changes: 102 additions & 0 deletions og_marl/environments/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,109 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import numpy as np
import flashbax as fbx
from flashbax.vault import Vault

BUFFER_TIME_AXIS_LEN = 100_000

class ExperienceRecorder:

def __init__(self, environment, vault_name: str, write_to_vault_every=10_000):

self._environment = environment
# self._buffer = fbx.make_trajectory_buffer(
# add_batch_size=1,
# max_length_time_axis=BUFFER_TIME_AXIS_LEN,
# min_length_time_axis=1,
# # Unused, as we are not sampling
# sample_batch_size=1,
# sample_sequence_length=1,
# period=1,
# )
self._buffer = fbx.make_flat_buffer(
max_length=2*10_000,
min_length=1,
# Unused:
sample_batch_size=1,
)
self._buffer_state = None
self._add_to_buffer = jax.jit(self._buffer.add, donate_argnums=0)

self._vault = None
self.vault_name = vault_name
self._has_initialised = False

self._write_to_vault_every = write_to_vault_every
self._step_count = 0


def _pack_timestep(self, observations, actions, rewards, terminals, truncations, infos):
packed_timestep = {
"observations": observations,
"actions": actions,
"rewards": rewards,
"terminals": terminals,
"truncations": truncations,
"infos": infos,
}
packed_timestep = jax.tree_map(lambda x: np.array(x), packed_timestep)
return packed_timestep

def reset(self):
observations, infos = self._environment.reset()

self._observations = observations
self._infos = infos

return observations, infos

def step(self, actions):
observations, rewards, terminals, truncations, infos = self._environment.step(actions)

packed_timestep = self._pack_timestep(
observations=self._observations,
actions=actions,
rewards=rewards,
terminals=terminals,
truncations=truncations,
infos=self._infos,
)

# Log stuff to vault/flashbax
if not self._has_initialised:
self._buffer_state = self._buffer.init(packed_timestep)
self._vault = Vault(
vault_name=self.vault_name,
init_fbx_state=self._buffer_state,
)
self._has_initialised = True

self._buffer_state = self._add_to_buffer(
self._buffer_state,
packed_timestep,
# jax.tree_map(lambda x: np.expand_dims(np.expand_dims(np.array(x), axis=0), axis=0), packed_timestep), # NOTE add time dimension and batch dimension. should we use flat buffer?
)

# Store new observations and infos
self._observations = observations
self._info = infos

self._step_count += 1
if self._step_count % self._write_to_vault_every == 0:
self._vault.write(self._buffer_state)

return observations, rewards, terminals, truncations, infos

def __getattr__(self, name: str):
"""Expose any other attributes of the underlying environment."""
if hasattr(self.__class__, name):
return self.__getattribute__(name)
else:
return getattr(self._environment, name)



class Dtype:

Expand Down
2 changes: 1 addition & 1 deletion og_marl/tf2/systems/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def train_online(self, replay_buffer, max_env_steps=1e6, train_period=20):
break

episodes += 1
if episodes % 20 == 0: # TODO: make variable
if episodes % 1 == 0: # TODO: make variable
self._logger.write({"Episodes": episodes, "Episode Return": episode_return, "Environment Steps": self._env_step_ctr}, force=True)

if self._env_step_ctr > max_env_steps:
Expand Down

0 comments on commit 0ef4cc2

Please sign in to comment.