Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato committed Jan 13, 2025
1 parent 64c5ffc commit 4ca78e5
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import pytest
from anndata import AnnData

from spatialdata import get_values, match_table_to_element
from spatialdata import SpatialData, get_values, match_table_to_element
from spatialdata._core.query.relational_query import (
_locate_value,
_ValueOrigin,
get_element_annotators,
join_spatialelement_table,
)
from spatialdata.models.models import TableModel
from spatialdata.testing import assert_anndata_equal, assert_geodataframe_equal


def test_match_table_to_element(sdata_query_aggregation):
Expand Down Expand Up @@ -376,20 +377,45 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation):
assert all(indices == reversed_instance_id)


def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation):
# TODO: 'left_exclusive' is currently not working, reported in this issue:
@pytest.mark.parametrize("join_type", ["left", "right", "inner", "right_exclusive"])
def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: SpatialData, join_type: str) -> None:
sdata = sdata_query_aggregation
sdata["table"].obs.index = ["a"] * sdata["table"].n_obs
sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5]
sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5]
sdata["values_circles"] = sdata_query_aggregation["values_circles"][:4]
sdata["values_polygons"] = sdata_query_aggregation["values_polygons"][:5]

element_dict, table = join_spatialelement_table(
sdata=sdata,
spatial_element_names=["values_circles", "values_polygons"],
table_name="table",
how="inner",
)

assert table.n_obs == 10
how=join_type,
)

if join_type in ["left", "inner"]:
# table check
assert table.n_obs == 9
assert np.array_equal(table.obs["instance_id"][:4], sdata["values_circles"].index)
assert np.array_equal(table.obs["instance_id"][4:], sdata["values_polygons"].index)
# shapes check
assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"])
assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"])
elif join_type == "right":
# table check
assert_anndata_equal(table.obs, sdata["table"].obs)
# shapes check
assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"])
assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"])
elif join_type == "left_exclusive":
# TODO: currently not working, reported in this issue
pass
else:
assert join_type == "right_exclusive"
# table check
assert table.n_obs == sdata["table"].n_obs - len(sdata["values_circles"]) - len(sdata["values_polygons"])
# shapes check
assert element_dict["values_circles"] is None
assert element_dict["values_polygons"] is None


# TODO: there is a lot of dublicate code, simplify with a function that tests both the case sdata=None and sdata=sdata
Expand Down

0 comments on commit 4ca78e5

Please sign in to comment.