Skip to content

Commit

Permalink
Structure descriptors take term or trunc as dones.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Sep 11, 2024
1 parent 18819c7 commit 1fd731d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions og_marl/vault_utils/analyse_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from og_marl.vault_utils.download_vault import get_available_uids
from tabulate import tabulate

def extend_dims(jnp_arr):
return


def get_structure_descriptors(
experience: Dict[str, Array], n_head: int = 1, done_flags: list = ["terminals"],
Expand All @@ -32,8 +35,9 @@ def get_structure_descriptors(

head = jax.tree_map(lambda x: x[0, :n_head, ...], experience)

# allow for "terminals" and "truncations" to be combined into one "done"
if len(done_flags==1):
terminal_flag = experience[done_flags[0]][0, :, ...].all(axis=-1)
terminal_flag = experience[done_flags[0]][0, :, ...].all(axis=-1) # .all is for all agents
elif len(done_flags==2):
done_1 = experience[done_flags[0]][0, :, ...].all(axis=-1)
done_2 = experience[done_flags[1]][0, :, ...].all(axis=-1)
Expand All @@ -53,6 +57,7 @@ def describe_structure(
vault_uids: Optional[List[str]] = None,
rel_dir: str = "vaults",
n_head: int = 0,
done_flags: list = ["terminals"],
) -> Dict[str, Array]:
# get all uids if not specified
if vault_uids is None:
Expand All @@ -67,7 +72,7 @@ def describe_structure(
exp = vlt.read().experience
n_trans = exp["actions"].shape[1]

struct, head, n_traj = get_structure_descriptors(exp, n_head)
struct, head, n_traj = get_structure_descriptors(exp, n_head, done_flags)

print(str(uid) + "\n-----")
for key, val in struct.items():
Expand Down

0 comments on commit 1fd731d

Please sign in to comment.