diff --git a/src/pyrovelocity/plots/_rainbow.py b/src/pyrovelocity/plots/_rainbow.py index 8a9265fe8..bf3d642e1 100644 --- a/src/pyrovelocity/plots/_rainbow.py +++ b/src/pyrovelocity/plots/_rainbow.py @@ -223,16 +223,31 @@ def plot_gene_on_embedding( show_data: bool, st_std: Optional[NDArray[Any]] = None, dotsize: int = 1, + plot_individual_obs: bool = False, + gridsize: int = 100, ): (index,) = np.where(adata.var_names == gene) - im = axes_dict[f"predictive_{n}"].scatter( - adata.obsm[f"X_{basis}"][:, 0], - adata.obsm[f"X_{basis}"][:, 1], - s=dotsize, - c=st[:, index].flatten(), - cmap="cividis", - edgecolors="none", - ) + + if plot_individual_obs: + im = axes_dict[f"predictive_{n}"].scatter( + adata.obsm[f"X_{basis}"][:, 0], + adata.obsm[f"X_{basis}"][:, 1], + s=dotsize, + c=st[:, index].flatten(), + cmap="cividis", + edgecolors="none", + ) + else: + im = axes_dict[f"predictive_{n}"].hexbin( + adata.obsm[f"X_{basis}"][:, 0], + adata.obsm[f"X_{basis}"][:, 1], + C=st[:, index].flatten(), + gridsize=gridsize, + cmap="cividis", + edgecolors="none", + reduce_C_function=np.mean, + ) + set_colorbar( im, axes_dict[f"predictive_{n}"], @@ -244,14 +259,26 @@ def plot_gene_on_embedding( axes_dict[f"predictive_{n}"].axis("off") if st_std is not None: - im = axes_dict[f"cv_{n}"].scatter( - adata.obsm[f"X_{basis}"][:, 0], - adata.obsm[f"X_{basis}"][:, 1], - s=dotsize, - c=st_std[:, index].flatten() / st[:, index].flatten(), - cmap="cividis", - edgecolors="none", - ) + if plot_individual_obs: + im = axes_dict[f"cv_{n}"].scatter( + adata.obsm[f"X_{basis}"][:, 0], + adata.obsm[f"X_{basis}"][:, 1], + s=dotsize, + c=st_std[:, index].flatten() / st[:, index].flatten(), + cmap="cividis", + edgecolors="none", + ) + else: + im = axes_dict[f"cv_{n}"].hexbin( + adata.obsm[f"X_{basis}"][:, 0], + adata.obsm[f"X_{basis}"][:, 1], + C=st_std[:, index].flatten() / st[:, index].flatten(), + gridsize=gridsize, + cmap="cividis", + edgecolors="none", + reduce_C_function=np.mean, + ) + set_colorbar( im, axes_dict[f"cv_{n}"], @@ -263,14 +290,26 @@ def plot_gene_on_embedding( axes_dict[f"cv_{n}"].axis("off") if show_data: - im = axes_dict[f"data_{n}"].scatter( - adata.obsm[f"X_{basis}"][:, 0], - adata.obsm[f"X_{basis}"][:, 1], - s=dotsize, - c=ensure_numpy_array(adata[:, index].X).flatten(), - cmap="cividis", - edgecolors="none", - ) + if plot_individual_obs: + im = axes_dict[f"data_{n}"].scatter( + adata.obsm[f"X_{basis}"][:, 0], + adata.obsm[f"X_{basis}"][:, 1], + s=dotsize, + c=ensure_numpy_array(adata[:, index].X).flatten(), + cmap="cividis", + edgecolors="none", + ) + else: + im = axes_dict[f"data_{n}"].hexbin( + adata.obsm[f"X_{basis}"][:, 0], + adata.obsm[f"X_{basis}"][:, 1], + C=ensure_numpy_array(adata[:, index].X).flatten(), + gridsize=gridsize, + cmap="cividis", + edgecolors="none", + reduce_C_function=np.mean, + ) + set_colorbar( im, axes_dict[f"data_{n}"],