diff --git a/docs/release-notes/3379.performance.md b/docs/release-notes/3379.performance.md new file mode 100644 index 0000000000..061c8addcd --- /dev/null +++ b/docs/release-notes/3379.performance.md @@ -0,0 +1 @@ +In `pp.normalize_total`, the median is now computed in-memory when using Dask {smaller}`S Dicks` diff --git a/src/scanpy/preprocessing/_normalization.py b/src/scanpy/preprocessing/_normalization.py index e1ee3d4822..fd8fc69ca7 100644 --- a/src/scanpy/preprocessing/_normalization.py +++ b/src/scanpy/preprocessing/_normalization.py @@ -32,19 +32,9 @@ def _normalize_data(X, counts, after=None, *, copy: bool = False): X = X.astype(np.float32) # TODO: Check if float64 should be used if after is None: if isinstance(counts, DaskArray): - - def nonzero_median(x): - return np.ma.median(np.ma.masked_array(x, x == 0)).item() - - after = da.from_delayed( - dask.delayed(nonzero_median)(counts), - shape=(), - meta=counts._meta, - dtype=counts.dtype, - ) - else: - counts_greater_than_zero = counts[counts > 0] - after = np.median(counts_greater_than_zero, axis=0) + counts = counts.compute() + counts_greater_than_zero = counts[counts > 0] + after = np.median(counts_greater_than_zero, axis=0) counts = counts / after return axis_mul_or_truediv( X,