Skip to content

Commit

Permalink
Add table of results
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Aug 9, 2024
1 parent f19a2f9 commit 6484c77
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def make_gt_detections(data_shape, gt_tracks, radius):


# %%
def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
def get_metrics(gt_graph, labels, run, results_df):
"""Calculate metrics for linked tracks by comparing to ground truth.
Args:
Expand All @@ -634,7 +634,6 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
Returns:
results (dict): Dictionary of metric results.
"""

gt_graph = traccuracy.TrackingGraph(
graph=gt_graph,
frame_key="t",
Expand All @@ -644,11 +643,11 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
)

pred_graph = traccuracy.TrackingGraph(
graph=pred_graph,
graph=run.tracks,
frame_key="t",
label_key="tracklet_id",
location_keys=("x", "y"),
segmentation=pred_segmentation,
segmentation=np.squeeze(run.output_segmentation),
)

results = run_metrics(
Expand All @@ -658,11 +657,24 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
metrics=[CTCMetrics(), DivisionMetrics()],
)

return results
results_filtered = {}
results_filtered.update(results[0]["results"])
results_filtered.update(results[1]["results"]["Frame Buffer 0"])
results_filtered["name"] = run.run_name
current_result = pd.DataFrame(results_filtered, index=[0])[columns]

if results_df is None:
results_df = current_result
else:
results_df = pd.concat([results_df, current_result])

return results_df


# %%
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg.astype(np.uint32))
results_df = None
results_df = get_metrics(gt_tracks, gt_dets, basic_run, results_df)
results_df

# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Question 3: Interpret your results based on metrics</h3>
Expand Down Expand Up @@ -798,7 +810,8 @@ def solve_drift_optimization(cand_graph):
widget.view_controller.update_napari_layers(drift_run, time_attr="t", pos_attr=("x", "y"))

# %%
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)
results_df = get_metrics(gt_tracks, gt_dets, drift_run, results_df)
results_df


# %% [markdown]
Expand Down Expand Up @@ -925,7 +938,6 @@ def get_ssvm_solution(cand_graph, solver_weights):

# %%
solution_seg = relabel_segmentation(solution_graph, segmentation)
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)

# %%
ssvm_run = MotileRun(
Expand All @@ -936,6 +948,11 @@ def get_ssvm_solution(cand_graph, solver_weights):

widget.view_controller.update_napari_layers(ssvm_run, time_attr="t", pos_attr=("x", "y"))

# %%

results_df = get_metrics(gt_tracks, gt_dets, ssvm_run, results_df)
results_df

# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Bonus Question: Interpret SSVM results</h3>
# <p>
Expand Down

0 comments on commit 6484c77

Please sign in to comment.