Skip to content

Commit

Permalink
allow filtering by ids
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Jul 9, 2024
1 parent cf7a84d commit fc09cc0
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import os
import warnings
from collections.abc import Generator
from collections.abc import Generator, Iterable
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -2143,6 +2143,59 @@ def __delitem__(self, key: str) -> None:
element_type, _, _ = self._find_element(key)
getattr(self, element_type).__delitem__(key)

def filter_elements_by_instances(
self,
element_names: Iterable[str],
instances: Iterable[int | str],
region_names: Iterable[str] | str | None = None,
) -> dict[str, DaskDataFrame | GeoDataFrame | AnnData]:
"""
Filter elements to contain only certain instances.
This filters both SpatialElements (points and shapes)
as well as tables to only contain certain IDs. In case of tables
the instance key column of table.obs will be filtered on and not
table.obs.index. Filtering labels by ID is currently not supported
as this is an expensive operation. Should you require this
please open an issue on github.com/scverse/spatialdata. Lastly,
tables not annotating an element cannot be filtered.
element_names:
Name of either points, shapes or table elements within the Spatialdata
object.
instances:
The instance IDs to filter the elements on.
region_names:
If filtering instances in a table, indicate the region_names (the names of the SpatialElement) for
which you want to filter the instances of the table. If not specified, the table instances for all regions
annotated by the table will be filtered by the given instances.
"""
element_dict = {}
element_names = [element_names] if isinstance(element_names, str) else list(element_names)
for element_name in element_names:
element = self.get(element_name)
if element is not None:
if (model := get_model(element)) == PointsModel:
instance_key = element.attrs[PointsModel.ATTRS_KEY][PointsModel.INSTANCE_KEY]
element_dict[element_name] = element[element[instance_key].isin(instances)]
elif model == ShapesModel:
element_dict[element_name] = element[element.index.isin(instances)]
elif model == TableModel:
instance_key = element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY]
region_key = element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]
if region_names:
region_names = [region_names] if isinstance(region_names, str) else region_names
element = element[element.obs[region_key].isin(region_names)]
regions = element.obs[region_key].cat.categories.tolist()
element_dict[element_name] = element[element.obs[instance_key].isin(instances)].copy()
element_dict[element_name].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = regions
TableModel().validate(element_dict[element_name])
else:
raise TypeError(f"`{model}` is not a valid model for filtering of instances.")
else:
raise KeyError(f"`{element_name}` is not an element in the SpatialData object.")
return element_dict


class QueryManager:
"""Perform queries on SpatialData objects."""
Expand Down

0 comments on commit fc09cc0

Please sign in to comment.