Skip to content

Commit

Permalink
add default argument for match_sdata_to_table(); fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato committed Feb 2, 2025
1 parent 23ac075 commit d79e533
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,8 @@ def match_element_to_table(

def match_sdata_to_table(
sdata: SpatialData,
table: AnnData,
table_name: str,
table: AnnData | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> SpatialData:
"""
Expand All @@ -795,6 +795,8 @@ def match_sdata_to_table(
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".
"""
if table is None:
table = sdata[table_name]
_, region_key, instance_key = get_table_keys(table)
annotated_regions = SpatialData.get_annotated_regions(table)
filtered_elements, filtered_table = join_spatialelement_table(
Expand Down
38 changes: 30 additions & 8 deletions tests/core/query/test_relational_query_match_sdata_to_table.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import pytest

from spatialdata import concatenate, match_sdata_to_table
from spatialdata import SpatialData, concatenate, match_sdata_to_table
from spatialdata.datasets import blobs_annotating_element

# constructing the example data; let's use a global variable as we can reuse the same object for all the tests

def _make_test_data() -> SpatialData:
sdata1 = blobs_annotating_element("blobs_polygons")
sdata2 = blobs_annotating_element("blobs_polygons")
sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True)
sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0]))
return sdata


# constructing the example data; let's use a global variable as we can reuse the same object on most tests
# without having to recreate it
sdata1 = blobs_annotating_element("blobs_polygons")
sdata2 = blobs_annotating_element("blobs_polygons")
sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True)
sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0]))
sdata = _make_test_data()


def test_match_sdata_to_table_filter_specific_instances():
Expand Down Expand Up @@ -74,6 +80,7 @@ def test_match_sdata_to_table_shapes_and_points():
The function works both for shapes (examples above) and points.
Changes the target of the table to labels.
"""
sdata = _make_test_data()
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points"))
sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
sdata.set_table_annotates_spatialelement(
Expand All @@ -100,7 +107,8 @@ def test_match_sdata_to_table_match_labels_error():
match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the
join.
"""
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("points", "labels"))
sdata = _make_test_data()
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels"))
sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
sdata.set_table_annotates_spatialelement(
table_name="table",
Expand All @@ -109,7 +117,10 @@ def test_match_sdata_to_table_match_labels_error():
instance_key="instance_id",
)

with pytest.warns(UserWarning, match="Element type `labels` not supported for 'right' join. Skipping "):
with pytest.warns(
UserWarning,
match="Element type `labels` not supported for 'right' join. Skipping ",
):
matched = match_sdata_to_table(
sdata,
table=sdata["table"],
Expand All @@ -120,3 +131,14 @@ def test_match_sdata_to_table_match_labels_error():
assert "blobs_labels-sdata1" in matched
assert "blobs_labels-sdata2" in matched
assert "blobs_points-sdata1" not in matched


def test_match_sdata_to_table_no_table_argument():
"""
If no table argument is passed, the table_name argument will be used to match the table.
"""
matched = match_sdata_to_table(sdata=sdata, table_name="table")

assert len(matched["table"]) == 10
assert "blobs_polygons-sdata1" in matched
assert "blobs_polygons-sdata2" in matched

0 comments on commit d79e533

Please sign in to comment.