Skip to content

Commit

Permalink
Rework of LR pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Apr 15, 2024
1 parent 90bda72 commit 849d280
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 182 deletions.
7 changes: 4 additions & 3 deletions mobgap/gsd/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base class for GSD detectors."""

from collections.abc import Iterable
from typing import Any, Union

import pandas as pd
Expand Down Expand Up @@ -97,10 +98,10 @@ def detect(self, data: pd.DataFrame, *, sampling_rate_hz: float, **kwargs: Unpac

def self_optimize(
self,
data: list[pd.DataFrame],
reference_gsd_list: list[pd.DataFrame],
data: Iterable[pd.DataFrame],
reference_gsd_list: Iterable[pd.DataFrame],
*,
sampling_rate_hz: Union[float, list[float]],
sampling_rate_hz: Union[float, Iterable[float]],
**kwargs: Unpack[dict[str, Any]],
) -> Self:
"""Optimize the internal parameters of the algorithm.
Expand Down
10 changes: 5 additions & 5 deletions mobgap/gsd/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from intervaltree.interval import Interval
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from tpcp import OptimizablePipeline
from tpcp import OptimizableParameter, OptimizablePipeline
from tpcp.validate import Aggregator, NoAgg
from typing_extensions import Self, Unpack

Expand Down Expand Up @@ -532,7 +532,7 @@ class GsdEvaluationPipeline(OptimizablePipeline[BaseGaitDatasetWithReference]):
"""

algo: BaseGsDetector
algo: OptimizableParameter[BaseGsDetector]

algo_: BaseGsDetector

Expand Down Expand Up @@ -590,9 +590,9 @@ def self_optimize(self, dataset: BaseGaitDatasetWithReference, **kwargs: Unpack[
The pipeline instance with the optimized GSD algorithm.
"""
all_data = [d.data_ss for d in dataset]
reference_wbs = [d.reference_parameters_.wb_list for d in dataset]
sampling_rate_hz = [d.sampling_rate_hz for d in dataset]
all_data = (d.data_ss for d in dataset)
reference_wbs = (d.reference_parameters_.wb_list for d in dataset)
sampling_rate_hz = (d.sampling_rate_hz for d in dataset)

self.algo.self_optimize(all_data, reference_wbs, sampling_rate_hz=sampling_rate_hz, **kwargs)

Expand Down
94 changes: 0 additions & 94 deletions mobgap/lrd/_lr_optimizer.py

This file was deleted.

66 changes: 29 additions & 37 deletions mobgap/lrd/_lrd_ml.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import typing
from functools import cache
from importlib.resources import files
from typing import Any, Optional
from typing import Any

import joblib
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, is_classifier
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils.validation import check_is_fitted
from tpcp import cf
from tpcp.misc import classproperty, set_defaults
Expand All @@ -17,6 +16,7 @@
from mobgap.data_transform import ButterworthFilter
from mobgap.data_transform.base import BaseFilter
from mobgap.lrd.base import BaseLRDetector, base_lrd_docfiller
from mobgap.utils._sklearn_protocol_types import SklearnClassifier, SklearnScaler


@cache
Expand Down Expand Up @@ -75,8 +75,8 @@ class LrdUllrich(BaseLRDetector):
available at: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9630653
"""

model: Optional[ClassifierMixin]
scaler: Optional[MinMaxScaler]
model: SklearnClassifier
scaler: SklearnScaler
smoothing_filter: BaseFilter

feature_matrix_: pd.DataFrame
Expand All @@ -97,8 +97,8 @@ class PredefinedParameters:

class _ModelConfig(TypedDict):
smoothing_filter: BaseFilter
model: ClassifierMixin
scaler: TransformerMixin
model: SklearnClassifier
scaler: SklearnScaler

@classmethod
def _load_model_config(cls, model_name: str) -> _ModelConfig:
Expand Down Expand Up @@ -136,13 +136,21 @@ def msproject_ms(cls) -> _ModelConfig: # noqa: N805
def __init__(
self,
smoothing_filter: BaseFilter,
model: ClassifierMixin,
scaler: MinMaxScaler,
model: SklearnClassifier,
scaler: SklearnScaler,
) -> None:
self.smoothing_filter = smoothing_filter
self.model = model
self.scaler = scaler

def _check_model(self, model: SklearnClassifier) -> None:
if not isinstance(model, ClassifierMixin) or not isinstance(model, BaseEstimator) or not is_classifier(model):
raise TypeError(
f"Unknown model type {type(model).__name__}."
"The model must inherit from ClassifierDtype and BaseEstimator. "
"Any valid scikit-learn classifier can be used."
)

@base_lrd_docfiller
def detect(
self,
Expand Down Expand Up @@ -176,16 +184,7 @@ def detect(
self.feature_matrix_ = pd.DataFrame(columns=self.feature_matrix_.columns)
return self

if (
not isinstance(self.model, ClassifierMixin)
or not isinstance(self.model, BaseEstimator)
and is_classifier(self.model)
):
raise TypeError(
f"Unknown model type {type(self.model).__name__}."
"The model must inherit from ClassifierMixin and BaseEstimator. "
"Any valid scikit-learn classifier can be used."
)
self._check_model(self.model)

# create a copy of ic_list, otherwise, they will get modified when adding the predicted labels
# We also remove the "lr_label" column, if it exists, to avoid conflicts
Expand All @@ -212,10 +211,11 @@ def detect(

def self_optimize(
self,
data_list: list[pd.DataFrame],
ic_list: list[pd.DataFrame],
label_list: list[pd.DataFrame],
sampling_rate_hz: float,
data: typing.Iterable[pd.DataFrame],
ic_list: typing.Iterable[pd.DataFrame],
reference_ic_lr_list: typing.Iterable[pd.DataFrame],
*,
sampling_rate_hz: typing.Union[float, typing.Iterable[float]],
) -> Self:
"""
Model optimization method based on the provided gait data, initial contact list, and the reference label list.
Expand All @@ -234,27 +234,19 @@ def self_optimize(
Returns
-------
self
The instance of the LrdUllrich class.
Optimized instance of the provided class.
"""
if not isinstance(data_list, list):
raise TypeError("'data' must be a list of pandas DataFrames")
self._check_model(self.model)

if not isinstance(ic_list, list):
raise TypeError("'ic_list' must be a list of pandas DataFrame")

if not isinstance(label_list, list):
raise TypeError("'label_list' must be a list of pandas DataFrame")

features = [self.extract_features(data, ic, sampling_rate_hz) for data, ic in zip(data_list, ic_list)]
features = [self.extract_features(data, ic, sampling_rate_hz) for data, ic in zip(data, ic_list)]
all_features = pd.concat(features, axis=0, ignore_index=True) if len(features) > 1 else features[0]
all_features = pd.DataFrame(
self.scaler.fit_transform(all_features.values), columns=all_features.columns, index=all_features.index
)
all_features = self.scaler.fit_transform(all_features)

# Concatenate the labels if there is more than one GS
label_list = [ic_lr_list["lr_label"] for ic_lr_list in reference_ic_lr_list]
all_labels = pd.concat(label_list, axis=0, ignore_index=True) if len(label_list) > 1 else label_list[0]

self.model.fit(all_features.to_numpy(), np.ravel(all_labels))
self.model.fit(all_features, all_labels)

return self

Expand Down
42 changes: 0 additions & 42 deletions mobgap/lrd/_utils.py

This file was deleted.

Loading

0 comments on commit 849d280

Please sign in to comment.