From 6484c77032f637e62683e8ce6788b583243ed9ae Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Fri, 9 Aug 2024 17:33:25 -0400 Subject: [PATCH] Add table of results --- solution.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/solution.py b/solution.py index 606551f..5445ba9 100644 --- a/solution.py +++ b/solution.py @@ -622,7 +622,7 @@ def make_gt_detections(data_shape, gt_tracks, radius): # %% -def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): +def get_metrics(gt_graph, labels, run, results_df): """Calculate metrics for linked tracks by comparing to ground truth. Args: @@ -634,7 +634,6 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): Returns: results (dict): Dictionary of metric results. """ - gt_graph = traccuracy.TrackingGraph( graph=gt_graph, frame_key="t", @@ -644,11 +643,11 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): ) pred_graph = traccuracy.TrackingGraph( - graph=pred_graph, + graph=run.tracks, frame_key="t", label_key="tracklet_id", location_keys=("x", "y"), - segmentation=pred_segmentation, + segmentation=np.squeeze(run.output_segmentation), ) results = run_metrics( @@ -658,11 +657,24 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation): metrics=[CTCMetrics(), DivisionMetrics()], ) - return results + results_filtered = {} + results_filtered.update(results[0]["results"]) + results_filtered.update(results[1]["results"]["Frame Buffer 0"]) + results_filtered["name"] = run.run_name + current_result = pd.DataFrame(results_filtered, index=[0])[columns] + + if results_df is None: + results_df = current_result + else: + results_df = pd.concat([results_df, current_result]) + + return results_df # %% -get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg.astype(np.uint32)) +results_df = None +results_df = get_metrics(gt_tracks, gt_dets, basic_run, results_df) +results_df # %% [markdown] #

Question 3: Interpret your results based on metrics

@@ -798,7 +810,8 @@ def solve_drift_optimization(cand_graph): widget.view_controller.update_napari_layers(drift_run, time_attr="t", pos_attr=("x", "y")) # %% -get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) +results_df = get_metrics(gt_tracks, gt_dets, drift_run, results_df) +results_df # %% [markdown] @@ -925,7 +938,6 @@ def get_ssvm_solution(cand_graph, solver_weights): # %% solution_seg = relabel_segmentation(solution_graph, segmentation) -get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg) # %% ssvm_run = MotileRun( @@ -936,6 +948,11 @@ def get_ssvm_solution(cand_graph, solver_weights): widget.view_controller.update_napari_layers(ssvm_run, time_attr="t", pos_attr=("x", "y")) +# %% + +results_df = get_metrics(gt_tracks, gt_dets, ssvm_run, results_df) +results_df + # %% [markdown] #

Bonus Question: Interpret SSVM results

#