Skip to content

Commit

Permalink
first implementation of HKLee algo
Browse files Browse the repository at this point in the history
  • Loading branch information
DMegaritis committed Feb 2, 2024
1 parent bcd16a4 commit 78177ec
Show file tree
Hide file tree
Showing 3 changed files with 368 additions and 0 deletions.
100 changes: 100 additions & 0 deletions examples/icd/_03_hklee_algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

"""
HKLee algo
=========
This example shows how to use the improved HKLee algorithm and some examples on how the results compare to the original
matlab implementation.
"""

import pandas as pd
from matplotlib import pyplot as plt
from gaitlink.ICD._hklee_algo_improved import IcdHKLeeImproved
from gaitlink.data import LabExampleDataset

# %%
# Loading data
# ------------
# .. note :: More infos about data loading can be found in the :ref:`data loading example <data_loading_example>`.
# We load example data from the lab dataset together with the INDIP reference system.
# We will use the INDIP output for initial contacts ("ic") as ground truth.

example_data = LabExampleDataset(reference_system="INDIP", reference_para_level="wb")

single_test = example_data.get_subset(cohort="HA", participant_id="001", test="Test11", trial="Trial1")
imu_data = single_test.data["LowerBack"]
reference_wbs = single_test.reference_parameters_.wb_list

sampling_rate_hz = single_test.sampling_rate_hz
ref_ics = single_test.reference_parameters_.ic_list

reference_wbs
# %%
# Applying the algorithm
# ----------------------
# Below we apply the shin algorithm to a lab trial.
# We will use the `GsIterator` to iterate over the gait sequences and apply the algorithm to each wb.
from gaitlink.pipeline import GsIterator

iterator = GsIterator()

for (gs, data), result in iterator.iterate(imu_data, reference_wbs):
result.initial_contacts = IcdHKLeeImproved().detect(data, sampling_rate_hz=sampling_rate_hz).ic_list_

detected_ics = iterator.initial_contacts_

detected_ics
# %%
# Matlab Outputs
# --------------
# To check if the algorithm was implemented correctly, we compare the results to the matlab implementation.
import json

from gaitlink import PACKAGE_ROOT


def load_matlab_output(datapoint):
p = datapoint.group_label
with (
PACKAGE_ROOT.parent
/ f"example_data/original_results/icd_shin_improved/lab/{p.cohort}/{p.participant_id}/SD_Output.json"
).open() as f:
original_results = json.load(f)["SD_Output"][p.time_measure][p.test][p.trial]["SU"]["LowerBack"]["SD"]

if not isinstance(original_results, list):
original_results = [original_results]

ics = {}
for i, gs in enumerate(original_results):
ics[i] = pd.DataFrame({"ic": gs["IC"]}).rename_axis(index="ic_id")

return (pd.concat(ics, names=["wb_id", ics[0].index.name]) * datapoint.sampling_rate_hz).astype(int)


detected_ics_matlab = load_matlab_output(single_test)
detected_ics_matlab
# %%
# Plotting the results
# --------------------
# With that we can compare the python, matlab and ground truth results.
# We zoom in into one of the gait sequences to better see the output.
#
# We can make a couple of main observations:
#
# 1. The python version finds the same ICs as the matlab version, but wil a small shift to the left (around 5-10
# samples/50-100 ms).
# This is likely due to some differences in the downsampling process.
# 2. Compared to the ground truth reference, both versions detect the IC too early most of the time.
# 3. Both algorithms can not detect the first IC of the gait sequence.
# However, this is expected, as per definition, this first IC marks the start of the WB in the reference system.
# Hence, there are not samples before that point the algorithm can use to detect the IC.

imu_data.reset_index(drop=True).plot(y="acc_x")

plt.plot(ref_ics["ic"], imu_data["acc_x"].iloc[ref_ics["ic"]], "o", label="ref")
#plt.plot(detected_ics["ic"], imu_data["acc_x"].iloc[detected_ics["ic"]], "x", label="hklee_algo_py")
plt.plot(detected_ics_matlab["ic"], imu_data["acc_x"].iloc[detected_ics_matlab["ic"]], "+", label="hklee_algo_matlab")
plt.xlim(reference_wbs.iloc[2]["start"] - 50, reference_wbs.iloc[2]["end"] + 50)
plt.legend()
plt.show()
216 changes: 216 additions & 0 deletions gaitlink/icd/_hklee_algo_improved.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from typing import Any, Literal

import numpy as np
import pandas as pd
from gaitmap.data_transform import Resample
from gaitmap.utils.array_handling import bool_array_to_start_end_array
from numpy.linalg import norm
from pywt import cwt
from scipy.ndimage import gaussian_filter, grey_closing, grey_opening
from scipy.signal import savgol_filter
from typing_extensions import Self, Unpack

from gaitlink.data_transform import EpflDedriftedGaitFilter, EpflGaitFilter
from gaitlink.ICD.base import BaseIcDetector, base_icd_docfiller


@base_icd_docfiller
class IcdHKLeeImproved(BaseIcDetector):

"""Detect initial contacts using the HKLee [1]_ algorithm, with improvements by Ionescu et al. [2]_.
This algorithm is designed to detect initial contacts from accelerometer signals within a gait sequence.
The algorithm filters the accelerometer signal down to its primary frequency components
and then employs morphological operations with closing and opening structural elements
to detect signal closings and openings, respectively.
Their difference is analyzed to identify instances where R is greater than 0.
These regions are interpreted as initial contacts.
This is based on the implementation published as part of the mobilised project [3]_.
However, this implementation deviates from the original implementation in some places.
For details, see the notes section and the examples.
Parameters
----------
axis
selecting which part of the accelerometer signal to be used. Can be 'x', 'y', 'z', or 'norm'.
The default is 'norm', which is also the default in the original implementation.
Other Parameters
----------------
%(other_parameters)s
Attributes
----------
%(ic_list_)s
final_filtered_signal_ (upsampled again in HKLee)
The downsampled signal after all filter steps.
This might be useful for debugging.
ic_list_internal_
The initial contacts detected on the downsampled signal, before upsampling to the original sampling rate.
This can be useful for debugging in combination with the `final_filtered_signal_` attribute.
Notes
-----
Points of deviation from the original implementation and their reasons:
- Configurable accelerometer signal: on matlab, all axes are used to calculate ICs, here we provide
the option to select which axis to use. Despite the fact that the Shin algorithm on matlab uses all axes,
here we provide the option of selecting a single axis because other contact detection algorithms use only the
horizontal axis.
- We use a different downsampling method, which should be "more" correct from a signal theory perspective,
but will yield slightly different results.
- only in case the upsampling will be removed #The matlab code upsamples to 120 Hz before the final morphological operations.
#We skip the upsampling of the filtered signal and perform the morphological operations on the downsampled signal.
#To compensate for the "loss of accuracy" due to the downsampling, we use linear interpolation to determine the
#exact position of the 0-crossing, even when it occurs between two samples.
#We then project the interpolated index back to the original sampling rate.
- For CWT and gaussian filter, the actual parameter we pass to the respective functions differ from the matlab
implementation, as the two languages use different implementations of the functions.
However, the similarity of the output was visually confirmed.
- All parameters are expressed in the units used in gaitlink, as opposed to matlab.
Specifically, we use m/s^2 instead of g.
- #Some early testing indicates, that the python version finds all ICs 5-10 samples earlier than the matlab version.
#However, this seems to be a relatively consistent offset.
#Hence, it might be possible to shift/tune this in the future.
.. [1] Lee, H-K., et al. "Computational methods to detect step events for normal and pathological
gait evaluation using accelerometer." Electronics letters 46.17 (2010): 1185-1187.
.. [2] Paraschiv-Ionescu, A. et al. "Real-world speed estimation using single trunk IMU:
methodological challenges for impaired gait patterns". IEEE EMBC (2020): 4596-4599
.. [3] https://github.com/mobilise-d/Mobilise-D-TVS-Recommended-Algorithms/blob/master/CADB_CADC/Library/Shin_algo_improved.m
"""

axis: Literal["x", "y", "z", "norm"]

_INTERNAL_FILTER_SAMPLING_RATE_HZ: int = 40

final_filtered_signal_: np.ndarray
ic_list_internal_: pd.DataFrame

def __init__(self, axis: Literal["x", "y", "z", "norm"] = "norm") -> None:
self.axis = axis

@base_icd_docfiller
def detect(self, data: pd.DataFrame, *, sampling_rate_hz: float, **_: Unpack[dict[str, Any]]) -> Self:
"""%(detect_short)s.
%(detect_info)s
Parameters
----------
%(detect_para)s
%(detect_return)s
"""
self.data = data
self.sampling_rate_hz = sampling_rate_hz

if self.axis not in ["x", "y", "z", "norm"]:
raise ValueError("Invalid axis. Choose 'x', 'y', 'z', or 'norm'.")

signal = (
norm(data[["acc_x", "acc_y", "acc_z"]].to_numpy(), axis=1)
if self.axis == "norm"
else data[f"acc_{self.axis}"].to_numpy()
)

# Resample to 40Hz to process with filters
signal_downsampled = (
Resample(self._INTERNAL_FILTER_SAMPLING_RATE_HZ)
.transform(data=signal, sampling_rate_hz=sampling_rate_hz)
.transformed_data_
)

# We need to intitialize the filter once to get the number of coefficients to calculate the padding.
# This is not ideal, but works for now.
# TODO: We should evaluate, if we need the padding at all, or if the filter methods that we use handle that
# correctly anyway. -> filtfilt uses padding automatically and savgol allows to actiavte padding, put uses the
# default mode (polyinomal interpolation) might be suffiecent anyway, cwt might have some edeeffects, but
# usually nothing to worry about.
n_coefficients = len(EpflGaitFilter().coefficients[0])

# Padding to cope with short data
len_pad = 4 * n_coefficients
signal_downsampled_padded = np.pad(signal_downsampled, (len_pad, len_pad), "wrap")

# Filters
# 1
# TODO (future): Replace svagol and cwt with class implementation to easily expose parameters.
tmp_sig_1 = savgol_filter(signal_downsampled_padded.squeeze(), window_length=21, polyorder=7)
# 2
tmp_sig_2 = (
EpflDedriftedGaitFilter()
.filter(tmp_sig_1, sampling_rate_hz=self._INTERNAL_FILTER_SAMPLING_RATE_HZ)
.filtered_data_
)
# 3
# NOTE: Original MATLAB code calls old version of cwt (open wavelet.internal.cwt in MATLAB to inspect) in
# accN_filt3=cwt(accN_filt2,10,'gaus2',1/40);
# Here, 10 is the scale, gaus2 is the second derivative of a Gaussian wavelet, aka a Mexican Hat or Ricker
# wavelet.
# In Python, a scale of 7 matches the MATLAB scale of 10 from visual inspection of plots (likely due to how to
# two languages initialise their wavelets), giving the line below
ricker_width = 10
tmp_sig_3, _ = cwt(
tmp_sig_2.squeeze(), [ricker_width], "gaus2", sampling_period=1 / self._INTERNAL_FILTER_SAMPLING_RATE_HZ
)
# 4
tmp_sig_4 = savgol_filter(tmp_sig_3.squeeze(), window_length=11, polyorder=5)
# 5
tmp_sig_5, _ = cwt(
tmp_sig_4.squeeze(), [ricker_width], "gaus2", sampling_period=1 / self._INTERNAL_FILTER_SAMPLING_RATE_HZ
)
# Compared to matlab the python gauss filter needs the matlab window with divided by 5
tmp_sig_6 = tmp_sig_5
for sigma in [2, 2, 3]:
tmp_sig_6 = gaussian_filter(tmp_sig_6.squeeze(), sigma)
# Remove padding
final_filtered = tmp_sig_6[len_pad:-len_pad]

self.final_filtered_signal_ = final_filtered

#Resample to 100Hz for consistency with the original data (for ICD) or to 120 for consistency with original paper
NEW_SAMPLING_RATE_HZ = 120
signal_upsampled = (
Resample(NEW_SAMPLING_RATE_HZ)
.transform(data=final_filtered, sampling_rate_hz=self._INTERNAL_FILTER_SAMPLING_RATE_HZ)
.transformed_data_
)

# Apply morphological filters
SE_closing = np.ones(32, dtype=int)
SE_opening = np.ones(18, dtype=int)

C = grey_closing(signal_upsampled, structure=SE_closing)
O = grey_opening(C, structure=SE_opening)
R = C - O

detected_ics = np.array([])

if np.any(R > 0):
idx = bool_array_to_start_end_array(R > 0)
detected_ics = np.zeros(len(idx), dtype=float)
for j in range(len(idx)):
start_idx, end_idx = idx[j, 0], idx[j, 1]
values_within_range = R[start_idx:end_idx + 1]
imax = start_idx + np.argmax(values_within_range)

# Assign the value to the NumPy array
detected_ics[j] = imax

detected_ics = pd.DataFrame({"ic": detected_ics}).rename_axis(index="ic_id")

self.ic_list_internal_ = detected_ics

# Downsample initial contacts to original sampling rate
IC_downsampled = (
(detected_ics * sampling_rate_hz / NEW_SAMPLING_RATE_HZ).round().astype(int)
)

self.ic_list_ = IC_downsampled

return self
52 changes: 52 additions & 0 deletions tests/test_icd/test_icd_hklee_improved.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pandas as pd
import pytest
from tpcp.testing import TestAlgorithmMixin

from gaitlink.data import LabExampleDataset
from gaitlink.ICD._hklee_algo_improved import IcdHKLeeImproved
from gaitlink.pipeline import GsIterator


class TestMetaHKLeeImproved(TestAlgorithmMixin):
__test__ = True

ALGORITHM_CLASS = IcdHKLeeImproved

@pytest.fixture()
def after_action_instance(self):
return self.ALGORITHM_CLASS().detect(
pd.DataFrame(np.zeros((1000, 3)), columns=["acc_x", "acc_y", "acc_z"]), sampling_rate_hz=120.0
)


class TestHKLeeImproved:
def test_invalid_axis_parameter(self):
with pytest.raises(ValueError):
IcdHKLeeImproved(axis="invalid").detect(pd.DataFrame(), sampling_rate_hz=100)

def test_no_ics_detected(self):
data = pd.DataFrame(np.zeros((1000, 3)), columns=["acc_x", "acc_y", "acc_z"])
output = IcdHKLeeImproved(axis="x")
output.detect(data, sampling_rate_hz=120.0)
output_ic = output.ic_list_["ic"]
empty_output = {}
assert output_ic.to_dict() == empty_output

class TestShinImprovedRegression:
@pytest.mark.parametrize("datapoint", LabExampleDataset(reference_system="INDIP", reference_para_level="wb"))
def test_example_lab_data(self, datapoint, snapshot):
data = datapoint.data["LowerBack"]
try:
ref_walk_bouts = datapoint.reference_parameters_.wb_list
except:
pytest.skip("No reference parameters available.")
sampling_rate_hz = datapoint.sampling_rate_hz

iterator = GsIterator()

for (gs, data), result in iterator.iterate(data, ref_walk_bouts):
result.initial_contacts = IcdHKLeeImproved().detect(data, sampling_rate_hz=sampling_rate_hz).ic_list_

detected_ics = iterator.initial_contacts_
snapshot.assert_match(detected_ics, str(datapoint.group_label))

0 comments on commit 78177ec

Please sign in to comment.