Skip to content

Commit

Permalink
Implement multi hypothesis candidate graph computation
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Apr 1, 2024
1 parent 8c401b8 commit 8d12a6a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 43 deletions.
41 changes: 39 additions & 2 deletions src/motile_toolbox/candidate_graph/iou.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from itertools import combinations

import networkx as nx
import numpy as np
from tqdm import tqdm

from .graph_attributes import EdgeAttr, NodeAttr
from .graph_from_segmentation import _get_node_id


def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]:
Expand Down Expand Up @@ -58,5 +61,39 @@ def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) -
node_seg_id = cand_graph.nodes[node_id][NodeAttr.SEG_ID.value]
for next_id in next_nodes:
next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value]
iou = ious.get(node_seg_id, {}).get( next_seg_id, 0)
cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou
iou = ious.get(node_seg_id, {}).get(next_seg_id, 0)
cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou


def add_multihypo_iou(
cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict
) -> None:
"""Add IOU to the candidate graph for multi-hypothesis segmentations.
Args:
cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated
segmentation (np.ndarray): Multiple hypothesis segmentation. Dimensions
are (t, h, [z], y, x), where h is the number of hypotheses.
"""
frames = sorted(node_frame_dict.keys())
num_hypotheses = segmentation.shape[1]
for frame in tqdm(frames):
if frame + 1 not in node_frame_dict:
continue
# construct dictionary of ious between node_ids in frame 1 and frame 2
ious: dict[str, dict[str, float]] = {}
for hypo1, hypo2 in combinations(range(num_hypotheses), 2):
hypo_ious = compute_ious(
segmentation[frame][hypo1], segmentation[frame + 1][hypo2]
)
for segid, intersecting_labels in hypo_ious.items():
node_id = _get_node_id(frame, segid, hypo1)
ious[node_id] = {}
for segid2, iou in intersecting_labels.items():
next_id = _get_node_id(frame + 1, segid2, hypo2)
ious[node_id][next_id] = iou
next_nodes = node_frame_dict[frame + 1]
for node_id in node_frame_dict[frame]:
for next_id in next_nodes:
iou = ious.get(node_id, {}).get(next_id, 0)
cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou
101 changes: 60 additions & 41 deletions src/motile_toolbox/candidate_graph/multi_seg_graph.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,87 @@
from typing import Any

from itertools import combinations

import networkx as nx
import numpy as np

from .graph_attributes import EdgeAttr, NodeAttr, add_iou
from .graph_from_segmentation import add_cand_edges, nodes_from_segmentation
from .graph_from_segmentation import (
_get_node_id,
add_cand_edges,
nodes_from_segmentation,
)
from .iou import add_multihypo_iou


def compute_multi_seg_graph(segmentations: list[np.ndarray]) -> tuple[nx.DiGraph, list[set]]:
"""Create a candidate graph from multi hypothesis segmentations. This is not
def compute_multi_seg_graph(
segmentation: np.ndarray,
max_edge_distance: float,
iou: bool = False,
) -> tuple[nx.DiGraph, list[set]]:
"""Create a candidate graph from multi hypothesis segmentation. This is not
tailored for agglomeration approaches with hierarchical merge graphs, it simply
creates a conflict set for any nodes that overlap in the same time frame.
Args:
segmentations (list[np.ndarray]):
segmentations (np.ndarray): Multiple hypothesis segmentation. Dimensions
are (t, h, [z], y, x), where h is the number of hypotheses.
Returns:
nx.DiGraph: _description_
"""
# for each segmentation, get nodes using same method as graph_from_segmentation
# add them all to one big graph
cand_graph, frame_dict = nodes_from_multi_segmentation(segmentations) # TODO: other args
cand_graph = nx.DiGraph()
node_frame_dict = {}
num_hypotheses = segmentation.shape[1]
for hypo_id in range(num_hypotheses):
hypothesis = segmentation[:,hypo_id]
node_graph, frame_dict = nodes_from_segmentation(hypothesis, hypo_id=hypo_id)
cand_graph.update(node_graph)
node_frame_dict.update(frame_dict)

# Compute conflict sets between segmentations
# can use same method as IOU (without the U) to compute conflict sets
conflicts = []
for time, segs in enumerate(segmentations):
conflicts.append(compute_conflict_sets(segs, time))
for time, segs in enumerate(segmentation):
conflicts.extend(compute_conflict_sets(segs, time))

# add edges with same method as before, with slightly different implementation
add_cand_edges(cand_graph) # TODO: other args
if EdgeAttr.IOU in edge_attributes:
# TODO: cross product when calling (need to re-organize add_iou to not assume stuff)
add_iou(cand_graph, segmentation)

return cand_graph



add_cand_edges(cand_graph, max_edge_distance, node_frame_dict)
if iou:
add_multihypo_iou(cand_graph, segmentation, node_frame_dict)

return cand_graph, conflicts


def nodes_from_multi_segmentation(
segmentations: list[np.ndarray],
attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
multi_hypo_node_graph = nx.DiGraph()
multi_frame_dict = {}
for layer_id, segmentation in enumerate(segmentations):
node_graph, frame_dict = nodes_from_segmentation(segmentation, layer_id)
# TODO: pass attributes, etc.
# TODO: add multi segmentation attribute to nodes_from_segmentation
# (use in node id and add to attributes)
multi_hypo_node_graph.update(node_graph)
multi_frame_dict.update(frame_dict)
# TODO: Make sure there is no node-id collision

return multi_hypo_node_graph, multi_frame_dict

def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]:
"""Segmentation in one frame only. Return
Args:
segmentation_frame (np.ndarray): One frame of the multiple hypothesis
segmentation. Dimensions are (h, [z], y, x), where h is the number of
hypotheses.
time (int): Time frame, for computing node_ids.
def compute_conflict_sets(segmenations: np.ndarray, time: int) -> list[set]:
"""Segmentations in one frame only. Return list of sets of node ids that conflict."""
# This will look a lot like the IOU code
pass
Returns:
list[set]: list of sets of node ids that overlap
"""
flattened_segs = [seg.flatten() for seg in segmentation_frame]

# get locations where at least two hypotheses have labels
# This approach may be inefficient, but likely doesn't matter compared to np.unique
conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool)
for seg1, seg2 in combinations(flattened_segs, 2):
non_zero_indices = np.logical_and(seg1, seg2)
conflict_indices = np.logical_or(conflict_indices, non_zero_indices)

flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs])
values = np.unique(flattened_stacked, axis=1)

conflict_sets = []
for conflicting_labels in values:
id_set = set()
for hypo_id, label in enumerate(conflicting_labels):
if label != 0:
id_set.add(_get_node_id(time, label, hypo_id))
conflict_sets.append(id_set)
return conflict_sets

0 comments on commit 8d12a6a

Please sign in to comment.