Skip to content

Commit

Permalink
Add function to relabel segmentation from tracking solution
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Mar 12, 2024
1 parent 6af0109 commit 05c0a8e
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/motile_toolbox/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .saving_utils import relabel_segmentation
40 changes: 40 additions & 0 deletions src/motile_toolbox/utils/saving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import networkx as nx
import numpy as np

from motile_toolbox.candidate_graph import NodeAttr


def relabel_segmentation(
solution_nx_graph: nx.DiGraph,
segmentation: np.array,
frame_key="t",
) -> np.array:
"""Relabel a segmentation based on tracking results so that nodes in same
track share the same id. IDs do change at division.
Args:
solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use
for relabeling. Nodes not in graph will be removed from seg. Original
segmentation ids have to be stored in the graph so we can map them back.
segmentation (np.array): Original segmentation with labels ids that correspond
to segmentation id in graph.
frame_key (str, optional): Time frame key in networkx graph. Defaults to "t".
Returns:
np.array: Relabeled segmentation array where nodes in same track share same id.
"""
tracked_masks = np.zeros_like(segmentation)
id_counter = 1
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]
soln_copy = solution_nx_graph.copy()
for parent_node in parent_nodes:
out_edges = solution_nx_graph.out_edges(parent_node)
soln_copy.remove_edges_from(out_edges)
for node_set in nx.weakly_connected_components(soln_copy):
for node in node_set:
time_frame = solution_nx_graph.nodes[node][frame_key]
previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value]
previous_seg_mask = segmentation[time_frame] == previous_seg_id
tracked_masks[time_frame][previous_seg_mask] = id_counter
id_counter += 1
return tracked_masks
61 changes: 61 additions & 0 deletions tests/test_utils/test_saving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import networkx as nx
import numpy as np
import pytest
from motile_toolbox.utils import relabel_segmentation
from numpy.testing import assert_array_equal
from skimage.draw import disk


@pytest.fixture
def segmentation_2d():
frame_shape = (100, 100)
total_shape = (2, *frame_shape)
segmentation = np.zeros(total_shape, dtype="int32")
# make frame with one cell in center with label 1
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100))
segmentation[0][rr, cc] = 1

# make frame with two cells
# first cell centered at (20, 80) with label 2
# second cell centered at (60, 45) with label 3
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape)
segmentation[1][rr, cc] = 2
rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape)
segmentation[1][rr, cc] = 3

return segmentation


@pytest.fixture
def graph_2d():
graph = nx.DiGraph()
nodes = [
("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}),
("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}),
]
edges = [
("0_1", "1_1", {"distance": 42.43}),
]
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
return graph


def test_relabel_segmentation(segmentation_2d, graph_2d):
frame_shape = (100, 100)
expected = np.zeros(segmentation_2d.shape, dtype="int32")
# make frame with one cell in center with label 1
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100))
expected[0][rr, cc] = 1

# make frame with cell centered at (20, 80) with label 1
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape)
expected[1][rr, cc] = 1

relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d)
print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}")
print(f"Nonzero expected: {np.count_nonzero(expected)}")
print(f"Max relabeled: {np.max(relabeled_seg)}")
print(f"Max expected: {np.max(expected)}")

assert_array_equal(relabeled_seg, expected)

0 comments on commit 05c0a8e

Please sign in to comment.