From edca2ed860f2b5a7e5fb34d9f0d193926e48cf84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oskar=20B=C3=B6rjesson?= Date: Tue, 22 Oct 2024 14:06:10 +0200 Subject: [PATCH 1/6] Fix mapping key (#129) --- scarf/mapping_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scarf/mapping_utils.py b/scarf/mapping_utils.py index 7bfd81e..cd10f5c 100644 --- a/scarf/mapping_utils.py +++ b/scarf/mapping_utils.py @@ -102,6 +102,7 @@ def _order_features( filter_null: bool, exclude_missing: bool, nthreads: int, + target_cell_key: str = "I", ) -> Tuple[np.ndarray, np.ndarray]: s_ids = pd.Series(s_assay.feats.fetch_all("ids")) t_ids = pd.Series(t_assay.feats.fetch_all("ids")) @@ -119,7 +120,7 @@ def _order_features( t_idx[t_idx] = ( controlled_compute( t_assay.rawData[:, list(t_idx[t_idx].index)][ - t_assay.cells.active_index("I"), : + t_assay.cells.active_index(target_cell_key), : ].sum(axis=0), nthreads, ) @@ -181,6 +182,7 @@ def align_features( filter_null, exclude_missing, nthreads, + target_cell_key, ) logger.info(f"{(t_idx == -1).sum()} features missing in target data") normed_loc = f"normed__{source_cell_key}__{source_feat_key}" From ca7d6cc5bbc87a80674b4924b324b077c6e2f5c1 Mon Sep 17 00:00:00 2001 From: Gautam Ahuja Date: Wed, 23 Oct 2024 20:38:39 +0530 Subject: [PATCH 2/6] CrDirReader Update [Polars -> Pandas] (#130) * Revert to Pandas from Polars for improved efficiency * Upadted Pandas and Code clean-up * Refactor process_batch to use Polars for faster aggregation and filtering --- scarf/readers.py | 130 ++++++++++++++++++++++++++--------------------- 1 file changed, 73 insertions(+), 57 deletions(-) diff --git a/scarf/readers.py b/scarf/readers.py index 1fc905c..69ef4f9 100644 --- a/scarf/readers.py +++ b/scarf/readers.py @@ -9,11 +9,11 @@ - LoomReader: A class to read in data in the form of a Loom file. """ +import math import os from abc import ABC, abstractmethod -from typing import Generator, Dict, List, Optional, Tuple -from typing import IO -import math +from typing import IO, Dict, Generator, List, Optional, Tuple + import h5py import numpy as np import pandas as pd @@ -392,38 +392,47 @@ def _read_dataset(self, key: Optional[str] = None): vals = None return vals - def read_header(self) -> pl.DataFrame: - header = pl.read_csv( + def read_header(self) -> pd.DataFrame: + header = pd.read_csv( self.matFn, - comment_prefix = '%', - separator=self.sep, - has_header=False, - n_rows=1, - new_columns=["nFeatures", "nCells", "nCounts"], + comment="%", + sep=self.sep, + header=None, + nrows=1, + names=["nFeatures", "nCells", "nCounts"], ) - if header['nCells'][0] == 0 and self.nCells > 0: - raise ValueError("ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file") - if header['nCells'][0] > 0 and self.nCells == 0: - raise ValueError("ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file") - if header['nCells'][0] == 0 and self.nCells == 0: - raise ValueError("ERROR: Barcode count in MTX header and barcodes file is 0. No data to read") + if header["nCells"][0] == 0 and self.nCells > 0: + raise ValueError( + "ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file" + ) + if header["nCells"][0] > 0 and self.nCells == 0: + raise ValueError( + "ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file" + ) + if header["nCells"][0] == 0 and self.nCells == 0: + raise ValueError( + "ERROR: Barcode count in MTX header and barcodes file is 0. No data to read" + ) return header - def process_batch(self, dfs: pl.DataFrame, filtering_cutoff: int) -> List: + def process_batch(self, dfs: List[pd.DataFrame], filtering_cutoff: int) -> np.array: """Returns a list of valid barcodes after filtering out background barcodes for a given batch. Args: dfs: A Polar DataFrame containing a chunk of data from the MTX file. filtering_cutoff: The cutoff value for filtering out background barcodes """ - dfs_ = dfs.group_by('barcode').agg(pl.sum('count')) + pl_dfs = [pl.DataFrame(df) for df in dfs] + pl_dfs = pl.concat(pl_dfs) + dfs_ = pl_dfs.group_by('barcode').agg(pl.sum('count')) dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff) return np.sort(dfs_['barcode']) def _get_valid_barcodes( - self, filtering_cutoff: int, - batch_size: int = int(10e4), - lines_in_mem: int = int(10e6) + self, + filtering_cutoff: int, + batch_size: int = int(10e3), + lines_in_mem: int = int(10e6), ) -> np.ndarray: """Returns a list of valid barcodes after filtering out background barcodes. @@ -433,48 +442,53 @@ def _get_valid_barcodes( lines_in_mem: The number of lines to read into memory """ test_counter = 0 - matrixIO = pl.scan_csv( - self.matFn, - comment_prefix='%', - # skip_rows=3, - skip_rows_after_header=1, - separator=self.sep, - has_header=False, + matrixIO = pd.read_csv( + self.matFn, + comment="%", + sep=self.sep, + header=0, + chunksize=lines_in_mem, + names=["gene", "barcode", "count"], ) - assert len(matrixIO.collect_schema().names()) == 3 - matrixIO = matrixIO.rename({'column_1': 'gene', 'column_2': 'barcode', 'column_3': 'count'}) + header = self.read_header() nChunks = math.ceil(header["nCounts"][0] / lines_in_mem) test_counter = 0 valid_idx = [] start = 1 - dfs = pl.DataFrame() - for i in tqdmbar( - range(nChunks), desc="Filtering out background barcodes" + + dfs = [] + for chunk in tqdmbar( + # range(nChunks), + matrixIO, + total=nChunks, + desc="Filtering out background barcodes", ): - chunk = matrixIO.slice(i*lines_in_mem, lines_in_mem).collect() - # Check if we've reached or exceeded the current batch boundary - if (chunk[-1]['barcode'][0] - start) >= batch_size: # If the last "cell id" is greater than the start + batch size + if ( + (chunk.iloc[-1]["barcode"] - start) >= batch_size + ): # If the last "cell id" is greater than the start + batch size # Filter rows in the current chunk that belong to the current batch - idx = np.array(chunk['barcode'] < (batch_size + start)) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included. + idx = np.array( + chunk["barcode"].values < (batch_size + start) + ) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included. # If no rows belong to the current batch, move to the next batch. if idx.sum() == 0: - dfs = pl.concat([dfs, chunk]) + dfs.append(chunk) start += batch_size test_counter += len(chunk) continue # Process the rows belonging to the current batch mask_pos = np.where(idx)[0] mask_neg = np.where(~idx)[0] - dfs = pl.concat([dfs, chunk[mask_pos]]) + dfs.append(chunk.iloc[mask_pos]) valid_idx.append(self.process_batch(dfs, filtering_cutoff)) # Prepare for the next batch del dfs - dfs = chunk[mask_neg] + dfs = [chunk.iloc[mask_neg]] start += batch_size else: # If we haven't reached the batch boundary, accumulate the chunk - dfs = pl.concat([dfs, chunk]) + dfs.append(chunk) test_counter += len(chunk) # Process any remaining data after the main loop if len(dfs) > 0: @@ -512,7 +526,7 @@ def cell_names(self) -> List[str]: def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List: df = pl.concat(collect) - barcodes = np.array(df['barcode']) + barcodes = np.array(df["barcode"]) count_hash = {} for i, x in enumerate(np.unique(barcodes)): count_hash[x] = i @@ -535,14 +549,14 @@ def consume( dtype: The data type of the matrix. """ matrixIO = pl.read_csv_batched( - self.matFn, - has_header=False, + self.matFn, + has_header=False, separator=self.sep, comment_prefix="%", - skip_rows_after_header=1, - new_columns=['gene', 'barcode', 'count'], - schema_overrides={'gene': pl.Int64, 'barcode': pl.Int64, 'count': pl.Int64}, - batch_size=lines_in_mem + skip_rows_after_header=1, + new_columns=["gene", "barcode", "count"], + schema_overrides={"gene": pl.Int64, "barcode": pl.Int64, "count": pl.Int64}, + batch_size=lines_in_mem, ) unique_list = [] collect = [] @@ -551,20 +565,20 @@ def consume( if chunk is None: break chunk = chunk[0] - chunk = chunk.filter(pl.col('barcode').is_in(self.validBarcodeIdx)) - in_uniques = np.unique(chunk['barcode']) + chunk = chunk.filter(pl.col("barcode").is_in(self.validBarcodeIdx)) + in_uniques = np.unique(chunk["barcode"]) unique_list.extend(in_uniques) unique_list = list(set(unique_list)) if len(unique_list) > batch_size: diff = batch_size - (len(unique_list) - len(in_uniques)) mask_pos = in_uniques[:diff] mask_neg = in_uniques[diff:] - extra = chunk.filter(pl.col('barcode').is_in(mask_pos)) + extra = chunk.filter(pl.col("barcode").is_in(mask_pos)) collect.append(extra) collect = self.rename_batches(collect, batch_size) mtx = self.to_sparse(np.array(collect), dtype=dtype) yield mtx - left_out = chunk.filter(pl.col('barcode').is_in(mask_neg)) + left_out = chunk.filter(pl.col("barcode").is_in(mask_neg)) collect = [] unique_list = list(mask_neg) collect.append(left_out) @@ -635,8 +649,9 @@ def __init__( self.obsmAttrsKey: self._validate_group(self.obsmAttrsKey), self.matrixKey: self._validate_group(self.matrixKey), } - self.nCells, self.nFeatures = self._get_n(self.cellAttrsKey), self._get_n( - self.featureAttrsKey + self.nCells, self.nFeatures = ( + self._get_n(self.cellAttrsKey), + self._get_n(self.featureAttrsKey), ) self.cellIdsKey = self._fix_name_key(self.cellAttrsKey, cell_ids_key) self.featIdsKey = self._fix_name_key(self.featureAttrsKey, feature_ids_key) @@ -809,8 +824,9 @@ def _get_col_data( if i in ignore_keys: continue if isinstance(self.h5[group][i], h5py.Dataset): - yield i, self._replace_category_values( - self.h5[group][i][:], i, group + yield ( + i, + self._replace_category_values(self.h5[group][i][:], i, group), ) def _get_obsm_data( @@ -832,7 +848,7 @@ def _get_obsm_data( yield f"{i}{j+1}", g[:, j] else: logger.warning( - f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541 + f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541 ) def get_cell_columns(self) -> Generator[Tuple[str, np.ndarray], None, None]: From 31fe77208ba71aeff842e1003d707d1680d67b9f Mon Sep 17 00:00:00 2001 From: Parashar Date: Wed, 23 Oct 2024 18:13:38 +0200 Subject: [PATCH 3/6] Version bump --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 88f8ee8..f95b087 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.29.5 +0.29.7 From fde515c4d758318777c9fa677376b16fd07fdb6d Mon Sep 17 00:00:00 2001 From: Gautam Ahuja Date: Mon, 16 Dec 2024 16:44:44 +0530 Subject: [PATCH 4/6] run get_feat_suffix() only when feature names and feature ids are same (#135) --- scarf/merge.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/scarf/merge.py b/scarf/merge.py index 0e5f9b5..0c60133 100644 --- a/scarf/merge.py +++ b/scarf/merge.py @@ -15,11 +15,7 @@ from dask.array.core import Array as daskArrayType from scipy.sparse import coo_matrix -from .assay import ( - ADTassay, - ATACassay, - RNAassay, -) +from .assay import Assay from .datastore.datastore import DataStore from .metadata import MetaData from .utils import ( @@ -96,7 +92,7 @@ class AssayMerge: def __init__( self, zarr_path: ZARRLOC, - assays: List[Union[RNAassay, ATACassay, ADTassay]], + assays: List[Assay], names: List[str], merge_assay_name: str, in_workspaces: Union[list[str], None] = None, @@ -124,10 +120,10 @@ def __init__( ) self.nCells: int = self.mergedCells.shape[0] self.featCollection: List[Dict[str, str]] = self._get_feat_ids(assays) - self.feat_suffix: Dict[int, int] = self.get_feat_suffix() self.feat_name_ids_same: bool = self.check_feat_ids(self.featCollection) if self.feat_name_ids_same is True: + self.feat_suffix: Dict[int, int] = self.get_feat_suffix() self.featCollection = self.update_feat_ids() self.featCollection_map: List[Dict[str, str]] = ( self.update_feat_ids_for_map() @@ -197,7 +193,7 @@ def perform_randomization_rows( for i in range(len(permutations)): in__dict: dict[int, np.ndarray] = {} last_key = i - 1 if i > 0 else 0 - offset = nCells[last_key] + offset if i > 0 else 0 + offset = nCells[last_key] + offset if i > 0 else 0 # noqa: F821 for j, arr in enumerate(permutations[i]): in__dict[j] = arr + offset permutations_rows_offset[i] = in__dict @@ -580,7 +576,9 @@ def _ini_cell_data(self, overwrite) -> None: f"cellData already exists so skipping _ini_cell_data" # noqa: F541 ) - def _dask_to_coo(self, d_arr, order: np.ndarray, order_map: np.ndarray, n_threads: int) -> coo_matrix: + def _dask_to_coo( + self, d_arr, order: np.ndarray, order_map: np.ndarray, n_threads: int + ) -> coo_matrix: """ Convert a Dask array to a sparse COO matrix. Args: @@ -780,7 +778,9 @@ def generate_dummy_assay(self, ds: DataStore, assay_name: str) -> DummyAssay: # Create a dummy assay with zero counts and matching features dummy_shape = (ds.cells.N, reference_assay.feats.N) - dummy_counts = zarr.zeros(dummy_shape, chunks=chunkShape, dtype=reference_assay.rawData.dtype) + dummy_counts = zarr.zeros( + dummy_shape, chunks=chunkShape, dtype=reference_assay.rawData.dtype + ) dummy_counts = from_array(dummy_counts, chunks=chunkShape) dummy_assay = DummyAssay( ds, dummy_counts, reference_assay.feats, reference_assay.name From fd40926ce3d14c76475d7a7b43126b1e4bd88ac3 Mon Sep 17 00:00:00 2001 From: Gautam Ahuja Date: Tue, 17 Dec 2024 14:07:31 +0530 Subject: [PATCH 5/6] Updated consume() in CrDirReader [Polars -> Pandas] (#134) * Updated consume() in CrDirReader. Fixed polar issue for reading compressed file * Comment Cleanup --- scarf/readers.py | 43 +++++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/scarf/readers.py b/scarf/readers.py index 69ef4f9..148059f 100644 --- a/scarf/readers.py +++ b/scarf/readers.py @@ -424,15 +424,15 @@ def process_batch(self, dfs: List[pd.DataFrame], filtering_cutoff: int) -> np.ar """ pl_dfs = [pl.DataFrame(df) for df in dfs] pl_dfs = pl.concat(pl_dfs) - dfs_ = pl_dfs.group_by('barcode').agg(pl.sum('count')) - dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff) - return np.sort(dfs_['barcode']) + dfs_ = pl_dfs.group_by("barcode").agg(pl.sum("count")) + dfs_ = dfs_.filter(pl.col("count") > filtering_cutoff) + return np.sort(dfs_["barcode"]) def _get_valid_barcodes( self, filtering_cutoff: int, batch_size: int = int(10e3), - lines_in_mem: int = int(10e6), + lines_in_mem: int = int(10e5), ) -> np.ndarray: """Returns a list of valid barcodes after filtering out background barcodes. @@ -524,7 +524,8 @@ def cell_names(self) -> List[str]: vals = vals[(self.validBarcodeIdx + self.indexOffset)] return list(vals) - def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List: + def rename_batches(self, collect: List[pd.DataFrame]) -> List: + collect = [pl.DataFrame(df) for df in collect] df = pl.concat(collect) barcodes = np.array(df["barcode"]) count_hash = {} @@ -548,44 +549,38 @@ def consume( lines_in_mem: The number of lines to read into memory. dtype: The data type of the matrix. """ - matrixIO = pl.read_csv_batched( + matrixIO = pd.read_csv( self.matFn, - has_header=False, - separator=self.sep, - comment_prefix="%", - skip_rows_after_header=1, - new_columns=["gene", "barcode", "count"], - schema_overrides={"gene": pl.Int64, "barcode": pl.Int64, "count": pl.Int64}, - batch_size=lines_in_mem, + comment="%", + sep=self.sep, + header=0, + chunksize=lines_in_mem, + names=["gene", "barcode", "count"], ) unique_list = [] collect = [] - while True: - chunk = matrixIO.next_batches(1) - if chunk is None: - break - chunk = chunk[0] - chunk = chunk.filter(pl.col("barcode").is_in(self.validBarcodeIdx)) - in_uniques = np.unique(chunk["barcode"]) + for chunk in matrixIO: + chunk = chunk[chunk["barcode"].isin(self.validBarcodeIdx)] + in_uniques = np.unique(chunk["barcode"].values) unique_list.extend(in_uniques) unique_list = list(set(unique_list)) if len(unique_list) > batch_size: diff = batch_size - (len(unique_list) - len(in_uniques)) mask_pos = in_uniques[:diff] mask_neg = in_uniques[diff:] - extra = chunk.filter(pl.col("barcode").is_in(mask_pos)) + extra = chunk[chunk["barcode"].isin(mask_pos)] collect.append(extra) - collect = self.rename_batches(collect, batch_size) + collect = self.rename_batches(collect) mtx = self.to_sparse(np.array(collect), dtype=dtype) yield mtx - left_out = chunk.filter(pl.col("barcode").is_in(mask_neg)) + left_out = chunk[chunk["barcode"].isin(mask_neg)] collect = [] unique_list = list(mask_neg) collect.append(left_out) else: collect.append(chunk) if len(collect) > 0: - collect = self.rename_batches(collect, batch_size) + collect = self.rename_batches(collect) mtx = self.to_sparse(np.array(collect), dtype=dtype) yield mtx From 4e93e149f0d76288de6d4a3a61f7307405f4beef Mon Sep 17 00:00:00 2001 From: Gautam Ahuja Date: Tue, 17 Dec 2024 14:23:54 +0530 Subject: [PATCH 6/6] Update Mark HVGS (#133) * Separated summary stat from mark_hvgs. Checked type issues. * Added check for lowess_frac between 0 and 1 --- scarf/assay.py | 89 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 23 deletions(-) diff --git a/scarf/assay.py b/scarf/assay.py index 477d8c0..83cf143 100644 --- a/scarf/assay.py +++ b/scarf/assay.py @@ -9,17 +9,18 @@ method for feature selection. """ -from typing import Tuple, List, Generator, Optional, Union +from typing import Generator, List, Optional, Tuple, Union import numpy as np import pandas as pd +import zarr from dask.array.core import Array as daskArrayType from dask.array.core import from_zarr from scipy.sparse import csr_matrix, vstack from zarr import hierarchy as z_hierarchy from .metadata import MetaData -from .utils import show_dask_progress, controlled_compute, logger +from .utils import controlled_compute, logger, show_dask_progress zarrGroup = z_hierarchy.Group @@ -279,12 +280,18 @@ def _verify_keys(self, cell_key: str, feat_key: str) -> None: feat_key: Name of the key (column) from feature attribute table Returns: None + + Note on type checking /GA: + 1. ds.cells.get_dtype(cell_key) == bool returns True because dtype('bool') (from numpy) is conceptually equivalent to Python's bool. + 2. isinstance(ds.cells.get_dtype(cell_key), bool) returns False because dtype('bool') is a numpy.dtype object, not the native Python bool type. + 3. Reason: dtype('bool') is a numpy object, and isinstance checks for the exact class, which is numpy.dtype, not bool. + """ - if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool: + if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool: # noqa: E721 raise ValueError( f"ERROR: Either {cell_key} does not exist or is not bool type" ) - if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool: + if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool: # noqa: E721 raise ValueError( f"ERROR: Either {feat_key} does not exist or is not bool type" ) @@ -526,9 +533,10 @@ def iter_normed_feature_wise( columns=feat_idx[chunk], ) else: - yield controlled_compute(data[:, chunk], self.nthreads).T, feat_idx[ - chunk - ] + yield ( + controlled_compute(data[:, chunk], self.nthreads).T, + feat_idx[chunk], + ) def save_normed_for_query( self, feat_key: Optional[str], batch_size: int, overwrite: bool = True @@ -549,6 +557,7 @@ def save_normed_for_query( None """ from joblib import Parallel, delayed + from .writers import create_zarr_obj_array def write_wrapper(idx: str, v: np.ndarray) -> None: @@ -563,7 +572,8 @@ def write_wrapper(idx: str, v: np.ndarray) -> None: None, feat_key, batch_size, "Saving features", False ): Parallel(n_jobs=self.nthreads)( - delayed(write_wrapper)(inds[i], mat[i]) for i in range(len(inds)) # type: ignore + delayed(write_wrapper)(inds[i], mat[i]) + for i in range(len(inds)) # type: ignore ) def save_aggregated_ordering( @@ -888,6 +898,51 @@ def set_feature_stats(self, cell_key: str) -> None: self.feats.unmount_location(identifier) return None + def set_summary_stats( + self, cell_key: str = None, n_bins: int = 200, lowess_frac: float = 0.1 + ) -> Tuple[str, str]: + """Calculates summary statistics for the features of the assay using only cells that are marked True by the 'cell_key' parameter. + + Args: + cell_key: Name of the key (column) from cell attribute table. + n_bins: Number of bins to divide the data into. + lowess_frac: Between 0 and 1. The fraction of the data used when estimating the fit between mean and + variance. This is same as `frac` in statsmodels.nonparametric.smoothers_lowess.lowess + + Returns: + A tuple of two strings. + identifier: The text that will be prepended to column names when summary statistics are loaded onto the feature attributes table. + c_var_col: The name of the column in the feature attribute table that contains the corrected variance values. + """ + + def col_renamer(x): + return f"{identifier}_{x}" + + if cell_key is None: + cell_key = "I" + + # check lowess_frac is between 0 and 1 + if not 0 <= lowess_frac <= 1: + raise ValueError("lowess_frac must be between 0 and 1") + + self.set_feature_stats(cell_key) + identifier = self._load_stats_loc(cell_key) + c_var_col = f"c_var__{n_bins}__{lowess_frac}" + if col_renamer(c_var_col) in self.feats.columns: + logger.info("Using existing corrected dispersion values") + else: + slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"] + for i in slots: + i = col_renamer(i) + if i not in self.feats.columns: + raise KeyError(f"ERROR: {i} not found in feature metadata") + c_var = self.feats.remove_trend( + col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac + ) + self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier) + + return identifier, c_var_col + # maybe we should return plot here? If one wants to modify it. /raz def mark_hvgs( self, @@ -950,21 +1005,9 @@ def mark_hvgs( def col_renamer(x): return f"{identifier}_{x}" - self.set_feature_stats(cell_key) - identifier = self._load_stats_loc(cell_key) - c_var_col = f"c_var__{n_bins}__{lowess_frac}" - if col_renamer(c_var_col) in self.feats.columns: - logger.info("Using existing corrected dispersion values") - else: - slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"] - for i in slots: - i = col_renamer(i) - if i not in self.feats.columns: - raise KeyError(f"ERROR: {i} not found in feature metadata") - c_var = self.feats.remove_trend( - col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac - ) - self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier) + logger.info("Calculating summary statistics") + identifier, c_var_col = self.set_summary_stats(cell_key, n_bins, lowess_frac) + logger.info("Calculating HVGs") if max_mean != np.inf: max_mean = 2**max_mean