Skip to content

Commit

Permalink
Add SSVM working code (no explanation)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Aug 5, 2024
1 parent 0caf79b commit 1d73abc
Showing 1 changed file with 59 additions and 4 deletions.
63 changes: 59 additions & 4 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def solve_drift_optimization(cand_graph):
solver.add_cost(
motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist")
)
solver.add_cost(motile.costs.Appear(constant=0, ignore_attribute="ignore_appear"))
solver.add_cost(motile.costs.Appear(constant=100, ignore_attribute="ignore_appear"))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))
Expand All @@ -941,6 +941,7 @@ def solve_drift_optimization(cand_graph):
# %%
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)


# %% [markdown]
# ## Checkpoint 4
# <div class="alert alert-block alert-success"><h3>Checkpoint 4</h3>
Expand All @@ -954,7 +955,61 @@ def solve_drift_optimization(cand_graph):
#

# %%
# get a ground truth track
# match it to the candidate graph and annotate TP nodes and edges
def get_cand_id(gt_node, gt_track, cand_segmentation):
data = gt_track.nodes[gt_node]
return cand_segmentation[data["t"], int(data["x"])][int(data["y"])]

for gt_node in gt_tracks.nodes():
cand_id = get_cand_id(gt_node, gt_tracks, segmentation)
if cand_id != 0:
cand_graph.nodes[cand_id]["gt"] = True
succs = gt_tracks.successors(gt_node)
for succ in succs:
succ_id = get_cand_id(succ, gt_tracks, segmentation)
if succ_id != 0:
cand_graph.edges[(cand_id, succ_id)]["gt"] = True


# %%
import logging

# find a negative node in napari
logging.basicConfig(level=logging.INFO)


def solve_SSVM_optimization(cand_graph):
"""Set up and solve the network flow problem.
Args:
cand_graph (nx.DiGraph): The candidate graph.
Returns:
nx.DiGraph: The networkx digraph with the selected solution tracks
"""

cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)

solver.add_cost(
motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist")
)
solver.add_cost(motile.costs.Appear(constant=0, ignore_attribute="ignore_appear"))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))

solver.fit_weights(gt_attribute="gt", regularizer_weight=0.00001, max_iterations=1000)
print(solver.weights)
solver.solve()
solution_graph = graph_to_nx(solver.get_selected_subgraph())
return solution_graph


# %%
solution_graph = solve_SSVM_optimization(cand_graph)

# %%

solution_seg = relabel_segmentation(solution_graph, segmentation)
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)

# %%

0 comments on commit 1d73abc

Please sign in to comment.