diff --git a/setup.sh b/setup.sh index f5f8668..427a148 100644 --- a/setup.sh +++ b/setup.sh @@ -26,6 +26,7 @@ pip install matplotlib pip install ipywidgets pip install nbformat pip install pandas +pip install git+https://github.com/funkelab/motile_napari_plugin.git@track-viewer#egg=motile_plugin # Make environment discoverable by Jupyter pip install ipykernel diff --git a/solution.py b/solution.py index df267fe..9ab1deb 100644 --- a/solution.py +++ b/solution.py @@ -62,9 +62,6 @@ # TODO: remove import motile - - - # %% import time from pathlib import Path @@ -73,16 +70,16 @@ import numpy as np import napari import networkx as nx -import plotly.io as pio import scipy -pio.renderers.default = "vscode" import motile import zarr from motile_toolbox.visualization import to_napari_tracks_layer from motile_toolbox.candidate_graph import graph_to_nx +import motile_plugin.widgets as plugin_widgets +from motile_plugin.backend.motile_run import MotileRun from napari.layers import Tracks import traccuracy from traccuracy import run_metrics @@ -200,10 +197,17 @@ def read_gt_tracks(): # We can also use the helper function `to_napari_tracks_layer` to visualize the ground truth tracks in our napari viewer. # %% -tracks_layer = to_napari_tracks_layer( - gt_tracks, frame_key="time", location_key="pos", name="gt_tracks" + +widget = plugin_widgets.TreeWidget(viewer) +viewer.window.add_dock_widget(widget, name="Lineage View", area="bottom") + +# %% +ground_truth_run = MotileRun( + run_name="ground_truth", + tracks=gt_tracks, ) -viewer.add_layer(tracks_layer) + +widget.view_controller.update_napari_layers(ground_truth_run, time_attr="t", pos_attr=("x", "y")) # %% [markdown] # ## Build a candidate graph from the detections @@ -521,32 +525,10 @@ def print_graph_stats(graph, name): # # Note that bad tracking results at this point does not mean that you implemented anything wrong! We still need to customize our costs and constraints to the task before we can get good results. As long as your pipeline selects something, and you can kind of interepret why it is going wrong, that is all that is needed at this point. -# %% -# Add a tracks layer -tracks_layer = to_napari_tracks_layer(solution_graph, frame_key="t", location_key="pos", name="solution_tracks") -viewer.add_layer(tracks_layer) - - -# %% -def filter_segmentation( - solution_nx_graph: nx.DiGraph, - segmentation: np.ndarray, -) -> np.ndarray: - filtered_masks = np.zeros_like(segmentation) - for node in solution_nx_graph.nodes(): - time_frame = solution_nx_graph.nodes[node]["t"] - seg_mask = ( - segmentation[time_frame] == node - ) - filtered_masks[time_frame][seg_mask] = node - return filtered_masks - -filtered_segmentation = filter_segmentation(solution_graph, segmentation) - - # %% # recolor the segmentation +from motile_toolbox.visualization.napari_utils import assign_tracklet_ids def relabel_segmentation( solution_nx_graph: nx.DiGraph, segmentation: np.ndarray, @@ -567,28 +549,28 @@ def relabel_segmentation( np.ndarray: Relabeled segmentation array where nodes in same track share same id with shape (t,1,[z],y,x) """ + assign_tracklet_ids(solution_nx_graph) tracked_masks = np.zeros_like(segmentation) - id_counter = 1 - parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] - soln_copy = solution_nx_graph.copy() - for parent_node in parent_nodes: - out_edges = solution_nx_graph.out_edges(parent_node) - soln_copy.remove_edges_from(out_edges) - for node_set in nx.weakly_connected_components(soln_copy): - for node in node_set: - time_frame = solution_nx_graph.nodes[node]["t"] - previous_seg_id = node - previous_seg_mask = ( - segmentation[time_frame] == previous_seg_id - ) - tracked_masks[time_frame][previous_seg_mask] = id_counter - solution_graph.nodes[node]["label"] = id_counter - id_counter += 1 + for node, data in solution_nx_graph.nodes(data=True): + time_frame = solution_nx_graph.nodes[node]["t"] + previous_seg_id = node + track_id = solution_nx_graph.nodes[node]["tracklet_id"] + previous_seg_mask = ( + segmentation[time_frame] == previous_seg_id + ) + tracked_masks[time_frame][previous_seg_mask] = track_id return tracked_masks - solution_seg = relabel_segmentation(solution_graph, segmentation) -viewer.add_labels(solution_seg, name="solution_seg") + +# %% +basic_run = MotileRun( + run_name="basic_solution_test", + tracks=solution_graph, + output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API +) + +widget.view_controller.update_napari_layers(basic_run, time_attr="t", pos_attr=("x", "y")) # %% [markdown] #

Question 2: Interpret your results based on visualization

@@ -618,6 +600,7 @@ def make_gt_detections(data_shape, gt_tracks, radius): for node, data in gt_tracks.nodes(data=True): pos = (data["x"], data["y"]) time = data["t"] + gt_tracks.nodes[node]["label"] = node rr, cc = disk(center=pos, radius=radius, shape=frame_shape) segmentation[time][rr, cc] = node return segmentation @@ -626,12 +609,6 @@ def make_gt_detections(data_shape, gt_tracks, radius): # viewer.add_image(gt_dets) -# %% - -for node in gt_tracks.nodes: - gt_tracks.nodes[node]["label"] = node - - # %% def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): """Calculate metrics for linked tracks by comparing to ground truth. @@ -657,7 +634,7 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): pred_graph = traccuracy.TrackingGraph( graph=pred_graph, frame_key="t", - label_key="label", + label_key="tracklet_id", location_keys=("x", "y"), segmentation=pred_segmentation, ) @@ -672,34 +649,6 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): return results -# %% - -gt_graph = traccuracy.TrackingGraph( - graph=gt_tracks, - frame_key="t", - label_key="label", - location_keys=("x", "y"), - segmentation=gt_dets, -) -print(gt_dets.shape) -pred_graph = traccuracy.TrackingGraph( - graph=solution_graph, - frame_key="t", - label_key="label", - location_keys=("x", "y"), - segmentation=solution_seg.astype(np.uint32), -) -print(solution_seg.astype(np.uint32).shape) -print(isinstance(gt_graph, traccuracy.TrackingGraph)) -print(isinstance(pred_graph, traccuracy.TrackingGraph)) - -matcher = IOUMatcher(iou_threshold=0.3, one_to_one=False) -matched = matcher._compute_mapping(gt_graph, pred_graph) -CTCMetrics().compute(matched).to_dict() - -# %% -DivisionMetrics().compute(matched) - # %% get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg.astype(np.uint32)) @@ -759,7 +708,6 @@ def add_appear_ignore_attr(cand_graph): cand_graph.nodes[node]["ignore_appear"] = True add_appear_ignore_attr(cand_graph) -cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="time") # %% [markdown] @@ -818,16 +766,17 @@ def solve_appear_optimization(cand_graph): # %% solution_graph = solve_appear_optimization(cand_graph) +solution_seg = relabel_segmentation(solution_graph, segmentation) # %% +appear_run = MotileRun( + run_name="appear_solution", + tracks=solution_graph, + output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API +) -tracks_layer = to_napari_tracks_layer(solution_graph, frame_key="time", location_key="pos", name="solution_appear_tracks") -viewer.add_layer(tracks_layer) - +widget.view_controller.update_napari_layers(appear_run, time_attr="t", pos_attr=("x", "y")) -# %% -solution_seg = relabel_segmentation(solution_graph, segmentation) -viewer.add_labels(solution_seg, name="solution_appear_seg") # %% get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) @@ -900,7 +849,7 @@ def solve_drift_optimization(cand_graph): solution_graph = graph_to_nx(solver.get_selected_subgraph()) return solution_graph -solution_graph = solve_drift_optimization(cand_trackgraph, 1, -20) +solution_graph = solve_drift_optimization(cand_graph) # %% tags=["solution"] @@ -921,6 +870,7 @@ def solve_drift_optimization(cand_graph): motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist") ) solver.add_cost(motile.costs.Appear(constant=100, ignore_attribute="ignore_appear")) + solver.add_cost(motile.costs.Split(constant=20)) solver.add_constraint(motile.constraints.MaxParents(1)) solver.add_constraint(motile.constraints.MaxChildren(2)) @@ -932,11 +882,16 @@ def solve_drift_optimization(cand_graph): # %% solution_graph = solve_drift_optimization(cand_graph) -# tracks_layer = to_napari_tracks_layer(solution_graph, frame_key="time", location_key="pos", name="solution_tracks_with_drift") -# viewer.add_layer(tracks_layer) - solution_seg = relabel_segmentation(solution_graph, segmentation) -viewer.add_labels(solution_seg, name="solution_seg_with_drift") + +# %% +drift_run = MotileRun( + run_name="drift_solution", + tracks=solution_graph, + output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API +) + +widget.view_controller.update_napari_layers(drift_run, time_attr="t", pos_attr=("x", "y")) # %% get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) @@ -959,32 +914,49 @@ def get_cand_id(gt_node, gt_track, cand_segmentation): data = gt_track.nodes[gt_node] return cand_segmentation[data["t"], int(data["x"])][int(data["y"])] -for gt_node in gt_tracks.nodes(): - cand_id = get_cand_id(gt_node, gt_tracks, segmentation) - if cand_id != 0: - cand_graph.nodes[cand_id]["gt"] = True - succs = gt_tracks.successors(gt_node) - for succ in succs: - succ_id = get_cand_id(succ, gt_tracks, segmentation) - if succ_id != 0: - cand_graph.edges[(cand_id, succ_id)]["gt"] = True +def add_gt_annotations(gt_tracks, cand_graph, segmentation): + for gt_node in gt_tracks.nodes(): + cand_id = get_cand_id(gt_node, gt_tracks, segmentation) + if cand_id != 0: + if cand_id in cand_graph: + cand_graph.nodes[cand_id]["gt"] = True + gt_succs = gt_tracks.successors(gt_node) + gt_succ_matches = [get_cand_id(gt_succ, gt_tracks, segmentation) for gt_succ in gt_succs] + cand_succs = cand_graph.successors(cand_id) + for succ in cand_succs: + if succ in gt_succ_matches: + cand_graph.edges[(cand_id, succ)]["gt"] = True + else: + cand_graph.edges[(cand_id, succ)]["gt"] = False + for node in cand_graph.nodes(): + if "gt" not in cand_graph.nodes[node]: + cand_graph.nodes[node]["gt"] = False # %% -import logging +validation_times = [0, 3] +validation_nodes = [node for node, data in cand_graph.nodes(data=True) + if (data["t"] >= validation_times[0] and data["t"] < validation_times[1])] +print(len(validation_nodes)) +validation_graph = cand_graph.subgraph(validation_nodes).copy() +add_gt_annotations(gt_tracks, validation_graph, segmentation) -logging.basicConfig(level=logging.INFO) +# %% +gt_pos_nodes = [node_id for node_id, data in validation_graph.nodes(data=True) if "gt" in data and data["gt"] is True] +gt_neg_nodes = [node_id for node_id, data in validation_graph.nodes(data=True) if "gt" in data and data["gt"] is False] +gt_pos_edges = [(source, target) for source, target, data in validation_graph.edges(data=True) if "gt" in data and data["gt"] is True] +gt_neg_edges = [(source, target) for source, target, data in validation_graph.edges(data=True) if "gt" in data and data["gt"] is False] -def solve_SSVM_optimization(cand_graph): - """Set up and solve the network flow problem. +print(f"{len(gt_pos_nodes) + len(gt_neg_nodes)} annotated: {len(gt_pos_nodes)} True, {len(gt_neg_nodes)} False") +print(f"{len(gt_pos_edges) + len(gt_neg_edges)} annotated: {len(gt_pos_edges)} True, {len(gt_neg_edges)} False") - Args: - cand_graph (nx.DiGraph): The candidate graph. +# %% +import logging - Returns: - nx.DiGraph: The networkx digraph with the selected solution tracks - """ +logging.basicConfig(level=logging.INFO) + +def get_ssvm_solver(cand_graph): cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t") solver = motile.Solver(cand_trackgraph) @@ -992,24 +964,45 @@ def solve_SSVM_optimization(cand_graph): solver.add_cost( motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist") ) - solver.add_cost(motile.costs.Appear(constant=0, ignore_attribute="ignore_appear")) + solver.add_cost(motile.costs.Appear(constant=20, ignore_attribute="ignore_appear")) + solver.add_cost(motile.costs.Split(constant=20)) solver.add_constraint(motile.constraints.MaxParents(1)) solver.add_constraint(motile.constraints.MaxChildren(2)) + return solver + + +# %% +ssvm_solver = get_ssvm_solver(validation_graph) +ssvm_solver.fit_weights(gt_attribute="gt", regularizer_weight=100, max_iterations=50) +optimal_weights = ssvm_solver.weights +optimal_weights - solver.fit_weights(gt_attribute="gt", regularizer_weight=0.00001, max_iterations=1000) - print(solver.weights) + +# %% +def get_ssvm_solution(cand_graph, solver_weights): + solver = get_ssvm_solver(cand_graph) + solver.weights = solver_weights solver.solve() solution_graph = graph_to_nx(solver.get_selected_subgraph()) return solution_graph +solution_graph = get_ssvm_solution(cand_graph, optimal_weights) + # %% -solution_graph = solve_SSVM_optimization(cand_graph) +solution_seg = relabel_segmentation(solution_graph, segmentation) +get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) # %% +ssvm_run = MotileRun( + run_name="ssvm_solution", + tracks=solution_graph, + output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API +) -solution_seg = relabel_segmentation(solution_graph, segmentation) -get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) +widget.view_controller.update_napari_layers(ssvm_run, time_attr="t", pos_attr=("x", "y")) + +# %% # %%