Skip to content

Commit

Permalink
fillneutral as option
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanharvey1 committed Jan 8, 2025
1 parent 4e7e6d8 commit 58ec8db
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions neuro_py/ensemble/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,12 @@ class PairwiseBias(object):
Fit the model with task data and transform the post-task data.
"""

def __init__(self, num_shuffles: int = 300, n_jobs: int = 10):
def __init__(
self, num_shuffles: int = 300, n_jobs: int = 10, fillneutral: float = np.nan
):
self.num_shuffles = num_shuffles
self.n_jobs = n_jobs
self.fillneutral = fillneutral
self.total_neurons = None
self.task_normalized = None
self.observed_correlation_ = None
Expand Down Expand Up @@ -564,7 +567,10 @@ def observed_and_shuffled_correlation(
filtered_neurons = post_neurons[start_idx:end_idx]

post_bias_matrix = self.bias_matrix(
filtered_spikes, filtered_neurons, self.total_neurons
filtered_spikes,
filtered_neurons,
self.total_neurons,
fillneutral=self.fillneutral,
)

post_normalized = self.normalize_bias_matrix(post_bias_matrix)
Expand All @@ -577,7 +583,10 @@ def observed_and_shuffled_correlation(
for _ in range(self.num_shuffles):
shuffled_neurons = np.random.permutation(filtered_neurons)
shuffled_bias_matrix = self.bias_matrix(
filtered_spikes, shuffled_neurons, self.total_neurons
filtered_spikes,
shuffled_neurons,
self.total_neurons,
fillneutral=self.fillneutral,
)
shuffled_normalized = self.normalize_bias_matrix(shuffled_bias_matrix)
shuffled_correlation.append(
Expand Down Expand Up @@ -620,7 +629,10 @@ def fit(
if task_intervals is None:
# Compute bias matrix for task data and normalize
task_bias_matrix = self.bias_matrix(
task_spikes, task_neurons, self.total_neurons
task_spikes,
task_neurons,
self.total_neurons,
fillneutral=self.fillneutral,
)
self.task_normalized = self.normalize_bias_matrix(task_bias_matrix)
else:
Expand All @@ -638,7 +650,10 @@ def fit(

# Compute the bias matrix for the interval
bias_matrix = self.bias_matrix(
interval_spikes, interval_neurons, self.total_neurons
interval_spikes,
interval_neurons,
self.total_neurons,
fillneutral=self.fillneutral,
)
bias_matrix = self.normalize_bias_matrix(bias_matrix)
task_normalized_matrices.append(bias_matrix)
Expand Down

0 comments on commit 58ec8db

Please sign in to comment.