-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add function to relabel segmentation from tracking solution
- Loading branch information
1 parent
6af0109
commit 05c0a8e
Showing
3 changed files
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .saving_utils import relabel_segmentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import networkx as nx | ||
import numpy as np | ||
|
||
from motile_toolbox.candidate_graph import NodeAttr | ||
|
||
|
||
def relabel_segmentation( | ||
solution_nx_graph: nx.DiGraph, | ||
segmentation: np.array, | ||
frame_key="t", | ||
) -> np.array: | ||
"""Relabel a segmentation based on tracking results so that nodes in same | ||
track share the same id. IDs do change at division. | ||
Args: | ||
solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use | ||
for relabeling. Nodes not in graph will be removed from seg. Original | ||
segmentation ids have to be stored in the graph so we can map them back. | ||
segmentation (np.array): Original segmentation with labels ids that correspond | ||
to segmentation id in graph. | ||
frame_key (str, optional): Time frame key in networkx graph. Defaults to "t". | ||
Returns: | ||
np.array: Relabeled segmentation array where nodes in same track share same id. | ||
""" | ||
tracked_masks = np.zeros_like(segmentation) | ||
id_counter = 1 | ||
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] | ||
soln_copy = solution_nx_graph.copy() | ||
for parent_node in parent_nodes: | ||
out_edges = solution_nx_graph.out_edges(parent_node) | ||
soln_copy.remove_edges_from(out_edges) | ||
for node_set in nx.weakly_connected_components(soln_copy): | ||
for node in node_set: | ||
time_frame = solution_nx_graph.nodes[node][frame_key] | ||
previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] | ||
previous_seg_mask = segmentation[time_frame] == previous_seg_id | ||
tracked_masks[time_frame][previous_seg_mask] = id_counter | ||
id_counter += 1 | ||
return tracked_masks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import networkx as nx | ||
import numpy as np | ||
import pytest | ||
from motile_toolbox.utils import relabel_segmentation | ||
from numpy.testing import assert_array_equal | ||
from skimage.draw import disk | ||
|
||
|
||
@pytest.fixture | ||
def segmentation_2d(): | ||
frame_shape = (100, 100) | ||
total_shape = (2, *frame_shape) | ||
segmentation = np.zeros(total_shape, dtype="int32") | ||
# make frame with one cell in center with label 1 | ||
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) | ||
segmentation[0][rr, cc] = 1 | ||
|
||
# make frame with two cells | ||
# first cell centered at (20, 80) with label 2 | ||
# second cell centered at (60, 45) with label 3 | ||
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) | ||
segmentation[1][rr, cc] = 2 | ||
rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) | ||
segmentation[1][rr, cc] = 3 | ||
|
||
return segmentation | ||
|
||
|
||
@pytest.fixture | ||
def graph_2d(): | ||
graph = nx.DiGraph() | ||
nodes = [ | ||
("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), | ||
("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}), | ||
] | ||
edges = [ | ||
("0_1", "1_1", {"distance": 42.43}), | ||
] | ||
graph.add_nodes_from(nodes) | ||
graph.add_edges_from(edges) | ||
return graph | ||
|
||
|
||
def test_relabel_segmentation(segmentation_2d, graph_2d): | ||
frame_shape = (100, 100) | ||
expected = np.zeros(segmentation_2d.shape, dtype="int32") | ||
# make frame with one cell in center with label 1 | ||
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) | ||
expected[0][rr, cc] = 1 | ||
|
||
# make frame with cell centered at (20, 80) with label 1 | ||
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) | ||
expected[1][rr, cc] = 1 | ||
|
||
relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) | ||
print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") | ||
print(f"Nonzero expected: {np.count_nonzero(expected)}") | ||
print(f"Max relabeled: {np.max(relabeled_seg)}") | ||
print(f"Max expected: {np.max(expected)}") | ||
|
||
assert_array_equal(relabeled_seg, expected) |