From 81c7d7373d63aec191def3efda91989481d89175 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Thu, 14 Dec 2023 14:48:48 +0100 Subject: [PATCH 1/2] add option to store psd for fft, stft and welch method --- py_neuromodulation/nm_oscillatory.py | 93 +++++++++++++++++++++++----- py_neuromodulation/nm_settings.json | 15 +++-- 2 files changed, 85 insertions(+), 23 deletions(-) diff --git a/py_neuromodulation/nm_oscillatory.py b/py_neuromodulation/nm_oscillatory.py index 0ab7e780..a9321c47 100644 --- a/py_neuromodulation/nm_oscillatory.py +++ b/py_neuromodulation/nm_oscillatory.py @@ -89,19 +89,33 @@ def update_KF(self, feature_calc: float, KF_name: str) -> float: feature_calc = self.KF_dict[KF_name].x[0] return feature_calc - def estimate_osc_features(self, features_compute: dict, data: np.ndarray, feature_name: np.ndarray, est_name: str): + def estimate_osc_features( + self, + features_compute: dict, + data: np.ndarray, + feature_name: np.ndarray, + est_name: str, + ): for feature_est_name in list(self.s[est_name]["features"].keys()): if self.s[est_name]["features"][feature_est_name] is True: # switch case for feature_est_name match feature_est_name: case "mean": - features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmean(data) + features_compute[ + f"{feature_name}_{feature_est_name}" + ] = np.nanmean(data) case "median": - features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmedian(data) + features_compute[ + f"{feature_name}_{feature_est_name}" + ] = np.nanmedian(data) case "std": - features_compute[f"{feature_name}_{feature_est_name}"] = np.nanstd(data) + features_compute[ + f"{feature_name}_{feature_est_name}" + ] = np.nanstd(data) case "max": - features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmax(data) + features_compute[ + f"{feature_name}_{feature_est_name}" + ] = np.nanmax(data) return features_compute @@ -122,13 +136,15 @@ def __init__( window_ms = self.s["fft_settings"]["windowlength_ms"] self.window_samples = int(-np.floor(window_ms / 1000 * sfreq)) - freqs = fft.rfftfreq(-self.window_samples, 1 / np.floor(self.sfreq)) + self.freqs = fft.rfftfreq( + -self.window_samples, 1 / np.floor(self.sfreq) + ) self.feature_params = [] for ch_idx, ch_name in enumerate(self.ch_names): for fband, f_range in self.f_ranges_dict.items(): idx_range = np.where( - (freqs >= f_range[0]) & (freqs < f_range[1]) + (self.freqs >= f_range[0]) & (self.freqs < f_range[1]) )[0] feature_name = "_".join([ch_name, "fft", fband]) self.feature_params.append((ch_idx, feature_name, idx_range)) @@ -145,10 +161,20 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict: Z = np.log10(Z) for ch_idx, feature_name, idx_range in self.feature_params: - Z_ch = Z[ch_idx, idx_range] - features_compute = self.estimate_osc_features(features_compute, Z_ch, feature_name, "fft_settings") + features_compute = self.estimate_osc_features( + features_compute, Z_ch, feature_name, "fft_settings" + ) + + for ch_idx, ch_name in enumerate(self.ch_names): + if self.s["fft_settings"]["return_spectrum"]: + features_compute.update( + { + f"{ch_name}_fft_psd_{str(f)}": Z[ch_idx][idx] + for idx, f in enumerate(self.freqs.astype(int)) + } + ) return features_compute @@ -177,7 +203,7 @@ def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float): ) def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict: - f, Z = signal.welch( + freqs, Z = signal.welch( data, fs=self.sfreq, window="hann", @@ -185,15 +211,31 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict: noverlap=None, ) + if self.log_transform: + Z = np.log10(Z) + for ch_idx, feature_name, f_range in self.feature_params: Z_ch = Z[ch_idx] - if self.log_transform: - Z_ch = np.log10(Z_ch) + idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[ + 0 + ] - idx_range = np.where((f >= f_range[0]) & (f <= f_range[1]))[0] + features_compute = self.estimate_osc_features( + features_compute, + Z_ch[idx_range], + feature_name, + "welch_settings", + ) - features_compute = self.estimate_osc_features(features_compute, Z_ch[idx_range], feature_name, "welch_settings") + for ch_idx, ch_name in enumerate(self.ch_names): + if self.s["welch_settings"]["return_spectrum"]: + features_compute.update( + { + f"{ch_name}_welch_psd_{str(f)}": Z[ch_idx][idx] + for idx, f in enumerate(freqs.astype(int)) + } + ) return features_compute @@ -223,7 +265,7 @@ def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float): ) def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict: - f, _, Zxx = signal.stft( + freqs, _, Zxx = signal.stft( data, fs=self.sfreq, window="hamming", @@ -235,9 +277,26 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict: Z = np.log10(Z) for ch_idx, feature_name, f_range in self.feature_params: Z_ch = Z[ch_idx] - idx_range = np.where((f >= f_range[0]) & (f <= f_range[1]))[0] + idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[ + 0 + ] + + features_compute = self.estimate_osc_features( + features_compute, + Z_ch[idx_range, :], + feature_name, + "stft_settings", + ) - features_compute = self.estimate_osc_features(features_compute, Z_ch[idx_range, :], feature_name, "stft_settings") + for ch_idx, ch_name in enumerate(self.ch_names): + if self.s["stft_settings"]["return_spectrum"]: + Z_ch_mean = Z[ch_idx].mean(axis=1) + features_compute.update( + { + f"{ch_name}_stft_psd_{str(f)}": Z_ch_mean[idx] + for idx, f in enumerate(freqs.astype(int)) + } + ) return features_compute diff --git a/py_neuromodulation/nm_settings.json b/py_neuromodulation/nm_settings.json index 61403b0c..bd8e5a4f 100644 --- a/py_neuromodulation/nm_settings.json +++ b/py_neuromodulation/nm_settings.json @@ -93,8 +93,9 @@ "mean": true, "median": false, "std": false, - "max" : false - } + "max": false + }, + "return_spectrum": true }, "welch_settings": { "windowlength_ms": 1000, @@ -103,8 +104,9 @@ "mean": true, "median": false, "std": false, - "max" : false - } + "max": false + }, + "return_spectrum": true }, "stft_settings": { "windowlength_ms": 500, @@ -113,8 +115,9 @@ "mean": true, "median": false, "std": false, - "max" : false - } + "max": false + }, + "return_spectrum": true }, "bandpass_filter_settings": { "segment_lengths_ms": { From df4486c7df0b6b8cd2c3c869ee7e86ef6508f5ee Mon Sep 17 00:00:00 2001 From: timonmerk Date: Thu, 14 Dec 2023 16:11:05 +0100 Subject: [PATCH 2/2] fix segment length tests --- py_neuromodulation/nm_oscillatory.py | 9 +++++++++ tests/test_feature_sampling_rates.py | 19 +++++++++++++++++++ tests/test_osc_features.py | 3 ++- tests/test_sampling.py | 1 + 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/py_neuromodulation/nm_oscillatory.py b/py_neuromodulation/nm_oscillatory.py index a9321c47..e0c065f9 100644 --- a/py_neuromodulation/nm_oscillatory.py +++ b/py_neuromodulation/nm_oscillatory.py @@ -38,6 +38,15 @@ def test_settings_osc( assert isinstance( s[osc_feature_name]["windowlength_ms"], int ), f"windowlength_ms needs to be type int, got {s[osc_feature_name]['windowlength_ms']}" + + assert ( + s[osc_feature_name]["windowlength_ms"] + <= s["segment_length_features_ms"] + ), ( + f"oscillatory feature windowlength_ms = ({s[osc_feature_name]['windowlength_ms']})" + f"needs to be smaller than" + f"s['segment_length_features_ms'] = {s['segment_length_features_ms']}", + ) else: for seg_length in s[osc_feature_name][ "segment_lengths_ms" diff --git a/tests/test_feature_sampling_rates.py b/tests/test_feature_sampling_rates.py index 0540ae78..4179fe0f 100644 --- a/tests/test_feature_sampling_rates.py +++ b/tests/test_feature_sampling_rates.py @@ -92,6 +92,21 @@ def test_different_sampling_rate_0DOT1Hz(): assert np.diff(df["time"].iloc[:2]) / 1000 == (1 / sampling_rate_features) +def test_wrong_initalization_of_segment_length_features_ms_and_osc_window_length(): + segment_length_features_ms = 800 + + arr_test = np.random.random([2, 1200]) + settings, nm_channels = get_example_settings(arr_test) + + settings["segment_length_features_ms"] = 800 + settings["fft_settings"]["windowlength_ms"] = 1000 + + with pytest.raises(Exception): + stream = nm_stream_offline.Stream( + sfreq=1000, nm_channels=nm_channels, settings=settings, verbose=True + ) + + def test_different_segment_lengths(): segment_length_features_ms = 800 @@ -99,6 +114,8 @@ def test_different_segment_lengths(): settings, nm_channels = get_example_settings(arr_test) settings["segment_length_features_ms"] = segment_length_features_ms + settings["fft_settings"]["windowlength_ms"] = segment_length_features_ms + stream = nm_stream_offline.Stream( sfreq=1000, nm_channels=nm_channels, settings=settings, verbose=True ) @@ -111,6 +128,8 @@ def test_different_segment_lengths(): settings, nm_channels = get_example_settings(arr_test) settings["segment_length_features_ms"] = segment_length_features_ms + settings["fft_settings"]["windowlength_ms"] = segment_length_features_ms + stream = nm_stream_offline.Stream( sfreq=1000, nm_channels=nm_channels, settings=settings, verbose=True ) diff --git a/tests/test_osc_features.py b/tests/test_osc_features.py index 5fe72844..b3ce00da 100644 --- a/tests/test_osc_features.py +++ b/tests/test_osc_features.py @@ -69,7 +69,8 @@ def test_fft_zero_data(): features_out = fft_obj.calc_feature(data, {}) for f in features_out.keys(): - assert features_out[f] == 0 + if "psd_0" not in f: + assert np.isclose(features_out[f], 0, atol=1e-6) def test_fft_random_data(): diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 0135302d..1ce4723e 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -16,6 +16,7 @@ def get_features(time_end_ms: int, segment_length_features_ms: int): data = np.random.random([2, time_end_ms]) settings = get_fast_compute_settings() settings["segment_length_features_ms"] = segment_length_features_ms + settings["fft_settings"]["windowlength_ms"] = segment_length_features_ms settings["frequency_ranges_hz"] = { # "high beta" : [20, 35],