From 78177ec08952967a608446283f8c71a690ef79b1 Mon Sep 17 00:00:00 2001 From: dmegaritis Date: Fri, 2 Feb 2024 11:01:18 +0000 Subject: [PATCH] first implementation of HKLee algo --- examples/icd/_03_hklee_algo.py | 100 ++++++++++ gaitlink/icd/_hklee_algo_improved.py | 216 ++++++++++++++++++++++ tests/test_icd/test_icd_hklee_improved.py | 52 ++++++ 3 files changed, 368 insertions(+) create mode 100644 examples/icd/_03_hklee_algo.py create mode 100644 gaitlink/icd/_hklee_algo_improved.py create mode 100644 tests/test_icd/test_icd_hklee_improved.py diff --git a/examples/icd/_03_hklee_algo.py b/examples/icd/_03_hklee_algo.py new file mode 100644 index 000000000..288ef4fe6 --- /dev/null +++ b/examples/icd/_03_hklee_algo.py @@ -0,0 +1,100 @@ + +""" +HKLee algo +========= + +This example shows how to use the improved HKLee algorithm and some examples on how the results compare to the original +matlab implementation. + +""" + +import pandas as pd +from matplotlib import pyplot as plt +from gaitlink.ICD._hklee_algo_improved import IcdHKLeeImproved +from gaitlink.data import LabExampleDataset + +# %% +# Loading data +# ------------ +# .. note :: More infos about data loading can be found in the :ref:`data loading example `. +# We load example data from the lab dataset together with the INDIP reference system. +# We will use the INDIP output for initial contacts ("ic") as ground truth. + +example_data = LabExampleDataset(reference_system="INDIP", reference_para_level="wb") + +single_test = example_data.get_subset(cohort="HA", participant_id="001", test="Test11", trial="Trial1") +imu_data = single_test.data["LowerBack"] +reference_wbs = single_test.reference_parameters_.wb_list + +sampling_rate_hz = single_test.sampling_rate_hz +ref_ics = single_test.reference_parameters_.ic_list + +reference_wbs +# %% +# Applying the algorithm +# ---------------------- +# Below we apply the shin algorithm to a lab trial. +# We will use the `GsIterator` to iterate over the gait sequences and apply the algorithm to each wb. +from gaitlink.pipeline import GsIterator + +iterator = GsIterator() + +for (gs, data), result in iterator.iterate(imu_data, reference_wbs): + result.initial_contacts = IcdHKLeeImproved().detect(data, sampling_rate_hz=sampling_rate_hz).ic_list_ + +detected_ics = iterator.initial_contacts_ + +detected_ics +# %% +# Matlab Outputs +# -------------- +# To check if the algorithm was implemented correctly, we compare the results to the matlab implementation. +import json + +from gaitlink import PACKAGE_ROOT + + +def load_matlab_output(datapoint): + p = datapoint.group_label + with ( + PACKAGE_ROOT.parent + / f"example_data/original_results/icd_shin_improved/lab/{p.cohort}/{p.participant_id}/SD_Output.json" + ).open() as f: + original_results = json.load(f)["SD_Output"][p.time_measure][p.test][p.trial]["SU"]["LowerBack"]["SD"] + + if not isinstance(original_results, list): + original_results = [original_results] + + ics = {} + for i, gs in enumerate(original_results): + ics[i] = pd.DataFrame({"ic": gs["IC"]}).rename_axis(index="ic_id") + + return (pd.concat(ics, names=["wb_id", ics[0].index.name]) * datapoint.sampling_rate_hz).astype(int) + + +detected_ics_matlab = load_matlab_output(single_test) +detected_ics_matlab +# %% +# Plotting the results +# -------------------- +# With that we can compare the python, matlab and ground truth results. +# We zoom in into one of the gait sequences to better see the output. +# +# We can make a couple of main observations: +# +# 1. The python version finds the same ICs as the matlab version, but wil a small shift to the left (around 5-10 +# samples/50-100 ms). +# This is likely due to some differences in the downsampling process. +# 2. Compared to the ground truth reference, both versions detect the IC too early most of the time. +# 3. Both algorithms can not detect the first IC of the gait sequence. +# However, this is expected, as per definition, this first IC marks the start of the WB in the reference system. +# Hence, there are not samples before that point the algorithm can use to detect the IC. + +imu_data.reset_index(drop=True).plot(y="acc_x") + +plt.plot(ref_ics["ic"], imu_data["acc_x"].iloc[ref_ics["ic"]], "o", label="ref") +#plt.plot(detected_ics["ic"], imu_data["acc_x"].iloc[detected_ics["ic"]], "x", label="hklee_algo_py") +plt.plot(detected_ics_matlab["ic"], imu_data["acc_x"].iloc[detected_ics_matlab["ic"]], "+", label="hklee_algo_matlab") +plt.xlim(reference_wbs.iloc[2]["start"] - 50, reference_wbs.iloc[2]["end"] + 50) +plt.legend() +plt.show() diff --git a/gaitlink/icd/_hklee_algo_improved.py b/gaitlink/icd/_hklee_algo_improved.py new file mode 100644 index 000000000..7fa65e995 --- /dev/null +++ b/gaitlink/icd/_hklee_algo_improved.py @@ -0,0 +1,216 @@ +from typing import Any, Literal + +import numpy as np +import pandas as pd +from gaitmap.data_transform import Resample +from gaitmap.utils.array_handling import bool_array_to_start_end_array +from numpy.linalg import norm +from pywt import cwt +from scipy.ndimage import gaussian_filter, grey_closing, grey_opening +from scipy.signal import savgol_filter +from typing_extensions import Self, Unpack + +from gaitlink.data_transform import EpflDedriftedGaitFilter, EpflGaitFilter +from gaitlink.ICD.base import BaseIcDetector, base_icd_docfiller + + +@base_icd_docfiller +class IcdHKLeeImproved(BaseIcDetector): + + """Detect initial contacts using the HKLee [1]_ algorithm, with improvements by Ionescu et al. [2]_. + + This algorithm is designed to detect initial contacts from accelerometer signals within a gait sequence. + The algorithm filters the accelerometer signal down to its primary frequency components + and then employs morphological operations with closing and opening structural elements + to detect signal closings and openings, respectively. + Their difference is analyzed to identify instances where R is greater than 0. + These regions are interpreted as initial contacts. + + This is based on the implementation published as part of the mobilised project [3]_. + However, this implementation deviates from the original implementation in some places. + For details, see the notes section and the examples. + + Parameters + ---------- + axis + selecting which part of the accelerometer signal to be used. Can be 'x', 'y', 'z', or 'norm'. + The default is 'norm', which is also the default in the original implementation. + + Other Parameters + ---------------- + %(other_parameters)s + + Attributes + ---------- + %(ic_list_)s + final_filtered_signal_ (upsampled again in HKLee) + The downsampled signal after all filter steps. + This might be useful for debugging. + ic_list_internal_ + The initial contacts detected on the downsampled signal, before upsampling to the original sampling rate. + This can be useful for debugging in combination with the `final_filtered_signal_` attribute. + + + Notes + ----- + Points of deviation from the original implementation and their reasons: + + - Configurable accelerometer signal: on matlab, all axes are used to calculate ICs, here we provide + the option to select which axis to use. Despite the fact that the Shin algorithm on matlab uses all axes, + here we provide the option of selecting a single axis because other contact detection algorithms use only the + horizontal axis. + - We use a different downsampling method, which should be "more" correct from a signal theory perspective, + but will yield slightly different results. + - only in case the upsampling will be removed #The matlab code upsamples to 120 Hz before the final morphological operations. + #We skip the upsampling of the filtered signal and perform the morphological operations on the downsampled signal. + #To compensate for the "loss of accuracy" due to the downsampling, we use linear interpolation to determine the + #exact position of the 0-crossing, even when it occurs between two samples. + #We then project the interpolated index back to the original sampling rate. + - For CWT and gaussian filter, the actual parameter we pass to the respective functions differ from the matlab + implementation, as the two languages use different implementations of the functions. + However, the similarity of the output was visually confirmed. + - All parameters are expressed in the units used in gaitlink, as opposed to matlab. + Specifically, we use m/s^2 instead of g. + - #Some early testing indicates, that the python version finds all ICs 5-10 samples earlier than the matlab version. + #However, this seems to be a relatively consistent offset. + #Hence, it might be possible to shift/tune this in the future. + + .. [1] Lee, H-K., et al. "Computational methods to detect step events for normal and pathological + gait evaluation using accelerometer." Electronics letters 46.17 (2010): 1185-1187. + .. [2] Paraschiv-Ionescu, A. et al. "Real-world speed estimation using single trunk IMU: + methodological challenges for impaired gait patterns". IEEE EMBC (2020): 4596-4599 + .. [3] https://github.com/mobilise-d/Mobilise-D-TVS-Recommended-Algorithms/blob/master/CADB_CADC/Library/Shin_algo_improved.m + + """ + + axis: Literal["x", "y", "z", "norm"] + + _INTERNAL_FILTER_SAMPLING_RATE_HZ: int = 40 + + final_filtered_signal_: np.ndarray + ic_list_internal_: pd.DataFrame + + def __init__(self, axis: Literal["x", "y", "z", "norm"] = "norm") -> None: + self.axis = axis + + @base_icd_docfiller + def detect(self, data: pd.DataFrame, *, sampling_rate_hz: float, **_: Unpack[dict[str, Any]]) -> Self: + """%(detect_short)s. + + %(detect_info)s + + Parameters + ---------- + %(detect_para)s + + %(detect_return)s + + """ + self.data = data + self.sampling_rate_hz = sampling_rate_hz + + if self.axis not in ["x", "y", "z", "norm"]: + raise ValueError("Invalid axis. Choose 'x', 'y', 'z', or 'norm'.") + + signal = ( + norm(data[["acc_x", "acc_y", "acc_z"]].to_numpy(), axis=1) + if self.axis == "norm" + else data[f"acc_{self.axis}"].to_numpy() + ) + + # Resample to 40Hz to process with filters + signal_downsampled = ( + Resample(self._INTERNAL_FILTER_SAMPLING_RATE_HZ) + .transform(data=signal, sampling_rate_hz=sampling_rate_hz) + .transformed_data_ + ) + + # We need to intitialize the filter once to get the number of coefficients to calculate the padding. + # This is not ideal, but works for now. + # TODO: We should evaluate, if we need the padding at all, or if the filter methods that we use handle that + # correctly anyway. -> filtfilt uses padding automatically and savgol allows to actiavte padding, put uses the + # default mode (polyinomal interpolation) might be suffiecent anyway, cwt might have some edeeffects, but + # usually nothing to worry about. + n_coefficients = len(EpflGaitFilter().coefficients[0]) + + # Padding to cope with short data + len_pad = 4 * n_coefficients + signal_downsampled_padded = np.pad(signal_downsampled, (len_pad, len_pad), "wrap") + + # Filters + # 1 + # TODO (future): Replace svagol and cwt with class implementation to easily expose parameters. + tmp_sig_1 = savgol_filter(signal_downsampled_padded.squeeze(), window_length=21, polyorder=7) + # 2 + tmp_sig_2 = ( + EpflDedriftedGaitFilter() + .filter(tmp_sig_1, sampling_rate_hz=self._INTERNAL_FILTER_SAMPLING_RATE_HZ) + .filtered_data_ + ) + # 3 + # NOTE: Original MATLAB code calls old version of cwt (open wavelet.internal.cwt in MATLAB to inspect) in + # accN_filt3=cwt(accN_filt2,10,'gaus2',1/40); + # Here, 10 is the scale, gaus2 is the second derivative of a Gaussian wavelet, aka a Mexican Hat or Ricker + # wavelet. + # In Python, a scale of 7 matches the MATLAB scale of 10 from visual inspection of plots (likely due to how to + # two languages initialise their wavelets), giving the line below + ricker_width = 10 + tmp_sig_3, _ = cwt( + tmp_sig_2.squeeze(), [ricker_width], "gaus2", sampling_period=1 / self._INTERNAL_FILTER_SAMPLING_RATE_HZ + ) + # 4 + tmp_sig_4 = savgol_filter(tmp_sig_3.squeeze(), window_length=11, polyorder=5) + # 5 + tmp_sig_5, _ = cwt( + tmp_sig_4.squeeze(), [ricker_width], "gaus2", sampling_period=1 / self._INTERNAL_FILTER_SAMPLING_RATE_HZ + ) + # Compared to matlab the python gauss filter needs the matlab window with divided by 5 + tmp_sig_6 = tmp_sig_5 + for sigma in [2, 2, 3]: + tmp_sig_6 = gaussian_filter(tmp_sig_6.squeeze(), sigma) + # Remove padding + final_filtered = tmp_sig_6[len_pad:-len_pad] + + self.final_filtered_signal_ = final_filtered + + #Resample to 100Hz for consistency with the original data (for ICD) or to 120 for consistency with original paper + NEW_SAMPLING_RATE_HZ = 120 + signal_upsampled = ( + Resample(NEW_SAMPLING_RATE_HZ) + .transform(data=final_filtered, sampling_rate_hz=self._INTERNAL_FILTER_SAMPLING_RATE_HZ) + .transformed_data_ + ) + + # Apply morphological filters + SE_closing = np.ones(32, dtype=int) + SE_opening = np.ones(18, dtype=int) + + C = grey_closing(signal_upsampled, structure=SE_closing) + O = grey_opening(C, structure=SE_opening) + R = C - O + + detected_ics = np.array([]) + + if np.any(R > 0): + idx = bool_array_to_start_end_array(R > 0) + detected_ics = np.zeros(len(idx), dtype=float) + for j in range(len(idx)): + start_idx, end_idx = idx[j, 0], idx[j, 1] + values_within_range = R[start_idx:end_idx + 1] + imax = start_idx + np.argmax(values_within_range) + + # Assign the value to the NumPy array + detected_ics[j] = imax + + detected_ics = pd.DataFrame({"ic": detected_ics}).rename_axis(index="ic_id") + + self.ic_list_internal_ = detected_ics + + # Downsample initial contacts to original sampling rate + IC_downsampled = ( + (detected_ics * sampling_rate_hz / NEW_SAMPLING_RATE_HZ).round().astype(int) + ) + + self.ic_list_ = IC_downsampled + + return self diff --git a/tests/test_icd/test_icd_hklee_improved.py b/tests/test_icd/test_icd_hklee_improved.py new file mode 100644 index 000000000..f81d385aa --- /dev/null +++ b/tests/test_icd/test_icd_hklee_improved.py @@ -0,0 +1,52 @@ +import numpy as np +import pandas as pd +import pytest +from tpcp.testing import TestAlgorithmMixin + +from gaitlink.data import LabExampleDataset +from gaitlink.ICD._hklee_algo_improved import IcdHKLeeImproved +from gaitlink.pipeline import GsIterator + + +class TestMetaHKLeeImproved(TestAlgorithmMixin): + __test__ = True + + ALGORITHM_CLASS = IcdHKLeeImproved + + @pytest.fixture() + def after_action_instance(self): + return self.ALGORITHM_CLASS().detect( + pd.DataFrame(np.zeros((1000, 3)), columns=["acc_x", "acc_y", "acc_z"]), sampling_rate_hz=120.0 + ) + + +class TestHKLeeImproved: + def test_invalid_axis_parameter(self): + with pytest.raises(ValueError): + IcdHKLeeImproved(axis="invalid").detect(pd.DataFrame(), sampling_rate_hz=100) + + def test_no_ics_detected(self): + data = pd.DataFrame(np.zeros((1000, 3)), columns=["acc_x", "acc_y", "acc_z"]) + output = IcdHKLeeImproved(axis="x") + output.detect(data, sampling_rate_hz=120.0) + output_ic = output.ic_list_["ic"] + empty_output = {} + assert output_ic.to_dict() == empty_output + +class TestShinImprovedRegression: + @pytest.mark.parametrize("datapoint", LabExampleDataset(reference_system="INDIP", reference_para_level="wb")) + def test_example_lab_data(self, datapoint, snapshot): + data = datapoint.data["LowerBack"] + try: + ref_walk_bouts = datapoint.reference_parameters_.wb_list + except: + pytest.skip("No reference parameters available.") + sampling_rate_hz = datapoint.sampling_rate_hz + + iterator = GsIterator() + + for (gs, data), result in iterator.iterate(data, ref_walk_bouts): + result.initial_contacts = IcdHKLeeImproved().detect(data, sampling_rate_hz=sampling_rate_hz).ic_list_ + + detected_ics = iterator.initial_contacts_ + snapshot.assert_match(detected_ics, str(datapoint.group_label))