diff --git a/pylabianca/analysis.py b/pylabianca/analysis.py index e280ce6..f77e5a7 100644 --- a/pylabianca/analysis.py +++ b/pylabianca/analysis.py @@ -3,7 +3,7 @@ from .utils import ( _get_trial_boundaries, _deal_with_picks, find_index, parse_sub_ses) from .utils.xarr import ( - xr_find_nested_dims, _inherit_metadata, assign_session_coord) + find_nested_dims, _inherit_metadata, assign_session_coord) from .utils.validate import _validate_xarray_for_aggregation @@ -752,7 +752,7 @@ def xarray_to_dict(xarr, ses_name='sub', reduce_coords=True, new_coords = dict() drop_coords = list() if 'cell' in arr.coords and 'trial' in arr.coords: - nested_coords = xr_find_nested_dims(arr, ('cell', 'trial')) + nested_coords = find_nested_dims(arr, ('cell', 'trial')) else: nested_coords = list() diff --git a/pylabianca/decoding.py b/pylabianca/decoding.py index 0e59d16..9e67bc3 100644 --- a/pylabianca/decoding.py +++ b/pylabianca/decoding.py @@ -460,7 +460,8 @@ def _do_permute(arr, decoding_fun, target_coord, average_folds=True, def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None, time=None, arguments=dict(), n_resamples=20, n_jobs=1, - permute=False, select_trials=None, decim=None): + permute=False, select_trials=None, decim=None, + resampling_fun=None): """Resample a decoding analysis. The resampling is done by rearranging trials within each subject, @@ -535,7 +536,7 @@ def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None, if n_jobs == 1 or n_resamples == 1: score_resamples = [ _do_resample(Xs, ys, decoding_fun, arguments, - time=time) + time=time, resampling_fun=resampling_fun) for resample_idx in range(n_resamples) ] else: @@ -543,7 +544,8 @@ def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None, score_resamples = Parallel(n_jobs=n_jobs)( delayed(_do_resample)( - Xs, ys, decoding_fun, arguments, time=time) + Xs, ys, decoding_fun, arguments, time=time, + resampling_fun=resampling_fun) for resample_idx in range(n_resamples) ) @@ -557,8 +559,13 @@ def resample_decoding(decoding_fun, frates=None, target=None, Xs=None, ys=None, return score_resamples -def _do_resample(Xs, ys, decoding_fun, arguments, time=None): - X, y = join_subjects(Xs, ys) +def _do_resample(Xs, ys, decoding_fun, arguments, time=None, + resampling_fun=None): + '''Resample a decoding analysis.''' + if resampling_fun is None: + X, y = join_subjects(Xs, ys) + else: + X, y = resampling_fun(Xs, ys) # do the actual decoding return decoding_fun(X, y, time=time, **arguments) diff --git a/pylabianca/selectivity.py b/pylabianca/selectivity.py index 5660a43..6905f18 100644 --- a/pylabianca/selectivity.py +++ b/pylabianca/selectivity.py @@ -4,7 +4,7 @@ import pandas as pd from .analysis import nested_groupby_apply -from .utils import (xr_find_nested_dims, cellinfo_from_xarray, +from .utils import (find_nested_dims, cellinfo_from_xarray, _inherit_metadata_from_xarray, assign_session_coord) @@ -211,7 +211,7 @@ def compute_selectivity_continuous(frate, compare='image', n_perm=500, # add cell coords # TODO: move after Dataset creation - copy_coords = xr_find_nested_dims(frate, 'cell') + copy_coords = find_nested_dims(frate, 'cell') if len(copy_coords) > 0: for key in results.keys(): results[key] = _inherit_metadata_from_xarray( diff --git a/pylabianca/test/test_utils.py b/pylabianca/test/test_utils.py index ceca937..ea69e4a 100644 --- a/pylabianca/test/test_utils.py +++ b/pylabianca/test/test_utils.py @@ -221,7 +221,7 @@ def test_turn_spike_rate_to_xarray(): assert (xr.trial.values == np.arange(10)).all() # make sure that metadata is inherited - trial_dims = pln.utils.xr_find_nested_dims(xr, 'trial') + trial_dims = pln.utils.find_nested_dims(xr, 'trial') assert len(trial_dims) == 2 assert 'a' in trial_dims assert 'b' in trial_dims @@ -246,3 +246,21 @@ def test_turn_spike_rate_to_xarray(): assert xr.dims == ('cell', 'time') assert (xr.cell.values == spk.cell_names).all() assert (xr.time.values == times).all() + + +def test_find_nested_dims(): + import xarray as xr + from pylabianca.testing import gen_random_xarr + + n_cells, n_trials, n_times = 5, 24, 100 + tri_coord = np.random.choice(list('abcd'), size=n_trials) + xarr = ( + gen_random_xarr(n_cells, n_trials, n_times) + .drop_vars('trial') + .assign_coords({'cond': ('trial', tri_coord)}) + ) + + sub_dims = pln.utils.xarr.find_nested_dims(xarr, 'trial') + assert isinstance(sub_dims, list) + assert len(sub_dims) == 1 + assert 'cond' in sub_dims diff --git a/pylabianca/utils/__init__.py b/pylabianca/utils/__init__.py index 1b53480..3f58efb 100644 --- a/pylabianca/utils/__init__.py +++ b/pylabianca/utils/__init__.py @@ -14,6 +14,6 @@ _validate_spike_epochs_input) from .xarr import ( _turn_spike_rate_to_xarray, _inherit_metadata, assign_session_coord, - _inherit_metadata_from_xarray, xr_find_nested_dims, cellinfo_from_xarray) + _inherit_metadata_from_xarray, find_nested_dims, cellinfo_from_xarray) from ._compat import (xarray_to_dict, dict_to_xarray, spike_centered_windows, shuffle_trials, read_drop_info) diff --git a/pylabianca/utils/validate.py b/pylabianca/utils/validate.py index e41f11a..b1b60f2 100644 --- a/pylabianca/utils/validate.py +++ b/pylabianca/utils/validate.py @@ -154,8 +154,8 @@ def _validate_cellinfo(spk, cellinfo): def _validate_xarray_for_aggregation(arr, groupby, per_cell): if groupby is not None: - from .xarr import xr_find_nested_dims - nested = xr_find_nested_dims(arr, ('cell', 'trial')) + from .xarr import find_nested_dims + nested = find_nested_dims(arr, ('cell', 'trial')) if groupby in nested and per_cell is False: raise ValueError( 'When using `per_cell=False`, the groupby coordinate cannot be' diff --git a/pylabianca/utils/xarr.py b/pylabianca/utils/xarr.py index a293206..ec2b8ef 100644 --- a/pylabianca/utils/xarr.py +++ b/pylabianca/utils/xarr.py @@ -125,7 +125,7 @@ def df_from_xarray_coords(xarr, dim): None is returned. ''' import pandas as pd - use_dims = xr_find_nested_dims(xarr, dim) + use_dims = find_nested_dims(xarr, dim) if len(use_dims) > 1: df = {dim: xarr.coords[dim].values for dim in use_dims} @@ -169,7 +169,7 @@ def _inherit_metadata(coords, metadata, dimname, tri=None): def _inherit_metadata_from_xarray(xarr_from, xarr_to, dimname, copy_coords=None): if copy_coords is None: - copy_coords = xr_find_nested_dims(xarr_from, dimname) + copy_coords = find_nested_dims(xarr_from, dimname) if len(copy_coords) > 0: coords = {coord: (dimname, xarr_from.coords[coord].values) for coord in copy_coords} @@ -177,16 +177,18 @@ def _inherit_metadata_from_xarray(xarr_from, xarr_to, dimname, return xarr_to -def xr_find_nested_dims(arr, dim_name): +def find_nested_dims(arr, dim_name): names = list() coords = list(arr.coords) if isinstance(dim_name, tuple): for dim in dim_name: - coords.remove(dim) + if dim in coords: + coords.remove(dim) sub_dim = dim_name else: - coords.remove(dim_name) + if dim_name in coords: + coords.remove(dim_name) sub_dim = (dim_name,) for coord in coords: diff --git a/pylabianca/viz.py b/pylabianca/viz.py index 337e524..a98ae35 100644 --- a/pylabianca/viz.py +++ b/pylabianca/viz.py @@ -42,6 +42,8 @@ def plot_shaded(arr, reduce_dim=None, groupby=None, ax=None, is ``None`` which uses the default matplotlib color cycle. labels : bool Whether to add labels to the axes. + kwargs : dict + Additional keyword arguments for the plot. Returns ------- @@ -130,6 +132,10 @@ def plot_xarray_shaded(arr, reduce_dim=None, x_dim='time', groupby=None, ax=None, legend=True, legend_pos=None, colors=None, **kwargs): """ + Plot xarray with error bar shade. + + Parameters + ---------- arr : xarray.DataArray Xarray with at least two dimensions: one is plotted along the x axis (this is controlled with ``x_dim`` argument); the other is reduced @@ -156,6 +162,13 @@ def plot_xarray_shaded(arr, reduce_dim=None, x_dim='time', groupby=None, List of RGB arrays to use as colors for condition groups. Can also be a dictionary linking condition names / values and RBG arrays. Default is ``None`` which uses the default matplotlib color cycle. + kwargs : dict + Additional keyword arguments for the plot. + + Returns + ------- + ax : matplotlib.Axes + Axis with the plot. """ import matplotlib.pyplot as plt assert reduce_dim is not None @@ -595,9 +608,9 @@ def plot_raster(spk, pick=0, groupby=None, ax=None, colors=None, labels=True, return ax -def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None, - pvals=None, pick=0, p_threshold=0.05, min_pval=0.001, ax=None, - eventplot_kwargs=None): +def plot_spikes(spk, frate, groupby=None, colors=None, df_clst=None, + clusters=None, pvals=None, pick=0, p_threshold=0.05, + min_pval=0.001, ax=None, eventplot_kwargs=None, fontsize=12): '''Plot average spike rate and spike raster. spk : pylabianca.spikes.SpikeEpochs @@ -607,6 +620,8 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None, ``spk.spike_density()``. groupby : str | None How to group the plots. If None, no grouping is done. + colors : list of str | None + List of colors to use for each group. If None uses the default df_clst : pandas.DataFrame | None DataFrame with cluster time ranges and p values. If None, no cluster information is shown. This argument is to support results obtained @@ -632,6 +647,8 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None, is used for raster plot. If None, a new figure is created. eventplot_kwargs : dict | None Additional keyword arguments for the eventplot. + fontsize : int + Font size for the labels. Returns ------- @@ -655,7 +672,7 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None, else: assert(len(ax) == 2) fig = ax[0].figure - plot_shaded(this_frate, groupby=groupby, ax=ax[0]) + plot_shaded(this_frate, groupby=groupby, colors=colors, ax=ax[0]) # add highlight add_highlight = (df_clst is not None) or ( @@ -669,15 +686,15 @@ def plot_spikes(spk, frate, groupby=None, df_clst=None, clusters=None, add_highlights(this_frate, clusters, pvals, ax=ax[0], p_threshold=p_threshold, min_pval=min_pval) - plot_raster(spk.copy().pick_cells(cell_name), pick=0, - groupby=groupby, ax=ax[1], eventplot_kwargs=eventplot_kwargs) + plot_raster(spk.copy().pick_cells(cell_name), pick=0, groupby=groupby, + colors=colors, ax=ax[1], eventplot_kwargs=eventplot_kwargs) ylim = ax[1].get_xlim() ax[0].set_xlim(ylim) ax[0].set_xlabel('') - ax[0].set_ylabel('Spike rate (Hz)', fontsize=12) - ax[1].set_xlabel('Time (s)', fontsize=12) - ax[1].set_ylabel('Trials', fontsize=12) + ax[0].set_ylabel('Spike rate (Hz)', fontsize=fontsize) + ax[1].set_xlabel('Time (s)', fontsize=fontsize) + ax[1].set_ylabel('Trials', fontsize=fontsize) return fig @@ -692,7 +709,7 @@ def _create_mask_from_window_str(window, frate): # TODO: could infer x coords from plot (if lines are already present) def add_highlights(arr, clusters, pvals, p_threshold=0.05, ax=None, min_pval=0.001, bottom_extend=True, pval_text=True, - text_props=None): + pval_fontsize=10, text_props=None): '''Highlight significant clusters along the last array dimension. Parameters @@ -729,6 +746,8 @@ def add_highlights(arr, clusters, pvals, p_threshold=0.05, ax=None, Dictionary with text properties for p value text boxes. If None, defaults to ``{'boxstyle': 'round', 'facecolor': 'white', 'alpha': 0.75, edgecolor='gray'}``. + pval_fontsize : int + Font size for the p value text. Returns ------- @@ -806,7 +825,7 @@ def add_highlights(arr, clusters, pvals, p_threshold=0.05, ax=None, p_txt = format_pvalue(this_pval) this_text = ax.text( - text_x, text_y, p_txt, + text_x, text_y, p_txt, fontsize=pval_fontsize, bbox=text_props, horizontalalignment='center' ) try: