diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index 599e636..ceb6431 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -1,8 +1,11 @@ +from itertools import combinations + import networkx as nx import numpy as np from tqdm import tqdm from .graph_attributes import EdgeAttr, NodeAttr +from .graph_from_segmentation import _get_node_id def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: @@ -58,5 +61,39 @@ def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) - node_seg_id = cand_graph.nodes[node_id][NodeAttr.SEG_ID.value] for next_id in next_nodes: next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value] - iou = ious.get(node_seg_id, {}).get( next_seg_id, 0) - cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou \ No newline at end of file + iou = ious.get(node_seg_id, {}).get(next_seg_id, 0) + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou + + +def add_multihypo_iou( + cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict +) -> None: + """Add IOU to the candidate graph for multi-hypothesis segmentations. + + Args: + cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated + segmentation (np.ndarray): Multiple hypothesis segmentation. Dimensions + are (t, h, [z], y, x), where h is the number of hypotheses. + """ + frames = sorted(node_frame_dict.keys()) + num_hypotheses = segmentation.shape[1] + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + # construct dictionary of ious between node_ids in frame 1 and frame 2 + ious: dict[str, dict[str, float]] = {} + for hypo1, hypo2 in combinations(range(num_hypotheses), 2): + hypo_ious = compute_ious( + segmentation[frame][hypo1], segmentation[frame + 1][hypo2] + ) + for segid, intersecting_labels in hypo_ious.items(): + node_id = _get_node_id(frame, segid, hypo1) + ious[node_id] = {} + for segid2, iou in intersecting_labels.items(): + next_id = _get_node_id(frame + 1, segid2, hypo2) + ious[node_id][next_id] = iou + next_nodes = node_frame_dict[frame + 1] + for node_id in node_frame_dict[frame]: + for next_id in next_nodes: + iou = ious.get(node_id, {}).get(next_id, 0) + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/src/motile_toolbox/candidate_graph/multi_seg_graph.py b/src/motile_toolbox/candidate_graph/multi_seg_graph.py index eb13085..4635044 100644 --- a/src/motile_toolbox/candidate_graph/multi_seg_graph.py +++ b/src/motile_toolbox/candidate_graph/multi_seg_graph.py @@ -1,68 +1,87 @@ -from typing import Any + +from itertools import combinations import networkx as nx import numpy as np -from .graph_attributes import EdgeAttr, NodeAttr, add_iou -from .graph_from_segmentation import add_cand_edges, nodes_from_segmentation +from .graph_from_segmentation import ( + _get_node_id, + add_cand_edges, + nodes_from_segmentation, +) +from .iou import add_multihypo_iou -def compute_multi_seg_graph(segmentations: list[np.ndarray]) -> tuple[nx.DiGraph, list[set]]: - """Create a candidate graph from multi hypothesis segmentations. This is not +def compute_multi_seg_graph( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, +) -> tuple[nx.DiGraph, list[set]]: + """Create a candidate graph from multi hypothesis segmentation. This is not tailored for agglomeration approaches with hierarchical merge graphs, it simply creates a conflict set for any nodes that overlap in the same time frame. Args: - segmentations (list[np.ndarray]): + segmentations (np.ndarray): Multiple hypothesis segmentation. Dimensions + are (t, h, [z], y, x), where h is the number of hypotheses. Returns: nx.DiGraph: _description_ """ # for each segmentation, get nodes using same method as graph_from_segmentation # add them all to one big graph - cand_graph, frame_dict = nodes_from_multi_segmentation(segmentations) # TODO: other args + cand_graph = nx.DiGraph() + node_frame_dict = {} + num_hypotheses = segmentation.shape[1] + for hypo_id in range(num_hypotheses): + hypothesis = segmentation[:,hypo_id] + node_graph, frame_dict = nodes_from_segmentation(hypothesis, hypo_id=hypo_id) + cand_graph.update(node_graph) + node_frame_dict.update(frame_dict) # Compute conflict sets between segmentations # can use same method as IOU (without the U) to compute conflict sets conflicts = [] - for time, segs in enumerate(segmentations): - conflicts.append(compute_conflict_sets(segs, time)) + for time, segs in enumerate(segmentation): + conflicts.extend(compute_conflict_sets(segs, time)) # add edges with same method as before, with slightly different implementation - add_cand_edges(cand_graph) # TODO: other args - if EdgeAttr.IOU in edge_attributes: - # TODO: cross product when calling (need to re-organize add_iou to not assume stuff) - add_iou(cand_graph, segmentation) - - return cand_graph - - - + add_cand_edges(cand_graph, max_edge_distance, node_frame_dict) + if iou: + add_multihypo_iou(cand_graph, segmentation, node_frame_dict) + return cand_graph, conflicts -def nodes_from_multi_segmentation( - segmentations: list[np.ndarray], - attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", -) -> tuple[nx.DiGraph, dict[int, list[Any]]]: - multi_hypo_node_graph = nx.DiGraph() - multi_frame_dict = {} - for layer_id, segmentation in enumerate(segmentations): - node_graph, frame_dict = nodes_from_segmentation(segmentation, layer_id) - # TODO: pass attributes, etc. - # TODO: add multi segmentation attribute to nodes_from_segmentation - # (use in node id and add to attributes) - multi_hypo_node_graph.update(node_graph) - multi_frame_dict.update(frame_dict) - # TODO: Make sure there is no node-id collision - - return multi_hypo_node_graph, multi_frame_dict - +def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: + """Segmentation in one frame only. Return + Args: + segmentation_frame (np.ndarray): One frame of the multiple hypothesis + segmentation. Dimensions are (h, [z], y, x), where h is the number of + hypotheses. + time (int): Time frame, for computing node_ids. -def compute_conflict_sets(segmenations: np.ndarray, time: int) -> list[set]: - """Segmentations in one frame only. Return list of sets of node ids that conflict.""" - # This will look a lot like the IOU code - pass + Returns: + list[set]: list of sets of node ids that overlap + """ + flattened_segs = [seg.flatten() for seg in segmentation_frame] + + # get locations where at least two hypotheses have labels + # This approach may be inefficient, but likely doesn't matter compared to np.unique + conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) + for seg1, seg2 in combinations(flattened_segs, 2): + non_zero_indices = np.logical_and(seg1, seg2) + conflict_indices = np.logical_or(conflict_indices, non_zero_indices) + + flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) + values = np.unique(flattened_stacked, axis=1) + + conflict_sets = [] + for conflicting_labels in values: + id_set = set() + for hypo_id, label in enumerate(conflicting_labels): + if label != 0: + id_set.add(_get_node_id(time, label, hypo_id)) + conflict_sets.append(id_set) + return conflict_sets