Skip to content

Commit

Permalink
Enumerate node and edge attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Mar 12, 2024
1 parent b1d2909 commit bcaabe3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
1 change: 1 addition & 0 deletions src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_from_segmentation import graph_from_segmentation
from .graph_to_nx import graph_to_nx
19 changes: 19 additions & 0 deletions src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum


class NodeAttr(Enum):
"""Node attributes that can be added to candidate graph using the toolbox.
Note: Motile can flexibly support any custom attributes. The toolbox provides
implementations of commonly used ones, listed here.
"""

SEG_ID = "segmentation_id"


class EdgeAttr(Enum):
"""Edge attributes that can be added to candidate graph using the toolbox.
Note: Motile can flexibly support any custom attributes. The toolbox provides
implementations of commonly used ones, listed here.
"""

DISTANCE = "distance"
34 changes: 10 additions & 24 deletions src/motile_toolbox/candidate_graph/graph_from_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from skimage.measure import regionprops
from tqdm import tqdm

from .graph_attributes import EdgeAttr, NodeAttr

logger = logging.getLogger(__name__)


Expand All @@ -33,7 +35,7 @@ def _get_location(

def nodes_from_segmentation(
segmentation: np.ndarray,
attributes: tuple[str, ...] | list[str] = ("segmentation_id",),
attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
Expand Down Expand Up @@ -73,8 +75,8 @@ def nodes_from_segmentation(
attrs = {
frame_key: t,
}
if "segmentation_id" in attributes:
attrs["segmentation_id"] = regionprop.label
if NodeAttr.SEG_ID in attributes:
attrs[NodeAttr.SEG_ID.value] = regionprop.label
centroid = regionprop.centroid # [z,] y, x
for label, value in zip(position_keys, centroid):
attrs[label] = value
Expand All @@ -88,7 +90,7 @@ def nodes_from_segmentation(
def add_cand_edges(
cand_graph: nx.DiGraph,
max_edge_distance: float,
attributes: tuple[str, ...] | list[str] = ("distance",),
attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
node_frame_dict: None | dict[int, list[Any]] = None,
Expand Down Expand Up @@ -133,17 +135,17 @@ def add_cand_edges(
for next_id, next_loc in zip(next_nodes, next_locs):
dist = math.dist(next_loc, loc)
attrs = {}
if "distance" in attributes:
attrs["distance"] = dist
if EdgeAttr.DISTANCE in attributes:
attrs[EdgeAttr.DISTANCE.value] = dist
if dist <= max_edge_distance:
cand_graph.add_edge(node, next_id, **attrs)


def graph_from_segmentation(
segmentation: np.ndarray,
max_edge_distance: float,
node_attributes: tuple[str, ...] | list[str] = ("segmentation_id",),
edge_attributes: tuple[str, ...] | list[str] = ("distance",),
node_attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,),
edge_attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
):
Expand Down Expand Up @@ -181,22 +183,6 @@ def graph_from_segmentation(
arguments, or if the number of position keys provided does not match the
number of position dimensions.
"""
valid_edge_attributes = [
"distance",
]
for attr in edge_attributes:
if attr not in valid_edge_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attrs: {valid_edge_attributes})"
)
valid_node_attributes = [
"segmentation_id",
]
for attr in node_attributes:
if attr not in valid_node_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attrs: {valid_node_attributes})"
)
if len(position_keys) != segmentation.ndim - 1:
raise ValueError(
f"Position labels {position_keys} does not match number of spatial dims "
Expand Down
File renamed without changes.

0 comments on commit bcaabe3

Please sign in to comment.