Skip to content

Commit

Permalink
fixed time_perm cluster optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaronearlerichardson committed Dec 1, 2023
1 parent 775b516 commit 34d4a10
Showing 1 changed file with 61 additions and 39 deletions.
100 changes: 61 additions & 39 deletions ieeg/calc/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

from ieeg import Doubles
from ieeg.calc.reshape import make_data_same
from ieeg.process import get_mem
import os
from scipy import stats as st
from numba import njit, vectorize, prange
from numba import njit, guvectorize, float64


def weighted_avg_and_std(values, weights, axis=0):
Expand Down Expand Up @@ -440,7 +438,7 @@ def time_perm_cluster(sig1: np.ndarray, sig2: np.ndarray, p_thresh: float,
>>> sig1 = np.array([[0,1,2,3,3,3,3,3,3,3,3,3,2,1,0]
... for _ in range(50)]) - rng.random((50, 15)) * 4
>>> sig2 = np.array([[0] * 15 for _ in range(100)]) + rng.random((100, 15))
>>> time_perm_cluster(sig1, sig2, 0.05, n_perm=100000)
>>> time_perm_cluster(sig1, sig2, 0.05, n_perm=10000)
array([False, False, False, True, True, True, True, True, True,
True, True, True, False, False, False])
>>> time_perm_cluster(sig1, sig2, 0.01, n_perm=10000)
Expand Down Expand Up @@ -484,30 +482,14 @@ def time_perm_cluster(sig1: np.ndarray, sig2: np.ndarray, p_thresh: float,
stat_func)

# Calculate the p value of the permutation distribution
# logger.info('Calculating permutation distribution')
#
# for i in range(diff.shape[0]):
# # p_perm is the probability of observing a difference as large as the
# # other permutations, or larger, by chance
#
# larger = tail_compare(diff[i], diff[np.arange(len(diff)) != i],tails)
# p_perm[i] = np.mean(larger, axis=0)

# The line below accomplishes the same as above twice as fast, but could
# run into memory errors if n_perm is greater than 1000
# p_perm = np.zeros(diff.shape, dtype=np.float32)
# chunksize = int(get_mem() / (np.prod(diff.shape) * diff.dtype.itemsize))
# for chunk in range(0, diff.shape[0], chunksize):
# temp = tail_compare(diff[chunk:chunk + chunksize], diff[:, np.newaxis], tails)
# p_perm[chunk:chunk + chunksize] = np.mean(temp[idx[:, chunk:chunk + chunksize]], axis=0)
# mmapped_diff = np.memmap('temp.npy', dtype='bool', mode='w+',
# shape=(diff.shape[0],) + diff.shape)
mmapped_diff = np.zeros((diff.shape[0],) + diff.shape, dtype=bool)
tail_compare(diff, diff[:, np.newaxis], tails, mmapped_diff)
p_perm = np.sum(mmapped_diff, axis=0) / (diff.shape[0] - 1)
del mmapped_diff
# p_perm = np.mean(tail_compare(diff, diff[:, np.newaxis], tails), axis=axis+1)
# p_perm = _calculate_p_perm(diff, tails)
if tails == 1:
p_perm = _perm_gt(diff)
elif tails == 2:
p_perm = _perm_gt(np.abs(diff))
elif tails == -1:
p_perm = _perm_lt(diff)
else:
raise ValueError('tails must be 1, 2, or -1')

# Create binary clusters using the p value threshold
b_act = tail_compare(1 - p_act, 1 - p_thresh, tails)
Expand All @@ -527,6 +509,28 @@ def time_perm_cluster(sig1: np.ndarray, sig2: np.ndarray, p_thresh: float,
return clusters


@guvectorize([(float64[:], float64[:])], '(n)->(n)', nopython=True)
def _perm_gt(diff, result):
n = diff.shape[0]
denom = n - 1
for i in range(n):
for j in range(n):
if i != j and diff[i] > diff[j]:
result[i] += 1
result[i] /= denom


@guvectorize([(float64[:], float64[:])], '(n)->(n)', nopython=True)
def _perm_lt(diff, result):
n = diff.shape[0]
denom = n - 1
for i in range(n):
for j in range(n):
if i != j and diff[i] < diff[j]:
result[i] += 1
result[i] /= denom


def time_cluster(act: np.ndarray, perm: np.ndarray, p_val: float = None,
tails: int = 1) -> np.ndarray:
"""Cluster correction for time series data.
Expand Down Expand Up @@ -593,8 +597,8 @@ def time_cluster(act: np.ndarray, perm: np.ndarray, p_val: float = None,


def tail_compare(diff: np.ndarray | float | int,
obs_diff: np.ndarray | float | int, tails: int = 1,
out: np.ndarray = None) -> np.ndarray | bool:
obs_diff: np.ndarray | float | int, tails: int = 1
) -> np.ndarray | bool:
"""Compare the difference between two groups to the observed difference.
This function applies the appropriate comparison based on the number of
Expand All @@ -608,8 +612,6 @@ def tail_compare(diff: np.ndarray | float | int,
The observed difference between the two groups.
tails : int, optional
The number of tails to use. 1 for one-tailed, 2 for two-tailed.
out : array, shape (..., time), optional
The array to place the output in.
Returns
-------
Expand All @@ -618,21 +620,18 @@ def tail_compare(diff: np.ndarray | float | int,
groups is larger than the observed difference.
"""

if out is None:
out = np.zeros((diff.shape[0],) + diff.shape, dtype=bool)

# Account for one or two tailed test
match tails:
case 1:
np.greater(diff, obs_diff, out=out)
temp = np.greater(diff, obs_diff)
case 2:
np.greater(np.abs(diff), np.abs(obs_diff), out=out)
temp = np.greater(np.abs(diff), np.abs(obs_diff))
case -1:
np.less(diff, obs_diff, out=out)
temp = np.less(diff, obs_diff)
case _:
raise ValueError('tails must be 1, 2, or -1')

return out
return temp


def time_perm_shuffle(sig1: np.ndarray, sig2: np.ndarray, n_perm: int = 1000,
Expand Down Expand Up @@ -806,3 +805,26 @@ def sine_f_test(window_fun: np.ndarray, x_p: np.ndarray
f_stat = num / den

return f_stat, A


if __name__ == '__main__':
import numpy as np
from timeit import timeit
rng = np.random.default_rng(seed=42)
sig1 = np.array([[0,1,2,3,3,3,3,3,3,3,3,3,2,1,0] for _ in range(50)]) - rng.random((50, 15)) * 4
sig2 = np.array([[0] * 15 for _ in range(100)]) + rng.random((100, 15))
p_act, diff = time_perm_shuffle(sig1, sig2, 10000, 1, 0, True)

# Calculate the p value of the permutation distribution and compare
# execution times

p_perm1 = _perm_gt(diff)
p_perm2 = np.sum(diff > diff[:, np.newaxis], axis=0) / (diff.shape[0] - 1)

# Time the functions
time2 = timeit('_perm_gt(diff)', globals=globals(), number=10)
time1 = timeit('np.sum(diff > diff[:, np.newaxis], axis=0) / '
'(diff.shape[0] - 1)', globals=globals(), number=10)

print(f'Time for calculate_p_perm: {time1:.6f} seconds')
print(f'Time for _calculate_p_perm: {time2:.6f} seconds')

0 comments on commit 34d4a10

Please sign in to comment.