Skip to content

Commit

Permalink
Merge branch 'parashardhapola:master' into scarfMetrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Gautam8387 authored Dec 17, 2024
2 parents da1f414 + 4e93e14 commit 9e80661
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 105 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.29.5
0.29.7
89 changes: 66 additions & 23 deletions scarf/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
method for feature selection.
"""

from typing import Tuple, List, Generator, Optional, Union
from typing import Generator, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import zarr
from dask.array.core import Array as daskArrayType
from dask.array.core import from_zarr
from scipy.sparse import csr_matrix, vstack
from zarr import hierarchy as z_hierarchy

from .metadata import MetaData
from .utils import show_dask_progress, controlled_compute, logger
from .utils import controlled_compute, logger, show_dask_progress

zarrGroup = z_hierarchy.Group

Expand Down Expand Up @@ -279,12 +280,18 @@ def _verify_keys(self, cell_key: str, feat_key: str) -> None:
feat_key: Name of the key (column) from feature attribute table
Returns: None
Note on type checking /GA:
1. ds.cells.get_dtype(cell_key) == bool returns True because dtype('bool') (from numpy) is conceptually equivalent to Python's bool.
2. isinstance(ds.cells.get_dtype(cell_key), bool) returns False because dtype('bool') is a numpy.dtype object, not the native Python bool type.
3. Reason: dtype('bool') is a numpy object, and isinstance checks for the exact class, which is numpy.dtype, not bool.
"""
if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool:
if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool: # noqa: E721
raise ValueError(
f"ERROR: Either {cell_key} does not exist or is not bool type"
)
if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool:
if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool: # noqa: E721
raise ValueError(
f"ERROR: Either {feat_key} does not exist or is not bool type"
)
Expand Down Expand Up @@ -526,9 +533,10 @@ def iter_normed_feature_wise(
columns=feat_idx[chunk],
)
else:
yield controlled_compute(data[:, chunk], self.nthreads).T, feat_idx[
chunk
]
yield (
controlled_compute(data[:, chunk], self.nthreads).T,
feat_idx[chunk],
)

def save_normed_for_query(
self, feat_key: Optional[str], batch_size: int, overwrite: bool = True
Expand All @@ -549,6 +557,7 @@ def save_normed_for_query(
None
"""
from joblib import Parallel, delayed

from .writers import create_zarr_obj_array

def write_wrapper(idx: str, v: np.ndarray) -> None:
Expand All @@ -563,7 +572,8 @@ def write_wrapper(idx: str, v: np.ndarray) -> None:
None, feat_key, batch_size, "Saving features", False
):
Parallel(n_jobs=self.nthreads)(
delayed(write_wrapper)(inds[i], mat[i]) for i in range(len(inds)) # type: ignore
delayed(write_wrapper)(inds[i], mat[i])
for i in range(len(inds)) # type: ignore
)

def save_aggregated_ordering(
Expand Down Expand Up @@ -888,6 +898,51 @@ def set_feature_stats(self, cell_key: str) -> None:
self.feats.unmount_location(identifier)
return None

def set_summary_stats(
self, cell_key: str = None, n_bins: int = 200, lowess_frac: float = 0.1
) -> Tuple[str, str]:
"""Calculates summary statistics for the features of the assay using only cells that are marked True by the 'cell_key' parameter.
Args:
cell_key: Name of the key (column) from cell attribute table.
n_bins: Number of bins to divide the data into.
lowess_frac: Between 0 and 1. The fraction of the data used when estimating the fit between mean and
variance. This is same as `frac` in statsmodels.nonparametric.smoothers_lowess.lowess
Returns:
A tuple of two strings.
identifier: The text that will be prepended to column names when summary statistics are loaded onto the feature attributes table.
c_var_col: The name of the column in the feature attribute table that contains the corrected variance values.
"""

def col_renamer(x):
return f"{identifier}_{x}"

if cell_key is None:
cell_key = "I"

# check lowess_frac is between 0 and 1
if not 0 <= lowess_frac <= 1:
raise ValueError("lowess_frac must be between 0 and 1")

self.set_feature_stats(cell_key)
identifier = self._load_stats_loc(cell_key)
c_var_col = f"c_var__{n_bins}__{lowess_frac}"
if col_renamer(c_var_col) in self.feats.columns:
logger.info("Using existing corrected dispersion values")
else:
slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"]
for i in slots:
i = col_renamer(i)
if i not in self.feats.columns:
raise KeyError(f"ERROR: {i} not found in feature metadata")
c_var = self.feats.remove_trend(
col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac
)
self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier)

return identifier, c_var_col

# maybe we should return plot here? If one wants to modify it. /raz
def mark_hvgs(
self,
Expand Down Expand Up @@ -950,21 +1005,9 @@ def mark_hvgs(
def col_renamer(x):
return f"{identifier}_{x}"

self.set_feature_stats(cell_key)
identifier = self._load_stats_loc(cell_key)
c_var_col = f"c_var__{n_bins}__{lowess_frac}"
if col_renamer(c_var_col) in self.feats.columns:
logger.info("Using existing corrected dispersion values")
else:
slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"]
for i in slots:
i = col_renamer(i)
if i not in self.feats.columns:
raise KeyError(f"ERROR: {i} not found in feature metadata")
c_var = self.feats.remove_trend(
col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac
)
self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier)
logger.info("Calculating summary statistics")
identifier, c_var_col = self.set_summary_stats(cell_key, n_bins, lowess_frac)
logger.info("Calculating HVGs")

if max_mean != np.inf:
max_mean = 2**max_mean
Expand Down
4 changes: 3 additions & 1 deletion scarf/mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _order_features(
filter_null: bool,
exclude_missing: bool,
nthreads: int,
target_cell_key: str = "I",
) -> Tuple[np.ndarray, np.ndarray]:
s_ids = pd.Series(s_assay.feats.fetch_all("ids"))
t_ids = pd.Series(t_assay.feats.fetch_all("ids"))
Expand All @@ -119,7 +120,7 @@ def _order_features(
t_idx[t_idx] = (
controlled_compute(
t_assay.rawData[:, list(t_idx[t_idx].index)][
t_assay.cells.active_index("I"), :
t_assay.cells.active_index(target_cell_key), :
].sum(axis=0),
nthreads,
)
Expand Down Expand Up @@ -181,6 +182,7 @@ def align_features(
filter_null,
exclude_missing,
nthreads,
target_cell_key,
)
logger.info(f"{(t_idx == -1).sum()} features missing in target data")
normed_loc = f"normed__{source_cell_key}__{source_feat_key}"
Expand Down
20 changes: 10 additions & 10 deletions scarf/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
from dask.array.core import Array as daskArrayType
from scipy.sparse import coo_matrix

from .assay import (
ADTassay,
ATACassay,
RNAassay,
)
from .assay import Assay
from .datastore.datastore import DataStore
from .metadata import MetaData
from .utils import (
Expand Down Expand Up @@ -96,7 +92,7 @@ class AssayMerge:
def __init__(
self,
zarr_path: ZARRLOC,
assays: List[Union[RNAassay, ATACassay, ADTassay]],
assays: List[Assay],
names: List[str],
merge_assay_name: str,
in_workspaces: Union[list[str], None] = None,
Expand Down Expand Up @@ -124,10 +120,10 @@ def __init__(
)
self.nCells: int = self.mergedCells.shape[0]
self.featCollection: List[Dict[str, str]] = self._get_feat_ids(assays)
self.feat_suffix: Dict[int, int] = self.get_feat_suffix()
self.feat_name_ids_same: bool = self.check_feat_ids(self.featCollection)

if self.feat_name_ids_same is True:
self.feat_suffix: Dict[int, int] = self.get_feat_suffix()
self.featCollection = self.update_feat_ids()
self.featCollection_map: List[Dict[str, str]] = (
self.update_feat_ids_for_map()
Expand Down Expand Up @@ -197,7 +193,7 @@ def perform_randomization_rows(
for i in range(len(permutations)):
in__dict: dict[int, np.ndarray] = {}
last_key = i - 1 if i > 0 else 0
offset = nCells[last_key] + offset if i > 0 else 0
offset = nCells[last_key] + offset if i > 0 else 0 # noqa: F821
for j, arr in enumerate(permutations[i]):
in__dict[j] = arr + offset
permutations_rows_offset[i] = in__dict
Expand Down Expand Up @@ -580,7 +576,9 @@ def _ini_cell_data(self, overwrite) -> None:
f"cellData already exists so skipping _ini_cell_data" # noqa: F541
)

def _dask_to_coo(self, d_arr, order: np.ndarray, order_map: np.ndarray, n_threads: int) -> coo_matrix:
def _dask_to_coo(
self, d_arr, order: np.ndarray, order_map: np.ndarray, n_threads: int
) -> coo_matrix:
"""
Convert a Dask array to a sparse COO matrix.
Args:
Expand Down Expand Up @@ -780,7 +778,9 @@ def generate_dummy_assay(self, ds: DataStore, assay_name: str) -> DummyAssay:

# Create a dummy assay with zero counts and matching features
dummy_shape = (ds.cells.N, reference_assay.feats.N)
dummy_counts = zarr.zeros(dummy_shape, chunks=chunkShape, dtype=reference_assay.rawData.dtype)
dummy_counts = zarr.zeros(
dummy_shape, chunks=chunkShape, dtype=reference_assay.rawData.dtype
)
dummy_counts = from_array(dummy_counts, chunks=chunkShape)
dummy_assay = DummyAssay(
ds, dummy_counts, reference_assay.feats, reference_assay.name
Expand Down
Loading

0 comments on commit 9e80661

Please sign in to comment.