Skip to content

Commit

Permalink
Merge pull request #288 from toni-neurosc/sharpwaves_performance_pr
Browse files Browse the repository at this point in the history
Sharpwaves performance fixes

Many thanks @toni-neurosc!
  • Loading branch information
timonmerk authored Feb 3, 2024
2 parents aab04f2 + d7dc768 commit 704ca15
Showing 1 changed file with 103 additions and 136 deletions.
239 changes: 103 additions & 136 deletions py_neuromodulation/nm_sharpwaves.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,12 @@ def _get_peaks_around(self, trough_ind, arr_ind_peaks, filtered_dat):
peak_right_val (np.ndarray): value of righ peak
"""

ind_greater = np.where(arr_ind_peaks > trough_ind)[0]
if ind_greater.shape[0] == 0:
raise NoValidTroughException("No valid trough")
val_ind_greater = arr_ind_peaks[ind_greater]
peak_right_idx = arr_ind_peaks[
ind_greater[np.argsort(val_ind_greater)[0]]
]

ind_smaller = np.where(arr_ind_peaks < trough_ind)[0]
if ind_smaller.shape[0] == 0:
raise NoValidTroughException("No valid trough")

val_ind_smaller = arr_ind_peaks[ind_smaller]
peak_left_idx = arr_ind_peaks[
ind_smaller[np.argsort(val_ind_smaller)[-1]]
]
try: peak_right_idx = arr_ind_peaks[arr_ind_peaks > trough_ind][0]
except IndexError: raise NoValidTroughException("No valid trough")

try: peak_left_idx = arr_ind_peaks[arr_ind_peaks < trough_ind][-1]
except IndexError: raise NoValidTroughException("No valid trough")

return (
peak_left_idx,
peak_right_idx,
Expand Down Expand Up @@ -150,12 +139,10 @@ def calc_feature(
"""
for ch_idx, ch_name in enumerate(self.ch_names):
for filter_name, filter in self.list_filter:
if filter_name == "no_filter":
self.data_process_sw = data[ch_idx, :]
else:
self.data_process_sw = signal.convolve(
data[ch_idx, :], filter, mode="same"
)
self.data_process_sw = (data[ch_idx, :]
if filter_name == "no_filter"
else signal.fftconvolve(data[ch_idx, :], filter, mode="same")
)

# check settings if troughs and peaks are analyzed

Expand Down Expand Up @@ -267,123 +254,103 @@ def analyze_waveform(self) -> None:
distance=self.sw_settings["detect_troughs"]["distance_troughs_ms"],
)[0]

for trough_idx in troughs:
try:
(
peak_idx_left,
peak_idx_right,
peak_left,
peak_right,
) = self._get_peaks_around(
trough_idx, peaks, self.data_process_sw
)
except NoValidTroughException:
# in this case there are no adjacent two peaks around this trough
# str(e) could print the exception error message
# print(str(e))
""" Find left and right peak indexes for each trough """
peak_pointer = 0
peak_idx_left = []
peak_idx_right = []
first_valid = last_valid = 0

for i, trough_idx in enumerate(troughs):

# Locate peak right of current trough
while peak_pointer < peaks.size and peaks[peak_pointer] < trough_idx:
peak_pointer += 1

if peak_pointer - 1 < 0:
# If trough has no peak to it's left, it's not valid
first_valid = i + 1 # Try with next one
continue

trough = self.data_process_sw[trough_idx]
self.trough.append(trough)
self.troughs_idx.append(trough_idx)

if self.sw_settings["sharpwave_features"]["interval"] is True:
if len(self.troughs_idx) > 1:
# take the last identified trough idx
# corresponds here to second last trough_idx

interval = (trough_idx - self.troughs_idx[-2]) * (
1000 / self.sfreq
)
else:
# set first interval to zero
interval = 0
self.interval.append(interval)

if self.sw_settings["sharpwave_features"]["peak_left"] is True:
self.peak_left.append(peak_left)

if self.sw_settings["sharpwave_features"]["peak_right"] is True:
self.peak_right.append(peak_right)

if self.sw_settings["sharpwave_features"]["sharpness"] is True:
# check if sharpness can be calculated
# trough_idx 5 ms need to be consistent
if (trough_idx - int(5 * (1000 / self.sfreq)) <= 0) or (
trough_idx + int(5 * (1000 / self.sfreq))
>= self.data_process_sw.shape[0]
):
continue

sharpness = (
(
self.data_process_sw[trough_idx]
- self.data_process_sw[
trough_idx - int(5 * (1000 / self.sfreq))
]
)
+ (
self.data_process_sw[trough_idx]
- self.data_process_sw[
trough_idx + int(5 * (1000 / self.sfreq))
]
)
) / 2

self.sharpness.append(sharpness)

if self.sw_settings["sharpwave_features"]["rise_steepness"] is True:
# steepness is calculated as the first derivative
# from peak/trough to trough/peak
# here + 1 due to python syntax, s.t. the last element is included
rise_steepness = np.max(
np.diff(
self.data_process_sw[peak_idx_left : trough_idx + 1]
)
)
self.rise_steepness.append(rise_steepness)

if (
self.sw_settings["sharpwave_features"]["decay_steepness"]
is True
):
decay_steepness = np.max(
np.diff(
self.data_process_sw[trough_idx : peak_idx_right + 1]
)
)
self.decay_steepness.append(decay_steepness)

if (
self.sw_settings["sharpwave_features"]["rise_steepness"] is True
and self.sw_settings["sharpwave_features"]["decay_steepness"]
is True
and self.sw_settings["sharpwave_features"]["slope_ratio"]
is True
):
self.slope_ratio.append(rise_steepness - decay_steepness)

if self.sw_settings["sharpwave_features"]["prominence"] is True:
self.prominence.append(
np.abs(
(peak_right + peak_left) / 2
- self.data_process_sw[trough_idx]
)
)

if self.sw_settings["sharpwave_features"]["decay_time"] is True:
self.decay_time.append(
(peak_idx_left - trough_idx) * (1000 / self.sfreq)
) # ms

if self.sw_settings["sharpwave_features"]["rise_time"] is True:
self.rise_time.append(
(peak_idx_right - trough_idx) * (1000 / self.sfreq)
) # ms

if self.sw_settings["sharpwave_features"]["width"] is True:
self.width.append(peak_idx_right - peak_idx_left) # ms
if peak_pointer == peaks.size:
# If we went past the end of the peaks list, trough had no peak to its right
continue

last_valid = i
peak_idx_left.append(peaks[peak_pointer - 1])
peak_idx_right.append(peaks[peak_pointer])

troughs = troughs[first_valid:last_valid + 1] # Remove non valid troughs

peak_idx_left = np.array(peak_idx_left, dtype=np.integer)
peak_idx_right = np.array(peak_idx_right, dtype=np.integer)

peak_left = self.data_process_sw[peak_idx_left]
peak_right = self.data_process_sw[peak_idx_right]
trough_values = self.data_process_sw[troughs]

# No need to store trough data as it is not used anywhere else in the program
# self.trough.append(trough)
# self.troughs_idx.append(trough_idx)

""" Calculate features (vectorized) """

if self.sw_settings["sharpwave_features"]["interval"]:
self.interval = np.concatenate(([0], np.diff(troughs))) * (1000 / self.sfreq)

if self.sw_settings["sharpwave_features"]["peak_left"]:
self.peak_left = peak_left

if self.sw_settings["sharpwave_features"]["peak_right"]:
self.peak_right = peak_right

if self.sw_settings["sharpwave_features"]["sharpness"]:
# sharpess is calculated on a +- 5 ms window
# valid troughs need 5 ms of margin on both siddes
troughs_valid = troughs[np.logical_and(
troughs - int(5 * (1000 / self.sfreq)) > 0,
troughs + int(5 * (1000 / self.sfreq)) < self.data_process_sw.shape[0])]

self.sharpness = (
(self.data_process_sw[troughs_valid] - self.data_process_sw[troughs_valid - int(5 * (1000 / self.sfreq))]) +
(self.data_process_sw[troughs_valid] - self.data_process_sw[troughs_valid + int(5 * (1000 / self.sfreq))])
) / 2

if (self.sw_settings["sharpwave_features"]["rise_steepness"] or
self.sw_settings["sharpwave_features"]["decay_steepness"]):

# steepness is calculated as the first derivative
steepness = np.concatenate(([0],np.diff(self.data_process_sw)))

if self.sw_settings["sharpwave_features"]["rise_steepness"]: # left peak -> trough
# + 1 due to python syntax, s.t. the last element is included
self.rise_steepness = np.array([
np.max(np.abs(steepness[peak_idx_left[i] : troughs[i] + 1]))
for i in range(trough_idx.size)
])

if self.sw_settings["sharpwave_features"]["decay_steepness"]: # trough -> right peak
self.decay_steepness = np.array([
np.max(np.abs(steepness[troughs[i] : peak_idx_right[i] + 1]))
for i in range(trough_idx.size)
])

if (self.sw_settings["sharpwave_features"]["rise_steepness"] and
self.sw_settings["sharpwave_features"]["decay_steepness"] and
self.sw_settings["sharpwave_features"]["slope_ratio"]):
self.slope_ratio = self.rise_steepness - self.decay_steepness

if self.sw_settings["sharpwave_features"]["prominence"]:
self.prominence = np.abs((peak_right + peak_left) / 2 - trough_values)

if self.sw_settings["sharpwave_features"]["decay_time"]:
self.decay_time = (peak_idx_left - troughs) * (1000 / self.sfreq) # ms

if self.sw_settings["sharpwave_features"]["rise_time"]:
self.rise_time = (peak_idx_right - troughs) * (1000 / self.sfreq) # ms

if self.sw_settings["sharpwave_features"]["width"]:
self.width = peak_idx_right - peak_idx_left # ms

@staticmethod
def test_settings(
s: dict,
Expand Down

0 comments on commit 704ca15

Please sign in to comment.