Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to store psd for fft, stft and welch method #276

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 85 additions & 17 deletions py_neuromodulation/nm_oscillatory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def test_settings_osc(
assert isinstance(
s[osc_feature_name]["windowlength_ms"], int
), f"windowlength_ms needs to be type int, got {s[osc_feature_name]['windowlength_ms']}"

assert (
s[osc_feature_name]["windowlength_ms"]
<= s["segment_length_features_ms"]
), (
f"oscillatory feature windowlength_ms = ({s[osc_feature_name]['windowlength_ms']})"
f"needs to be smaller than"
f"s['segment_length_features_ms'] = {s['segment_length_features_ms']}",
)
else:
for seg_length in s[osc_feature_name][
"segment_lengths_ms"
Expand Down Expand Up @@ -89,19 +98,33 @@ def update_KF(self, feature_calc: float, KF_name: str) -> float:
feature_calc = self.KF_dict[KF_name].x[0]
return feature_calc

def estimate_osc_features(self, features_compute: dict, data: np.ndarray, feature_name: np.ndarray, est_name: str):
def estimate_osc_features(
self,
features_compute: dict,
data: np.ndarray,
feature_name: np.ndarray,
est_name: str,
):
for feature_est_name in list(self.s[est_name]["features"].keys()):
if self.s[est_name]["features"][feature_est_name] is True:
# switch case for feature_est_name
match feature_est_name:
case "mean":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmean(data)
features_compute[
f"{feature_name}_{feature_est_name}"
] = np.nanmean(data)
case "median":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmedian(data)
features_compute[
f"{feature_name}_{feature_est_name}"
] = np.nanmedian(data)
case "std":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanstd(data)
features_compute[
f"{feature_name}_{feature_est_name}"
] = np.nanstd(data)
case "max":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmax(data)
features_compute[
f"{feature_name}_{feature_est_name}"
] = np.nanmax(data)

return features_compute

Expand All @@ -122,13 +145,15 @@ def __init__(

window_ms = self.s["fft_settings"]["windowlength_ms"]
self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
freqs = fft.rfftfreq(-self.window_samples, 1 / np.floor(self.sfreq))
self.freqs = fft.rfftfreq(
-self.window_samples, 1 / np.floor(self.sfreq)
)

self.feature_params = []
for ch_idx, ch_name in enumerate(self.ch_names):
for fband, f_range in self.f_ranges_dict.items():
idx_range = np.where(
(freqs >= f_range[0]) & (freqs < f_range[1])
(self.freqs >= f_range[0]) & (self.freqs < f_range[1])
)[0]
feature_name = "_".join([ch_name, "fft", fband])
self.feature_params.append((ch_idx, feature_name, idx_range))
Expand All @@ -145,10 +170,20 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
Z = np.log10(Z)

for ch_idx, feature_name, idx_range in self.feature_params:

Z_ch = Z[ch_idx, idx_range]

features_compute = self.estimate_osc_features(features_compute, Z_ch, feature_name, "fft_settings")
features_compute = self.estimate_osc_features(
features_compute, Z_ch, feature_name, "fft_settings"
)

for ch_idx, ch_name in enumerate(self.ch_names):
if self.s["fft_settings"]["return_spectrum"]:
features_compute.update(
{
f"{ch_name}_fft_psd_{str(f)}": Z[ch_idx][idx]
for idx, f in enumerate(self.freqs.astype(int))
}
)

return features_compute

Expand Down Expand Up @@ -177,23 +212,39 @@ def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
)

def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
f, Z = signal.welch(
freqs, Z = signal.welch(
data,
fs=self.sfreq,
window="hann",
nperseg=self.sfreq,
noverlap=None,
)

if self.log_transform:
Z = np.log10(Z)

for ch_idx, feature_name, f_range in self.feature_params:
Z_ch = Z[ch_idx]

if self.log_transform:
Z_ch = np.log10(Z_ch)
idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[
0
]

idx_range = np.where((f >= f_range[0]) & (f <= f_range[1]))[0]
features_compute = self.estimate_osc_features(
features_compute,
Z_ch[idx_range],
feature_name,
"welch_settings",
)

features_compute = self.estimate_osc_features(features_compute, Z_ch[idx_range], feature_name, "welch_settings")
for ch_idx, ch_name in enumerate(self.ch_names):
if self.s["welch_settings"]["return_spectrum"]:
features_compute.update(
{
f"{ch_name}_welch_psd_{str(f)}": Z[ch_idx][idx]
for idx, f in enumerate(freqs.astype(int))
}
)

return features_compute

Expand Down Expand Up @@ -223,7 +274,7 @@ def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
)

def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
f, _, Zxx = signal.stft(
freqs, _, Zxx = signal.stft(
data,
fs=self.sfreq,
window="hamming",
Expand All @@ -235,9 +286,26 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
Z = np.log10(Z)
for ch_idx, feature_name, f_range in self.feature_params:
Z_ch = Z[ch_idx]
idx_range = np.where((f >= f_range[0]) & (f <= f_range[1]))[0]
idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[
0
]

features_compute = self.estimate_osc_features(
features_compute,
Z_ch[idx_range, :],
feature_name,
"stft_settings",
)

features_compute = self.estimate_osc_features(features_compute, Z_ch[idx_range, :], feature_name, "stft_settings")
for ch_idx, ch_name in enumerate(self.ch_names):
if self.s["stft_settings"]["return_spectrum"]:
Z_ch_mean = Z[ch_idx].mean(axis=1)
features_compute.update(
{
f"{ch_name}_stft_psd_{str(f)}": Z_ch_mean[idx]
for idx, f in enumerate(freqs.astype(int))
}
)

return features_compute

Expand Down
15 changes: 9 additions & 6 deletions py_neuromodulation/nm_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@
"mean": true,
"median": false,
"std": false,
"max" : false
}
"max": false
},
"return_spectrum": true
},
"welch_settings": {
"windowlength_ms": 1000,
Expand All @@ -103,8 +104,9 @@
"mean": true,
"median": false,
"std": false,
"max" : false
}
"max": false
},
"return_spectrum": true
},
"stft_settings": {
"windowlength_ms": 500,
Expand All @@ -113,8 +115,9 @@
"mean": true,
"median": false,
"std": false,
"max" : false
}
"max": false
},
"return_spectrum": true
},
"bandpass_filter_settings": {
"segment_lengths_ms": {
Expand Down
19 changes: 19 additions & 0 deletions tests/test_feature_sampling_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,30 @@ def test_different_sampling_rate_0DOT1Hz():
assert np.diff(df["time"].iloc[:2]) / 1000 == (1 / sampling_rate_features)


def test_wrong_initalization_of_segment_length_features_ms_and_osc_window_length():
segment_length_features_ms = 800

arr_test = np.random.random([2, 1200])
settings, nm_channels = get_example_settings(arr_test)

settings["segment_length_features_ms"] = 800
settings["fft_settings"]["windowlength_ms"] = 1000

with pytest.raises(Exception):
stream = nm_stream_offline.Stream(
sfreq=1000, nm_channels=nm_channels, settings=settings, verbose=True
)


def test_different_segment_lengths():
segment_length_features_ms = 800

arr_test = np.random.random([2, 1200])
settings, nm_channels = get_example_settings(arr_test)

settings["segment_length_features_ms"] = segment_length_features_ms
settings["fft_settings"]["windowlength_ms"] = segment_length_features_ms

stream = nm_stream_offline.Stream(
sfreq=1000, nm_channels=nm_channels, settings=settings, verbose=True
)
Expand All @@ -111,6 +128,8 @@ def test_different_segment_lengths():
settings, nm_channels = get_example_settings(arr_test)

settings["segment_length_features_ms"] = segment_length_features_ms
settings["fft_settings"]["windowlength_ms"] = segment_length_features_ms

stream = nm_stream_offline.Stream(
sfreq=1000, nm_channels=nm_channels, settings=settings, verbose=True
)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_osc_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_fft_zero_data():
features_out = fft_obj.calc_feature(data, {})

for f in features_out.keys():
assert features_out[f] == 0
if "psd_0" not in f:
assert np.isclose(features_out[f], 0, atol=1e-6)


def test_fft_random_data():
Expand Down
1 change: 1 addition & 0 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_features(time_end_ms: int, segment_length_features_ms: int):
data = np.random.random([2, time_end_ms])
settings = get_fast_compute_settings()
settings["segment_length_features_ms"] = segment_length_features_ms
settings["fft_settings"]["windowlength_ms"] = segment_length_features_ms

settings["frequency_ranges_hz"] = {
# "high beta" : [20, 35],
Expand Down
Loading