diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index efd4cbb..d06fd4e 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,2 +1,3 @@ +from .graph_attributes import EdgeAttr, NodeAttr from .graph_from_segmentation import graph_from_segmentation from .graph_to_nx import graph_to_nx diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py new file mode 100644 index 0000000..767e023 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class NodeAttr(Enum): + """Node attributes that can be added to candidate graph using the toolbox. + Note: Motile can flexibly support any custom attributes. The toolbox provides + implementations of commonly used ones, listed here. + """ + + SEG_ID = "segmentation_id" + + +class EdgeAttr(Enum): + """Edge attributes that can be added to candidate graph using the toolbox. + Note: Motile can flexibly support any custom attributes. The toolbox provides + implementations of commonly used ones, listed here. + """ + + DISTANCE = "distance" diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py index 96c0879..42fe480 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py @@ -7,6 +7,8 @@ from skimage.measure import regionprops from tqdm import tqdm +from .graph_attributes import EdgeAttr, NodeAttr + logger = logging.getLogger(__name__) @@ -33,7 +35,7 @@ def _get_location( def nodes_from_segmentation( segmentation: np.ndarray, - attributes: tuple[str, ...] | list[str] = ("segmentation_id",), + attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: @@ -73,8 +75,8 @@ def nodes_from_segmentation( attrs = { frame_key: t, } - if "segmentation_id" in attributes: - attrs["segmentation_id"] = regionprop.label + if NodeAttr.SEG_ID in attributes: + attrs[NodeAttr.SEG_ID.value] = regionprop.label centroid = regionprop.centroid # [z,] y, x for label, value in zip(position_keys, centroid): attrs[label] = value @@ -88,7 +90,7 @@ def nodes_from_segmentation( def add_cand_edges( cand_graph: nx.DiGraph, max_edge_distance: float, - attributes: tuple[str, ...] | list[str] = ("distance",), + attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", node_frame_dict: None | dict[int, list[Any]] = None, @@ -133,8 +135,8 @@ def add_cand_edges( for next_id, next_loc in zip(next_nodes, next_locs): dist = math.dist(next_loc, loc) attrs = {} - if "distance" in attributes: - attrs["distance"] = dist + if EdgeAttr.DISTANCE in attributes: + attrs[EdgeAttr.DISTANCE.value] = dist if dist <= max_edge_distance: cand_graph.add_edge(node, next_id, **attrs) @@ -142,8 +144,8 @@ def add_cand_edges( def graph_from_segmentation( segmentation: np.ndarray, max_edge_distance: float, - node_attributes: tuple[str, ...] | list[str] = ("segmentation_id",), - edge_attributes: tuple[str, ...] | list[str] = ("distance",), + node_attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), + edge_attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", ): @@ -181,22 +183,6 @@ def graph_from_segmentation( arguments, or if the number of position keys provided does not match the number of position dimensions. """ - valid_edge_attributes = [ - "distance", - ] - for attr in edge_attributes: - if attr not in valid_edge_attributes: - raise ValueError( - f"Invalid attribute {attr} (supported attrs: {valid_edge_attributes})" - ) - valid_node_attributes = [ - "segmentation_id", - ] - for attr in node_attributes: - if attr not in valid_node_attributes: - raise ValueError( - f"Invalid attribute {attr} (supported attrs: {valid_node_attributes})" - ) if len(position_keys) != segmentation.ndim - 1: raise ValueError( f"Position labels {position_keys} does not match number of spatial dims " diff --git a/tests/utils/test_loading_utils.py b/tests/test_utils/test_loading_utils.py similarity index 100% rename from tests/utils/test_loading_utils.py rename to tests/test_utils/test_loading_utils.py