From 84290cb0508041c8e44659333e07d468fe742bf9 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:14:08 -0700 Subject: [PATCH 1/6] Add imageio dependencies for pypi wheel (#1950) Add imagio dependencies for pypi wheel Co-authored-by: roomrys --- pypi_requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pypi_requirements.txt b/pypi_requirements.txt index 62c0c0ddc..ad34e2ad8 100644 --- a/pypi_requirements.txt +++ b/pypi_requirements.txt @@ -6,6 +6,8 @@ # These are also distributed through conda and not pip installed when using conda. attrs>=21.2.0,<=21.4.0 cattrs==1.1.1 +imageio +imageio-ffmpeg # certifi>=2017.4.17,<=2021.10.8 jsmin jsonpickle==1.2 From c090df3a33dc6426d5196055bf3c9d6180ba72e1 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:14:50 -0700 Subject: [PATCH 2/6] Do not always color skeletons table black (#1952) Co-authored-by: roomrys --- sleap/gui/dataviews.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 0a008bea7..f68dc0180 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -413,13 +413,6 @@ def set_item(self, item, key, value): elif key == "symmetry": self.context.setNodeSymmetry(skeleton=self.obj, node=item, symmetry=value) - def get_item_color(self, item: Any, key: str): - if self.skeleton: - color = self.context.app.color_manager.get_item_color( - item, parent_skeleton=self.skeleton - ) - return QtGui.QColor(*color) - class SkeletonEdgesTableModel(GenericTableModel): """Table model for skeleton edges.""" @@ -436,14 +429,6 @@ def object_to_items(self, skeleton: Skeleton): ] return items - def get_item_color(self, item: Any, key: str): - if self.skeleton: - edge_pair = (item["source"], item["destination"]) - color = self.context.app.color_manager.get_item_color( - edge_pair, parent_skeleton=self.skeleton - ) - return QtGui.QColor(*color) - class LabeledFrameTableModel(GenericTableModel): """Table model for listing instances in labeled frame. From e4bb4449ee4907f8315ef9f64511a7aaa0c79155 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:16:03 -0700 Subject: [PATCH 3/6] Remove no module named work error (#1956) * Do not always color skeletons table black * Remove offending (possibly unneeded) line that causes the no module named work error to print in terminal * Remove offending (possibly unneeded) line that causes the no module named work error to print in terminal * Remove accidentally added changes * Add (failing) test to ensure menu-item updates with state change * Reconnect callback for menu-item (using lambda) * Add (failing) test to ensure menu-item updates with state change Do not assume inital state * Reconnect callback for menu-item (using lambda) --------- Co-authored-by: roomrys --- sleap/gui/app.py | 4 +++- tests/gui/test_app.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 4c75dac3f..2dbceb3b7 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -377,7 +377,9 @@ def add_menu_item(menu, key: str, name: str, action: Callable): def connect_check(key): self._menu_actions[key].setCheckable(True) self._menu_actions[key].setChecked(self.state[key]) - self.state.connect(key, self._menu_actions[key].setChecked) + self.state.connect( + key, lambda checked: self._menu_actions[key].setChecked(checked) + ) # add checkable menu item connected to state variable def add_menu_check_item(menu, key: str, name: str): diff --git a/tests/gui/test_app.py b/tests/gui/test_app.py index 745989da1..def835b6e 100644 --- a/tests/gui/test_app.py +++ b/tests/gui/test_app.py @@ -414,6 +414,12 @@ def toggle_and_verify_visibility(expected_visibility: bool = True): window.showNormal() vp = window.player + # Change state and ensure menu-item check updates + color_predicted = window.state["color predicted"] + assert window._menu_actions["color predicted"].isChecked() == color_predicted + window.state["color predicted"] = not color_predicted + assert window._menu_actions["color predicted"].isChecked() == (not color_predicted) + # Enable distinct colors window.state["color predicted"] = True From 3c7f5afa8dd952fdee8ef3c8ac94b0689dac9cdb Mon Sep 17 00:00:00 2001 From: DivyaSesh <64513125+gitttt-1234@users.noreply.github.com> Date: Wed, 18 Sep 2024 14:42:45 -0700 Subject: [PATCH 4/6] Add `normalized_instance_similarity` method (#1939) * Add normalize function * Expose normalization function * Fix tests * Expose object keypoint sim function * Fix tests --- docs/guides/cli.md | 2 +- docs/guides/proofreading.md | 2 ++ sleap/config/pipeline_form.yaml | 4 ++-- sleap/nn/inference.py | 2 ++ sleap/nn/tracker/components.py | 16 ++++++++++++++++ sleap/nn/tracking.py | 13 ++++++++++++- tests/nn/test_inference.py | 12 ++++++++++-- tests/nn/test_tracker_components.py | 19 +++++++++++-------- tests/nn/test_tracking_integration.py | 4 +++- 9 files changed, 59 insertions(+), 15 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 03b806903..134461c60 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -207,7 +207,7 @@ optional arguments: --tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD IOU to use when culling instances *after* tracking. (default: 0) --tracking.similarity TRACKING.SIMILARITY - Options: instance, centroid, iou (default: instance) + Options: instance, normalized_instance, object_keypoint, centroid, iou (default: instance) --tracking.match TRACKING.MATCH Options: hungarian, greedy (default: greedy) --tracking.robust TRACKING.ROBUST diff --git a/docs/guides/proofreading.md b/docs/guides/proofreading.md index fea1c5ebc..941b85154 100644 --- a/docs/guides/proofreading.md +++ b/docs/guides/proofreading.md @@ -50,6 +50,8 @@ There are currently three methods for matching instances in frame N against thes - “**centroid**” measures similarity by the distance between the instance centroids - “**iou**” measures similarity by the intersection/overlap of the instance bounding boxes - “**instance**” measures similarity by looking at the distances between corresponding nodes in the instances, normalized by the number of valid nodes in the candidate instance. +- “**normalized_instance**” measures similarity by looking at the distances between corresponding nodes in the instances, normalized by the number of valid nodes in the candidate instance and the keypoints normalized by the image size. +- “**object_keypoint**” measures similarity by measuring the distance between each keypoints from a reference instance and a query instance, takes the exp(-d**2), sum for all the keypoints and divide by the number of visible keypoints in the reference instance. Once SLEAP has measured the similarity between all the candidates and the instances in frame N, you need to choose a way to pair them up. You can do this either by picking the best match, and the picking the best remaining match for each remaining instance in turn—this is “**greedy**” matching—or you can find the way of matching identities which minimizes the total cost (or: maximizes the total similarity)—this is “**Hungarian**” matching. diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index d130b9cb9..1bb930e58 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -439,7 +439,7 @@ inference: label: Similarity Method type: list default: instance - options: "instance,centroid,iou,object keypoint" + options: "instance,normalized_instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -538,7 +538,7 @@ inference: label: Similarity Method type: list default: instance - options: "instance,centroid,iou,object keypoint" + options: "instance,normalized_instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 421378d56..14e0d5c6f 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -2622,6 +2622,7 @@ def _object_builder(): # Set tracks for predicted instances in this frame. predicted_instances = self.tracker.track( untracked_instances=predicted_instances, + img_hw=ex["image"].shape[-3:-1], img=image, t=frame_ind, ) @@ -3264,6 +3265,7 @@ def _object_builder(): # Set tracks for predicted instances in this frame. predicted_instances = self.tracker.track( untracked_instances=predicted_instances, + img_hw=ex["image"].shape[-3:-1], img=image, t=frame_ind, ) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index b2f35b21f..0b77f4ac9 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,6 +12,7 @@ """ + import operator from collections import defaultdict import logging @@ -29,6 +30,21 @@ InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) +def normalized_instance_similarity( + ref_instance: InstanceType, query_instance: InstanceType, img_hw: Tuple[int] +) -> float: + """Computes similarity between instances with normalized keypoints.""" + + normalize_factors = np.array((img_hw[1], img_hw[0])) + ref_visible = ~(np.isnan(ref_instance.points_array).any(axis=1)) + normalized_query_keypoints = query_instance.points_array / normalize_factors + normalized_ref_keypoints = ref_instance.points_array / normalize_factors + dists = np.sum((normalized_query_keypoints - normalized_ref_keypoints) ** 2, axis=1) + similarity = np.nansum(np.exp(-dists)) / np.sum(ref_visible) + + return similarity + + def instance_similarity( ref_instance: InstanceType, query_instance: InstanceType ) -> float: diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 2b02839de..558aa9309 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -5,6 +5,7 @@ import attr import numpy as np import cv2 +import functools from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple from sleap import Track, LabeledFrame, Skeleton @@ -12,6 +13,7 @@ from sleap.nn.tracker.components import ( factory_object_keypoint_similarity, instance_similarity, + normalized_instance_similarity, centroid_distance, instance_iou, hungarian_matching, @@ -495,7 +497,8 @@ def get_candidates( instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, - object_keypoint=instance_similarity, + normalized_instance=normalized_instance_similarity, + object_keypoint=factory_object_keypoint_similarity, ) match_policies = dict( @@ -639,6 +642,7 @@ def uses_image(self): def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ) -> List[InstanceType]: @@ -646,12 +650,18 @@ def track( Args: untracked_instances: List of instances to assign to tracks. + img_hw: (height, width) of the image used to normalize the keypoints. img: Image data of the current frame for flow shifting. t: Current timestep. If not provided, increments from the internal queue. Returns: A list of the instances that were tracked. """ + if self.similarity_function == normalized_instance_similarity: + factory_normalized_instance = functools.partial( + normalized_instance_similarity, img_hw=img_hw + ) + self.similarity_function = factory_normalized_instance if self.candidate_maker is None: return untracked_instances @@ -1520,6 +1530,7 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele track_args["img"] = lf.video[lf.frame_idx] else: track_args["img"] = None + track_args["img_hw"] = lf.image.shape[-3:-1] new_lf = LabeledFrame( frame_idx=lf.frame_idx, diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fd615ea81..0a978de0a 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1932,7 +1932,11 @@ def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + track_args = dict( + untracked_instances=lf.instances, + img=lf.video[lf.frame_idx], + img_hw=lf.image.shape[-3:-1], + ) tracker.track(**track_args) # Check that saved instances are pruned to track window @@ -1975,7 +1979,11 @@ def test_max_tracks_matching_queue( for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + track_args = dict( + untracked_instances=lf.instances, + img=lf.video[lf.frame_idx], + img_hw=lf.image.shape[-3:-1], + ) tracker.track(**track_args) if trackername == "flowmaxtracks": diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 5786945fb..0c7ba2b0a 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -30,14 +30,17 @@ def tracker_by_name(frames=None, **kwargs): inst.track = None track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) - t.track(**track_args) + t.track(**track_args, img_hw=(1, 1)) t.final_pass(frames) @pytest.mark.parametrize( "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] ) -@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) +@pytest.mark.parametrize( + "similarity", + ["instance", "normalized_instance", "iou", "centroid", "object_keypoint"], +) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name( @@ -288,7 +291,7 @@ def test_max_tracking_large_gap_single_track(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -305,7 +308,7 @@ def test_max_tracking_large_gap_single_track(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -352,7 +355,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -369,7 +372,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -421,7 +424,7 @@ def test_max_tracking_extra_detections(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -438,7 +441,7 @@ def test_max_tracking_extra_detections(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index a6592dc4d..625302fd0 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -102,7 +102,7 @@ def run_tracker(frames, tracker): new_lf = LabeledFrame( frame_idx=lf.frame_idx, video=lf.video, - instances=tracker.track(**track_args), + instances=tracker.track(**track_args, img_hw=lf.image.shape[-3:-1]), ) new_lfs.append(new_lf) @@ -138,6 +138,8 @@ def main(f, dir): instance=sleap.nn.tracker.components.instance_similarity, centroid=sleap.nn.tracker.components.centroid_distance, iou=sleap.nn.tracker.components.instance_iou, + normalized_instance=sleap.nn.tracker.components.normalized_instance_similarity, + object_keypoint=sleap.nn.tracker.components.factory_object_keypoint_similarity(), ) scales = ( 1, From ab93b9ed14765d8fe03a8bfddad0769cb842323a Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:08:03 -0700 Subject: [PATCH 5/6] Handle skeleton decoding internally (#1961) * Reorganize (and add) imports * Add (and reorganize) imports * Modify decode_preview_image to return bytes if specified * Implement (minimally tested) replace_jsonpickle_decode * Add support for using idx_to_node map i.e. loading from Labels (slp file) * Ignore None items in reduce_list * Convert large function to SkeletonDecoder class * Update SkeletonDecoder.decode docstring * Move decode_preview_image to SkeletonDecoder * Use SkeletonDecoder instead of jsonpickle in tests * Remove unused imports * Add test for decoding dict vs tuple pystates --- sleap/gui/widgets/docks.py | 8 +- sleap/skeleton.py | 311 +++++++++++++++++- sleap/util.py | 18 - .../fly_skeleton_legs_pystate_dict.json | 1 + tests/fixtures/skeletons.py | 15 +- tests/test_skeleton.py | 31 +- tests/test_util.py | 8 - 7 files changed, 344 insertions(+), 48 deletions(-) create mode 100644 tests/data/skeleton/fly_skeleton_legs_pystate_dict.json diff --git a/sleap/gui/widgets/docks.py b/sleap/gui/widgets/docks.py index 3375e4713..bd20bf79a 100644 --- a/sleap/gui/widgets/docks.py +++ b/sleap/gui/widgets/docks.py @@ -30,10 +30,8 @@ ) from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.widgets.views import CollapsibleWidget -from sleap.skeleton import Skeleton -from sleap.util import decode_preview_image, find_files_by_suffix, get_package_file - -# from sleap.gui.app import MainWindow +from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.util import find_files_by_suffix, get_package_file class DockWidget(QDockWidget): @@ -365,7 +363,7 @@ def create_templates_groupbox(self) -> QGroupBox: def updatePreviewImage(preview_image_bytes: bytes): # Decode the preview image - preview_image = decode_preview_image(preview_image_bytes) + preview_image = SkeletonDecoder.decode_preview_image(preview_image_bytes) # Create a QImage from the Image preview_image = QtGui.QImage( diff --git a/sleap/skeleton.py b/sleap/skeleton.py index eca393b8e..ed083fc0e 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -6,24 +6,25 @@ their connection to each other, and needed meta-data. """ -import attr -import cattr -import numpy as np -import jsonpickle -import json -import h5py +import base64 import copy - +import json import operator from enum import Enum +from io import BytesIO from itertools import count -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text +from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union +import attr +import cattr +import h5py +import jsonpickle import networkx as nx +import numpy as np from networkx.readwrite import json_graph +from PIL import Image from scipy.io import loadmat - NodeRef = Union[str, "Node"] H5FileRef = Union[str, h5py.File] @@ -85,6 +86,296 @@ def matches(self, other: "Node") -> bool: return other.name == self.name and other.weight == self.weight +class SkeletonDecoder: + """Replace jsonpickle.decode with our own decoder. + + This function will decode the following from jsonpickle's encoded format: + + `Node` objects from + { + "py/object": "sleap.skeleton.Node", + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + to `Node(name="thorax1", weight=1.0)` + + `EdgeType` objects from + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + to `EdgeType(1)` + + `bytes` from + { + "py/b64": "aVZC..." + } + to `b"iVBO..."` + + and any repeated objects from + { + "py/id": 1 + } + to the object with the same reconstruction id (from top to bottom). + """ + + def __init__(self): + self.decoded_objects: List[Union[Node, EdgeType]] = [] + + def _decode_id(self, id: int) -> Union[Node, EdgeType]: + """Decode the object with the given `py/id` value of `id`. + + Args: + id: The `py/id` value to decode (1-indexed). + objects: The dictionary of objects that have already been decoded. + + Returns: + The object with the given `py/id` value. + """ + return self.decoded_objects[id - 1] + + @staticmethod + def _decode_state(state: dict) -> Node: + """Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph. + + We support states in either dictionary or tuple format: + { + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + or + { + "py/state": {"name": "thorax1", "weight": 1.0} + } + + Args: + state: The state to decode, i.e. state = dict["py/state"] + + Returns: + The `Node` object reconstructed from the state. + """ + + if "py/tuple" in state: + return Node(*state["py/tuple"]) + + return Node(**state) + + @staticmethod + def _decode_object_dict(object_dict) -> Node: + """Decode dict containing `py/object` key in the serialized nx_graph. + + Args: + object_dict: The dict to decode, i.e. + object_dict = {"py/object": ..., "py/state":...} + + Raises: + ValueError: If object_dict does not have 'py/object' and 'py/state' keys. + ValueError: If object_dict['py/object'] is not 'sleap.skeleton.Node'. + + Returns: + The decoded `Node` object. + """ + + if object_dict["py/object"] != "sleap.skeleton.Node": + raise ValueError("Only 'sleap.skeleton.Node' objects are supported.") + + node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"]) + return node + + def _decode_node(self, encoded_node: dict) -> Node: + """Decode an item believed to be an encoded `Node` object. + + Also updates the list of decoded objects. + + Args: + encoded_node: The encoded node to decode. + + Returns: + The decoded node and the updated list of decoded objects. + """ + + if isinstance(encoded_node, int): + # Using index mapping to replace the object (load from Labels) + return encoded_node + elif "py/object" in encoded_node: + decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node) + self.decoded_objects.append(decoded_node) + elif "py/id" in encoded_node: + decoded_node: Node = self._decode_id(encoded_node["py/id"]) + + return decoded_node + + def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]: + """Decode the 'nodes' key in the serialized nx_graph. + + The encoded_nodes is a list of dictionary of two types: + - A dictionary with 'py/object' and 'py/state' keys. + - A dictionary with 'py/id' key. + + Args: + encoded_nodes: The list of encoded nodes to decode. + + Returns: + The decoded nodes. + """ + + decoded_nodes: List[Dict[str, Node]] = [] + for e_node_dict in encoded_nodes: + e_node = e_node_dict["id"] + d_node = self._decode_node(e_node) + decoded_nodes.append({"id": d_node}) + + return decoded_nodes + + def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType: + """Decode the 'reduce' key in the serialized nx_graph. + + The reduce_dict is a dictionary in the following format: + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + + Args: + reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...} + + Returns: + The decoded `EdgeType` object. + """ + + reduce_list = reduce_dict["py/reduce"] + has_py_type = has_py_tuple = False + for reduce_item in reduce_list: + if reduce_item is None: + # Sometimes the reduce list has None values, skip them + continue + if ( + "py/type" in reduce_item + and reduce_item["py/type"] == "sleap.skeleton.EdgeType" + ): + has_py_type = True + elif "py/tuple" in reduce_item: + edge_type: int = reduce_item["py/tuple"][0] + has_py_tuple = True + + if not has_py_type or not has_py_tuple: + raise ValueError( + "Only 'sleap.skeleton.EdgeType' objects are supported. " + "The 'py/reduce' list must have dictionaries with 'py/type' and " + "'py/tuple' keys." + f"\n\tHas py/type: {has_py_type}\n\tHas py/tuple: {has_py_tuple}" + ) + + edge = EdgeType(edge_type) + self.decoded_objects.append(edge) + + return edge + + def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType: + """Decode the 'type' key in the serialized nx_graph. + + Args: + encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key. + + Returns: + The decoded `EdgeType` object. + """ + + if "py/reduce" in encoded_edge_type: + edge_type = self._decode_reduce_dict(encoded_edge_type) + else: + # Expect a "py/id" instead of "py/reduce" + edge_type = self._decode_id(encoded_edge_type["py/id"]) + return edge_type + + def _decode_links( + self, links: List[dict] + ) -> List[Dict[str, Union[int, Node, EdgeType]]]: + """Decode the 'links' key in the serialized nx_graph. + + The links are the edges in the graph and will have the following keys: + - source: The source node of the edge. + - target: The destination node of the edge. + - type: The type of the edge (e.g. BODY, SYMMETRY). + and more. + + Args: + encoded_links: The list of encoded links to decode. + """ + + for link in links: + for key, value in link.items(): + if key == "source": + link[key] = self._decode_node(value) + elif key == "target": + link[key] = self._decode_node(value) + elif key == "type": + link[key] = self._decode_edge_type(value) + + return links + + @staticmethod + def decode_preview_image( + img_b64: bytes, return_bytes: bool = False + ) -> Union[Image.Image, bytes]: + """Decode a skeleton preview image byte string representation to a `PIL.Image` + + Args: + img_b64: a byte string representation of a skeleton preview image + return_bytes: whether to return the decoded image as bytes + + Returns: + Either a PIL.Image of the skeleton preview image or the decoded image as bytes + (if `return_bytes` is True). + """ + bytes = base64.b64decode(img_b64) + if return_bytes: + return bytes + + buffer = BytesIO(bytes) + img = Image.open(buffer) + return img + + def _decode(self, json_str: str): + dicts = json.loads(json_str) + + # Enforce same format across template and non-template skeletons + if "nx_graph" not in dicts: + # Non-template skeletons use the dicts as the "nx_graph" + dicts = {"nx_graph": dicts} + + # Decode the graph + nx_graph = dicts["nx_graph"] + + self.decoded_objects = [] # Reset the decoded objects incase reusing decoder + for key, value in nx_graph.items(): + if key == "nodes": + nx_graph[key] = self._decode_nodes(value) + elif key == "links": + nx_graph[key] = self._decode_links(value) + + # Decode the preview image (if it exists) + preview_image = dicts.get("preview_image", None) + if preview_image is not None: + dicts["preview_image"] = SkeletonDecoder.decode_preview_image( + preview_image["py/b64"], return_bytes=True + ) + + return dicts + + @classmethod + def decode(cls, json_str: str) -> Dict: + """Decode the given json string into a dictionary. + + Returns: + A dict with `Node`s, `EdgeType`s, and `bytes` decoded/reconstructed. + """ + decoder = cls() + return decoder._decode(json_str) + + class Skeleton: """The main object for representing animal skeletons. @@ -1071,7 +1362,7 @@ def from_json( Returns: An instance of the `Skeleton` object decoded from the JSON. """ - dicts = jsonpickle.decode(json_str) + dicts: dict = SkeletonDecoder.decode(json_str) nx_graph = dicts.get("nx_graph", dicts) graph = json_graph.node_link_graph(nx_graph) diff --git a/sleap/util.py b/sleap/util.py index eef762ff4..bc3389b7d 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -3,13 +3,11 @@ Try not to put things in here unless they really have no other place. """ -import base64 import json import os import re import shutil from collections import defaultdict -from io import BytesIO from pathlib import Path from typing import Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse @@ -26,7 +24,6 @@ from importlib.resources import files # New in 3.9+ except ImportError: from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. -from PIL import Image import sleap.version as sleap_version @@ -374,18 +371,3 @@ def find_files_by_suffix( def parse_uri_path(uri: str) -> str: """Parse a URI starting with 'file:///' to a posix path.""" return Path(url2pathname(urlparse(unquote(uri)).path)).as_posix() - - -def decode_preview_image(img_b64: bytes) -> Image: - """Decode a skeleton preview image byte string representation to a `PIL.Image` - - Args: - img_b64: a byte string representation of a skeleton preview image - - Returns: - A PIL.Image of the skeleton preview - """ - bytes = base64.b64decode(img_b64) - buffer = BytesIO(bytes) - img = Image.open(buffer) - return img diff --git a/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json new file mode 100644 index 000000000..eae83d6bc --- /dev/null +++ b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json @@ -0,0 +1 @@ +{"directed": true, "graph": {"name": "skeleton_legs.mat", "num_edges_inserted": 23}, "links": [{"edge_insert_idx": 1, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "neck", "weight": 1.0}}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "head", "weight": 1.0}}, "type": {"py/reduce": [{"py/type": "sleap.skeleton.EdgeType"}, {"py/tuple": [1]}]}}, {"edge_insert_idx": 0, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "thorax", "weight": 1.0}}, "target": {"py/id": 1}, "type": {"py/id": 3}}, {"edge_insert_idx": 2, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "abdomen", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 3, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingL", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 4, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingR", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 5, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 8, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 11, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 14, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 17, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 20, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 6, "key": 0, "source": {"py/id": 8}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 7, "key": 0, "source": {"py/id": 14}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 9, "key": 0, "source": {"py/id": 9}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 10, "key": 0, "source": {"py/id": 16}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 12, "key": 0, "source": {"py/id": 10}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 13, "key": 0, "source": {"py/id": 18}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 15, "key": 0, "source": {"py/id": 11}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 16, "key": 0, "source": {"py/id": 20}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 18, "key": 0, "source": {"py/id": 12}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 19, "key": 0, "source": {"py/id": 22}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 21, "key": 0, "source": {"py/id": 13}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 22, "key": 0, "source": {"py/id": 24}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR3", "weight": 1.0}}, "type": {"py/id": 3}}], "multigraph": true, "nodes": [{"id": {"py/id": 2}}, {"id": {"py/id": 1}}, {"id": {"py/id": 4}}, {"id": {"py/id": 5}}, {"id": {"py/id": 6}}, {"id": {"py/id": 7}}, {"id": {"py/id": 8}}, {"id": {"py/id": 14}}, {"id": {"py/id": 15}}, {"id": {"py/id": 9}}, {"id": {"py/id": 16}}, {"id": {"py/id": 17}}, {"id": {"py/id": 10}}, {"id": {"py/id": 18}}, {"id": {"py/id": 19}}, {"id": {"py/id": 11}}, {"id": {"py/id": 20}}, {"id": {"py/id": 21}}, {"id": {"py/id": 12}}, {"id": {"py/id": 22}}, {"id": {"py/id": 23}}, {"id": {"py/id": 13}}, {"id": {"py/id": 24}}, {"id": {"py/id": 25}}]} \ No newline at end of file diff --git a/tests/fixtures/skeletons.py b/tests/fixtures/skeletons.py index 311510e6a..b432ca2c7 100644 --- a/tests/fixtures/skeletons.py +++ b/tests/fixtures/skeletons.py @@ -3,14 +3,27 @@ from sleap.skeleton import Skeleton TEST_FLY_LEGS_SKELETON = "tests/data/skeleton/fly_skeleton_legs.json" +TEST_FLY_LEGS_SKELETON_DICT = "tests/data/skeleton/fly_skeleton_legs_pystate_dict.json" @pytest.fixture def fly_legs_skeleton_json(): - """Path to fly_skeleton_legs.json""" + """Path to fly_skeleton_legs.json + + This skeleton json has py/state in tuple format. + """ return TEST_FLY_LEGS_SKELETON +@pytest.fixture +def fly_legs_skeleton_dict_json(): + """Path to fly_skeleton_legs_pystate_dict.json + + This skeleton json has py/state dict format. + """ + return TEST_FLY_LEGS_SKELETON_DICT + + @pytest.fixture def stickman(): diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 1f7c3a853..2fc32628a 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,10 +1,9 @@ -import os import copy +import os -import jsonpickle import pytest -from sleap.skeleton import Skeleton +from sleap.skeleton import Skeleton, SkeletonDecoder def test_add_dupe_node(skeleton): @@ -194,9 +193,9 @@ def test_json(skeleton: Skeleton, tmpdir): ) assert skeleton.is_template == False json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) - assert "nx_graph" not in json_dict_keys + assert "nx_graph" in json_dict_keys # SkeletonDecoder adds this key assert "preview_image" not in json_dict_keys assert "description" not in json_dict_keys @@ -208,7 +207,7 @@ def test_json(skeleton: Skeleton, tmpdir): skeleton._is_template = True json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) assert "nx_graph" in json_dict_keys assert "preview_image" in json_dict_keys @@ -224,6 +223,26 @@ def test_json(skeleton: Skeleton, tmpdir): assert skeleton.matches(skeleton_copy) +def test_decode_preview_image(flies13_skeleton: Skeleton): + skeleton = flies13_skeleton + img_b64 = skeleton.preview_image + img = SkeletonDecoder.decode_preview_image(img_b64) + assert img.mode == "RGBA" + + +def test_skeleton_decoder(fly_legs_skeleton_json, fly_legs_skeleton_dict_json): + """Test that SkeletonDecoder can decode both tuple and dict py/state formats.""" + + skeleton_tuple_pystate = Skeleton.load_json(fly_legs_skeleton_json) + assert isinstance(skeleton_tuple_pystate, Skeleton) + + skeleton_dict_pystate = Skeleton.load_json(fly_legs_skeleton_dict_json) + assert isinstance(skeleton_dict_pystate, Skeleton) + + # These are the same skeleton, so they should match + assert skeleton_dict_pystate.matches(skeleton_tuple_pystate) + + def test_hdf5(skeleton, stickman, tmpdir): filename = os.path.join(tmpdir, "skeleton.h5") diff --git a/tests/test_util.py b/tests/test_util.py index a7916d47f..35b41afa8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,4 @@ import pytest -from sleap.skeleton import Skeleton from sleap.util import * @@ -147,10 +146,3 @@ def test_save_dict_to_hdf5(tmpdir): assert f["bar"][-1].decode() == "zop" assert f["cab"]["a"][()] == 2 - - -def test_decode_preview_image(flies13_skeleton: Skeleton): - skeleton = flies13_skeleton - img_b64 = skeleton.preview_image - img = decode_preview_image(img_b64) - assert img.mode == "RGBA" From ef803f65f4c1e50e42a3fe5d0d9f2fde99db508d Mon Sep 17 00:00:00 2001 From: Elizabeth <106755962+eberrigan@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:48:13 -0700 Subject: [PATCH 6/6] Handle skeleton encoding internally (#1970) * start class `SkeletonEncoder` * _encoded_objects need to be a dict to add to * add notebook for testing * format * fix type in docstring * finish classmethod for encoding Skeleton as a json string * test encoded Skeleton as json string by decoding it * add test for decoded encoded skeleton * update jupyter notebook for easy testing * constraining attrs in dev environment to make sure decode format is always the same locally * encode links first then encode source then target then type * save first enconding statically as an input to _get_or_assign_id so that we do not always get py/id * save first encoding statically * first encoding is passed to _get_or_assign_id * use first_encoding variable to determine if we should assign a py/id * add print statements for debugging * update notebook for easy testing * black * remove comment * adding attrs constraint to show this passes for certain attrs version only * add import * switch out jsonpickle.encode * oops remove import * can attrs be unconstrained? * forgot comma * pin attrs for testing * test Skeleton from json, template, with symmetries, and template * use SkeletonEncoder.encode * black * try removing None values in EdgeType reduced * Handle case when nodes are replaced by integer indices from caller * Remove prototyping notebook * Remove attrs pins * Remove sort keys (which flips the neccessary ordering of our py/ids) * Do not add extra indents to encoded file * Only append links after fully encoded (fat-finger) * Remove outdated comment * Lint --------- Co-authored-by: Talmo Pereira Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- environment.yml | 2 +- environment_no_cuda.yml | 2 +- sleap/skeleton.py | 197 +++++++++++++++++++++++++++++++++++++++- tests/test_skeleton.py | 55 ++++++++++- 4 files changed, 248 insertions(+), 8 deletions(-) diff --git a/environment.yml b/environment.yml index 2aba3c7d2..06a0633d2 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin diff --git a/environment_no_cuda.yml b/environment_no_cuda.yml index 2adee7a89..ba2b54a22 100644 --- a/environment_no_cuda.yml +++ b/environment_no_cuda.yml @@ -11,7 +11,7 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin diff --git a/sleap/skeleton.py b/sleap/skeleton.py index ed083fc0e..f6477cf66 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -376,6 +376,193 @@ def decode(cls, json_str: str) -> Dict: return decoder._decode(json_str) +class SkeletonEncoder: + """Replace jsonpickle.encode with our own encoder. + + The input is a dictionary containing python objects that need to be encoded as + JSON strings. The output is a JSON string that represents the input dictionary. + + `Node(name='neck', weight=1.0)` => + { + "py/object": "sleap.Skeleton.Node", + "py/state": {"py/tuple" ["neck", 1.0]} + } + + `` => + {"py/reduce": [ + {"py/type": "sleap.Skeleton.EdgeType"}, + {"py/tuple": [1] } + ] + }` + + Where `name` and `weight` are the attributes of the `Node` class; weight is always 1.0. + `EdgeType` is an enum with values `BODY = 1` and `SYMMETRY = 2`. + + See sleap.skeleton.Node and sleap.skeleton.EdgeType. + + If the object has been "seen" before, it will not be encoded as the full JSON string + but referenced by its `py/id`, which starts at 1 and indexes the objects in the + order they are seen so that the second time the first object is used, it will be + referenced as `{"py/id": 1}`. + """ + + def __init__(self): + """Initializes a SkeletonEncoder instance.""" + # Maps object id to py/id + self._encoded_objects: Dict[int, int] = {} + + @classmethod + def encode(cls, data: Dict[str, Any]) -> str: + """Encodes the input dictionary as a JSON string. + + Args: + data: The data to encode. + + Returns: + json_str: The JSON string representation of the data. + """ + encoder = cls() + encoded_data = encoder._encode(data) + json_str = json.dumps(encoded_data) + return json_str + + def _encode(self, obj: Any) -> Any: + """Recursively encodes the input object. + + Args: + obj: The object to encode. Can be a dictionary, list, Node, EdgeType or + primitive data type. + + Returns: + The encoded object as a dictionary. + """ + if isinstance(obj, dict): + encoded_obj = {} + for key, value in obj.items(): + if key == "links": + encoded_obj[key] = self._encode_links(value) + else: + encoded_obj[key] = self._encode(value) + return encoded_obj + elif isinstance(obj, list): + return [self._encode(v) for v in obj] + elif isinstance(obj, EdgeType): + return self._encode_edge_type(obj) + elif isinstance(obj, Node): + return self._encode_node(obj) + else: + return obj # Primitive data types + + def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Encodes the list of links (edges) in the skeleton graph. + + Args: + links: A list of dictionaries, each representing an edge in the graph. + + Returns: + A list of encoded edge dictionaries with keys ordered as specified. + """ + encoded_links = [] + for link in links: + # Use a regular dict (insertion order preserved in Python 3.7+) + encoded_link = {} + + for key, value in link.items(): + if key in ("source", "target"): + encoded_link[key] = self._encode_node(value) + elif key == "type": + encoded_link[key] = self._encode_edge_type(value) + else: + encoded_link[key] = self._encode(value) + encoded_links.append(encoded_link) + + return encoded_links + + def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]: + """Encodes a Node object. + + Args: + node: The Node object to encode or integer index. The latter requires that + the class has the `idx_to_node` attribute set. + + Returns: + The encoded `Node` object as a dictionary. + """ + if isinstance(node, int): + # We sometimes have the node object already replaced by its index (when + # `node_to_idx` is provided). In this case, the node is already encoded. + return node + + # Check if object has been encoded before + first_encoding = self._is_first_encoding(node) + py_id = self._get_or_assign_id(node, first_encoding) + if first_encoding: + # Full encoding + return { + "py/object": "sleap.skeleton.Node", + "py/state": {"py/tuple": [node.name, node.weight]}, + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]: + """Encodes an EdgeType object. + + Args: + edge_type: The EdgeType object to encode. Either `EdgeType.BODY` or + `EdgeType.SYMMETRY` enum with values 1 and 2 respectively. + + Returns: + The encoded EdgeType object as a dictionary. + """ + # Check if object has been encoded before + first_encoding = self._is_first_encoding(edge_type) + py_id = self._get_or_assign_id(edge_type, first_encoding) + if first_encoding: + # Full encoding + return { + "py/reduce": [ + {"py/type": "sleap.skeleton.EdgeType"}, + {"py/tuple": [edge_type.value]}, + ] + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int: + """Gets or assigns a py/id for the object. + + Args: + The object to get or assign a py/id for. + + Returns: + The py/id assigned to the object. + """ + # Object id is unique for each object in the current session + obj_id = id(obj) + # Assign a py/id to the object if it hasn't been assigned one yet + if first_encoding: + py_id = len(self._encoded_objects) + 1 # py/id starts at 1 + # Assign the py/id to the object and store it in _encoded_objects + self._encoded_objects[obj_id] = py_id + return self._encoded_objects[obj_id] + + def _is_first_encoding(self, obj: Any) -> bool: + """Checks if the object is being encoded for the first time. + + Args: + obj: The object to check. + + Returns: + True if this is the first encoding of the object, False otherwise. + """ + obj_id = id(obj) + first_time = obj_id not in self._encoded_objects + return first_time + + class Skeleton: """The main object for representing animal skeletons. @@ -1228,7 +1415,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. - return json.loads(obj.to_json(node_to_idx)) + return json.loads(obj.to_json(node_to_idx=node_to_idx)) @classmethod def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": @@ -1292,10 +1479,10 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: - indexed_node_graph = nx.relabel_nodes( - G=self._graph, mapping=node_to_idx - ) # map nodes to int + # Map Nodes to int + indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx) else: + # Keep graph nodes as Node objects indexed_node_graph = self._graph # Encode to JSON @@ -1314,7 +1501,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: else: data = graph - json_str = jsonpickle.encode(data) + json_str = SkeletonEncoder.encode(data) return json_str diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 2fc32628a..7c5216316 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,9 +1,62 @@ import copy import os - import pytest +import json +from networkx.readwrite import json_graph from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.skeleton import SkeletonEncoder + + +def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Get the skeleton from the fixture + skeleton = Skeleton.load_json(fly_legs_skeleton_json) + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + +@pytest.mark.parametrize( + "skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"] +) +def test_decoded_encoded_Skeleton(skeleton_fixture_name, request): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Use request.getfixturevalue to get the actual fixture value by name + skeleton = request.getfixturevalue(skeleton_fixture_name) + + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + # Now make everything into a JSON string + skeleton_json_str = skeleton.to_json() + decoded_skeleton_json_str = decoded_skeleton.to_json() + + # Check that the JSON strings are the same + assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str) def test_add_dupe_node(skeleton):