Skip to content

Commit

Permalink
CHore: typing functions in combine_vaults.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Aug 20, 2024
1 parent e01afdc commit 3cfffca
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions og_marl/vault_utils/combine_vaults.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import List
from git import Optional

import jax
import pickle
import flashbax as fbx
Expand All @@ -9,7 +12,11 @@
)


def get_all_vaults(rel_dir, vault_name, vault_uids=[]):
def get_all_vaults(
vault_name: str,
vault_uids: Optional[List[str]] = [],
rel_dir: str = "vaults",
) -> List[Vault]:
if len(vault_uids) == 0:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")

Expand All @@ -19,7 +26,7 @@ def get_all_vaults(rel_dir, vault_name, vault_uids=[]):
return vlts


def stitch_vault_from_many(vlts, vault_name, vault_uid, rel_dir):
def stitch_vault_from_many(vlts: List[Vault], vault_name: str, vault_uid: str, rel_dir: str) -> int:
all_data = vlts[0].read()
offline_data = all_data.experience

Expand Down Expand Up @@ -63,7 +70,7 @@ def stitch_vault_from_many(vlts, vault_name, vault_uid, rel_dir):
return tot_timesteps


def combine_vaults(rel_dir, vault_name, vault_uids=[]):
def combine_vaults(rel_dir: str, vault_name: str, vault_uids: List[str] = []) -> str:
# check that the vault to be combined exists
if not check_directory_exists_and_not_empty(f"./{rel_dir}/{vault_name}"):
print(f"Vault './{rel_dir}/{vault_name}' does not exist and cannot be combined.")
Expand Down

0 comments on commit 3cfffca

Please sign in to comment.