diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index a83588be..877349d4 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -3,7 +3,7 @@ import pytest from anndata import AnnData -from spatialdata import get_values, match_table_to_element +from spatialdata import SpatialData, get_values, match_table_to_element from spatialdata._core.query.relational_query import ( _locate_value, _ValueOrigin, @@ -11,6 +11,7 @@ join_spatialelement_table, ) from spatialdata.models.models import TableModel +from spatialdata.testing import assert_anndata_equal, assert_geodataframe_equal def test_match_table_to_element(sdata_query_aggregation): @@ -376,20 +377,45 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): assert all(indices == reversed_instance_id) -def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation): +# TODO: 'left_exclusive' is currently not working, reported in this issue: +@pytest.mark.parametrize("join_type", ["left", "right", "inner", "right_exclusive"]) +def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: SpatialData, join_type: str) -> None: sdata = sdata_query_aggregation sdata["table"].obs.index = ["a"] * sdata["table"].n_obs - sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] - sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] + sdata["values_circles"] = sdata_query_aggregation["values_circles"][:4] + sdata["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] element_dict, table = join_spatialelement_table( sdata=sdata, spatial_element_names=["values_circles", "values_polygons"], table_name="table", - how="inner", - ) - - assert table.n_obs == 10 + how=join_type, + ) + + if join_type in ["left", "inner"]: + # table check + assert table.n_obs == 9 + assert np.array_equal(table.obs["instance_id"][:4], sdata["values_circles"].index) + assert np.array_equal(table.obs["instance_id"][4:], sdata["values_polygons"].index) + # shapes check + assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"]) + assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"]) + elif join_type == "right": + # table check + assert_anndata_equal(table.obs, sdata["table"].obs) + # shapes check + assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"]) + assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"]) + elif join_type == "left_exclusive": + # TODO: currently not working, reported in this issue + pass + else: + assert join_type == "right_exclusive" + # table check + assert table.n_obs == sdata["table"].n_obs - len(sdata["values_circles"]) - len(sdata["values_polygons"]) + # shapes check + assert element_dict["values_circles"] is None + assert element_dict["values_polygons"] is None # TODO: there is a lot of dublicate code, simplify with a function that tests both the case sdata=None and sdata=sdata