Skip to content

Commit

Permalink
add color maps for splits
Browse files Browse the repository at this point in the history
  • Loading branch information
adamovanja committed Apr 29, 2024
1 parent 61caa7b commit 3d1b155
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions q2_ritme/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@
from q2_ritme.feature_space.transform_features import transform_features
from q2_ritme.model_space._static_trainables import NeuralNet

plt.rcParams.update({"font.family": "DejaVu Sans"})
plt.style.use("seaborn-v0_8-pastel")

# custom color map
color_map = {
"train": "lightskyblue",
"test": "peachpuff",
"rmse_train": "lightskyblue",
"rmse_val": "plum",
}


def _get_checkpoint_path(result: Result) -> str:
"""
Expand Down Expand Up @@ -166,7 +177,15 @@ def plot_rmse_over_experiments(preds_dic, save_loc, dpi=400):

plt.figure(dpi=dpi) # Increase the resolution by setting a higher dpi
rmse_df = pd.DataFrame(rmse_dic).T
rmse_df.plot(kind="bar", title="Overall", ylabel="RMSE")
rmse_df = rmse_df[
sorted(rmse_df.columns, key=lambda x: 0 if "train" in x else 1)
] # Enforce column order
rmse_df.plot(
kind="bar",
title="Overall",
ylabel="RMSE",
color=[color_map.get(col, "gray") for col in rmse_df.columns],
)
path_to_save = os.path.join(save_loc, "rmse_over_experiments_train_test.png")
plt.tight_layout()
plt.savefig(path_to_save, dpi=dpi)
Expand Down Expand Up @@ -195,10 +214,19 @@ def plot_rmse_over_time(preds_dic, ls_model_types, save_loc, dpi=300):
if split is not None:
grouped_df = grouped_df[[split]].copy()

# Enforce column order
grouped_df = grouped_df[
sorted(grouped_df.columns, key=lambda x: 0 if "train" in x else 1)
]

# Plot
plt.figure(dpi=dpi)
grouped_df.plot(
kind="bar", title=f"Model: {model_type}", ylabel="RMSE", figsize=(10, 5)
kind="bar",
title=f"Model: {model_type}",
ylabel="RMSE",
figsize=(10, 5),
color=[color_map.get(col, "gray") for col in grouped_df.columns],
)
path_to_save = os.path.join(
save_loc, f"rmse_over_time_train_test_{model_type}.png"
Expand Down Expand Up @@ -261,9 +289,15 @@ def plot_best_models_comparison(
title (str): Title of the plot.
dpi (int): Resolution of the plot.
"""
plt.style.use("seaborn-v0_8-colorblind")
df2plot = df_metrics.sort_values(by="rmse_val", ascending=True)
df2plot.plot(kind="bar", figsize=(12, 6))
df2plot = df2plot[
sorted(df2plot.columns, key=lambda x: 0 if "train" in x else 1)
] # Enforce column order
df2plot.plot(
kind="bar",
figsize=(12, 6),
color=[color_map.get(col, "gray") for col in df2plot.columns],
)
plt.xticks(rotation=45)
plt.ylabel("RMSE")
plt.xlabel("Model type (order: increasing val score)")
Expand Down

0 comments on commit 3d1b155

Please sign in to comment.