From 1d73abcc8da486e81ea353b2d6b8352eeb2955da Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 5 Aug 2024 17:41:18 -0400 Subject: [PATCH] Add SSVM working code (no explanation) --- solution.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/solution.py b/solution.py index 9ccd938..df267fe 100644 --- a/solution.py +++ b/solution.py @@ -920,7 +920,7 @@ def solve_drift_optimization(cand_graph): solver.add_cost( motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist") ) - solver.add_cost(motile.costs.Appear(constant=0, ignore_attribute="ignore_appear")) + solver.add_cost(motile.costs.Appear(constant=100, ignore_attribute="ignore_appear")) solver.add_constraint(motile.constraints.MaxParents(1)) solver.add_constraint(motile.constraints.MaxChildren(2)) @@ -941,6 +941,7 @@ def solve_drift_optimization(cand_graph): # %% get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) + # %% [markdown] # ## Checkpoint 4 #

Checkpoint 4

@@ -954,7 +955,61 @@ def solve_drift_optimization(cand_graph): # # %% -# get a ground truth track -# match it to the candidate graph and annotate TP nodes and edges +def get_cand_id(gt_node, gt_track, cand_segmentation): + data = gt_track.nodes[gt_node] + return cand_segmentation[data["t"], int(data["x"])][int(data["y"])] + +for gt_node in gt_tracks.nodes(): + cand_id = get_cand_id(gt_node, gt_tracks, segmentation) + if cand_id != 0: + cand_graph.nodes[cand_id]["gt"] = True + succs = gt_tracks.successors(gt_node) + for succ in succs: + succ_id = get_cand_id(succ, gt_tracks, segmentation) + if succ_id != 0: + cand_graph.edges[(cand_id, succ_id)]["gt"] = True + + +# %% +import logging -# find a negative node in napari +logging.basicConfig(level=logging.INFO) + + +def solve_SSVM_optimization(cand_graph): + """Set up and solve the network flow problem. + + Args: + cand_graph (nx.DiGraph): The candidate graph. + + Returns: + nx.DiGraph: The networkx digraph with the selected solution tracks + """ + + cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t") + solver = motile.Solver(cand_trackgraph) + + solver.add_cost( + motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist") + ) + solver.add_cost(motile.costs.Appear(constant=0, ignore_attribute="ignore_appear")) + + solver.add_constraint(motile.constraints.MaxParents(1)) + solver.add_constraint(motile.constraints.MaxChildren(2)) + + solver.fit_weights(gt_attribute="gt", regularizer_weight=0.00001, max_iterations=1000) + print(solver.weights) + solver.solve() + solution_graph = graph_to_nx(solver.get_selected_subgraph()) + return solution_graph + + +# %% +solution_graph = solve_SSVM_optimization(cand_graph) + +# %% + +solution_seg = relabel_segmentation(solution_graph, segmentation) +get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) + +# %%