From 8d88668bda4ef9636fb4bf9ab5e58fc16a5e2077 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Thu, 10 Aug 2023 16:44:53 -0700 Subject: [PATCH 01/26] Initial commit --- sleap/nn/tracking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index b861c359f..e962be9f3 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -413,6 +413,7 @@ class Tracker(BaseTracker): similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching candidate_maker: object = attr.ib(factory=FlowCandidateMaker) + tracks: dict() # Hold tracks, each as a deque with length as track_window cleaner: Optional[Callable] = None # todo: deprecate target_instance_count: int = 0 From 3b23e1eefdab7b721bfbbd701cb572806f35e15a Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Mon, 14 Aug 2023 11:55:01 -0700 Subject: [PATCH 02/26] format files --- sleap/nn/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index e962be9f3..47e86ecf4 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -413,7 +413,7 @@ class Tracker(BaseTracker): similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching candidate_maker: object = attr.ib(factory=FlowCandidateMaker) - tracks: dict() # Hold tracks, each as a deque with length as track_window + tracks: dict() # Hold tracks, each as a deque with length as track_window cleaner: Optional[Callable] = None # todo: deprecate target_instance_count: int = 0 From e32b8e3b87d3188c72f676a53c5a4a36aa431847 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Mon, 14 Aug 2023 17:35:04 -0700 Subject: [PATCH 03/26] [wip] adding local deque for tracks --- sleap/nn/tracking.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 47e86ecf4..c4baee1f6 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -413,9 +413,10 @@ class Tracker(BaseTracker): similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching candidate_maker: object = attr.ib(factory=FlowCandidateMaker) - tracks: dict() # Hold tracks, each as a deque with length as track_window + track_local_deque: bool = False + tracks: Dict[int, Deque[Track]] # Hold tracks, each as a deque with length as track_window - cleaner: Optional[Callable] = None # todo: deprecate + cleaner: Optional[Callable] = None # TODO: deprecate target_instance_count: int = 0 pre_cull_function: Optional[Callable] = None post_connect_single_breaks: bool = False @@ -606,6 +607,7 @@ def make_tracker_by_name( robust: float = 1.0, min_new_track_points: int = 0, min_match_points: int = 0, + track_loacl_deque: bool = False, # Optical flow options img_scale: float = 1.0, of_window_size: int = 21, From fe5aeef1a5f2fbabb22cff1680aded5b6ccc2d13 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Mon, 14 Aug 2023 17:38:02 -0700 Subject: [PATCH 04/26] format files --- sleap/nn/tracking.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index c4baee1f6..bf9851422 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -414,7 +414,9 @@ class Tracker(BaseTracker): matching_function: Callable = greedy_matching candidate_maker: object = attr.ib(factory=FlowCandidateMaker) track_local_deque: bool = False - tracks: Dict[int, Deque[Track]] # Hold tracks, each as a deque with length as track_window + tracks: Dict[ + int, Deque[Track] + ] # Hold tracks, each as a deque with length as track_window cleaner: Optional[Callable] = None # TODO: deprecate target_instance_count: int = 0 @@ -607,7 +609,7 @@ def make_tracker_by_name( robust: float = 1.0, min_new_track_points: int = 0, min_match_points: int = 0, - track_loacl_deque: bool = False, + track_local_deque: bool = False, # Optical flow options img_scale: float = 1.0, of_window_size: int = 21, From f494cad741f15355876867c340821bf105f12fbf Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Tue, 15 Aug 2023 08:49:18 -0700 Subject: [PATCH 05/26] [wip] adding local deque for tracks --- sleap/nn/tracking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index bf9851422..4abf72ceb 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -409,14 +409,14 @@ class Tracker(BaseTracker): 0.95 is a good value. """ + track_map: Dict[ + int, Deque[InstanceType] + ] # Hold tracks, each as a deque with length as track_window track_window: int = 5 similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching candidate_maker: object = attr.ib(factory=FlowCandidateMaker) track_local_deque: bool = False - tracks: Dict[ - int, Deque[Track] - ] # Hold tracks, each as a deque with length as track_window cleaner: Optional[Callable] = None # TODO: deprecate target_instance_count: int = 0 From 6352e2a3c5a2e2c45984712a41d7d722cabe929a Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Mon, 21 Aug 2023 19:10:38 -0700 Subject: [PATCH 06/26] [wip] Add max tracking for simpletracker --- sleap/instance.py | 3 + sleap/nn/tracking.py | 108 ++++++++++++++++++++++---- tests/nn/test_tracking_integration.py | 31 ++++++-- 3 files changed, 121 insertions(+), 21 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index c14038552..be18f705e 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -333,6 +333,9 @@ def matches(self, other: "Track"): """ return attr.asdict(self) == attr.asdict(other) + def __hash__(self) -> int: + return hash(self.spawned_on, self.name) + # NOTE: # Instance cannot be a slotted class at the moment. This is because it creates diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 4abf72ceb..470b095c8 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -5,7 +5,7 @@ import attr import numpy as np import cv2 -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple, TypeVar from sleap import Track, LabeledFrame, Skeleton @@ -88,6 +88,13 @@ class MatchedFrameInstances: img_t: Optional[np.ndarray] = None +@attr.s(auto_attribs=True, slots=True) +class MatchedFrameInstance: + t: int + instance_t: InstanceType + img_t: Optional[np.ndarray] = None + + @attr.s(auto_attribs=True, slots=True) class MatchedShiftedFrameInstances: ref_t: int @@ -334,9 +341,40 @@ def get_candidates( return candidate_instances +@attr.s(auto_attribs=True) +class SimpleMaxTracksCandidateMaker: + """Class to generate instances based on the maximum number of tracks from prior frames.""" + + min_points: int = 0 + + @property + def uses_image(self): + return False + + def get_candidates( + self, + track_matching_queue_dict: dict(), + *args, + **kwargs, + ) -> List[InstanceType]: + # Create set of matchable candidate instances from each track for max number of tracks. + candidate_instances = [] + number_of_tracks = 0 + for track in track_matching_queue_dict.keys(): + if track and number_of_tracks <= max_tracks: + number_of_tracks += 1 + for matched_item in track: + ref_t, ref_instances = matched_item.t, matched_item.instances_t + for ref_instance in ref_instances: + if ref_instance.n_visible_points >= self.min_points: + candidate_instances.append(ref_instance) + return candidate_instances + + tracker_policies = dict( simple=SimpleCandidateMaker, flow=FlowCandidateMaker, + simplemaxtracks=SimpleMaxTracksCandidateMaker, ) similarity_policies = dict( @@ -409,14 +447,12 @@ class Tracker(BaseTracker): 0.95 is a good value. """ - track_map: Dict[ - int, Deque[InstanceType] - ] # Hold tracks, each as a deque with length as track_window + max_tracks: int track_window: int = 5 similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching candidate_maker: object = attr.ib(factory=FlowCandidateMaker) - track_local_deque: bool = False + max_tracking: bool = False # To enable maximum tracking. cleaner: Optional[Callable] = None # TODO: deprecate target_instance_count: int = 0 @@ -428,6 +464,10 @@ class Tracker(BaseTracker): track_matching_queue: Deque[MatchedFrameInstances] = attr.ib() + # Hold track, instances with each instances as a deque with length as track_window + track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib( + factory=dict + ) spawned_tracks: List[Track] = attr.ib(factory=list) save_tracked_instances: bool = False @@ -447,7 +487,11 @@ def _init_matching_queue(self): return deque(maxlen=self.track_window) def reset_candidates(self): - self.track_matching_queue = deque(maxlen=self.track_window) + if self.max_tracking and self.track_matching_queue_dict: + for track, candidates in self.track_matching_queue_dict: + candidates = deque(maxlen=self.track_window) + else: + self.track_matching_queue = deque(maxlen=self.track_window) @property def unique_tracks_in_queue(self) -> List[Track]: @@ -458,6 +502,10 @@ def unique_tracks_in_queue(self) -> List[Track]: for instance in match_item.instances_t: unique_tracks.add(instance.track) + if self.max_tracking: + for track in self.track_matching_queue_dict.keys(): + unique_tracks.add(track) + return list(unique_tracks) @property @@ -507,11 +555,19 @@ def track( self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - candidate_instances = self.candidate_maker.get_candidates( - track_matching_queue=self.track_matching_queue, - t=t, - img=img, - ) + if self.max_tracking: + candidate_instances = self.candidate_maker.get_candidates( + track_matching_queue_dict=self.track_matching_queue_dict, + max_tracks=self.max_tracks, + t=t, + img=img, + ) + else: + candidate_instances = self.candidate_maker.get_candidates( + track_matching_queue=self.track_matching_queue, + t=t, + img=img, + ) # Determine matches for untracked instances in current frame. frame_matches = FrameMatches.from_candidate_instances( @@ -535,10 +591,18 @@ def track( self.spawn_for_untracked_instances(frame_matches.unmatched_instances, t) ) - # Add the tracked instances to the matching buffer. - self.track_matching_queue.append( - MatchedFrameInstances(t, tracked_instances, img) - ) + # Add the tracked instances to the dictionary of matched instances. + if self.max_tracking: + for tracked_instance in tracked_instances: + if self.track_matching_queue_dict[tracked_instance]: + self.track_matching_queue_dict[tracked_instance.track].append( + MatchedFrameInstance(t, tracked_instance, img) + ) + else: + # Add the tracked instances to the matching buffer. + self.track_matching_queue.append( + MatchedFrameInstances(t, tracked_instances, img) + ) # Save tracked instances internally. if self.save_tracked_instances: @@ -609,7 +673,7 @@ def make_tracker_by_name( robust: float = 1.0, min_new_track_points: int = 0, min_match_points: int = 0, - track_local_deque: bool = False, + max_tracking: bool = False, # Optical flow options img_scale: float = 1.0, of_window_size: int = 21, @@ -682,6 +746,7 @@ def pre_cull_function(inst_list): candidate_maker=candidate_maker, cleaner=cleaner, pre_cull_function=pre_cull_function, + max_tracking=max_tracking, target_instance_count=target_instance_count, post_connect_single_breaks=post_connect_single_breaks, ) @@ -868,6 +933,17 @@ class SimpleTracker(Tracker): candidate_maker: object = attr.ib(factory=SimpleCandidateMaker) +@attr.s(auto_attribs=True) +class SimpleMaxTracker(Tracker): + """A tracked pre-configured to use simple, non-image-based candidates but with a maximum number of tracks.""" + + max_tracks: int = attr.ib(kw_only=True) + similarity_function: Callable = instance_iou + matching_function: Callable = hungarian_matching + candidate_maker: object = attr.ib(factory=SimpleMaxTracksCandidateMaker) + max_tracking: bool = True + + @attr.s(auto_attribs=True) class KalmanInitSet: init_frame_count: int diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 829b7c3cb..34f802605 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -95,6 +95,7 @@ def main(f, dir): trackers = dict( simple=sleap.nn.tracker.simple.SimpleTracker, flow=sleap.nn.tracker.flow.FlowTracker, + simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker, ) matchers = dict( hungarian=sleap.nn.tracker.components.hungarian_matching, @@ -110,11 +111,21 @@ def main(f, dir): 0.25, ) - def make_tracker(tracker_name, matcher_name, sim_name, scale=0): - tracker = trackers[tracker_name]( - matching_function=matchers[matcher_name], - similarity_function=similarities[sim_name], - ) + def make_tracker( + tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 + ): + if tracker_name == "simplemaxtracks": + tracker[tracker_name]( + matching_function=matchers[matcher_name], + similarity_function=similarities[sim_name], + max_tracks=max_tracks, + max_tracking=max_tracking, + ) + else: + tracker = trackers[tracker_name]( + matching_function=matchers[matcher_name], + similarity_function=similarities[sim_name], + ) if scale: tracker.candidate_maker.img_scale = scale return tracker @@ -145,6 +156,16 @@ def make_tracker_and_filename(*args, **kwargs): scale=scale, ) f(frames, tracker, gt_filename) + elif tracker_name == "simplemaxtracks": + tracker, gt_filename = make_tracker_and_filename( + tracker_name=tracker_name, + matcher_name=matcher_name, + sim_name=sim_name, + max_tracks=5, + max_tracking=True, + scale=0, + ) + f(frames, tracker, gt_filename) else: tracker, gt_filename = make_tracker_and_filename( tracker_name=tracker_name, From 4ddcbc9e11d03119a1f002ac50d5f069715a8572 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Mon, 21 Aug 2023 19:52:55 -0700 Subject: [PATCH 07/26] [wip] Add max tracking for simple tracker --- sleap/instance.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index be18f705e..c14038552 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -333,9 +333,6 @@ def matches(self, other: "Track"): """ return attr.asdict(self) == attr.asdict(other) - def __hash__(self) -> int: - return hash(self.spawned_on, self.name) - # NOTE: # Instance cannot be a slotted class at the moment. This is because it creates From 6ff43c648e75ad26e9291298fa34873286046cb9 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Tue, 22 Aug 2023 14:43:06 -0700 Subject: [PATCH 08/26] [wip] add missing argument --- sleap/nn/tracking.py | 1 + tests/nn/test_tracker_components.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 470b095c8..8e353fe3b 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -747,6 +747,7 @@ def pre_cull_function(inst_list): cleaner=cleaner, pre_cull_function=pre_cull_function, max_tracking=max_tracking, + max_tracks=max_tracks, target_instance_count=target_instance_count, post_connect_single_breaks=post_connect_single_breaks, ) diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 869ebc85c..644aad5f1 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -14,7 +14,7 @@ from sleap.skeleton import Skeleton -@pytest.mark.parametrize("tracker", ["simple", "flow"]) +@pytest.mark.parametrize("tracker", ["simple", "flow", "simplemaxtracks"]) @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) From 216f3a3e549f91da88b2f376d6613a2714ba64fd Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Wed, 23 Aug 2023 16:51:37 -0700 Subject: [PATCH 09/26] [wip] Add and modify test functions --- sleap/nn/tracking.py | 43 ++++++++++++++++++++++++++++++++++---- tests/nn/test_inference.py | 6 +++++- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 8e353fe3b..3b0822ccb 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -345,6 +345,7 @@ def get_candidates( class SimpleMaxTracksCandidateMaker: """Class to generate instances based on the maximum number of tracks from prior frames.""" + max_tracks: int = None min_points: int = 0 @property @@ -361,7 +362,7 @@ def get_candidates( candidate_instances = [] number_of_tracks = 0 for track in track_matching_queue_dict.keys(): - if track and number_of_tracks <= max_tracks: + if track and number_of_tracks <= self.max_tracks: number_of_tracks += 1 for matched_item in track: ref_t, ref_instances = matched_item.t, matched_item.instances_t @@ -447,7 +448,7 @@ class Tracker(BaseTracker): 0.95 is a good value. """ - max_tracks: int + max_tracks: int = None track_window: int = 5 similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching @@ -534,6 +535,22 @@ def track( # Infer timestep if not provided. if t is None: + if self.max_tracking: + if len(self.track_matching_queue_dict) > 0: + + # Default to last timestep + 1 if available. Here we find the track that has the most instances. + track_with_max_instances = max( + self.track_matching_queue_dict, + track=lambda track: len(self.track_matching_queue_dict[track]), + ) + t = ( + self.track_matching_queue_dict[track_with_max_instances][-1].t + + 1 + ) + + else: + t = 0 + if len(self.track_matching_queue) > 0: # Default to last timestep + 1 if available. @@ -594,7 +611,7 @@ def track( # Add the tracked instances to the dictionary of matched instances. if self.max_tracking: for tracked_instance in tracked_instances: - if self.track_matching_queue_dict[tracked_instance]: + if tracked_instance.track in self.track_matching_queue_dict: self.track_matching_queue_dict[tracked_instance.track].append( MatchedFrameInstance(t, tracked_instance, img) ) @@ -666,6 +683,7 @@ def get_name(self): @classmethod def make_tracker_by_name( cls, + # Tracker options tracker: str = "flow", similarity: str = "instance", match: str = "greedy", @@ -673,7 +691,6 @@ def make_tracker_by_name( robust: float = 1.0, min_new_track_points: int = 0, min_match_points: int = 0, - max_tracking: bool = False, # Optical flow options img_scale: float = 1.0, of_window_size: int = 21, @@ -691,6 +708,9 @@ def make_tracker_by_name( # Kalman filter options kf_init_frame_count: int = 0, kf_node_indices: Optional[list] = None, + # Max tracking options + max_tracks: int = None, + max_tracking: bool = False, **kwargs, ) -> BaseTracker: @@ -721,6 +741,9 @@ def make_tracker_by_name( candidate_maker.save_shifted_instances = save_shifted_instances candidate_maker.track_window = track_window + if tracker == "simplemaxtracks": + candidate_maker.max_tracks = max_tracks + cleaner = None if clean_instance_count: cleaner = TrackCleaner( @@ -779,6 +802,18 @@ def get_by_name_factory_options(cls): ] options.append(option) + option = dict(name="max_tracking", default=False) + option["type"] = bool + option[ + "help" + ] = "If true then the tracker will cap the max number of tracks created or tracked." + options.append(option) + + option = dict(name="max_tracks", default=0) + option["type"] = None + option["help"] = "Maximum number of tracks to be tracked by the tracker." + options.append(option) + option = dict(name="target_instance_count", default=0) option["type"] = int option["help"] = "Target number of instances to track per frame." diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index cc65ac3fe..450a6f356 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1335,7 +1335,8 @@ def test_topdown_id_predictor_save( @pytest.mark.parametrize( - "output_path,tracker_method", [("not_default", "flow"), (None, "simple")] + "output_path,tracker_method", + [("not_default", "flow"), (None, "simple"), (None, "simplemaxtracks")], ) def test_retracking( centered_pair_predictions: Labels, tmpdir, output_path, tracker_method @@ -1350,6 +1351,9 @@ def test_retracking( ) if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" + elif tracker_method == "simplemaxtracks": + cmd += " --tracking.max_tracking 1" + cmd += " --tracking.max_tracks 2" if output_path == "not_default": output_path = Path(tmpdir, "tracked_slp.slp") cmd += f" --output {output_path}" From d8971706bf63ec5a5edb0148cdfce7ab9c9c3aef Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Wed, 23 Aug 2023 17:13:41 -0700 Subject: [PATCH 10/26] [wip] Add and modify test functions --- sleap/nn/tracking.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 3b0822ccb..f12985beb 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -550,14 +550,14 @@ def track( else: t = 0 + else: + if len(self.track_matching_queue) > 0: - if len(self.track_matching_queue) > 0: - - # Default to last timestep + 1 if available. - t = self.track_matching_queue[-1].t + 1 + # Default to last timestep + 1 if available. + t = self.track_matching_queue[-1].t + 1 - else: - t = 0 + else: + t = 0 # Initialize containers for tracked instances at the current timestep. tracked_instances = [] From 9c5579f1ca48e7426cc489c5419dbb796cca70d8 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Thu, 24 Aug 2023 17:55:44 -0700 Subject: [PATCH 11/26] Bug fix and refactoring code --- sleap/nn/tracking.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index f12985beb..d1f6d5139 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -361,14 +361,12 @@ def get_candidates( # Create set of matchable candidate instances from each track for max number of tracks. candidate_instances = [] number_of_tracks = 0 - for track in track_matching_queue_dict.keys(): + for track, matched_instances in track_matching_queue_dict.items(): if track and number_of_tracks <= self.max_tracks: number_of_tracks += 1 - for matched_item in track: - ref_t, ref_instances = matched_item.t, matched_item.instances_t - for ref_instance in ref_instances: - if ref_instance.n_visible_points >= self.min_points: - candidate_instances.append(ref_instance) + for ref_instance in matched_instances: + if ref_instance.instance_t.n_visible_points >= self.min_points: + candidate_instances.append(ref_instance.instance_t) return candidate_instances @@ -489,7 +487,7 @@ def _init_matching_queue(self): def reset_candidates(self): if self.max_tracking and self.track_matching_queue_dict: - for track, candidates in self.track_matching_queue_dict: + for track, candidates in self.track_matching_queue_dict.items(): candidates = deque(maxlen=self.track_window) else: self.track_matching_queue = deque(maxlen=self.track_window) @@ -541,7 +539,7 @@ def track( # Default to last timestep + 1 if available. Here we find the track that has the most instances. track_with_max_instances = max( self.track_matching_queue_dict, - track=lambda track: len(self.track_matching_queue_dict[track]), + key=lambda track: len(self.track_matching_queue_dict[track]), ) t = ( self.track_matching_queue_dict[track_with_max_instances][-1].t @@ -615,6 +613,14 @@ def track( self.track_matching_queue_dict[tracked_instance.track].append( MatchedFrameInstance(t, tracked_instance, img) ) + elif len(self.track_matching_queue_dict) < self.max_tracks: + self.track_matching_queue_dict[tracked_instance.track] = deque( + maxlen=self.track_window + ) + self.track_matching_queue_dict[tracked_instance.track].append( + MatchedFrameInstance(t, tracked_instance, img) + ) + else: # Add the tracked instances to the matching buffer. self.track_matching_queue.append( @@ -809,8 +815,8 @@ def get_by_name_factory_options(cls): ] = "If true then the tracker will cap the max number of tracks created or tracked." options.append(option) - option = dict(name="max_tracks", default=0) - option["type"] = None + option = dict(name="max_tracks", default=None) + option["type"] = int option["help"] = "Maximum number of tracks to be tracked by the tracker." options.append(option) From 380038f00b953c7ec2f6d49c785e930c8da6df52 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Mon, 28 Aug 2023 11:53:00 -0700 Subject: [PATCH 12/26] [wip] Add max tracking for flow tracker. --- sleap/nn/tracking.py | 133 +++++++++++++++++++++++++- tests/nn/test_inference.py | 9 +- tests/nn/test_tracker_components.py | 4 +- tests/nn/test_tracking_integration.py | 17 +++- 4 files changed, 155 insertions(+), 8 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index d1f6d5139..7f68ff28a 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -318,6 +318,119 @@ def flow_shift_instances( return shifted_instances +@attr.s(auto_attribs=True) +class FlowMaxTracksCandidateMaker(FlowCandidateMaker): + """Class for producing optical flow shift matching candidates with a maximum on the number of tracks. + + Attributes: + min_points: Minimum number of points that must be detected in the new frame in + order to generate a new shifted instance. + img_scale: Factor to scale the images by when computing optical flow. Decrease + this to increase performance at the cost of finer accuracy. Sometimes + decreasing the image scale can improve performance with fast movements. + of_window_size: Optical flow window size to consider at each pyramid scale + level. + of_max_levels: Number of pyramid scale levels to consider. This is different + from the scale parameter, which determines the initial image scaling. + save_shifted_instances: If True, save the shifted instances between elapsed + frames. + track_window: How many frames back to look for candidate instances to match + instances in the current frame against. + + """ + + max_tracks: int = None + min_points: int = 0 + img_scale: float = 1.0 + of_window_size: int = 21 + of_max_levels: int = 3 + save_shifted_instances: bool = False + track_window: int = 5 + + shifted_instances: Dict[ + Tuple[int, int], List[ShiftedInstance] # keyed by (src_t, dst_t) + ] = attr.ib(factory=dict) + + @property + def uses_image(self): + return True + + def get_candidates( + self, + track_matching_queue_dict: dict(), + t: int, + img: np.ndarray, + *args, + **kwargs, + ) -> List[ShiftedInstance]: + candidate_instances = [] + + # Prune old shifted instances to save time and memory + self.prune_shifted_instances(t) + number_of_tracks = 0 + # -> List[ShiftedInstance] + def get_ref_instances(r_t, r_img) -> List[InstanceType]: + instances = [] + for track, matched_items in track_matching_queue_dict.items(): + instances + [ + item.instance_t + for item in matched_items + if item.t == r_t and np.all(item.img_t == r_img) + ] + return instances + + for track, matched_items in track_matching_queue_dict.items(): + if track and number_of_tracks <= self.max_tracks: + number_of_tracks += 1 + for matched_item in matched_items: + ref_t, ref_img = ( + matched_item.t, + matched_item.img_t, + ) + ref_instances = get_ref_instances(r_t=ref_t, r_img=ref_img) + + # Check if shifted instance was computed at earlier time + if self.save_shifted_instances: + for ti in reversed(range(ref_t, t)): + if (ref_t, ti) in self.shifted_instances: + ref_shifted_instances = self.shifted_instances[ + (ref_t, ti) + ] + # Use shifted instance as a reference + if len(ref_shifted_instances.instances_t) > 0: + ref_img = ref_shifted_instances.img_t + ref_instances = ref_shifted_instances.instances_t + break + + if len(ref_instances) > 0: + # Flow shift reference instances to current frame. + shifted_instances = self.flow_shift_instances( + ref_instances, + ref_img, + img, + min_shifted_points=self.min_points, + scale=self.img_scale, + window_size=self.of_window_size, + max_levels=self.of_max_levels, + ) + + # Add to candidate pool. + candidate_instances.extend(shifted_instances) + + # Save shifted instances. + if self.save_shifted_instances: + self.shifted_instances[ + (ref_t, t) + ] = MatchedShiftedFrameInstances( + ref_t, + t, + shifted_instances, + img, + ) + + return candidate_instances + + @attr.s(auto_attribs=True) class SimpleCandidateMaker: """Class for producing list of matching candidates from prior frames.""" @@ -374,6 +487,7 @@ def get_candidates( simple=SimpleCandidateMaker, flow=FlowCandidateMaker, simplemaxtracks=SimpleMaxTracksCandidateMaker, + flowmaxtracks=FlowMaxTracksCandidateMaker, ) similarity_policies = dict( @@ -463,7 +577,7 @@ class Tracker(BaseTracker): track_matching_queue: Deque[MatchedFrameInstances] = attr.ib() - # Hold track, instances with each instances as a deque with length as track_window + # Hold track, instances with instances as a deque with length as track_window. track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib( factory=dict ) @@ -747,7 +861,7 @@ def make_tracker_by_name( candidate_maker.save_shifted_instances = save_shifted_instances candidate_maker.track_window = track_window - if tracker == "simplemaxtracks": + if tracker == "simplemaxtracks" or tracker == "flowmaxtracks": candidate_maker.max_tracks = max_tracks cleaner = None @@ -966,6 +1080,19 @@ class FlowTracker(Tracker): candidate_maker: object = attr.ib(factory=FlowCandidateMaker) +attr.s(auto_attribs=True) + + +class FlowMaxTracker(Tracker): + """A Tracker pre-configured to use optical flow shifted candidates with a maximum limit on tracks.""" + + max_tracks: int = attr.ib(kw_only=True) + similarity_function: Callable = instance_similarity + matching_function: Callable = greedy_matching + candidate_maker: object = attr.ib(factory=FlowMaxTracksCandidateMaker) + max_tracking: bool = True + + @attr.s(auto_attribs=True) class SimpleTracker(Tracker): """A Tracker pre-configured to use simple, non-image-based candidates.""" @@ -977,7 +1104,7 @@ class SimpleTracker(Tracker): @attr.s(auto_attribs=True) class SimpleMaxTracker(Tracker): - """A tracked pre-configured to use simple, non-image-based candidates but with a maximum number of tracks.""" + """A tracked pre-configured to use simple, non-image-based candidates with a maximum number of tracks.""" max_tracks: int = attr.ib(kw_only=True) similarity_function: Callable = instance_iou diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 450a6f356..6273eac59 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1336,7 +1336,12 @@ def test_topdown_id_predictor_save( @pytest.mark.parametrize( "output_path,tracker_method", - [("not_default", "flow"), (None, "simple"), (None, "simplemaxtracks")], + [ + ("not_default", "flow"), + ("not_default", "flowmaxtracks"), + (None, "simple"), + (None, "simplemaxtracks"), + ], ) def test_retracking( centered_pair_predictions: Labels, tmpdir, output_path, tracker_method @@ -1351,7 +1356,7 @@ def test_retracking( ) if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" - elif tracker_method == "simplemaxtracks": + elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks": cmd += " --tracking.max_tracking 1" cmd += " --tracking.max_tracks 2" if output_path == "not_default": diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 644aad5f1..7b4dafd63 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -14,7 +14,9 @@ from sleap.skeleton import Skeleton -@pytest.mark.parametrize("tracker", ["simple", "flow", "simplemaxtracks"]) +@pytest.mark.parametrize( + "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] +) @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 34f802605..310001cb9 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -96,6 +96,7 @@ def main(f, dir): simple=sleap.nn.tracker.simple.SimpleTracker, flow=sleap.nn.tracker.flow.FlowTracker, simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker, + flowmaxtracks=sleap.nn.tracker.FlowMaxTracker, ) matchers = dict( hungarian=sleap.nn.tracker.components.hungarian_matching, @@ -114,7 +115,7 @@ def main(f, dir): def make_tracker( tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 ): - if tracker_name == "simplemaxtracks": + if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks": tracker[tracker_name]( matching_function=matchers[matcher_name], similarity_function=similarities[sim_name], @@ -156,12 +157,24 @@ def make_tracker_and_filename(*args, **kwargs): scale=scale, ) f(frames, tracker, gt_filename) + elif tracker_name == "flowmaxtracks": + # If this tracker supports scale, try multiple scales + for scale in scales: + tracker, gt_filename = make_tracker_and_filename( + tracker_name=tracker_name, + matcher_name=matcher_name, + sim_name=sim_name, + max_tracks=2, + max_tracking=True, + scale=scale, + ) + f(frames, tracker, gt_filename) elif tracker_name == "simplemaxtracks": tracker, gt_filename = make_tracker_and_filename( tracker_name=tracker_name, matcher_name=matcher_name, sim_name=sim_name, - max_tracks=5, + max_tracks=2, max_tracking=True, scale=0, ) From 9413bfde54b29ec1607591cc0db9b7e51bf226e4 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Thu, 31 Aug 2023 12:45:21 -0700 Subject: [PATCH 13/26] [wip] Including suggested changes --- sleap/nn/tracking.py | 67 ++++++++++++++++---------------------- tests/nn/test_inference.py | 3 +- 2 files changed, 30 insertions(+), 40 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 7f68ff28a..a8d2d128d 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -5,7 +5,7 @@ import attr import numpy as np import cv2 -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple from sleap import Track, LabeledFrame, Skeleton @@ -340,20 +340,22 @@ class FlowMaxTracksCandidateMaker(FlowCandidateMaker): """ max_tracks: int = None - min_points: int = 0 - img_scale: float = 1.0 - of_window_size: int = 21 - of_max_levels: int = 3 - save_shifted_instances: bool = False - track_window: int = 5 - shifted_instances: Dict[ - Tuple[int, int], List[ShiftedInstance] # keyed by (src_t, dst_t) - ] = attr.ib(factory=dict) - - @property - def uses_image(self): - return True + @staticmethod + def get_ref_instances( + ref_t: int, + ref_img: np.ndarray, + track_matching_queue_dict: dict(), + ) -> List[InstanceType]: + """Generates a list of instances based on the reference time and reference image.""" + instances = [] + for track, matched_items in track_matching_queue_dict.items(): + instances += [ + item.instance_t + for item in matched_items + if item.t == ref_t and np.all(item.img_t == ref_img) + ] + return instances def get_candidates( self, @@ -367,27 +369,19 @@ def get_candidates( # Prune old shifted instances to save time and memory self.prune_shifted_instances(t) - number_of_tracks = 0 - # -> List[ShiftedInstance] - def get_ref_instances(r_t, r_img) -> List[InstanceType]: - instances = [] - for track, matched_items in track_matching_queue_dict.items(): - instances + [ - item.instance_t - for item in matched_items - if item.t == r_t and np.all(item.img_t == r_img) - ] - return instances + tracks = [] for track, matched_items in track_matching_queue_dict.items(): - if track and number_of_tracks <= self.max_tracks: - number_of_tracks += 1 + if len(tracks) <= self.max_tracks: + tracks.append(track) for matched_item in matched_items: ref_t, ref_img = ( matched_item.t, matched_item.img_t, ) - ref_instances = get_ref_instances(r_t=ref_t, r_img=ref_img) + ref_instances = self.get_ref_instances( + ref_t, ref_img, track_matching_queue_dict + ) # Check if shifted instance was computed at earlier time if self.save_shifted_instances: @@ -455,15 +449,10 @@ def get_candidates( @attr.s(auto_attribs=True) -class SimpleMaxTracksCandidateMaker: +class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): """Class to generate instances based on the maximum number of tracks from prior frames.""" max_tracks: int = None - min_points: int = 0 - - @property - def uses_image(self): - return False def get_candidates( self, @@ -473,10 +462,10 @@ def get_candidates( ) -> List[InstanceType]: # Create set of matchable candidate instances from each track for max number of tracks. candidate_instances = [] - number_of_tracks = 0 + tracks = [] for track, matched_instances in track_matching_queue_dict.items(): - if track and number_of_tracks <= self.max_tracks: - number_of_tracks += 1 + if len(tracks) <= self.max_tracks: + tracks.append(track) for ref_instance in matched_instances: if ref_instance.instance_t.n_visible_points >= self.min_points: candidate_instances.append(ref_instance.instance_t) @@ -601,8 +590,8 @@ def _init_matching_queue(self): def reset_candidates(self): if self.max_tracking and self.track_matching_queue_dict: - for track, candidates in self.track_matching_queue_dict.items(): - candidates = deque(maxlen=self.track_window) + for track in self.track_matching_queue_dict: + self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) else: self.track_matching_queue = deque(maxlen=self.track_window) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 6273eac59..34e4bc67e 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -51,7 +51,8 @@ main as sleap_track, export_cli as sleap_export, ) -from sleap.nn.tracking import FlowCandidateMaker, Tracker +from sleap.nn.tracking import FlowCandidateMaker, FlowMaxTracksCandidateMaker, Tracker +from sleap.instance import Track sleap.nn.system.use_cpu_only() From 1969d2bb4dd7f58d360a80353ba42de592ed16dd Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Fri, 1 Sep 2023 11:13:00 -0700 Subject: [PATCH 14/26] [wip] refactor code --- sleap/instance.py | 3 + sleap/nn/tracking.py | 143 +++++++++++++++++-------------------- tests/nn/test_inference.py | 7 +- 3 files changed, 74 insertions(+), 79 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index c14038552..385f5f3aa 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -333,6 +333,9 @@ def matches(self, other: "Track"): """ return attr.asdict(self) == attr.asdict(other) + def __hash__(self) -> int: + return hash((self.spawned_on, self.name)) + # NOTE: # Instance cannot be a slotted class at the moment. This is because it creates diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index a8d2d128d..45a7f66f3 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -139,6 +139,51 @@ class FlowCandidateMaker: def uses_image(self): return True + def get_shifted_instances_from_earlier_time( + self, ref_t: int, ref_img: np.ndarray, ref_instances: List[InstanceType], t: int + ) -> (np.ndarray, List[InstanceType]): + """Check if shifted instance was computed at earlier time and return instances and corresponding image.""" + for ti in reversed(range(ref_t, t)): + if (ref_t, ti) in self.shifted_instances: + ref_shifted_instances = self.shifted_instances[(ref_t, ti)] + # Use shifted instance as a reference + if len(ref_shifted_instances.instances_t) > 0: + ref_img = ref_shifted_instances.img_t + ref_instances = ref_shifted_instances.instances_t + break + return [ref_img, ref_instances] + + def get_shifted_instances( + self, + ref_instances: List[InstanceType], + ref_img: np.ndarray, + ref_t: int, + img: np.ndarray, + t: int, + ) -> List[ShiftedInstance]: + """Returns a list of shifted instances and saves the shifted instances if necessary.""" + # Flow shift reference instances to current frame. + shifted_instances = self.flow_shift_instances( + ref_instances, + ref_img, + img, + min_shifted_points=self.min_points, + scale=self.img_scale, + window_size=self.of_window_size, + max_levels=self.of_max_levels, + ) + + # Save shifted instances. + if self.save_shifted_instances: + self.shifted_instances[(ref_t, t)] = MatchedShiftedFrameInstances( + ref_t, + t, + shifted_instances, + img, + ) + + return shifted_instances + def get_candidates( self, track_matching_queue: Deque[MatchedFrameInstances], @@ -159,39 +204,15 @@ def get_candidates( # Check if shifted instance was computed at earlier time if self.save_shifted_instances: - for ti in reversed(range(ref_t, t)): - if (ref_t, ti) in self.shifted_instances: - ref_shifted_instances = self.shifted_instances[(ref_t, ti)] - # Use shifted instance as a reference - if len(ref_shifted_instances.instances_t) > 0: - ref_img = ref_shifted_instances.img_t - ref_instances = ref_shifted_instances.instances_t - break + ref_img, ref_instances = self.get_shifted_instances_from_earlier_time( + ref_t, ref_img, ref_instances, t + ) if len(ref_instances) > 0: - # Flow shift reference instances to current frame. - shifted_instances = self.flow_shift_instances( - ref_instances, - ref_img, - img, - min_shifted_points=self.min_points, - scale=self.img_scale, - window_size=self.of_window_size, - max_levels=self.of_max_levels, + candidate_instances.extend( + self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t) ) - # Add to candidate pool. - candidate_instances.extend(shifted_instances) - - # Save shifted instances. - if self.save_shifted_instances: - self.shifted_instances[(ref_t, t)] = MatchedShiftedFrameInstances( - ref_t, - t, - shifted_instances, - img, - ) - return candidate_instances def prune_shifted_instances(self, t: int): @@ -320,22 +341,10 @@ def flow_shift_instances( @attr.s(auto_attribs=True) class FlowMaxTracksCandidateMaker(FlowCandidateMaker): - """Class for producing optical flow shift matching candidates with a maximum on the number of tracks. + """Class for producing optical flow shift matching candidates with cap on the number of tracks. Attributes: - min_points: Minimum number of points that must be detected in the new frame in - order to generate a new shifted instance. - img_scale: Factor to scale the images by when computing optical flow. Decrease - this to increase performance at the cost of finer accuracy. Sometimes - decreasing the image scale can improve performance with fast movements. - of_window_size: Optical flow window size to consider at each pyramid scale - level. - of_max_levels: Number of pyramid scale levels to consider. This is different - from the scale parameter, which determines the initial image scaling. - save_shifted_instances: If True, save the shifted instances between elapsed - frames. - track_window: How many frames back to look for candidate instances to match - instances in the current frame against. + max_tracks: The maximum number of tracks that needs to be maintained in order to avoid redundant/irrelevant tracks. """ @@ -345,7 +354,7 @@ class FlowMaxTracksCandidateMaker(FlowCandidateMaker): def get_ref_instances( ref_t: int, ref_img: np.ndarray, - track_matching_queue_dict: dict(), + track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], ) -> List[InstanceType]: """Generates a list of instances based on the reference time and reference image.""" instances = [] @@ -359,7 +368,7 @@ def get_ref_instances( def get_candidates( self, - track_matching_queue_dict: dict(), + track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], t: int, img: np.ndarray, *args, @@ -369,6 +378,7 @@ def get_candidates( # Prune old shifted instances to save time and memory self.prune_shifted_instances(t) + # Storing the tracks from the dictionary for counting purpose. tracks = [] for track, matched_items in track_matching_queue_dict.items(): @@ -385,42 +395,19 @@ def get_candidates( # Check if shifted instance was computed at earlier time if self.save_shifted_instances: - for ti in reversed(range(ref_t, t)): - if (ref_t, ti) in self.shifted_instances: - ref_shifted_instances = self.shifted_instances[ - (ref_t, ti) - ] - # Use shifted instance as a reference - if len(ref_shifted_instances.instances_t) > 0: - ref_img = ref_shifted_instances.img_t - ref_instances = ref_shifted_instances.instances_t - break - - if len(ref_instances) > 0: - # Flow shift reference instances to current frame. - shifted_instances = self.flow_shift_instances( - ref_instances, + ( ref_img, - img, - min_shifted_points=self.min_points, - scale=self.img_scale, - window_size=self.of_window_size, - max_levels=self.of_max_levels, + ref_instances, + ) = self.get_shifted_instances_from_earlier_time( + ref_t, ref_img, ref_instances, t ) - # Add to candidate pool. - candidate_instances.extend(shifted_instances) - - # Save shifted instances. - if self.save_shifted_instances: - self.shifted_instances[ - (ref_t, t) - ] = MatchedShiftedFrameInstances( - ref_t, - t, - shifted_instances, - img, + if len(ref_instances) > 0: + candidate_instances.extend( + self.get_shifted_instances( + ref_instances, ref_img, ref_t, img, t ) + ) return candidate_instances @@ -456,7 +443,7 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): def get_candidates( self, - track_matching_queue_dict: dict(), + track_matching_queue_dict: Dict, *args, **kwargs, ) -> List[InstanceType]: diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 34e4bc67e..8f6809001 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -51,7 +51,12 @@ main as sleap_track, export_cli as sleap_export, ) -from sleap.nn.tracking import FlowCandidateMaker, FlowMaxTracksCandidateMaker, Tracker +from sleap.nn.tracking import ( + MatchedFrameInstance, + FlowCandidateMaker, + FlowMaxTracksCandidateMaker, + Tracker, +) from sleap.instance import Track From c927090ccaf9bebd158d0ec1f85a0ce03a4734bf Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Fri, 1 Sep 2023 15:55:34 -0700 Subject: [PATCH 15/26] Add test function to check max tracks --- tests/nn/test_inference.py | 52 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 8f6809001..fe848bb1c 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1492,6 +1492,58 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): assert abs(key[0] - key[1]) <= track_window # References within window +@pytest.mark.parametrize( + "max_tracks, trackername", + [ + (2, "flowmaxtracks"), + (2, "simplemaxtracks"), + ], +) +def test_max_tracks_matching_queue( + centered_pair_predictions: Labels, max_tracks, trackername +): + """Test flow max tracks instance generation.""" + labels: Labels = centered_pair_predictions + max_tracking = True + track_window = 5 + + # Setup flow max tracker + tracker: Tracker = Tracker.make_tracker_by_name( + tracker=trackername, + track_window=track_window, + save_shifted_instances=True, + max_tracking=max_tracking, + max_tracks=max_tracks, + ) + + tracker.candidate_maker = cast(FlowMaxTracksCandidateMaker, tracker.candidate_maker) + + # Run tracking + frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) + + for lf in frames[:20]: + + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + tracker.track(**track_args) + + if trackername == "flowmaxtracks": + # Check that saved instances are pruned to track window + for key in tracker.candidate_maker.shifted_instances.keys(): + assert lf.frame_idx - key[0] <= track_window # Keys are pruned + assert abs(key[0] - key[1]) <= track_window + + # Check if the length of each of the tracks is not more than the track window + for track in tracker.track_matching_queue_dict.keys(): + assert len(tracker.track_matching_queue_dict[track]) <= track_window + + # Check if number of tracks that are generated are not more than the maximum tracks + assert len(tracker.track_matching_queue_dict) <= max_tracks + + def test_movenet_inference(movenet_video): inference_layer = MoveNetInferenceLayer(model_name="lightning") inference_model = MoveNetInferenceModel(inference_layer) From 2e4fecea9943e92feaed4f5e108ee5ca798e3d56 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam Date: Fri, 1 Sep 2023 17:29:08 -0700 Subject: [PATCH 16/26] Added suggestions and feedback --- sleap/instance.py | 3 --- sleap/nn/tracking.py | 52 +++++++++++++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index 385f5f3aa..c14038552 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -333,9 +333,6 @@ def matches(self, other: "Track"): """ return attr.asdict(self) == attr.asdict(other) - def __hash__(self) -> int: - return hash((self.spawned_on, self.name)) - # NOTE: # Instance cannot be a slotted class at the moment. This is because it creates diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 45a7f66f3..345e525a6 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -142,7 +142,14 @@ def uses_image(self): def get_shifted_instances_from_earlier_time( self, ref_t: int, ref_img: np.ndarray, ref_instances: List[InstanceType], t: int ) -> (np.ndarray, List[InstanceType]): - """Check if shifted instance was computed at earlier time and return instances and corresponding image.""" + """Generate shifted instances and corresponding image from earlier time. + + Args: + ref_instances: Reference instances in the previous frame. + ref_img: Previous frame image as a numpy array. + ref_t: Previous frame time instance. + t: Current time instance. + """ for ti in reversed(range(ref_t, t)): if (ref_t, ti) in self.shifted_instances: ref_shifted_instances = self.shifted_instances[(ref_t, ti)] @@ -161,7 +168,15 @@ def get_shifted_instances( img: np.ndarray, t: int, ) -> List[ShiftedInstance]: - """Returns a list of shifted instances and saves the shifted instances if necessary.""" + """Returns a list of shifted instances and save shifted instances if needed. + + Args: + ref_instances: Reference instances in the previous frame. + ref_img: Previous frame image as a numpy array. + ref_t: Previous frame time instance. + img: Current frame image as a numpy array. + t: Current time instance. + """ # Flow shift reference instances to current frame. shifted_instances = self.flow_shift_instances( ref_instances, @@ -341,10 +356,10 @@ def flow_shift_instances( @attr.s(auto_attribs=True) class FlowMaxTracksCandidateMaker(FlowCandidateMaker): - """Class for producing optical flow shift matching candidates with cap on the number of tracks. + """Class for producing optical flow shift matching candidates with maximum tracks. Attributes: - max_tracks: The maximum number of tracks that needs to be maintained in order to avoid redundant/irrelevant tracks. + max_tracks: The maximum number of tracks to avoid redundant tracks. """ @@ -356,7 +371,14 @@ def get_ref_instances( ref_img: np.ndarray, track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], ) -> List[InstanceType]: - """Generates a list of instances based on the reference time and reference image.""" + """Generates a list of instances based on the reference time and image. + + Args: + ref_t: Previous frame time instance. + ref_img: Previous frame image as a numpy array. + track_matching_queue_dict: A dictionary of mapping between the tracks + and the corresponding instances associated with the track. + """ instances = [] for track, matched_items in track_matching_queue_dict.items(): instances += [ @@ -437,7 +459,7 @@ def get_candidates( @attr.s(auto_attribs=True) class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): - """Class to generate instances based on the maximum number of tracks from prior frames.""" + """Class to generate instances with maximum number of tracks from prior frames.""" max_tracks: int = None @@ -447,7 +469,7 @@ def get_candidates( *args, **kwargs, ) -> List[InstanceType]: - # Create set of matchable candidate instances from each track for max number of tracks. + # Create set of matchable candidate instances from each track. candidate_instances = [] tracks = [] for track, matched_instances in track_matching_queue_dict.items(): @@ -534,6 +556,7 @@ class Tracker(BaseTracker): use a robust quantile similarity score for the track. If the value is 1, use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value. + max_tracking: Max tracking is incorporated when this is set to true. """ max_tracks: int = None @@ -576,7 +599,7 @@ def _init_matching_queue(self): return deque(maxlen=self.track_window) def reset_candidates(self): - if self.max_tracking and self.track_matching_queue_dict: + if self.max_tracking: for track in self.track_matching_queue_dict: self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) else: @@ -626,7 +649,8 @@ def track( if self.max_tracking: if len(self.track_matching_queue_dict) > 0: - # Default to last timestep + 1 if available. Here we find the track that has the most instances. + # Default to last timestep + 1 if available. + # Here we find the track that has the most instances. track_with_max_instances = max( self.track_matching_queue_dict, key=lambda track: len(self.track_matching_queue_dict[track]), @@ -805,7 +829,7 @@ def make_tracker_by_name( kf_init_frame_count: int = 0, kf_node_indices: Optional[list] = None, # Max tracking options - max_tracks: int = None, + max_tracks: Optional[int] = None, max_tracking: bool = False, **kwargs, ) -> BaseTracker: @@ -900,9 +924,7 @@ def get_by_name_factory_options(cls): option = dict(name="max_tracking", default=False) option["type"] = bool - option[ - "help" - ] = "If true then the tracker will cap the max number of tracks created or tracked." + option["help"] = "If true then the tracker will cap the max number of tracks." options.append(option) option = dict(name="max_tracks", default=None) @@ -1060,7 +1082,7 @@ class FlowTracker(Tracker): class FlowMaxTracker(Tracker): - """A Tracker pre-configured to use optical flow shifted candidates with a maximum limit on tracks.""" + """Pre-configured tracker to use optical flow shifted candidates with max tracks.""" max_tracks: int = attr.ib(kw_only=True) similarity_function: Callable = instance_similarity @@ -1080,7 +1102,7 @@ class SimpleTracker(Tracker): @attr.s(auto_attribs=True) class SimpleMaxTracker(Tracker): - """A tracked pre-configured to use simple, non-image-based candidates with a maximum number of tracks.""" + """Pre-configured tracker to use simple, non-image-based candidates with max tracks.""" max_tracks: int = attr.ib(kw_only=True) similarity_function: Callable = instance_iou From bff3333912ba88e2933bd4b2233e1bfff1c3365b Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Wed, 6 Sep 2023 14:51:40 -0700 Subject: [PATCH 17/26] Prevent the creation of more than max tracks when we have unmatched detections --- sleap/nn/tracking.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 345e525a6..9865b7db5 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -771,6 +771,13 @@ def spawn_for_untracked_instances( if inst.n_visible_points < self.min_new_track_points: continue + # Skip if we've reached the maximum number of tracks. + if ( + self.max_tracking + and len(self.track_matching_queue_dict) >= self.max_tracks + ): + break + # Spawn new track. new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}") self.spawned_tracks.append(new_track) From 0492ff729e4b1e3190b52fa6a3bbba99cb78d24c Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Wed, 6 Sep 2023 14:51:47 -0700 Subject: [PATCH 18/26] Add tests --- tests/nn/test_tracker_components.py | 199 ++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 7b4dafd63..040a7f3b8 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -168,3 +168,202 @@ def test_frame_match_object(): assert matches[1].track == "track b" assert matches[1].instance == "instance b" + + +def test_max_tracking(): + skel = Skeleton.from_names_and_edge_inds( + ["A", "B", "C"], edge_inds=[[0, 1], [1, 2]] + ) + + def make_inst(x, y): + pts = np.array([[-0.1, -0.1], [0.0, 0.0], [0.1, 0.1]]) + np.array([[x, y]]) + return PredictedInstance.from_numpy(pts, [1, 1, 1], 1, skel) + + # Track 2 instances with gap > window size + preds = [ + [ + make_inst(0, 0), + make_inst(0, 1), + ], + [ + make_inst(0.1, 0), + make_inst(0.1, 1), + ], + [ + make_inst(0.2, 0), + make_inst(0.2, 1), + ], + [ + make_inst(0.3, 0), + ], + [ + make_inst(0.4, 0), + ], + [ + make_inst(0.5, 0), + make_inst(0.5, 1), + ], + [ + make_inst(0.6, 0), + make_inst(0.6, 1), + ], + ] + + tracker = Tracker.make_tracker_by_name( + tracker="simple", + # tracker="simplemaxtracks", + match="hungarian", + track_window=2, + # max_tracks=2, + # max_tracking=True, + ) + + tracked = [] + for insts in preds: + tracked_insts = tracker.track(insts) + tracked.append(tracked_insts) + all_tracks = list(set([inst.track for frame in tracked for inst in frame])) + + assert len(all_tracks) == 3 + + tracker = Tracker.make_tracker_by_name( + # tracker="simple", + tracker="simplemaxtracks", + match="hungarian", + track_window=2, + max_tracks=2, + max_tracking=True, + ) + + tracked = [] + for insts in preds: + tracked_insts = tracker.track(insts) + tracked.append(tracked_insts) + all_tracks = list(set([inst.track for frame in tracked for inst in frame])) + + assert len(all_tracks) == 2 + + # Test 2 instances with both tracks with gap > window size + preds = [ + [ + make_inst(0, 0), + make_inst(0, 1), + ], + [ + make_inst(0.1, 0), + make_inst(0.1, 1), + ], + [ + make_inst(0.2, 0), + make_inst(0.2, 1), + ], + [], + [], + [ + make_inst(0.5, 0), + make_inst(0.5, 1), + ], + [ + make_inst(0.6, 0), + make_inst(0.6, 1), + ], + ] + + tracker = Tracker.make_tracker_by_name( + tracker="simple", + # tracker="simplemaxtracks", + match="hungarian", + track_window=2, + # max_tracks=2, + # max_tracking=True, + ) + + tracked = [] + for insts in preds: + tracked_insts = tracker.track(insts) + tracked.append(tracked_insts) + all_tracks = list(set([inst.track for frame in tracked for inst in frame])) + + assert len(all_tracks) == 4 + + tracker = Tracker.make_tracker_by_name( + # tracker="simple", + tracker="simplemaxtracks", + match="hungarian", + track_window=2, + max_tracks=2, + max_tracking=True, + ) + + tracked = [] + for insts in preds: + tracked_insts = tracker.track(insts) + tracked.append(tracked_insts) + all_tracks = list(set([inst.track for frame in tracked for inst in frame])) + + assert len(all_tracks) == 2 + + # Test having more than 2 detected instances in a frame + preds = [ + [ + make_inst(0, 0), + make_inst(0, 1), + ], + [ + make_inst(0.1, 0), + make_inst(0.1, 1), + ], + [ + make_inst(0.2, 0), + make_inst(0.2, 1), + ], + [ + make_inst(0.3, 0), + ], + [ + make_inst(0.4, 0), + ], + [ + make_inst(0.5, 0), + make_inst(0.5, 1), + ], + [ + make_inst(0.6, 0), + make_inst(0.6, 1), + make_inst(0.6, 0.5), + ], + ] + + tracker = Tracker.make_tracker_by_name( + tracker="simple", + # tracker="simplemaxtracks", + match="hungarian", + track_window=2, + # max_tracks=2, + # max_tracking=True, + ) + + tracked = [] + for insts in preds: + tracked_insts = tracker.track(insts) + tracked.append(tracked_insts) + all_tracks = list(set([inst.track for frame in tracked for inst in frame])) + + assert len(all_tracks) == 4 + + tracker = Tracker.make_tracker_by_name( + # tracker="simple", + tracker="simplemaxtracks", + match="hungarian", + track_window=2, + max_tracks=2, + max_tracking=True, + ) + + tracked = [] + for insts in preds: + tracked_insts = tracker.track(insts) + tracked.append(tracked_insts) + all_tracks = list(set([inst.track for frame in tracked for inst in frame])) + + assert len(all_tracks) == 2 From 89b04af76850616fb3c9c104717de0485fc27ecb Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Wed, 6 Sep 2023 14:59:39 -0700 Subject: [PATCH 19/26] Use maximum tracking by default when loading model via high level API --- sleap/nn/inference.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 222a80bda..6ca1a20f6 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -68,7 +68,7 @@ ) from sleap.nn.utils import reset_input_layer from sleap.io.dataset import Labels -from sleap.util import frame_list +from sleap.util import frame_list, make_scoped_dictionary from sleap.instance import PredictedInstance, LabeledFrame from tensorflow.python.framework.convert_to_constants import ( @@ -4773,8 +4773,7 @@ def load_model( be performed. tracker_window: Number of frames of history to use when tracking. No effect when `tracker` is `None`. - tracker_max_instances: If not `None`, discard instances beyond this count when - tracking. No effect when `tracker` is `None`. + tracker_max_instances: If not `None`, create at most this many tracks. disable_gpu_preallocation: If `True` (the default), initialize the GPU and disable preallocation of memory. This is necessary to prevent freezing on some systems with low GPU memory and has negligible impact on performance. @@ -4863,11 +4862,18 @@ def unpack_sleap_model(model_path): ) predictor.verbosity = progress_reporting if tracker is not None: + use_max_tracker = tracker_max_instances is not None + if use_max_tracker and not tracker.endswith("maxtracks"): + # Append maxtracks to the tracker name to use the right tracker variants. + tracker += "maxtracks" + predictor.tracker = Tracker.make_tracker_by_name( tracker=tracker, track_window=tracker_window, post_connect_single_breaks=True, - clean_instance_count=tracker_max_instances, + max_tracking=use_max_tracker, + max_tracks=tracker_max_instances, + # clean_instance_count=tracker_max_instances, ) # Remove temp dirs. @@ -5335,7 +5341,7 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: Returns: An instance of `Tracker` or `None` if tracking method was not specified. """ - policy_args = sleap.util.make_scoped_dictionary(vars(args), exclude_nones=True) + policy_args = make_scoped_dictionary(vars(args), exclude_nones=True) if "tracking" in policy_args: tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) return tracker From e62db0f19019adad121aad0af4d6f46acc48d715 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 10:42:58 -0700 Subject: [PATCH 20/26] Lint --- sleap/nn/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6ca1a20f6..6d7d24f8c 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4866,7 +4866,7 @@ def unpack_sleap_model(model_path): if use_max_tracker and not tracker.endswith("maxtracks"): # Append maxtracks to the tracker name to use the right tracker variants. tracker += "maxtracks" - + predictor.tracker = Tracker.make_tracker_by_name( tracker=tracker, track_window=tracker_window, From 3a72e1a4e13012a1386450aacf1b899d3a9f1d39 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 10:44:36 -0700 Subject: [PATCH 21/26] Fix integration test --- tests/nn/test_tracking_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 310001cb9..4923a048b 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -116,7 +116,7 @@ def make_tracker( tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 ): if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks": - tracker[tracker_name]( + tracker = tracker[tracker_name]( matching_function=matchers[matcher_name], similarity_function=similarities[sim_name], max_tracks=max_tracks, From 74ea9807c335c67472febce1cec651754ba96027 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 10:45:43 -0700 Subject: [PATCH 22/26] Refactor max tracker tests --- tests/nn/test_tracker_components.py | 178 ++++++++++++++++------------ 1 file changed, 99 insertions(+), 79 deletions(-) diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 040a7f3b8..f861241ee 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -170,7 +170,7 @@ def test_frame_match_object(): assert matches[1].instance == "instance b" -def test_max_tracking(): +def make_insts(trx): skel = Skeleton.from_names_and_edge_inds( ["A", "B", "C"], edge_inds=[[0, 1], [1, 2]] ) @@ -179,35 +179,47 @@ def make_inst(x, y): pts = np.array([[-0.1, -0.1], [0.0, 0.0], [0.1, 0.1]]) + np.array([[x, y]]) return PredictedInstance.from_numpy(pts, [1, 1, 1], 1, skel) + insts = [] + for frame in trx: + insts_frame = [] + for x, y in frame: + insts_frame.append(make_inst(x, y)) + insts.append(insts_frame) + return insts + + +def test_max_tracking_large_gap_single_track(): # Track 2 instances with gap > window size - preds = [ - [ - make_inst(0, 0), - make_inst(0, 1), - ], - [ - make_inst(0.1, 0), - make_inst(0.1, 1), - ], - [ - make_inst(0.2, 0), - make_inst(0.2, 1), - ], - [ - make_inst(0.3, 0), - ], + preds = make_insts( [ - make_inst(0.4, 0), - ], - [ - make_inst(0.5, 0), - make_inst(0.5, 1), - ], - [ - make_inst(0.6, 0), - make_inst(0.6, 1), - ], - ] + [ + (0, 0), + (0, 1), + ], + [ + (0.1, 0), + (0.1, 1), + ], + [ + (0.2, 0), + (0.2, 1), + ], + [ + (0.3, 0), + ], + [ + (0.4, 0), + ], + [ + (0.5, 0), + (0.5, 1), + ], + [ + (0.6, 0), + (0.6, 1), + ], + ] + ) tracker = Tracker.make_tracker_by_name( tracker="simple", @@ -243,31 +255,35 @@ def make_inst(x, y): assert len(all_tracks) == 2 + +def test_max_tracking_small_gap_on_both_tracks(): # Test 2 instances with both tracks with gap > window size - preds = [ - [ - make_inst(0, 0), - make_inst(0, 1), - ], - [ - make_inst(0.1, 0), - make_inst(0.1, 1), - ], + preds = make_insts( [ - make_inst(0.2, 0), - make_inst(0.2, 1), - ], - [], - [], - [ - make_inst(0.5, 0), - make_inst(0.5, 1), - ], - [ - make_inst(0.6, 0), - make_inst(0.6, 1), - ], - ] + [ + (0, 0), + (0, 1), + ], + [ + (0.1, 0), + (0.1, 1), + ], + [ + (0.2, 0), + (0.2, 1), + ], + [], + [], + [ + (0.5, 0), + (0.5, 1), + ], + [ + (0.6, 0), + (0.6, 1), + ], + ] + ) tracker = Tracker.make_tracker_by_name( tracker="simple", @@ -303,36 +319,40 @@ def make_inst(x, y): assert len(all_tracks) == 2 + +def test_max_tracking_extra_detections(): # Test having more than 2 detected instances in a frame - preds = [ - [ - make_inst(0, 0), - make_inst(0, 1), - ], - [ - make_inst(0.1, 0), - make_inst(0.1, 1), - ], + preds = make_insts( [ - make_inst(0.2, 0), - make_inst(0.2, 1), - ], - [ - make_inst(0.3, 0), - ], - [ - make_inst(0.4, 0), - ], - [ - make_inst(0.5, 0), - make_inst(0.5, 1), - ], - [ - make_inst(0.6, 0), - make_inst(0.6, 1), - make_inst(0.6, 0.5), - ], - ] + [ + (0, 0), + (0, 1), + ], + [ + (0.1, 0), + (0.1, 1), + ], + [ + (0.2, 0), + (0.2, 1), + ], + [ + (0.3, 0), + ], + [ + (0.4, 0), + ], + [ + (0.5, 0), + (0.5, 1), + ], + [ + (0.6, 0), + (0.6, 1), + (0.6, 0.5), + ], + ] + ) tracker = Tracker.make_tracker_by_name( tracker="simple", From f8254a418bd2e7e3e77e4f159c0c0b68005cca4b Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 13:14:00 -0700 Subject: [PATCH 23/26] Add integration test for CLI --- tests/nn/test_tracking_integration.py | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 4923a048b..6470a0cdb 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -3,10 +3,42 @@ import os import time +import sleap +from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components from sleap.io.dataset import Labels, LabeledFrame +def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): + cli = ( + "--tracking.tracker simple " + "--frames 200-300 " + f"-o {tmpdir}/simpletracks.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/simpletracks.slp") + assert len(labels.tracks) == 27 + + +def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): + cli = ( + "--tracking.tracker simplemaxtracks " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + "--frames 200-300 " + f"-o {tmpdir}/simplemaxtracks.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/simplemaxtracks.slp") + assert len(labels.tracks) == 2 + + +# TODO: Refactor the below things into a real test suite. + + def make_ground_truth(frames, tracker, gt_filename): t0 = time.time() new_labels = run_tracker(frames, tracker) From e1e7285293ad8fda071a0d6537c7abeb6cf7c9ac Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 13:14:31 -0700 Subject: [PATCH 24/26] typo --- tests/nn/test_tracking_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 6470a0cdb..a6592dc4d 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -148,7 +148,7 @@ def make_tracker( tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 ): if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks": - tracker = tracker[tracker_name]( + tracker = trackers[tracker_name]( matching_function=matchers[matcher_name], similarity_function=similarities[sim_name], max_tracks=max_tracks, From 91ad0b64495406f1053ddb82fa69db43a8695b31 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 13:43:03 -0700 Subject: [PATCH 25/26] Add max tracks to the tracking GUI --- sleap/config/pipeline_form.yaml | 96 ++++++++++++++++++++------------- sleap/gui/learning/runners.py | 12 +++++ 2 files changed, 71 insertions(+), 37 deletions(-) diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index 77722f0d4..cbcea2be5 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -376,28 +376,39 @@ inference: none: flow: - - type: text - text: 'Pre-tracker data cleaning:' - - name: tracking.target_instance_count - label: Target Number of Instances Per Frame - type: optional_int - none_label: No target - default_disabled: true - range: 1,100 - default: 1 - - name: tracking.pre_cull_to_target - label: Cull to Target Instance Count - type: bool - default: false - - name: tracking.pre_cull_iou_threshold - label: Cull using IoU Threshold - type: double - default: 0.8 + # - type: text + # text: 'Pre-tracker data cleaning:' + # - name: tracking.target_instance_count + # label: Target Number of Instances Per Frame + # type: optional_int + # none_label: No target + # default_disabled: true + # range: 1,100 + # default: 1 + # - name: tracking.pre_cull_to_target + # label: Cull to Target Instance Count + # type: bool + # default: false + # - name: tracking.pre_cull_iou_threshold + # label: Cull using IoU Threshold + # type: double + # default: 0.8 - type: text text: 'Tracking with optical flow:
This tracker "shifts" instances from previous frames using optical flow before matching instances in each frame to the shifted instances from prior frames.' + # - name: tracking.max_tracking + # label: Limit max number of tracks + # type: bool + default: false + - name: tracking.max_tracks + label: Max number of tracks + type: optional_int + none_label: No limit + default_disabled: true + range: 1,100 + default: 1 - name: tracking.similarity label: Similarity Method type: list @@ -422,10 +433,10 @@ inference: none_label: Use max (non-robust) range: 0,1 default: 0.95 - - name: tracking.save_shifted_instances - label: Save shifted instances - type: bool - default: false + # - name: tracking.save_shifted_instances + # label: Save shifted instances + # type: bool + # default: false - type: text text: 'Kalman filter-based tracking:
Uses the above tracking options to track instances for an initial @@ -449,27 +460,38 @@ inference: default: false simple: + # - type: text + # text: 'Pre-tracker data cleaning:' + # - name: tracking.target_instance_count + # label: Target Number of Instances Per Frame + # type: optional_int + # none_label: No target + # default_disabled: true + # range: 1,100 + # default: 1 + # - name: tracking.pre_cull_to_target + # label: Cull to Target Instance Count + # type: bool + # default: false + # - name: tracking.pre_cull_iou_threshold + # label: Cull using IoU Threshold + # type: double + # default: 0.8 - type: text - text: 'Pre-tracker data cleaning:' - - name: tracking.target_instance_count - label: Target Number of Instances Per Frame + text: 'Tracking:
+ This tracker assigns track identities by matching instances from prior + frames to instances on subsequent frames.' + # - name: tracking.max_tracking + # label: Limit max number of tracks + # type: bool + # default: false + - name: tracking.max_tracks + label: Max number of tracks type: optional_int - none_label: No target + none_label: No limit default_disabled: true range: 1,100 default: 1 - - name: tracking.pre_cull_to_target - label: Cull to Target Instance Count - type: bool - default: false - - name: tracking.pre_cull_iou_threshold - label: Cull using IoU Threshold - type: double - default: 0.8 - - type: text - text: 'Tracking:
- This tracker assigns track identities by matching instances from prior - frames to instances on subsequent frames.' - name: tracking.similarity label: Similarity Method type: list diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 3909f1019..ca60c4127 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -224,6 +224,7 @@ def make_predict_cli_call( optional_items_as_nones = ( "tracking.target_instance_count", + "tracking.max_tracks", "tracking.kf_init_frame_count", "tracking.robust", "max_instances", @@ -233,6 +234,16 @@ def make_predict_cli_call( if key in self.inference_params and self.inference_params[key] is None: del self.inference_params[key] + # Setting max_tracks to True means we want to use the max_tracking mode. + if "tracking.max_tracks" in self.inference_params: + self.inference_params["tracking.max_tracking"] = True + + # Hacky: Update the tracker name to include "maxtracks" suffix. + if self.inference_params["tracking.tracker"] in ("simple", "flow"): + self.inference_params["tracking.tracker"] = ( + self.inference_params["tracking.tracker"] + "maxtracks" + ) + # --tracking.kf_init_frame_count enables the kalman filter tracking # so if not set, then remove other (unused) args if "tracking.kf_init_frame_count" not in self.inference_params: @@ -241,6 +252,7 @@ def make_predict_cli_call( bool_items_as_ints = ( "tracking.pre_cull_to_target", + "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", ) From 0b2fdd9282dee40cd275bb040f91e575754bb048 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 8 Sep 2023 14:02:37 -0700 Subject: [PATCH 26/26] Update CLI docs and add examples --- docs/guides/cli.md | 220 +++++++++++++++++++++++---------------------- 1 file changed, 114 insertions(+), 106 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 0c08e9b17..35ea52171 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -118,158 +118,166 @@ optional arguments: If you specify how many identities there should be in a frame (i.e., the number of animals) with the {code}`--tracking.clean_instance_count` argument, then we will use a heuristic method to connect "breaks" in the track identities where we lose one identity and spawn another. This can be used as part of the inference pipeline (if models are specified), as part of the tracking-only pipeline (if the predictions file is specified and no models are specified), or by itself on predictions with pre-tracked identities (if you specify {code}`--tracking.tracker none`). See {ref}`proofreading` for more details on tracking. ```none -usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] - [--only-suggested-frames] [-o OUTPUT] [--no-empty-frames] - [--verbosity {none,rich,json}] - [--video.dataset VIDEO.DATASET] - [--video.input_format VIDEO.INPUT_FORMAT] - [--video.index VIDEO.INDEX] - [--cpu | --first-gpu | --last-gpu | --gpu GPU] - [--peak_threshold PEAK_THRESHOLD] [--batch_size BATCH_SIZE] - [--open-in-gui] [--tracking.tracker TRACKING.TRACKER] - [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT] - [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] - [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD] +usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [--only-suggested-frames] [-o OUTPUT] [--no-empty-frames] + [--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT] + [--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO] + [--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD] + [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING] + [--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT] + [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD] [--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS] - [--tracking.clean_instance_count TRACKING.CLEAN_INSTANCE_COUNT] - [--tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD] - [--tracking.similarity TRACKING.SIMILARITY] - [--tracking.match TRACKING.MATCH] - [--tracking.track_window TRACKING.TRACK_WINDOW] - [--tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES] - [--tracking.min_new_track_points TRACKING.MIN_NEW_TRACK_POINTS] - [--tracking.min_match_points TRACKING.MIN_MATCH_POINTS] - [--tracking.img_scale TRACKING.IMG_SCALE] - [--tracking.of_window_size TRACKING.OF_WINDOW_SIZE] - [--tracking.of_max_levels TRACKING.OF_MAX_LEVELS] - [--tracking.kf_node_indices TRACKING.KF_NODE_INDICES] + [--tracking.clean_instance_count TRACKING.CLEAN_INSTANCE_COUNT] [--tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD] + [--tracking.similarity TRACKING.SIMILARITY] [--tracking.match TRACKING.MATCH] [--tracking.robust TRACKING.ROBUST] + [--tracking.track_window TRACKING.TRACK_WINDOW] [--tracking.min_new_track_points TRACKING.MIN_NEW_TRACK_POINTS] + [--tracking.min_match_points TRACKING.MIN_MATCH_POINTS] [--tracking.img_scale TRACKING.IMG_SCALE] + [--tracking.of_window_size TRACKING.OF_WINDOW_SIZE] [--tracking.of_max_levels TRACKING.OF_MAX_LEVELS] + [--tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES] [--tracking.kf_node_indices TRACKING.KF_NODE_INDICES] [--tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT] [data_path] positional arguments: - data_path Path to data to predict on. This can be a labels - (.slp) file or any supported video format. + data_path Path to data to predict on. This can be a labels (.slp) file or any supported video format. optional arguments: -h, --help show this help message and exit -m MODELS, --model MODELS - Path to trained model directory (with - training_config.json). Multiple models can be - specified, each preceded by --model. - --frames FRAMES List of frames to predict when running on a video. Can - be specified as a comma separated list (e.g. 1,2,3) or - a range separated by hyphen (e.g., 1-3, for 1,2,3). If - not provided, defaults to predicting on the entire - video. + Path to trained model directory (with training_config.json). Multiple models can be specified, each preceded by --model. + --frames FRAMES List of frames to predict when running on a video. Can be specified as a comma separated list (e.g. 1,2,3) or a range + separated by hyphen (e.g., 1-3, for 1,2,3). If not provided, defaults to predicting on the entire video. --only-labeled-frames - Only run inference on user labeled frames when running - on labels dataset. This is useful for generating - predictions to compare against ground truth. + Only run inference on user labeled frames when running on labels dataset. This is useful for generating predictions to compare + against ground truth. --only-suggested-frames - Only run inference on unlabeled suggested frames when - running on labels dataset. This is useful for - generating predictions for initialization during - labeling. + Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for + initialization during labeling. -o OUTPUT, --output OUTPUT - The output filename to use for the predicted data. If - not provided, defaults to - '[data_path].predictions.slp' if generating predictions or - '[data_path].[tracker].[similarity method].[matching method].slp' - if retracking predictions. - --no-empty-frames Clear any empty frames that did not have any detected - instances before saving to output. - -n, --max_instances MAX_INSTANCES - Limit maximum number of instances in multi-instance models. - Not available for ID models. Defaults to None. + The output filename to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'. + --no-empty-frames Clear any empty frames that did not have any detected instances before saving to output. --verbosity {none,rich,json} - Verbosity of inference progress reporting. 'none' does - not output anything during inference, 'rich' displays - an updating progress bar, and 'json' outputs the - progress as a JSON encoded response to the console. + Verbosity of inference progress reporting. 'none' does not output anything during inference, 'rich' displays an updating + progress bar, and 'json' outputs the progress as a JSON encoded response to the console. --video.dataset VIDEO.DATASET The dataset for HDF5 videos. --video.input_format VIDEO.INPUT_FORMAT The input_format for HDF5 videos. --video.index VIDEO.INDEX - The index of the video to run inference on. Only used if - data_path points to a labels file. - --cpu Run inference only on CPU. If not specified, will use - available GPU. + Integer index of video in .slp file to predict on. To be used with an .slp path as an alternative to specifying the video + path. + --cpu Run inference only on CPU. If not specified, will use available GPU. --first-gpu Run inference on the first GPU, if available. --last-gpu Run inference on the last GPU, if available. - --gpu GPU Run training on the i-th GPU on the system. If 'auto', run on - the GPU with the highest percentage of available memory. + --gpu GPU Run training on the i-th GPU on the system. If 'auto', run on the GPU with the highest percentage of available memory. --max_edge_length_ratio MAX_EDGE_LENGTH_RATIO - The maximum expected length of a connected pair of points as a - fraction of the image size. Candidate connections longer than - this length will be penalized during matching. Only applies to - bottom-up (PAF) models. + The maximum expected length of a connected pair of points as a fraction of the image size. Candidate connections longer than + this length will be penalized during matching. Only applies to bottom-up (PAF) models. --dist_penalty_weight DIST_PENALTY_WEIGHT - A coefficient to scale weight of the distance penalty. Set to - values greater than 1.0 to enforce the distance penalty more + A coefficient to scale weight of the distance penalty. Set to values greater than 1.0 to enforce the distance penalty more strictly. Only applies to bottom-up (PAF) models. - --peak_threshold PEAK_THRESHOLD - Minimum confidence map value to consider a peak as - valid. --batch_size BATCH_SIZE - Number of frames to predict at a time. Larger values - result in faster inference speeds, but require more - memory. - --open-in-gui Open the resulting predictions in the GUI when - finished. + Number of frames to predict at a time. Larger values result in faster inference speeds, but require more memory. + --open-in-gui Open the resulting predictions in the GUI when finished. + --peak_threshold PEAK_THRESHOLD + Minimum confidence map value to consider a peak as valid. + -n MAX_INSTANCES, --max_instances MAX_INSTANCES + Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. --tracking.tracker TRACKING.TRACKER - Options: simple, flow, None (default: None) + Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None) + --tracking.max_tracking TRACKING.MAX_TRACKING + If true then the tracker will cap the max number of tracks. (default: False) + --tracking.max_tracks TRACKING.MAX_TRACKS + Maximum number of tracks to be tracked by the tracker. (default: None) --tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT - Target number of instances to track per frame. - (default: 0) + Target number of instances to track per frame. (default: 0) --tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET - If non-zero and target_instance_count is also non- - zero, then cull instances over target count per frame - *before* tracking. (default: 0) + If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. + (default: 0) --tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD - If non-zero and pre_cull_to_target also set, then use - IOU threshold to remove overlapping instances over - count *before* tracking. (default: 0) + If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* + tracking. (default: 0) --tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS - If non-zero and target_instance_count is also non- - zero, then connect track breaks when exactly one track - is lost and exactly one track is spawned in frame. - (default: 0) + If non-zero and target_instance_count is also non-zero, then connect track breaks when exactly one track is lost and exactly + one track is spawned in frame. (default: 0) --tracking.clean_instance_count TRACKING.CLEAN_INSTANCE_COUNT - Target number of instances to clean *after* tracking. - (default: 0) + Target number of instances to clean *after* tracking. (default: 0) --tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD - IOU to use when culling instances *after* tracking. - (default: 0) + IOU to use when culling instances *after* tracking. (default: 0) --tracking.similarity TRACKING.SIMILARITY Options: instance, centroid, iou (default: instance) --tracking.match TRACKING.MATCH Options: hungarian, greedy (default: greedy) + --tracking.robust TRACKING.ROBUST + Robust quantile of similarity score for instance matching. If equal to 1, keep the max similarity score (non-robust). + (default: 1) --tracking.track_window TRACKING.TRACK_WINDOW How many frames back to look for matches (default: 5) - --tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES - For optical-flow: Save the shifted instances between - elapsed frames for optimal comparison (default: 0) --tracking.min_new_track_points TRACKING.MIN_NEW_TRACK_POINTS - Minimum number of instance points for spawning new - track (default: 0) + Minimum number of instance points for spawning new track (default: 0) --tracking.min_match_points TRACKING.MIN_MATCH_POINTS Minimum points for match candidates (default: 0) --tracking.img_scale TRACKING.IMG_SCALE For optical-flow: Image scale (default: 1.0) --tracking.of_window_size TRACKING.OF_WINDOW_SIZE - For optical-flow: Optical flow window size to consider - at each pyramid (default: 21) + For optical-flow: Optical flow window size to consider at each pyramid (default: 21) --tracking.of_max_levels TRACKING.OF_MAX_LEVELS - For optical-flow: Number of pyramid scale levels to - consider (default: 3) + For optical-flow: Number of pyramid scale levels to consider (default: 3) + --tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES + If non-zero and tracking.tracker is set to flow, save the shifted instances between elapsed frames (default: 0) --tracking.kf_node_indices TRACKING.KF_NODE_INDICES - For Kalman filter: Indices of nodes to track. - (default: ) + For Kalman filter: Indices of nodes to track. (default: ) --tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT - For Kalman filter: Number of frames to track with - other tracker. 0 means no Kalman filters will be used. - (default: 0) + For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) +``` + +#### Examples: + +**1. Simple inference without tracking:** + +```none +sleap-track -m "models/my_model" -o "output_predictions.slp" "input_video.mp4" +``` + +**2. Inference with multi-model pipelines (e.g., top-down):** + +```none +sleap-track -m "models/centroid" -m "models/centered_instance" -o "output_predictions.slp" "input_video.mp4" +``` + +**3. Inference on suggested frames of a labeling project:** + +```none +sleap-track -m "models/my_model" --only-suggested-frames -o "labels_with_predictions.slp" "labels.v005.slp" +``` + +The resulting `labels_with_predictions.slp` can then merged into the base labels project from the SLEAP GUI via **File** --> **Merge into project...**. + +**4. Inference with simple tracking:** + +```none +sleap-track -m "models/my_model" --tracking.tracker simple -o "output_predictions.slp" "input_video.mp4" +``` + +**5. Inference with max tracks limit:** + +```none +sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" +``` + +**6. Re-tracking without pose inference:** + +```none +sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" +``` + +**7. Select GPU for pose inference:** + +```none +sleap-track --gpu 1 ... +``` + +**8. Select subset of frames to predict on:** + +```none +sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4" ``` ## Dataset files