From 2e9b45ae8efcfe7c23875c7825a60b915233123f Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 30 Sep 2024 12:57:27 +0100 Subject: [PATCH] add img_hw arg to Tracker.track --- sleap/nn/tracking.py | 3 +++ tests/nn/test_tracker_components.py | 1 - tests/nn/test_tracking_integration.py | 3 +-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 59d819643..770028b20 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -542,6 +542,7 @@ def run_step(self, lf: LabeledFrame) -> LabeledFrame: track_args["img"] = lf.video[lf.frame_idx] else: track_args["img"] = None + track_args["img_hw"] = lf.image.shape[-3:-1] return LabeledFrame( frame_idx=lf.frame_idx, @@ -667,6 +668,7 @@ def run_tracker( def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ): @@ -1561,6 +1563,7 @@ def cull_function(inst_list): def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ) -> List[InstanceType]: diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 9d3b65b38..fa0cc5f51 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,7 +9,6 @@ FrameMatches, greedy_matching, ) -from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index caebe49ff..c479462f8 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -1,4 +1,3 @@ -import inspect import operator import os import time @@ -7,7 +6,7 @@ import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components -from sleap.io.dataset import Labels, LabeledFrame +from sleap.io.dataset import Labels def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):