diff --git a/examples/gsd/_03_gsd_evaluation.py b/examples/gsd/_03_gsd_evaluation.py index f0fccc187..2b05f6ba4 100644 --- a/examples/gsd/_03_gsd_evaluation.py +++ b/examples/gsd/_03_gsd_evaluation.py @@ -7,6 +7,7 @@ This example shows how to apply evaluation algorithms to GSD and thus how to rate the performance of a GSD algorithm. """ +# %% import pandas as pd from mobgap.data import LabExampleDataset from mobgap.gsd import GsdIluz @@ -272,3 +273,160 @@ def load_reference(single_test_data): # In general, it is a good idea to use ``cross_validation`` also for algorithms that do not have tunable parameters. # This way you can ensure that the performance of the algorithm is stable across different splits of the data, and it # allows the direct comparison between tunable and non-tunable algorithms. + + +# %% +# Calculate performance index after "Running a full evaluation pipeline" +# ------------------------------ +# Bonci et al (2020) (https://www.mdpi.com/1424-8220/20/22/6509) suggest a methodology to determine a performance +# index that combines multiple metrics into a single value. +import numpy as np +from mobgap.gsd.evaluation import calc_gs_duration_icc, calc_performance_index + +# Get a dictionary of available scoring metrics +evaluation_results_dict = evaluation_results.drop( + ["single_reference", "single_detected"], axis=1 +).T[0] +evaluation_results_dict + + +# %% +# Define metrics that are used to calculate the performance index +# For each metric, the underlying score, criterion (cost/benefit), aggregation (e.g., mean, std, ...), and weight needs to be defined +# weighting_factor_micoamigo define the weights of the metrics as suggested by Micó-Amigo (https://pubmed.ncbi.nlm.nih.gov/37316858/) +weighting_factor_micoamigo = { + "recall_mean": { + "metric": "single_recall", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.117, + }, + "specificity_mean": { + "metric": "single_specificity", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.178, + }, + "precision_mean": { + "metric": "single_precision", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.105, + }, + "accuracy_mean": { + "metric": "single_accuracy", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.160, + }, + "gs_absolute_relative_duration_error_mean": { + "metric": "single_gs_absolute_relative_duration_error_log", + "criterion": "cost", + "normalization": "exponential", + "aggregation": np.mean, + "weight": 0.122, + }, + "gs_absolute_relative_duration_error_std": { + "metric": "single_gs_absolute_relative_duration_error_log", + "criterion": "cost", + "normalization": "exponential", + "aggregation": np.std, + "weight": 0.122, + }, + "icc_mean": { + "metric": [ + "single_detected_gs_duration_s", + "single_reference_gs_duration_s", + ], + "criterion": "benefit", + "normalization": None, + "aggregation": calc_gs_duration_icc, + "weight": 0.196, + }, +} + +# %% +# Calculate performance index + +performance_index = calc_performance_index( + evaluation_results=evaluation_results_dict, + weighting_factor=weighting_factor_micoamigo, +) +performance_index + + +# %% +# Define metrics that are used to calculate the performance index +# For each metric, the underlying score, criterion (cost/benefit), aggregation (e.g., mean, std, ...), and weight needs to be defined +# weighting_factor_kluge define the weights of the metrics as suggested by Kluge et al (https://doi.org/10.2196/50035) used (see also Multimedia Appendix 1) +weighting_factor_kluge = { + "recall_mean": { + "metric": "single_recall", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.100, + }, + "specificity_mean": { + "metric": "single_specificity", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.151, + }, + "precision_mean": { + "metric": "single_precision", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.089, + }, + "accuracy_mean": { + "metric": "single_accuracy", + "criterion": "benefit", + "normalization": None, + "aggregation": np.mean, + "weight": 0.135, + }, + "gs_absolute_relative_duration_error_mean": { + "metric": "single_gs_absolute_relative_duration_error_log", + "criterion": "cost", + "normalization": "exponential", + "aggregation": np.mean, + "weight": 0.104, + }, + "gs_absolute_relative_duration_error_std": { + "metric": "single_gs_absolute_relative_duration_error_log", + "criterion": "cost", + "normalization": "exponential", + "aggregation": np.std, + "weight": 0.104, + }, + "icc_mean": { + "metric": [ + "single_detected_gs_duration_s", + "single_reference_gs_duration_s", + ], + "criterion": "benefit", + "normalization": None, + "aggregation": calc_gs_duration_icc, + "weight": 0.167, + }, + "gs_nr_mean": { + "metric": "single_num_gs_absolute_relative_error_log", + "criterion": "cost", + "normalization": "exponential", + "aggregation": np.mean, + "weight": 0.150, + }, +} + +performance_index = calc_performance_index( + evaluation_results=evaluation_results_dict, + weighting_factor=weighting_factor_kluge, +) +performance_index diff --git a/mobgap/gsd/evaluation.py b/mobgap/gsd/evaluation.py index b885d70a1..675ddc61c 100644 --- a/mobgap/gsd/evaluation.py +++ b/mobgap/gsd/evaluation.py @@ -10,6 +10,7 @@ from intervaltree.interval import Interval from matplotlib.axes import Axes from matplotlib.figure import Figure +from pingouin import intraclass_corr from typing_extensions import Unpack from mobgap.utils.evaluation import ( @@ -93,7 +94,12 @@ def calculate_matched_gsd_performance_metrics( # estimate performance metrics precision_recall_f1 = precision_recall_f1_score(matches, zero_division=zero_division) - gsd_metrics = {"tp_samples": tp_samples, "fp_samples": fp_samples, "fn_samples": fn_samples, **precision_recall_f1} + gsd_metrics = { + "tp_samples": tp_samples, + "fp_samples": fp_samples, + "fn_samples": fn_samples, + **precision_recall_f1, + } # tn-dependent metrics if tn_samples != 0: @@ -228,7 +234,10 @@ def calculate_unmatched_gsd_performance_metrics( def categorize_intervals( - *, gsd_list_detected: pd.DataFrame, gsd_list_reference: pd.DataFrame, n_overall_samples: Optional[int] = None + *, + gsd_list_detected: pd.DataFrame, + gsd_list_reference: pd.DataFrame, + n_overall_samples: Optional[int] = None, ) -> pd.DataFrame: """ Evaluate detected gait sequence intervals against a reference on a sample-wise level. @@ -367,7 +376,10 @@ def _check_input_sanity( raise TypeError("`gsd_list_detected` and `gsd_list_reference` must be of type `pandas.DataFrame`.") # check if start and end columns are present try: - detected, reference = gsd_list_detected[["start", "end"]], gsd_list_reference[["start", "end"]] + detected, reference = ( + gsd_list_detected[["start", "end"]], + gsd_list_reference[["start", "end"]], + ) except KeyError as e: raise ValueError( "`gsd_list_detected` and `gsd_list_reference` must have columns named 'start' and 'end'." @@ -403,7 +415,10 @@ def _get_false_matches_from_overlap_data(overlaps: list[Interval], interval: Int def find_matches_with_min_overlap( - *, gsd_list_detected: pd.DataFrame, gsd_list_reference: pd.DataFrame, overlap_threshold: float = 0.8 + *, + gsd_list_detected: pd.DataFrame, + gsd_list_reference: pd.DataFrame, + overlap_threshold: float = 0.8, ) -> pd.DataFrame: """ Find all matches of `gsd_list_detected` in `gsd_list_reference` with at least ``overlap_threshold`` overlap. @@ -516,15 +531,35 @@ def _get_tn_intervals(categorized_intervals: pd.DataFrame, n_overall_samples: Un def plot_categorized_intervals( - gsd_list_detected: pd.DataFrame, gsd_list_reference: pd.DataFrame, categorized_intervals: pd.DataFrame + gsd_list_detected: pd.DataFrame, + gsd_list_reference: pd.DataFrame, + categorized_intervals: pd.DataFrame, ) -> Figure: """Plot the categorized intervals together with the detected and reference intervals.""" fig, ax = plt.subplots(figsize=(10, 3)) _plot_intervals_from_df(gsd_list_reference, 3, ax, color="orange") _plot_intervals_from_df(gsd_list_detected, 2, ax, color="blue") - _plot_intervals_from_df(categorized_intervals.query("match_type == 'tp'"), 1, ax, color="green", label="TP") - _plot_intervals_from_df(categorized_intervals.query("match_type == 'fp'"), 1, ax, color="red", label="FP") - _plot_intervals_from_df(categorized_intervals.query("match_type == 'fn'"), 1, ax, color="purple", label="FN") + _plot_intervals_from_df( + categorized_intervals.query("match_type == 'tp'"), + 1, + ax, + color="green", + label="TP", + ) + _plot_intervals_from_df( + categorized_intervals.query("match_type == 'fp'"), + 1, + ax, + color="red", + label="FP", + ) + _plot_intervals_from_df( + categorized_intervals.query("match_type == 'fn'"), + 1, + ax, + color="purple", + label="FN", + ) plt.yticks([1, 2, 3], ["Categorized", "Detected", "Reference"]) plt.ylim(0, 4) plt.xlabel("Index") @@ -546,10 +581,126 @@ def _plot_intervals_from_df(df: pd.DataFrame, y: int, ax: Axes, **kwargs: Unpack ax.hlines(y, row["start"], row["end"], lw=20, **kwargs) +def _normalize( + x: np.ndarray, + criterion: Literal["benefit", "cost"] = "benefit", + normalization: Literal["minmax", "sigmoid", "exponential", None] = None, +) -> np.ndarray: + """ + Normalize a given array of values based on Bonci et al. + + Parameters + ---------- + - x (array-like): The input array to be normalized. + - criterion (str, optional): The type of normalization to be applied. + Valid options are "cost" and "benefit" (default). + - normalization (str, optional): Which normalization to perform. + Valid options are "minmax", "sigmoid", "exponential", or None (default). + + Returns + ------- + - array-like: The normalized array. + + Raises + ------ + - ValueError: If the criterion is not specified as either 'benefit' or 'cost'. + + Examples + -------- + >>> x = [1, 2, 3, 4, 5] + >>> _normalize(x, normalization="minmax") + array([0. , 0.25, 0.5 , 0.75, 1. ]) + + >>> _normalize(x, criterion="benefit", normalization="sigmoid") + array([0.73105858, 0.88079708, 0.95257413, 0.98201379, 0.99330715]) + + >>> _normalize(x, criterion="cost", normalization="sigmoid") + array([0.26894142, 0.11920292, 0.04742587, 0.01798621, 0.00669285]) + + >>> _normalize(x, criterion="benefit", normalization="exponential") + array([0.63212056, 0.86466472, 0.95021293, 0.98201417, 0.99326205]) + """ + x = np.array(x) + + if normalization == "minmax": + x_norm = (x - min(x)) / (max(x) - min(x)) + elif normalization == "sigmoid": + x_norm = 1 / (1 + np.exp(-x)) + elif normalization == "exponential": + x_norm = 1 - np.exp(-x) + else: + x_norm = x + + if criterion == "benefit": + x_criterion = x_norm + elif criterion == "cost": + x_criterion = 1 - x_norm + else: + raise ValueError("criterion needs to be specified as either 'benefit' or 'cost'.") + + return x_criterion + + +def calc_gs_duration_icc(x: np.ndarray) -> float: + """Calculate the Intraclass Correlation Coefficient (ICC) for a given dataset.""" + # Prepare data frame + x_df = ( + pd.DataFrame(x) + .rename(columns={0: "duration_s"}) + .assign(trial_id=lambda df_: df_.duration_s.map(lambda x: range(1, len(x) + 1))) + .explode(["duration_s", "trial_id"]) + .assign( + duration_s=lambda df_: df_.duration_s.astype(float), + trial_id=lambda df_: df_.trial_id.astype(int), + ) + .rename_axis("system") + .reset_index() + ) + # Calculate ICC + x_icc = intraclass_corr(x_df, ratings="duration_s", raters="system", targets="trial_id") + + # Return ICC + return x_icc.loc[0].ICC + + +def calc_performance_index(evaluation_results: dict, weighting_factor: dict) -> float: + """ + Calculate the performance index based on evaluation results and weighting factors. + + Parameters + ---------- + evaluation_results : dict + A dictionary containing the evaluation results for different metrics. + weighting_factor : dict + A dictionary containing the weighting factors for different metrics. + + Returns + ------- + float + The calculated performance index. + + """ + performance_index = sum( + weighting_factor[key]["aggregation"]( + _normalize( + evaluation_results[weighting_factor[key]["metric"]], + criterion=weighting_factor[key]["criterion"], + normalization=weighting_factor[key]["normalization"], + ) + ) + * weighting_factor[key]["weight"] + for key in weighting_factor + ) + + return performance_index + + __all__ = [ "categorize_intervals", "find_matches_with_min_overlap", "calculate_matched_gsd_performance_metrics", "calculate_unmatched_gsd_performance_metrics", "plot_categorized_intervals", + "calc_gs_duration_icc", + "calc_performance_index", ] diff --git a/pyproject.toml b/pyproject.toml index 94f74c12e..d898d2f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pywavelets = ">=1.5.0" pooch = ">=1.8.1" # TODO: Remove this dependency at some point gaitmap = ">=2.5.0" +pingouin = "^0.5.4" [tool.poetry.group.dev.dependencies] poethepoet = "^0.22.0"