-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bcd16a4
commit 78177ec
Showing
3 changed files
with
368 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
|
||
""" | ||
HKLee algo | ||
========= | ||
This example shows how to use the improved HKLee algorithm and some examples on how the results compare to the original | ||
matlab implementation. | ||
""" | ||
|
||
import pandas as pd | ||
from matplotlib import pyplot as plt | ||
from gaitlink.ICD._hklee_algo_improved import IcdHKLeeImproved | ||
from gaitlink.data import LabExampleDataset | ||
|
||
# %% | ||
# Loading data | ||
# ------------ | ||
# .. note :: More infos about data loading can be found in the :ref:`data loading example <data_loading_example>`. | ||
# We load example data from the lab dataset together with the INDIP reference system. | ||
# We will use the INDIP output for initial contacts ("ic") as ground truth. | ||
|
||
example_data = LabExampleDataset(reference_system="INDIP", reference_para_level="wb") | ||
|
||
single_test = example_data.get_subset(cohort="HA", participant_id="001", test="Test11", trial="Trial1") | ||
imu_data = single_test.data["LowerBack"] | ||
reference_wbs = single_test.reference_parameters_.wb_list | ||
|
||
sampling_rate_hz = single_test.sampling_rate_hz | ||
ref_ics = single_test.reference_parameters_.ic_list | ||
|
||
reference_wbs | ||
# %% | ||
# Applying the algorithm | ||
# ---------------------- | ||
# Below we apply the shin algorithm to a lab trial. | ||
# We will use the `GsIterator` to iterate over the gait sequences and apply the algorithm to each wb. | ||
from gaitlink.pipeline import GsIterator | ||
|
||
iterator = GsIterator() | ||
|
||
for (gs, data), result in iterator.iterate(imu_data, reference_wbs): | ||
result.initial_contacts = IcdHKLeeImproved().detect(data, sampling_rate_hz=sampling_rate_hz).ic_list_ | ||
|
||
detected_ics = iterator.initial_contacts_ | ||
|
||
detected_ics | ||
# %% | ||
# Matlab Outputs | ||
# -------------- | ||
# To check if the algorithm was implemented correctly, we compare the results to the matlab implementation. | ||
import json | ||
|
||
from gaitlink import PACKAGE_ROOT | ||
|
||
|
||
def load_matlab_output(datapoint): | ||
p = datapoint.group_label | ||
with ( | ||
PACKAGE_ROOT.parent | ||
/ f"example_data/original_results/icd_shin_improved/lab/{p.cohort}/{p.participant_id}/SD_Output.json" | ||
).open() as f: | ||
original_results = json.load(f)["SD_Output"][p.time_measure][p.test][p.trial]["SU"]["LowerBack"]["SD"] | ||
|
||
if not isinstance(original_results, list): | ||
original_results = [original_results] | ||
|
||
ics = {} | ||
for i, gs in enumerate(original_results): | ||
ics[i] = pd.DataFrame({"ic": gs["IC"]}).rename_axis(index="ic_id") | ||
|
||
return (pd.concat(ics, names=["wb_id", ics[0].index.name]) * datapoint.sampling_rate_hz).astype(int) | ||
|
||
|
||
detected_ics_matlab = load_matlab_output(single_test) | ||
detected_ics_matlab | ||
# %% | ||
# Plotting the results | ||
# -------------------- | ||
# With that we can compare the python, matlab and ground truth results. | ||
# We zoom in into one of the gait sequences to better see the output. | ||
# | ||
# We can make a couple of main observations: | ||
# | ||
# 1. The python version finds the same ICs as the matlab version, but wil a small shift to the left (around 5-10 | ||
# samples/50-100 ms). | ||
# This is likely due to some differences in the downsampling process. | ||
# 2. Compared to the ground truth reference, both versions detect the IC too early most of the time. | ||
# 3. Both algorithms can not detect the first IC of the gait sequence. | ||
# However, this is expected, as per definition, this first IC marks the start of the WB in the reference system. | ||
# Hence, there are not samples before that point the algorithm can use to detect the IC. | ||
|
||
imu_data.reset_index(drop=True).plot(y="acc_x") | ||
|
||
plt.plot(ref_ics["ic"], imu_data["acc_x"].iloc[ref_ics["ic"]], "o", label="ref") | ||
#plt.plot(detected_ics["ic"], imu_data["acc_x"].iloc[detected_ics["ic"]], "x", label="hklee_algo_py") | ||
plt.plot(detected_ics_matlab["ic"], imu_data["acc_x"].iloc[detected_ics_matlab["ic"]], "+", label="hklee_algo_matlab") | ||
plt.xlim(reference_wbs.iloc[2]["start"] - 50, reference_wbs.iloc[2]["end"] + 50) | ||
plt.legend() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from typing import Any, Literal | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from gaitmap.data_transform import Resample | ||
from gaitmap.utils.array_handling import bool_array_to_start_end_array | ||
from numpy.linalg import norm | ||
from pywt import cwt | ||
from scipy.ndimage import gaussian_filter, grey_closing, grey_opening | ||
from scipy.signal import savgol_filter | ||
from typing_extensions import Self, Unpack | ||
|
||
from gaitlink.data_transform import EpflDedriftedGaitFilter, EpflGaitFilter | ||
from gaitlink.ICD.base import BaseIcDetector, base_icd_docfiller | ||
|
||
|
||
@base_icd_docfiller | ||
class IcdHKLeeImproved(BaseIcDetector): | ||
|
||
"""Detect initial contacts using the HKLee [1]_ algorithm, with improvements by Ionescu et al. [2]_. | ||
This algorithm is designed to detect initial contacts from accelerometer signals within a gait sequence. | ||
The algorithm filters the accelerometer signal down to its primary frequency components | ||
and then employs morphological operations with closing and opening structural elements | ||
to detect signal closings and openings, respectively. | ||
Their difference is analyzed to identify instances where R is greater than 0. | ||
These regions are interpreted as initial contacts. | ||
This is based on the implementation published as part of the mobilised project [3]_. | ||
However, this implementation deviates from the original implementation in some places. | ||
For details, see the notes section and the examples. | ||
Parameters | ||
---------- | ||
axis | ||
selecting which part of the accelerometer signal to be used. Can be 'x', 'y', 'z', or 'norm'. | ||
The default is 'norm', which is also the default in the original implementation. | ||
Other Parameters | ||
---------------- | ||
%(other_parameters)s | ||
Attributes | ||
---------- | ||
%(ic_list_)s | ||
final_filtered_signal_ (upsampled again in HKLee) | ||
The downsampled signal after all filter steps. | ||
This might be useful for debugging. | ||
ic_list_internal_ | ||
The initial contacts detected on the downsampled signal, before upsampling to the original sampling rate. | ||
This can be useful for debugging in combination with the `final_filtered_signal_` attribute. | ||
Notes | ||
----- | ||
Points of deviation from the original implementation and their reasons: | ||
- Configurable accelerometer signal: on matlab, all axes are used to calculate ICs, here we provide | ||
the option to select which axis to use. Despite the fact that the Shin algorithm on matlab uses all axes, | ||
here we provide the option of selecting a single axis because other contact detection algorithms use only the | ||
horizontal axis. | ||
- We use a different downsampling method, which should be "more" correct from a signal theory perspective, | ||
but will yield slightly different results. | ||
- only in case the upsampling will be removed #The matlab code upsamples to 120 Hz before the final morphological operations. | ||
#We skip the upsampling of the filtered signal and perform the morphological operations on the downsampled signal. | ||
#To compensate for the "loss of accuracy" due to the downsampling, we use linear interpolation to determine the | ||
#exact position of the 0-crossing, even when it occurs between two samples. | ||
#We then project the interpolated index back to the original sampling rate. | ||
- For CWT and gaussian filter, the actual parameter we pass to the respective functions differ from the matlab | ||
implementation, as the two languages use different implementations of the functions. | ||
However, the similarity of the output was visually confirmed. | ||
- All parameters are expressed in the units used in gaitlink, as opposed to matlab. | ||
Specifically, we use m/s^2 instead of g. | ||
- #Some early testing indicates, that the python version finds all ICs 5-10 samples earlier than the matlab version. | ||
#However, this seems to be a relatively consistent offset. | ||
#Hence, it might be possible to shift/tune this in the future. | ||
.. [1] Lee, H-K., et al. "Computational methods to detect step events for normal and pathological | ||
gait evaluation using accelerometer." Electronics letters 46.17 (2010): 1185-1187. | ||
.. [2] Paraschiv-Ionescu, A. et al. "Real-world speed estimation using single trunk IMU: | ||
methodological challenges for impaired gait patterns". IEEE EMBC (2020): 4596-4599 | ||
.. [3] https://github.com/mobilise-d/Mobilise-D-TVS-Recommended-Algorithms/blob/master/CADB_CADC/Library/Shin_algo_improved.m | ||
""" | ||
|
||
axis: Literal["x", "y", "z", "norm"] | ||
|
||
_INTERNAL_FILTER_SAMPLING_RATE_HZ: int = 40 | ||
|
||
final_filtered_signal_: np.ndarray | ||
ic_list_internal_: pd.DataFrame | ||
|
||
def __init__(self, axis: Literal["x", "y", "z", "norm"] = "norm") -> None: | ||
self.axis = axis | ||
|
||
@base_icd_docfiller | ||
def detect(self, data: pd.DataFrame, *, sampling_rate_hz: float, **_: Unpack[dict[str, Any]]) -> Self: | ||
"""%(detect_short)s. | ||
%(detect_info)s | ||
Parameters | ||
---------- | ||
%(detect_para)s | ||
%(detect_return)s | ||
""" | ||
self.data = data | ||
self.sampling_rate_hz = sampling_rate_hz | ||
|
||
if self.axis not in ["x", "y", "z", "norm"]: | ||
raise ValueError("Invalid axis. Choose 'x', 'y', 'z', or 'norm'.") | ||
|
||
signal = ( | ||
norm(data[["acc_x", "acc_y", "acc_z"]].to_numpy(), axis=1) | ||
if self.axis == "norm" | ||
else data[f"acc_{self.axis}"].to_numpy() | ||
) | ||
|
||
# Resample to 40Hz to process with filters | ||
signal_downsampled = ( | ||
Resample(self._INTERNAL_FILTER_SAMPLING_RATE_HZ) | ||
.transform(data=signal, sampling_rate_hz=sampling_rate_hz) | ||
.transformed_data_ | ||
) | ||
|
||
# We need to intitialize the filter once to get the number of coefficients to calculate the padding. | ||
# This is not ideal, but works for now. | ||
# TODO: We should evaluate, if we need the padding at all, or if the filter methods that we use handle that | ||
# correctly anyway. -> filtfilt uses padding automatically and savgol allows to actiavte padding, put uses the | ||
# default mode (polyinomal interpolation) might be suffiecent anyway, cwt might have some edeeffects, but | ||
# usually nothing to worry about. | ||
n_coefficients = len(EpflGaitFilter().coefficients[0]) | ||
|
||
# Padding to cope with short data | ||
len_pad = 4 * n_coefficients | ||
signal_downsampled_padded = np.pad(signal_downsampled, (len_pad, len_pad), "wrap") | ||
|
||
# Filters | ||
# 1 | ||
# TODO (future): Replace svagol and cwt with class implementation to easily expose parameters. | ||
tmp_sig_1 = savgol_filter(signal_downsampled_padded.squeeze(), window_length=21, polyorder=7) | ||
# 2 | ||
tmp_sig_2 = ( | ||
EpflDedriftedGaitFilter() | ||
.filter(tmp_sig_1, sampling_rate_hz=self._INTERNAL_FILTER_SAMPLING_RATE_HZ) | ||
.filtered_data_ | ||
) | ||
# 3 | ||
# NOTE: Original MATLAB code calls old version of cwt (open wavelet.internal.cwt in MATLAB to inspect) in | ||
# accN_filt3=cwt(accN_filt2,10,'gaus2',1/40); | ||
# Here, 10 is the scale, gaus2 is the second derivative of a Gaussian wavelet, aka a Mexican Hat or Ricker | ||
# wavelet. | ||
# In Python, a scale of 7 matches the MATLAB scale of 10 from visual inspection of plots (likely due to how to | ||
# two languages initialise their wavelets), giving the line below | ||
ricker_width = 10 | ||
tmp_sig_3, _ = cwt( | ||
tmp_sig_2.squeeze(), [ricker_width], "gaus2", sampling_period=1 / self._INTERNAL_FILTER_SAMPLING_RATE_HZ | ||
) | ||
# 4 | ||
tmp_sig_4 = savgol_filter(tmp_sig_3.squeeze(), window_length=11, polyorder=5) | ||
# 5 | ||
tmp_sig_5, _ = cwt( | ||
tmp_sig_4.squeeze(), [ricker_width], "gaus2", sampling_period=1 / self._INTERNAL_FILTER_SAMPLING_RATE_HZ | ||
) | ||
# Compared to matlab the python gauss filter needs the matlab window with divided by 5 | ||
tmp_sig_6 = tmp_sig_5 | ||
for sigma in [2, 2, 3]: | ||
tmp_sig_6 = gaussian_filter(tmp_sig_6.squeeze(), sigma) | ||
# Remove padding | ||
final_filtered = tmp_sig_6[len_pad:-len_pad] | ||
|
||
self.final_filtered_signal_ = final_filtered | ||
|
||
#Resample to 100Hz for consistency with the original data (for ICD) or to 120 for consistency with original paper | ||
NEW_SAMPLING_RATE_HZ = 120 | ||
signal_upsampled = ( | ||
Resample(NEW_SAMPLING_RATE_HZ) | ||
.transform(data=final_filtered, sampling_rate_hz=self._INTERNAL_FILTER_SAMPLING_RATE_HZ) | ||
.transformed_data_ | ||
) | ||
|
||
# Apply morphological filters | ||
SE_closing = np.ones(32, dtype=int) | ||
SE_opening = np.ones(18, dtype=int) | ||
|
||
C = grey_closing(signal_upsampled, structure=SE_closing) | ||
O = grey_opening(C, structure=SE_opening) | ||
R = C - O | ||
|
||
detected_ics = np.array([]) | ||
|
||
if np.any(R > 0): | ||
idx = bool_array_to_start_end_array(R > 0) | ||
detected_ics = np.zeros(len(idx), dtype=float) | ||
for j in range(len(idx)): | ||
start_idx, end_idx = idx[j, 0], idx[j, 1] | ||
values_within_range = R[start_idx:end_idx + 1] | ||
imax = start_idx + np.argmax(values_within_range) | ||
|
||
# Assign the value to the NumPy array | ||
detected_ics[j] = imax | ||
|
||
detected_ics = pd.DataFrame({"ic": detected_ics}).rename_axis(index="ic_id") | ||
|
||
self.ic_list_internal_ = detected_ics | ||
|
||
# Downsample initial contacts to original sampling rate | ||
IC_downsampled = ( | ||
(detected_ics * sampling_rate_hz / NEW_SAMPLING_RATE_HZ).round().astype(int) | ||
) | ||
|
||
self.ic_list_ = IC_downsampled | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from tpcp.testing import TestAlgorithmMixin | ||
|
||
from gaitlink.data import LabExampleDataset | ||
from gaitlink.ICD._hklee_algo_improved import IcdHKLeeImproved | ||
from gaitlink.pipeline import GsIterator | ||
|
||
|
||
class TestMetaHKLeeImproved(TestAlgorithmMixin): | ||
__test__ = True | ||
|
||
ALGORITHM_CLASS = IcdHKLeeImproved | ||
|
||
@pytest.fixture() | ||
def after_action_instance(self): | ||
return self.ALGORITHM_CLASS().detect( | ||
pd.DataFrame(np.zeros((1000, 3)), columns=["acc_x", "acc_y", "acc_z"]), sampling_rate_hz=120.0 | ||
) | ||
|
||
|
||
class TestHKLeeImproved: | ||
def test_invalid_axis_parameter(self): | ||
with pytest.raises(ValueError): | ||
IcdHKLeeImproved(axis="invalid").detect(pd.DataFrame(), sampling_rate_hz=100) | ||
|
||
def test_no_ics_detected(self): | ||
data = pd.DataFrame(np.zeros((1000, 3)), columns=["acc_x", "acc_y", "acc_z"]) | ||
output = IcdHKLeeImproved(axis="x") | ||
output.detect(data, sampling_rate_hz=120.0) | ||
output_ic = output.ic_list_["ic"] | ||
empty_output = {} | ||
assert output_ic.to_dict() == empty_output | ||
|
||
class TestShinImprovedRegression: | ||
@pytest.mark.parametrize("datapoint", LabExampleDataset(reference_system="INDIP", reference_para_level="wb")) | ||
def test_example_lab_data(self, datapoint, snapshot): | ||
data = datapoint.data["LowerBack"] | ||
try: | ||
ref_walk_bouts = datapoint.reference_parameters_.wb_list | ||
except: | ||
pytest.skip("No reference parameters available.") | ||
sampling_rate_hz = datapoint.sampling_rate_hz | ||
|
||
iterator = GsIterator() | ||
|
||
for (gs, data), result in iterator.iterate(data, ref_walk_bouts): | ||
result.initial_contacts = IcdHKLeeImproved().detect(data, sampling_rate_hz=sampling_rate_hz).ic_list_ | ||
|
||
detected_ics = iterator.initial_contacts_ | ||
snapshot.assert_match(detected_ics, str(datapoint.group_label)) |