diff --git a/src/motile_toolbox/utils/__init__.py b/src/motile_toolbox/utils/__init__.py index e69de29..eebaf83 100644 --- a/src/motile_toolbox/utils/__init__.py +++ b/src/motile_toolbox/utils/__init__.py @@ -0,0 +1 @@ +from .saving_utils import relabel_segmentation diff --git a/src/motile_toolbox/utils/saving_utils.py b/src/motile_toolbox/utils/saving_utils.py new file mode 100644 index 0000000..7419bf3 --- /dev/null +++ b/src/motile_toolbox/utils/saving_utils.py @@ -0,0 +1,40 @@ +import networkx as nx +import numpy as np + +from motile_toolbox.candidate_graph import NodeAttr + + +def relabel_segmentation( + solution_nx_graph: nx.DiGraph, + segmentation: np.array, + frame_key="t", +) -> np.array: + """Relabel a segmentation based on tracking results so that nodes in same + track share the same id. IDs do change at division. + + Args: + solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use + for relabeling. Nodes not in graph will be removed from seg. Original + segmentation ids have to be stored in the graph so we can map them back. + segmentation (np.array): Original segmentation with labels ids that correspond + to segmentation id in graph. + frame_key (str, optional): Time frame key in networkx graph. Defaults to "t". + + Returns: + np.array: Relabeled segmentation array where nodes in same track share same id. + """ + tracked_masks = np.zeros_like(segmentation) + id_counter = 1 + parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] + soln_copy = solution_nx_graph.copy() + for parent_node in parent_nodes: + out_edges = solution_nx_graph.out_edges(parent_node) + soln_copy.remove_edges_from(out_edges) + for node_set in nx.weakly_connected_components(soln_copy): + for node in node_set: + time_frame = solution_nx_graph.nodes[node][frame_key] + previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] + previous_seg_mask = segmentation[time_frame] == previous_seg_id + tracked_masks[time_frame][previous_seg_mask] = id_counter + id_counter += 1 + return tracked_masks diff --git a/tests/test_utils/test_saving_utils.py b/tests/test_utils/test_saving_utils.py new file mode 100644 index 0000000..c4ff2ac --- /dev/null +++ b/tests/test_utils/test_saving_utils.py @@ -0,0 +1,61 @@ +import networkx as nx +import numpy as np +import pytest +from motile_toolbox.utils import relabel_segmentation +from numpy.testing import assert_array_equal +from skimage.draw import disk + + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 2 + # second cell centered at (60, 45) with label 3 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 2 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 3 + + return segmentation + + +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), + ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}), + ] + edges = [ + ("0_1", "1_1", {"distance": 42.43}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def test_relabel_segmentation(segmentation_2d, graph_2d): + frame_shape = (100, 100) + expected = np.zeros(segmentation_2d.shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + expected[0][rr, cc] = 1 + + # make frame with cell centered at (20, 80) with label 1 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + expected[1][rr, cc] = 1 + + relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) + print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") + print(f"Nonzero expected: {np.count_nonzero(expected)}") + print(f"Max relabeled: {np.max(relabeled_seg)}") + print(f"Max expected: {np.max(expected)}") + + assert_array_equal(relabeled_seg, expected)