Skip to content

Commit

Permalink
Chore: function typing for analyse_vaults.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Aug 20, 2024
1 parent eb43b4a commit 329323d
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions og_marl/vault_utils/analyse_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -26,14 +26,15 @@
from tabulate import tabulate


# obtain "sanity" info about the vault -
def get_structure_descriptors(experience, n_head=1):
def get_structure_descriptors(
experience: Dict[str, Array], n_head: int = 1
) -> Tuple[Dict[str, Array], Dict[str, Array], int]:
struct = jax.tree_map(lambda x: x.shape, experience)

head = jax.tree_map(lambda x: x[0, :n_head, ...], experience)

terminal_flag = experience["terminals"][0, :, ...].all(axis=-1)
num_episodes = jnp.sum(terminal_flag)
num_episodes = int(jnp.sum(terminal_flag))

return struct, head, num_episodes

Expand Down Expand Up @@ -74,7 +75,9 @@ def describe_structure(
return heads


def get_episode_return_descriptors(experience):
def get_episode_return_descriptors(
experience: Dict[str, Array],
) -> Tuple[float, float, float, float, Array]:
episode_returns = calculate_returns(experience)

mean = jnp.mean(episode_returns)
Expand All @@ -83,7 +86,9 @@ def get_episode_return_descriptors(experience):
return mean, stddev, jnp.max(episode_returns), jnp.min(episode_returns), episode_returns


def plot_eps_returns_violin(all_uid_returns, vault_name, save_path=""):
def plot_eps_returns_violin(
all_uid_returns: Dict[str, Array], vault_name: str, save_path: str = ""
) -> None:
sns.set_theme(style="whitegrid") # Set seaborn theme with a light blue background
plt.figure(figsize=(8, 6)) # Adjust figsize as needed

Expand All @@ -99,8 +104,13 @@ def plot_eps_returns_violin(all_uid_returns, vault_name, save_path=""):


def plot_eps_returns_hist(
all_uid_returns, vault_name, n_bins, min_return, max_return, save_path=""
):
all_uid_returns: Dict[str, Array],
vault_name: str,
n_bins: int,
min_return: float,
max_return: float,
save_path: str = "",
) -> None:
vault_uids = list(all_uid_returns.keys())

sns.set_theme(style="whitegrid") # Set seaborn theme with a light blue background
Expand Down Expand Up @@ -141,7 +151,7 @@ def describe_episode_returns(
plot_violin: bool = True,
save_violin: bool = False,
n_bins: Optional[int] = 50,
) -> Dict[str, Array]:
) -> None:
# get all uids if not specified
if len(vault_uids) == 0:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")
Expand Down Expand Up @@ -236,7 +246,7 @@ def scan_fn(carry: Array, inputs: Array) -> Array:
return episode_returns


def get_saco(experience: Dict[str, Array]):
def get_saco(experience: Dict[str, Array]) -> Tuple[float, Array, Array]:
"""Calculate the joint SACo in a dataset of experience.
Args:
Expand All @@ -261,7 +271,9 @@ def get_saco(experience: Dict[str, Array]):
return saco, count_vals, count_freq


def plot_count_frequencies(all_count_vals, all_count_freq, save_path=""):
def plot_count_frequencies(
all_count_vals: Dict[str, Array], all_count_freq: Dict[str, Array], save_path: str = ""
) -> None:
vault_uids = list(all_count_vals.keys())
colors = sns.color_palette()

Expand Down Expand Up @@ -291,7 +303,7 @@ def describe_coverage(
rel_dir: str = "vaults",
plot_count_freq: bool = True,
save_plot: bool = False,
) -> Dict[str, Array]:
) -> None:
# get all uids if not specified
if len(vault_uids) == 0:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")
Expand Down

0 comments on commit 329323d

Please sign in to comment.