Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ulation into main
  • Loading branch information
timonmerk committed Jan 28, 2024
2 parents 018d18b + 906f50d commit 8df33fc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 86 deletions.
17 changes: 5 additions & 12 deletions py_neuromodulation/nm_bursts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,18 @@ def get_burst_amplitude_length(
bursts = np.zeros((beta_averp_norm.shape[0] + 1), dtype=bool)
bursts[1:] = beta_averp_norm >= burst_thr
deriv = np.diff(bursts)
isburst = False
burst_length = []
burst_amplitude = []
burst_start = 0

for index, burst_state in enumerate(deriv):
if burst_state == True:
if isburst == True:
burst_length.append(index - burst_start)
burst_amplitude.append(beta_averp_norm[burst_start:index])
burst_time_points = np.where(deriv==True)[0]

isburst = False
else:
burst_start = index
isburst = True
for i in range(burst_time_points.size//2):
burst_length.append(burst_time_points[2 * i + 1] - burst_time_points[2 * i])
burst_amplitude.append(beta_averp_norm[burst_time_points[2 * i] : burst_time_points[2 * i + 1]])

# the last burst length (in case isburst == True) is omitted,
# since the true burst length cannot be estimated

burst_length = np.array(burst_length) / sfreq

return burst_amplitude, burst_length

124 changes: 50 additions & 74 deletions py_neuromodulation/nm_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from sklearn import preprocessing
import numpy as np


class NORM_METHODS(Enum):
MEAN = "mean"
MEDIAN = "median"
Expand Down Expand Up @@ -138,6 +136,17 @@ def process(self, data: np.ndarray) -> np.ndarray:

return data

"""
Functions to check for NaN's before deciding which Numpy function to call
"""
def nan_mean(data, axis):
return np.nanmean(data, axis=axis) if np.any(np.isnan(sum(data))) else np.mean(data, axis=axis)

def nan_std(data, axis):
return np.nanstd(data, axis=axis) if np.any(np.isnan(sum(data))) else np.std(data, axis=axis)

def nan_median(data, axis):
return np.nanmedian(data, axis=axis) if np.any(np.isnan(sum(data))) else np.median(data, axis=axis)

def _normalize_and_clip(
current: np.ndarray,
Expand All @@ -147,82 +156,49 @@ def _normalize_and_clip(
description: str,
) -> tuple[np.ndarray, np.ndarray]:
"""Normalize data."""
if method == NORM_METHODS.MEAN.value:
mean = np.nanmean(previous, axis=0)
current = (current - mean) / mean
elif method == NORM_METHODS.MEDIAN.value:
median = np.nanmedian(previous, axis=0)
current = (current - median) / median
elif method == NORM_METHODS.ZSCORE.value:
mean = np.nanmean(previous, axis=0)
current = (current - mean) / np.nanstd(previous, axis=0)
elif method == NORM_METHODS.ZSCORE_MEDIAN.value:
current = (current - np.nanmedian(previous, axis=0)) / np.nanstd(
previous, axis=0
)
# For the following methods we check for the shape of current
# when current is a 1D array, then it is the post-processing normalization,
# and we need to expand, and take the [0, :] component
# When current is a 2D array, then it is pre-processing normalization, and
# there's no need for expanding.
elif method == NORM_METHODS.QUANTILE.value:
if len(current.shape) == 1:
current = (
preprocessing.QuantileTransformer(n_quantiles=300)
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]
)
else:
current = (
preprocessing.QuantileTransformer(n_quantiles=300)
.fit(np.nan_to_num(previous))
.transform(current)
)
elif method == NORM_METHODS.ROBUST.value:
if len(current.shape) == 1:
current = (
preprocessing.RobustScaler()
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]
)
else:
match method:
case NORM_METHODS.MEAN.value:
mean = nan_mean(previous, axis=0)
current = (current - mean) / mean
case NORM_METHODS.MEDIAN.value:
median = nan_median(previous, axis=0)
current = (current - median) / median
case NORM_METHODS.ZSCORE.value:
current = (current - nan_mean(previous, axis=0)) / nan_std(previous, axis=0)
case NORM_METHODS.ZSCORE_MEDIAN.value:
current = (current - nan_median(previous, axis=0)) / nan_std(previous, axis=0)
# For the following methods we check for the shape of current
# when current is a 1D array, then it is the post-processing normalization,
# and we need to expand, and remove the extra dimension afterwards
# When current is a 2D array, then it is pre-processing normalization, and
# there's no need for expanding.
case (NORM_METHODS.QUANTILE.value |
NORM_METHODS.ROBUST.value |
NORM_METHODS.MINMAX.value |
NORM_METHODS.POWER.value):

norm_methods = {
NORM_METHODS.QUANTILE.value : lambda: preprocessing.QuantileTransformer(n_quantiles=300),
NORM_METHODS.ROBUST.value : preprocessing.RobustScaler,
NORM_METHODS.MINMAX.value : preprocessing.MinMaxScaler,
NORM_METHODS.POWER.value : preprocessing.PowerTransformer
}

current = (
preprocessing.RobustScaler()
norm_methods[method]()
.fit(np.nan_to_num(previous))
.transform(current)
.transform(
# if post-processing: pad dimensions to 2
np.reshape(current, (2-len(current.shape))*(1,) + current.shape)
)
.squeeze() # if post-processing: remove extra dimension
)

elif method == NORM_METHODS.MINMAX.value:
if len(current.shape) == 1:
current = (
preprocessing.MinMaxScaler()
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]
)
else:
current = (
preprocessing.MinMaxScaler()
.fit(np.nan_to_num(previous))
.transform(current)
)
elif method == NORM_METHODS.POWER.value:
if len(current.shape) == 1:
current = (
preprocessing.PowerTransformer()
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]

case _:
raise ValueError(
f"Only {[e.value for e in NORM_METHODS]} are supported as "
f"{description} normalization methods. Got {method}."
)
else:
current = (
preprocessing.PowerTransformer()
.fit(np.nan_to_num(previous))
.transform(current)
)
else:
raise ValueError(
f"Only {[e.value for e in NORM_METHODS]} are supported as "
f"{description} normalization methods. Got {method}."
)

if clip:
current = _clip(data=current, clip=clip)
Expand Down

0 comments on commit 8df33fc

Please sign in to comment.