Skip to content

Commit

Permalink
Add node scores and convenience pipeline function
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Aug 21, 2024
1 parent 324ce41 commit 44677fc
Showing 1 changed file with 147 additions and 41 deletions.
188 changes: 147 additions & 41 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.15.0
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -81,27 +81,20 @@
# ## Import packages

# %%
import time
from pathlib import Path

import skimage
import numpy as np
import napari
import networkx as nx
import scipy


import motile

import zarr
from motile_toolbox.visualization import to_napari_tracks_layer
from motile_toolbox.candidate_graph import graph_to_nx
from motile_toolbox.visualization.napari_utils import assign_tracklet_ids
import motile_plugin.widgets as plugin_widgets
from motile_plugin.backend.motile_run import MotileRun
from napari.layers import Tracks
import traccuracy
from traccuracy import run_metrics
from traccuracy.metrics import CTCMetrics, DivisionMetrics
from traccuracy.matchers import IOUMatcher
from csv import DictReader
Expand All @@ -127,13 +120,13 @@
probabilities = data_root["probs"][:]

# %% [markdown]
# Let's use [napari](https://napari.org/tutorials/fundamentals/getting_started.html) to visualize the data. Napari is a wonderful viewer for imaging data that you can interact with in python, even directly out of jupyter notebooks. If you've never used napari, you might want to take a few minutes to go through [this tutorial](https://napari.org/stable/tutorials/fundamentals/viewer.html).
# Let's use [napari](https://napari.org/tutorials/fundamentals/getting_started.html) to visualize the data. Napari is a wonderful viewer for imaging data that you can interact with in python, even directly out of jupyter notebooks. If you've never used napari, you might want to take a few minutes to go through [this tutorial](https://napari.org/stable/tutorials/fundamentals/viewer.html). Here we visualize the raw data, the predicted segmentations, and the predicted probabilities as separate layers. You can toggle each layer on and off in the layers list on the left.

# %%
viewer = napari.Viewer()
viewer.add_image(probabilities, name="probs", scale=(1, 2, 2))
viewer.add_image(image_data, name="raw")
viewer.add_labels(segmentation, name="seg")
viewer.add_image(probabilities, name="probs", scale=(1, 2, 2))

# %% [markdown]
# After running the previous cell, open NoMachine and check for an open napari window.
Expand Down Expand Up @@ -196,6 +189,7 @@ def read_gt_tracks():
gt_tracks.add_node(_id, **attrs)
if parent_id != -1:
gt_tracks.add_edge(parent_id, _id)

return gt_tracks

gt_tracks = read_gt_tracks()
Expand Down Expand Up @@ -234,6 +228,7 @@ def read_gt_tracks():
# Hint - if your screen is too small, you can "pop out" the lineage tree view into a separate window using the icon that looks like two boxes in the top left of the lineage tree view. You can also close the tree view with the x just above it, and open it again from the menu bar: Plugins -> Motile -> Lineage View (then re-run the below cell to add the data to the lineage view).

# %%
assign_tracklet_ids(gt_tracks)
ground_truth_run = MotileRun(
run_name="ground_truth",
tracks=gt_tracks,
Expand Down Expand Up @@ -264,6 +259,7 @@ def read_gt_tracks():
# <li>The node id is the label of the detection</li>
# <li>Each node has an integer "t" attribute, based on the index into the first dimension of the input segmentation array</li>
# <li>Each node has float "x" and "y" attributes containing the "x" and "y" values from the centroid of the detection region</li>
# <li>Each node has a "score" attribute containing the probability score output from StarDist. The probability map is at half resolution, so you will need to divide the centroid by 2 before indexing into the probability score.</li>
# <li>The graph has no edges (yet!)</li>
# </ol>
# </div>
Expand Down Expand Up @@ -310,10 +306,13 @@ def nodes_from_segmentation(segmentation: np.ndarray) -> nx.DiGraph:
props = skimage.measure.regionprops(seg_frame)
for regionprop in props:
node_id = regionprop.label
x = float(regionprop.centroid[0])
y = float(regionprop.centroid[1])
attrs = {
"t": t,
"x": float(regionprop.centroid[0]),
"y": float(regionprop.centroid[1]),
"x": x,
"y": y,
"score": float(probabilities[t, int(x // 2), int(y // 2)]),
}
assert node_id not in cand_graph.nodes
cand_graph.add_node(node_id, **attrs)
Expand All @@ -333,6 +332,8 @@ def nodes_from_segmentation(segmentation: np.ndarray) -> nx.DiGraph:
assert type(data["x"]) == float, f"'x' attribute has type {type(data['x'])}, expected 'float'"
assert "y" in data, f"'y' attribute missing for node {node}"
assert type(data["y"]) == float, f"'y' attribute has type {type(data['y'])}, expected 'float'"
assert "score" in data, f"'score' attribute missing for node {node}"
assert type(data["score"]) == float, f"'score' attribute has type {type(data['score'])}, expected 'float'"
print("Your candidate graph passed all the tests!")

# %% [markdown]
Expand Down Expand Up @@ -441,8 +442,6 @@ def add_cand_edges(
# A set of linear constraints ensures that the solution will be a feasible cell tracking graph. For example, if an edge is part of $\tilde{G}$, both its incident nodes have to be part of $\tilde{G}$ as well.
#
# `motile` ([docs here](https://funkelab.github.io/motile/)), makes it easy to link with an ILP in python by implementing common linking constraints and costs.
#
# TODO: delete this?

# %% [markdown]
# ## Task 3 - Basic tracking with motile
Expand All @@ -451,11 +450,10 @@ def add_cand_edges(
#
# Here are some key similarities and differences between the quickstart and our task:
# <ul>
# <li>We do not have scores on our nodes. This means we do not need to include a `NodeSelection` cost.</li>
# <li>We also do not have scores on our edges. However, we can use the edge distance as a cost, so that longer edges are more costly than shorter edges. Instead of using the `EdgeSelection` cost, we can use the <a href=https://funkelab.github.io/motile/api.html#edgedistance>`EdgeDistance`</a> cost with `position_attribute="pos"`. You will want a positive weight, since higher distances should be more costly, unlike in the example when higher scores were good and so we inverted them with a negative weight.</li>
# <li>Because distance is always positive, and you want a positive weight, you will want to include a negative constant on the `EdgeDistance` cost. If there are no negative selection costs, the ILP will always select nothing, because the cost of selecting nothing is zero.</li>
# <li>We want to allow divisions. So, we should pass in 2 to our `MaxChildren` constraint. The `MaxParents` constraint should have 1, the same as the quickstart, because neither task allows merging.</li>
# <li>You should include an Appear cost similar to the one in the quickstart.</li>
# <li>We do not have scores on our edges. However, we can use the edge distance as a cost, so that longer edges are more costly than shorter edges. Instead of using the <code>EdgeSelection</code> cost, we can use the <a href=https://funkelab.github.io/motile/api.html#edgedistance><code>EdgeDistance</code></a> cost with <code>position_attribute="pos"</code>. You will want a positive weight, since higher distances should be more costly, unlike in the example when higher scores were good and so we inverted them with a negative weight.</li>
# <li>Because distance is always positive, and you want a positive weight, you will want to include a negative constant on the <code>EdgeDistance</code> cost. If there are no negative selection costs, the ILP will always select nothing, because the cost of selecting nothing is zero.</li>
# <li>We want to allow divisions. So, we should pass in 2 to our <code>MaxChildren</code> constraint. The <code>MaxParents</code> constraint should have 1, the same as the quickstart, because neither task allows merging.</li>
# <li>You should include an <code>Appear</code> cost and a <code>NodeSelection</code> cost similar to the one in the quickstart.</li>
# </ul>
#
# Once you have set up the basic motile optimization task in the function below, you will probably need to adjust the weight and constant values on your costs until you get a solution that looks reasonable.
Expand Down Expand Up @@ -495,11 +493,14 @@ def solve_basic_optimization(cand_graph):
"""
cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)

solver.add_cost(
motile.costs.NodeSelection(weight=-1.0, attribute="score")
)
solver.add_cost(
motile.costs.EdgeDistance(weight=1, constant=-20, position_attribute=("x", "y"))
)
solver.add_cost(motile.costs.Appear(constant=1.0))
solver.add_cost(motile.costs.Appear(constant=2.0))
solver.add_cost(motile.costs.Split(constant=1.0))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))
Expand Down Expand Up @@ -536,7 +537,6 @@ def print_graph_stats(graph, name):
print_graph_stats(gt_tracks, "gt tracks")



# %% [markdown]
# If you haven't selected any nodes or edges in your solution, try adjusting your weight and/or constant values. Make sure you have some negative costs or selecting nothing will always be the best solution!

Expand Down Expand Up @@ -672,7 +672,7 @@ def get_metrics(gt_graph, labels, run, results_df):
segmentation=np.squeeze(run.output_segmentation),
)

results = run_metrics(
results = traccuracy.run_metrics(
gt_data=gt_graph,
pred_data=pred_graph,
matcher=IOUMatcher(iou_threshold=0.3, one_to_one=True),
Expand All @@ -698,6 +698,7 @@ def get_metrics(gt_graph, labels, run, results_df):
results_df = get_metrics(gt_tracks, gt_dets, basic_run, results_df)
results_df


# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Question 3: Interpret your results based on metrics</h3>
# <p>
Expand All @@ -708,9 +709,90 @@ def get_metrics(gt_graph, labels, run, results_df):

# %% [markdown]
# <div class="alert alert-block alert-success"><h2>Checkpoint 3</h2>
# If you reach this checkpoint with extra time, think about what kinds of improvements you could make to the costs and constraints to fix the issues that you are seeing. You can try tuning your weights and constants, or adding or removing motile Costs and Constraints, and seeing how that changes the output. See how good you can make the results!
# If you reach this checkpoint with extra time, think about what kinds of improvements you could make to the costs and constraints to fix the issues that you are seeing. You can try tuning your weights and constants, or adding or removing motile Costs and Constraints, and seeing how that changes the output. We have added a convenience function in the box below where you can copy your solution from above, adapt it, and run the whole pipeline including visualizaiton and metrics computation.
#
# Do not get frustrated if you cannot get good results yet! Try to think about why and what custom costs we might add.
# </div>

# %% tags=["task"]
def adapt_basic_optimization(cand_graph):
"""Set up and solve the network flow problem.
Args:
graph (nx.DiGraph): The candidate graph.
Returns:
nx.DiGraph: The networkx digraph with the selected solution tracks
"""
cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)
### YOUR CODE HERE ###
solver.solve()
solution_graph = graph_to_nx(solver.get_selected_subgraph())

return solution_graph

def run_pipeline(cand_graph, run_name, results_df):
solution_graph = adapt_basic_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)
run = MotileRun(
run_name=run_name,
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)
widget.view_controller.update_napari_layers(run, time_attr="t", pos_attr=("x", "y"))
results_df = get_metrics(gt_tracks, gt_dets, run, results_df)
return results_df

# Don't forget to rename your run below, so you can tell them apart in the results table
results_df = run_pipeline(cand_graph, "basic_solution_2", results_df)
results_df


# %% tags=["solution"]
def adapt_basic_optimization(cand_graph):
"""Set up and solve the network flow problem.
Args:
graph (nx.DiGraph): The candidate graph.
Returns:
nx.DiGraph: The networkx digraph with the selected solution tracks
"""
cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)
solver.add_cost(
motile.costs.NodeSelection(weight=-5.0, constant=2.5, attribute="score")
)
solver.add_cost(
motile.costs.EdgeDistance(weight=1, constant=-20, position_attribute=("x", "y"))
)
solver.add_cost(motile.costs.Appear(constant=20.0))
solver.add_cost(motile.costs.Split(constant=15.0))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))
solver.solve()
solution_graph = graph_to_nx(solver.get_selected_subgraph())

return solution_graph

def run_pipeline(cand_graph, run_name, results_df):
solution_graph = adapt_basic_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)
run = MotileRun(
run_name=run_name,
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)
widget.view_controller.update_napari_layers(run, time_attr="t", pos_attr=("x", "y"))
results_df = get_metrics(gt_tracks, gt_dets, run, results_df)
return results_df

results_df = run_pipeline(cand_graph, "basic_solution_2", results_df)
results_df


# %% [markdown]
# ## Customizing the Tracking Task
#
Expand Down Expand Up @@ -789,7 +871,22 @@ def solve_drift_optimization(cand_graph):
solution_graph = graph_to_nx(solver.get_selected_subgraph())
return solution_graph

solution_graph = solve_drift_optimization(cand_graph)

def run_pipeline(cand_graph, run_name, results_df):
solution_graph = solve_drift_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)
run = MotileRun(
run_name=run_name,
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)
widget.view_controller.update_napari_layers(run, time_attr="t", pos_attr=("x", "y"))
results_df = get_metrics(gt_tracks, gt_dets, run, results_df)
return results_df

# Don't forget to rename your run if you re-run this cell!
results_df = run_pipeline(cand_graph, "drift_dist", results_df)
results_df


# %% tags=["solution"]
Expand All @@ -805,10 +902,14 @@ def solve_drift_optimization(cand_graph):

cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)

solver.add_cost(
motile.costs.NodeSelection(weight=-100, constant=75, attribute="score")
)
solver.add_cost(
motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist")
)
solver.add_cost(motile.costs.Appear(constant=40.0))
solver.add_cost(motile.costs.Split(constant=45.0))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))
Expand All @@ -818,24 +919,27 @@ def solve_drift_optimization(cand_graph):
return solution_graph


# %%
solution_graph = solve_drift_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)

# %%
drift_run = MotileRun(
run_name="drift_solution",
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)

widget.view_controller.update_napari_layers(drift_run, time_attr="t", pos_attr=("x", "y"))
def run_pipeline(cand_graph, run_name, results_df):
solution_graph = solve_drift_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)
run = MotileRun(
run_name=run_name,
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)
widget.view_controller.update_napari_layers(run, time_attr="t", pos_attr=("x", "y"))
results_df = get_metrics(gt_tracks, gt_dets, run, results_df)
return results_df

# %%
results_df = get_metrics(gt_tracks, gt_dets, drift_run, results_df)
# Don't forget to rename your run if you re-run this cell!
results_df = run_pipeline(cand_graph, "node_const_75", results_df)
results_df


# %% [markdown]
# Feel free to tinker with the weights and constants manually to try and improve the results.
# You should be able to get something decent now, but this dataset is quite difficult! There are still many custom costs that could be added to improve the results - we will discuss some ideas together shortly.

# %% [markdown]
# <div class="alert alert-block alert-success"><h3>Checkpoint 4</h3>
# That is the end of the main exercise! If you have extra time, feel free to go onto the below bonus exercise to see how to learn the weights of your costs instead of setting them manually.
Expand Down Expand Up @@ -918,7 +1022,9 @@ def get_ssvm_solver(cand_graph):

cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)

solver.add_cost(
motile.costs.NodeSelection(weight=-1.0, attribute='score')
)
solver.add_cost(
motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist")
)
Expand Down

0 comments on commit 44677fc

Please sign in to comment.