Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fast ZETA test #22

Open
wants to merge 78 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
f49db9f
STY: use njit instead of jit(nopython=True)
mmagnuski May 30, 2024
a06f321
ENH: add numba _monotonic_find_first and _select_spikes
mmagnuski May 30, 2024
6453fb1
FIX: compat with new ttest_ind_no_p in borsar (for some reason does n…
mmagnuski May 30, 2024
6ce43f1
ENH: some new numba functions + tests
mmagnuski May 30, 2024
8126c9c
zeta test, part 1
mmagnuski Jun 1, 2024
06a5989
zeta test, part 2, multi-unit
mmagnuski Jun 1, 2024
ceda829
small changes
mmagnuski Jun 3, 2024
f89465d
DOC: improve gumbel docstring
mmagnuski Jun 3, 2024
c201e27
separate compute_pvalues function
mmagnuski Jun 3, 2024
6f15277
towards unified API for numpy and numba versions
mmagnuski Jun 3, 2024
6f8ef1c
updates
mmagnuski Jun 5, 2024
953ac81
ENH: n-condition ZETA for numba
mmagnuski Jun 5, 2024
f76817e
ENH: make numpy backend working for 2 and n conditions
mmagnuski Jun 5, 2024
c7d8e34
ENH: switch to object arrays of trial-level spike arrays in numpy
mmagnuski Jun 5, 2024
613e631
ENH: add subsample, return reference_times in the dict
mmagnuski Jun 6, 2024
588db85
DOC: add docstring to ZETA function
mmagnuski Jun 6, 2024
c9787fe
small changes
mmagnuski Jun 6, 2024
cb8e774
STY: rename functions
mmagnuski Jun 29, 2024
ace9049
DOC: add some docstrings
mmagnuski Jun 29, 2024
ed8156c
do not subsample by default
mmagnuski Jun 29, 2024
0cd4993
STY: remove unused functions, improve docstring
mmagnuski Jul 1, 2024
8a733ab
ENH: function to read the zeta example data
mmagnuski Jul 1, 2024
8bd8e97
ENH: cache numba functions, improve function and variable names
mmagnuski Jul 1, 2024
c9fc74e
STY: whitespace
mmagnuski Jul 1, 2024
75f90cf
FIX TST: fix so that searched 9 is always in the array
mmagnuski Aug 9, 2024
7eb94ef
FIX: better output of _monotonic_find_first when value is not found
mmagnuski Aug 9, 2024
723867b
small cleanup
mmagnuski Aug 9, 2024
16af2ac
FIX, ENH: use var in numba implementation for N-condition cases too
mmagnuski Aug 9, 2024
17e3d81
TST: add test for ZETA
mmagnuski Aug 9, 2024
67f67a9
DEV: move pylabianca._zeta.ZETA to pylabianca.selectivity.zeta_test
mmagnuski Aug 13, 2024
f40d86a
TST: adapt tests to the move
mmagnuski Aug 13, 2024
589d369
TST, DEV: reorganize tests
mmagnuski Aug 13, 2024
2c75f8b
DOC: update whats_new.md
mmagnuski Aug 13, 2024
84bd4b7
FIX empirical pvalues calculation
mmagnuski Sep 2, 2024
b169895
STY: use njit instead of jit(nopython=True)
mmagnuski May 30, 2024
6d26c61
ENH: add numba _monotonic_find_first and _select_spikes
mmagnuski May 30, 2024
fd16a4a
FIX: compat with new ttest_ind_no_p in borsar (for some reason does n…
mmagnuski May 30, 2024
cc4933d
ENH: some new numba functions + tests
mmagnuski May 30, 2024
9045fff
zeta test, part 1
mmagnuski Jun 1, 2024
68a3046
zeta test, part 2, multi-unit
mmagnuski Jun 1, 2024
2591881
small changes
mmagnuski Jun 3, 2024
f62c2d7
DOC: improve gumbel docstring
mmagnuski Jun 3, 2024
68dbf9c
separate compute_pvalues function
mmagnuski Jun 3, 2024
9e15188
towards unified API for numpy and numba versions
mmagnuski Jun 3, 2024
5258bb4
updates
mmagnuski Jun 5, 2024
99fe9f2
ENH: n-condition ZETA for numba
mmagnuski Jun 5, 2024
37b4461
ENH: make numpy backend working for 2 and n conditions
mmagnuski Jun 5, 2024
1f5fb73
ENH: switch to object arrays of trial-level spike arrays in numpy
mmagnuski Jun 5, 2024
94994bd
ENH: add subsample, return reference_times in the dict
mmagnuski Jun 6, 2024
cf1403d
DOC: add docstring to ZETA function
mmagnuski Jun 6, 2024
e16895e
small changes
mmagnuski Jun 6, 2024
ed37485
STY: rename functions
mmagnuski Jun 29, 2024
3e86011
DOC: add some docstrings
mmagnuski Jun 29, 2024
cf09d56
do not subsample by default
mmagnuski Jun 29, 2024
bd1b139
STY: remove unused functions, improve docstring
mmagnuski Jul 1, 2024
fe82bef
ENH: function to read the zeta example data
mmagnuski Jul 1, 2024
8f3ffd6
ENH: cache numba functions, improve function and variable names
mmagnuski Jul 1, 2024
e460fad
STY: whitespace
mmagnuski Jul 1, 2024
6d34cb2
FIX TST: fix so that searched 9 is always in the array
mmagnuski Aug 9, 2024
254bc89
FIX: better output of _monotonic_find_first when value is not found
mmagnuski Aug 9, 2024
e8eade4
small cleanup
mmagnuski Aug 9, 2024
0c722d4
FIX, ENH: use var in numba implementation for N-condition cases too
mmagnuski Aug 9, 2024
82367a3
TST: add test for ZETA
mmagnuski Aug 9, 2024
00a5de4
DEV: move pylabianca._zeta.ZETA to pylabianca.selectivity.zeta_test
mmagnuski Aug 13, 2024
710df78
TST: adapt tests to the move
mmagnuski Aug 13, 2024
b4b5c27
TST, DEV: reorganize tests
mmagnuski Aug 13, 2024
625d89e
DOC: update whats_new.md
mmagnuski Aug 13, 2024
0e92941
FIX empirical pvalues calculation
mmagnuski Sep 2, 2024
0daff30
ENH: make sure each cell has the same permutation structure
mmagnuski Sep 3, 2024
875a0f7
TST: fix tests
mmagnuski Sep 3, 2024
774cf39
FIX: event_id can be a single value
mmagnuski Sep 3, 2024
be6b7ff
TST: skip numba testing if numba not available
mmagnuski Sep 3, 2024
beaca35
TST: attempt to fix tests
mmagnuski Sep 3, 2024
297519d
FIX: make sure randomly generated cluster ids are unique
mmagnuski Sep 3, 2024
ff62ec8
TST: fix tests
mmagnuski Sep 3, 2024
c7fcd05
FIX merge conflict
mmagnuski Sep 4, 2024
6de95af
specify np.int64 to avoid errors
mmagnuski Sep 4, 2024
5417cc1
option to permute each cell independently
mmagnuski Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
mamba install pytest -y;
hash -r pytest
pip install pytest-cov
pip install zetapy
- run:
name: Run tests
command: |
Expand Down
185 changes: 174 additions & 11 deletions pylabianca/_numba.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
from numba import jit
from numba import jit, njit
from numba.extending import overload


@jit(nopython=True)
@njit
def _compute_spike_rate_numba(spike_times, spike_trials, time_limits,
n_trials, winlen=0.25, step=0.05):
half_win = winlen / 2
Expand All @@ -26,7 +26,7 @@ def _compute_spike_rate_numba(spike_times, spike_trials, time_limits,
return times, frate


@jit(nopython=True)
@njit
def _monotonic_unique_counts(values):
n_val = len(values)
if n_val == 0:
Expand Down Expand Up @@ -56,6 +56,7 @@ def _monotonic_unique_counts(values):
return uni, cnt



def numba_compare_times(spk, cell_idx1, cell_idx2, spk2=None):
times1 = (spk.timestamps[cell_idx1] / spk.sfreq).astype('float64')

Expand All @@ -67,7 +68,7 @@ def numba_compare_times(spk, cell_idx1, cell_idx2, spk2=None):
return res


@jit(nopython=True)
@njit
def _numba_compare_times(times1, times2, distances):
n_times1 = times1.shape[0]
n_times2 = times2.shape[0]
Expand All @@ -87,7 +88,7 @@ def _numba_compare_times(times1, times2, distances):
return distances


@jit(nopython=True)
@njit
def _xcorr_hist_auto_numba(times, bins, batch_size=1_000):
'''Compute auto-correlation histogram for a single cell.

Expand Down Expand Up @@ -125,7 +126,7 @@ def _xcorr_hist_auto_numba(times, bins, batch_size=1_000):
return counts


@jit(nopython=True)
@njit
def _xcorr_hist_cross_numba(times, times2, bins, batch_size=1_000):
'''Compute cross-correlation histogram for a single cell.

Expand Down Expand Up @@ -179,7 +180,7 @@ def _xcorr_hist_cross_numba(times, times2, bins, batch_size=1_000):
return counts


@jit(nopython=True)
@njit
def compute_bin(x, bin_edges):
'''Copied from https://numba.pydata.org/numba-examples/examples/density_estimation/histogram/results.html'''
# assuming uniform bins for now
Expand All @@ -199,7 +200,7 @@ def compute_bin(x, bin_edges):
return bin


@jit(nopython=True)
@njit
def numba_histogram(a, bin_edges):
'''Copied from https://numba.pydata.org/numba-examples/examples/density_estimation/histogram/results.html'''
n_bins = len(bin_edges) - 1
Expand All @@ -214,7 +215,7 @@ def numba_histogram(a, bin_edges):
return hist, bin_edges


@jit(nopython=True, cache=True)
@njit(cache=True)
def _epoch_spikes_numba(timestamps, event_times, tmin, tmax):
trial_idx = [-1]
n_in_trial = [0]
Expand Down Expand Up @@ -274,7 +275,7 @@ def _epoch_spikes_numba(timestamps, event_times, tmin, tmax):
return trial, time


@jit(nopython=True)
@njit
def create_trials_from_short(trial_idx, n_in_trial):
n_all = sum(n_in_trial)
trial = np.empty(n_all, dtype=np.int16)
Expand All @@ -287,7 +288,7 @@ def create_trials_from_short(trial_idx, n_in_trial):
return trial


@jit(nopython=True)
@njit
def concat_times(times, n_in_trial):
n_all = sum(n_in_trial)
time = np.empty(n_all, dtype=np.float64)
Expand All @@ -298,3 +299,165 @@ def concat_times(times, n_in_trial):
idx = idx_end

return time


@njit
def _select_spikes_numba(spikes, trials, tri_sel):
'''Assumes both trials and tri_sel are sorted.'''
tri_sel_idx = 0
current_tri = tri_sel[tri_sel_idx]
msk = np.zeros(len(trials), dtype='bool')
for idx, tri in enumerate(trials):
if tri < current_tri:
continue

if tri == current_tri:
msk[idx] = True
elif tri > current_tri:
too_low = True
while too_low:
tri_sel_idx += 1
current_tri = tri_sel[tri_sel_idx]
too_low = tri > current_tri
if tri == current_tri:
msk[idx] = True

return spikes[msk]


# TODO: could return error if not found (or be wrapped to return error)
# (or [x] at least return out-of-bounds index)
@njit
def _monotonic_find_first(values, find_val):
n_val = values.shape[0]
for idx in range(n_val):
if values[idx] == find_val:
return idx
return n_val


def _get_trial_boundaries(spk, cell_idx):
return _get_trial_boundaries_numba(spk.trial[cell_idx], spk.n_trials)


@njit
def _get_trial_boundaries_numba(trials, n_trials):
'''
Numba implementation of get_trial_boundaries.

Parameters
----------
trials : np.ndarray
Trial indices for each spike.
n_trials : int
Number of trials (actual number of trials, not the number of trials
that spikes of given cell appear in).

Returns
-------
trial_boundaries : np.ndarray
Spike indices where trials start.
trial_ids : np.ndarray
Trial indices (useful in case spikes did not appear in some of the
trials for given cell).
'''
n_spikes = trials.shape[0]
trial_boundaries = np.zeros(n_trials + 1, dtype='int32')
trial_ids = np.zeros(n_trials + 1, dtype='int32')
idx = -1
boundaries_idx = -1
prev_trial = -1
while idx < (n_spikes - 1):
idx += 1
this_trial = trials[idx]
if this_trial > prev_trial:
boundaries_idx += 1
trial_ids[boundaries_idx] = this_trial
trial_boundaries[boundaries_idx] = idx
prev_trial = this_trial

boundaries_idx += 1
trial_ids = trial_ids[:boundaries_idx]
boundaries_idx += 1
trial_boundaries = trial_boundaries[:boundaries_idx]
trial_boundaries[-1] = n_spikes

return trial_boundaries, trial_ids


@njit
def get_condition_indices_and_unique_numba(cnd_values):
n_trials = cnd_values.shape[0]
uni_cnd = np.unique(cnd_values)
n_cnd = uni_cnd.shape[0]
cnd_idx_per_tri = np.zeros(n_trials, dtype='int32')
n_trials_per_cond = np.zeros(n_cnd, dtype='int32')

for idx in range(n_trials):
cnd_val = cnd_values[idx]
cnd_idx = _monotonic_find_first(uni_cnd, cnd_val)
cnd_idx_per_tri[idx] = cnd_idx
n_trials_per_cond[cnd_idx] += 1

return cnd_idx_per_tri, n_trials_per_cond, uni_cnd, n_cnd


@njit
def depth_of_selectivity_numba(arr, groupby):
avg_by_cond = groupby_mean(arr, groupby)
n_categories = avg_by_cond.shape[0]
selectivity = depth_of_selectivity_numba_low_level(
avg_by_cond, n_categories
)

return selectivity, avg_by_cond


@njit
def depth_of_selectivity_numba_low_level(avg_by_cond, n_categories):
r_max = max_2d_axis_0(avg_by_cond)
numerator = n_categories - (avg_by_cond / r_max).sum(axis=0)
return numerator / (n_categories - 1)


@njit
def groupby_mean(arr, groupby):
cnd_idx_per_tri, n_trials_per_cond, _, n_cnd = (
get_condition_indices_and_unique_numba(groupby)
)
avg_by_cnd = _groupby_mean_low_level(
arr, cnd_idx_per_tri, n_trials_per_cond, n_cnd)
return avg_by_cnd


@njit
def max_2d_axis_0(arr):
out = np.zeros(arr.shape[1], dtype=arr.dtype)
for idx in range(arr.shape[1]):
out[idx] = arr[:, idx].max()
return out


@njit
def var_2d_axis_0(arr):
out = np.zeros(arr.shape[1], dtype=arr.dtype)
for idx in range(arr.shape[1]):
this_arr = arr[:, idx]
avg = this_arr.mean()
out[idx] = ((this_arr - avg) ** 2).sum() / (this_arr.shape[0] - 1)
return out


@njit
def _groupby_mean_low_level(arr, cnd_idx_per_tri, n_trials_per_cond, n_cnd):
n_trials = arr.shape[0]
nd2 = arr.shape[1]
avg_by_cnd = np.zeros((n_cnd, nd2), dtype=arr.dtype)
for idx in range(n_trials):
cnd_idx = cnd_idx_per_tri[idx]
avg_by_cnd[cnd_idx] += arr[idx]

for cnd_idx in range(n_cnd):
avg_by_cnd[cnd_idx] /= n_trials_per_cond[cnd_idx]

return avg_by_cnd
Loading