Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Older fixes #36

Merged
merged 7 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pylabianca/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .utils import (
_get_trial_boundaries, _deal_with_picks, find_index, parse_sub_ses)
from .utils.xarr import (
xr_find_nested_dims, _inherit_metadata, assign_session_coord)
find_nested_dims, _inherit_metadata, assign_session_coord)
from .utils.validate import _validate_xarray_for_aggregation


Expand Down Expand Up @@ -752,7 +752,7 @@ def xarray_to_dict(xarr, ses_name='sub', reduce_coords=True,
new_coords = dict()
drop_coords = list()
if 'cell' in arr.coords and 'trial' in arr.coords:
nested_coords = xr_find_nested_dims(arr, ('cell', 'trial'))
nested_coords = find_nested_dims(arr, ('cell', 'trial'))
else:
nested_coords = list()

Expand Down
17 changes: 12 additions & 5 deletions pylabianca/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def _do_permute(arr, decoding_fun, target_coord, average_folds=True,

def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None,
time=None, arguments=dict(), n_resamples=20, n_jobs=1,
permute=False, select_trials=None, decim=None):
permute=False, select_trials=None, decim=None,
resampling_fun=None):
"""Resample a decoding analysis.

The resampling is done by rearranging trials within each subject,
Expand Down Expand Up @@ -535,15 +536,16 @@ def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None,
if n_jobs == 1 or n_resamples == 1:
score_resamples = [
_do_resample(Xs, ys, decoding_fun, arguments,
time=time)
time=time, resampling_fun=resampling_fun)
for resample_idx in range(n_resamples)
]
else:
from joblib import Parallel, delayed

score_resamples = Parallel(n_jobs=n_jobs)(
delayed(_do_resample)(
Xs, ys, decoding_fun, arguments, time=time)
Xs, ys, decoding_fun, arguments, time=time,
resampling_fun=resampling_fun)
for resample_idx in range(n_resamples)
)

Expand All @@ -557,8 +559,13 @@ def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None,
return score_resamples


def _do_resample(Xs, ys, decoding_fun, arguments, time=None):
X, y = join_subjects(Xs, ys)
def _do_resample(Xs, ys, decoding_fun, arguments, time=None,
resampling_fun=None):
'''Resample a decoding analysis.'''
if resampling_fun is None:
X, y = join_subjects(Xs, ys)
else:
X, y = resampling_fun(Xs, ys)

# do the actual decoding
return decoding_fun(X, y, time=time, **arguments)
Expand Down
4 changes: 2 additions & 2 deletions pylabianca/selectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from .analysis import nested_groupby_apply
from .utils import (xr_find_nested_dims, cellinfo_from_xarray,
from .utils import (find_nested_dims, cellinfo_from_xarray,
_inherit_metadata_from_xarray, assign_session_coord)


Expand Down Expand Up @@ -211,7 +211,7 @@ def compute_selectivity_continuous(frate, compare='image', n_perm=500,

# add cell coords
# TODO: move after Dataset creation
copy_coords = xr_find_nested_dims(frate, 'cell')
copy_coords = find_nested_dims(frate, 'cell')
if len(copy_coords) > 0:
for key in results.keys():
results[key] = _inherit_metadata_from_xarray(
Expand Down
20 changes: 19 additions & 1 deletion pylabianca/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_turn_spike_rate_to_xarray():
assert (xr.trial.values == np.arange(10)).all()

# make sure that metadata is inherited
trial_dims = pln.utils.xr_find_nested_dims(xr, 'trial')
trial_dims = pln.utils.find_nested_dims(xr, 'trial')
assert len(trial_dims) == 2
assert 'a' in trial_dims
assert 'b' in trial_dims
Expand All @@ -246,3 +246,21 @@ def test_turn_spike_rate_to_xarray():
assert xr.dims == ('cell', 'time')
assert (xr.cell.values == spk.cell_names).all()
assert (xr.time.values == times).all()


def test_find_nested_dims():
import xarray as xr
from pylabianca.testing import gen_random_xarr

n_cells, n_trials, n_times = 5, 24, 100
tri_coord = np.random.choice(list('abcd'), size=n_trials)
xarr = (
gen_random_xarr(n_cells, n_trials, n_times)
.drop_vars('trial')
.assign_coords({'cond': ('trial', tri_coord)})
)

sub_dims = pln.utils.xarr.find_nested_dims(xarr, 'trial')
assert isinstance(sub_dims, list)
assert len(sub_dims) == 1
assert 'cond' in sub_dims
2 changes: 1 addition & 1 deletion pylabianca/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
_validate_spike_epochs_input)
from .xarr import (
_turn_spike_rate_to_xarray, _inherit_metadata, assign_session_coord,
_inherit_metadata_from_xarray, xr_find_nested_dims, cellinfo_from_xarray)
_inherit_metadata_from_xarray, find_nested_dims, cellinfo_from_xarray)
from ._compat import (xarray_to_dict, dict_to_xarray, spike_centered_windows,
shuffle_trials, read_drop_info)
4 changes: 2 additions & 2 deletions pylabianca/utils/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def _validate_cellinfo(spk, cellinfo):

def _validate_xarray_for_aggregation(arr, groupby, per_cell):
if groupby is not None:
from .xarr import xr_find_nested_dims
nested = xr_find_nested_dims(arr, ('cell', 'trial'))
from .xarr import find_nested_dims
nested = find_nested_dims(arr, ('cell', 'trial'))
if groupby in nested and per_cell is False:
raise ValueError(
'When using `per_cell=False`, the groupby coordinate cannot be'
Expand Down
12 changes: 7 additions & 5 deletions pylabianca/utils/xarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def df_from_xarray_coords(xarr, dim):
None is returned.
'''
import pandas as pd
use_dims = xr_find_nested_dims(xarr, dim)
use_dims = find_nested_dims(xarr, dim)

if len(use_dims) > 1:
df = {dim: xarr.coords[dim].values for dim in use_dims}
Expand Down Expand Up @@ -169,24 +169,26 @@ def _inherit_metadata(coords, metadata, dimname, tri=None):
def _inherit_metadata_from_xarray(xarr_from, xarr_to, dimname,
copy_coords=None):
if copy_coords is None:
copy_coords = xr_find_nested_dims(xarr_from, dimname)
copy_coords = find_nested_dims(xarr_from, dimname)
if len(copy_coords) > 0:
coords = {coord: (dimname, xarr_from.coords[coord].values)
for coord in copy_coords}
xarr_to = xarr_to.assign_coords(coords)
return xarr_to


def xr_find_nested_dims(arr, dim_name):
def find_nested_dims(arr, dim_name):
names = list()
coords = list(arr.coords)

if isinstance(dim_name, tuple):
for dim in dim_name:
coords.remove(dim)
if dim in coords:
coords.remove(dim)
sub_dim = dim_name
else:
coords.remove(dim_name)
if dim_name in coords:
coords.remove(dim_name)
sub_dim = (dim_name,)

for coord in coords:
Expand Down
41 changes: 30 additions & 11 deletions pylabianca/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def plot_shaded(arr, reduce_dim=None, groupby=None, ax=None,
is ``None`` which uses the default matplotlib color cycle.
labels : bool
Whether to add labels to the axes.
kwargs : dict
Additional keyword arguments for the plot.

Returns
-------
Expand Down Expand Up @@ -130,6 +132,10 @@ def plot_xarray_shaded(arr, reduce_dim=None, x_dim='time', groupby=None,
ax=None, legend=True, legend_pos=None, colors=None,
**kwargs):
"""
Plot xarray with error bar shade.

Parameters
----------
arr : xarray.DataArray
Xarray with at least two dimensions: one is plotted along the x axis
(this is controlled with ``x_dim`` argument); the other is reduced
Expand All @@ -156,6 +162,13 @@ def plot_xarray_shaded(arr, reduce_dim=None, x_dim='time', groupby=None,
List of RGB arrays to use as colors for condition groups. Can also be
a dictionary linking condition names / values and RBG arrays. Default
is ``None`` which uses the default matplotlib color cycle.
kwargs : dict
Additional keyword arguments for the plot.

Returns
-------
ax : matplotlib.Axes
Axis with the plot.
"""
import matplotlib.pyplot as plt
assert reduce_dim is not None
Expand Down Expand Up @@ -595,9 +608,9 @@ def plot_raster(spk, pick=0, groupby=None, ax=None, colors=None, labels=True,
return ax


def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None,
pvals=None, pick=0, p_threshold=0.05, min_pval=0.001, ax=None,
eventplot_kwargs=None):
def plot_spikes(spk, frate, groupby=None, colors=None, df_clst=None,
clusters=None, pvals=None, pick=0, p_threshold=0.05,
min_pval=0.001, ax=None, eventplot_kwargs=None, fontsize=12):
'''Plot average spike rate and spike raster.

spk : pylabianca.spikes.SpikeEpochs
Expand All @@ -607,6 +620,8 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None,
``spk.spike_density()``.
groupby : str | None
How to group the plots. If None, no grouping is done.
colors : list of str | None
List of colors to use for each group. If None uses the default
df_clst : pandas.DataFrame | None
DataFrame with cluster time ranges and p values. If None, no cluster
information is shown. This argument is to support results obtained
Expand All @@ -632,6 +647,8 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None,
is used for raster plot. If None, a new figure is created.
eventplot_kwargs : dict | None
Additional keyword arguments for the eventplot.
fontsize : int
Font size for the labels.

Returns
-------
Expand All @@ -655,7 +672,7 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None,
else:
assert(len(ax) == 2)
fig = ax[0].figure
plot_shaded(this_frate, groupby=groupby, ax=ax[0])
plot_shaded(this_frate, groupby=groupby, colors=colors, ax=ax[0])

# add highlight
add_highlight = (df_clst is not None) or (
Expand All @@ -669,15 +686,15 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None,
add_highlights(this_frate, clusters, pvals, ax=ax[0],
p_threshold=p_threshold, min_pval=min_pval)

plot_raster(spk.copy().pick_cells(cell_name), pick=0,
groupby=groupby, ax=ax[1], eventplot_kwargs=eventplot_kwargs)
plot_raster(spk.copy().pick_cells(cell_name), pick=0, groupby=groupby,
colors=colors, ax=ax[1], eventplot_kwargs=eventplot_kwargs)
ylim = ax[1].get_xlim()
ax[0].set_xlim(ylim)

ax[0].set_xlabel('')
ax[0].set_ylabel('Spike rate (Hz)', fontsize=12)
ax[1].set_xlabel('Time (s)', fontsize=12)
ax[1].set_ylabel('Trials', fontsize=12)
ax[0].set_ylabel('Spike rate (Hz)', fontsize=fontsize)
ax[1].set_xlabel('Time (s)', fontsize=fontsize)
ax[1].set_ylabel('Trials', fontsize=fontsize)

return fig

Expand All @@ -692,7 +709,7 @@ def _create_mask_from_window_str(window, frate):
# TODO: could infer x coords from plot (if lines are already present)
def add_highlights(arr, clusters, pvals, p_threshold=0.05, ax=None,
min_pval=0.001, bottom_extend=True, pval_text=True,
text_props=None):
pval_fontsize=10, text_props=None):
'''Highlight significant clusters along the last array dimension.

Parameters
Expand Down Expand Up @@ -729,6 +746,8 @@ def add_highlights(arr, clusters, pvals, p_threshold=0.05, ax=None,
Dictionary with text properties for p value text boxes. If None,
defaults to
``{'boxstyle': 'round', 'facecolor': 'white', 'alpha': 0.75, edgecolor='gray'}``.
pval_fontsize : int
Font size for the p value text.

Returns
-------
Expand Down Expand Up @@ -806,7 +825,7 @@ def add_highlights(arr, clusters, pvals, p_threshold=0.05, ax=None,
p_txt = format_pvalue(this_pval)

this_text = ax.text(
text_x, text_y, p_txt,
text_x, text_y, p_txt, fontsize=pval_fontsize,
bbox=text_props, horizontalalignment='center'
)
try:
Expand Down