diff --git a/environment.yml b/environment.yml index 2aba3c7d2..06a0633d2 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin diff --git a/environment_no_cuda.yml b/environment_no_cuda.yml index 2adee7a89..ba2b54a22 100644 --- a/environment_no_cuda.yml +++ b/environment_no_cuda.yml @@ -11,7 +11,7 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin diff --git a/sleap/skeleton.py b/sleap/skeleton.py index ed083fc0e..f6477cf66 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -376,6 +376,193 @@ def decode(cls, json_str: str) -> Dict: return decoder._decode(json_str) +class SkeletonEncoder: + """Replace jsonpickle.encode with our own encoder. + + The input is a dictionary containing python objects that need to be encoded as + JSON strings. The output is a JSON string that represents the input dictionary. + + `Node(name='neck', weight=1.0)` => + { + "py/object": "sleap.Skeleton.Node", + "py/state": {"py/tuple" ["neck", 1.0]} + } + + `` => + {"py/reduce": [ + {"py/type": "sleap.Skeleton.EdgeType"}, + {"py/tuple": [1] } + ] + }` + + Where `name` and `weight` are the attributes of the `Node` class; weight is always 1.0. + `EdgeType` is an enum with values `BODY = 1` and `SYMMETRY = 2`. + + See sleap.skeleton.Node and sleap.skeleton.EdgeType. + + If the object has been "seen" before, it will not be encoded as the full JSON string + but referenced by its `py/id`, which starts at 1 and indexes the objects in the + order they are seen so that the second time the first object is used, it will be + referenced as `{"py/id": 1}`. + """ + + def __init__(self): + """Initializes a SkeletonEncoder instance.""" + # Maps object id to py/id + self._encoded_objects: Dict[int, int] = {} + + @classmethod + def encode(cls, data: Dict[str, Any]) -> str: + """Encodes the input dictionary as a JSON string. + + Args: + data: The data to encode. + + Returns: + json_str: The JSON string representation of the data. + """ + encoder = cls() + encoded_data = encoder._encode(data) + json_str = json.dumps(encoded_data) + return json_str + + def _encode(self, obj: Any) -> Any: + """Recursively encodes the input object. + + Args: + obj: The object to encode. Can be a dictionary, list, Node, EdgeType or + primitive data type. + + Returns: + The encoded object as a dictionary. + """ + if isinstance(obj, dict): + encoded_obj = {} + for key, value in obj.items(): + if key == "links": + encoded_obj[key] = self._encode_links(value) + else: + encoded_obj[key] = self._encode(value) + return encoded_obj + elif isinstance(obj, list): + return [self._encode(v) for v in obj] + elif isinstance(obj, EdgeType): + return self._encode_edge_type(obj) + elif isinstance(obj, Node): + return self._encode_node(obj) + else: + return obj # Primitive data types + + def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Encodes the list of links (edges) in the skeleton graph. + + Args: + links: A list of dictionaries, each representing an edge in the graph. + + Returns: + A list of encoded edge dictionaries with keys ordered as specified. + """ + encoded_links = [] + for link in links: + # Use a regular dict (insertion order preserved in Python 3.7+) + encoded_link = {} + + for key, value in link.items(): + if key in ("source", "target"): + encoded_link[key] = self._encode_node(value) + elif key == "type": + encoded_link[key] = self._encode_edge_type(value) + else: + encoded_link[key] = self._encode(value) + encoded_links.append(encoded_link) + + return encoded_links + + def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]: + """Encodes a Node object. + + Args: + node: The Node object to encode or integer index. The latter requires that + the class has the `idx_to_node` attribute set. + + Returns: + The encoded `Node` object as a dictionary. + """ + if isinstance(node, int): + # We sometimes have the node object already replaced by its index (when + # `node_to_idx` is provided). In this case, the node is already encoded. + return node + + # Check if object has been encoded before + first_encoding = self._is_first_encoding(node) + py_id = self._get_or_assign_id(node, first_encoding) + if first_encoding: + # Full encoding + return { + "py/object": "sleap.skeleton.Node", + "py/state": {"py/tuple": [node.name, node.weight]}, + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]: + """Encodes an EdgeType object. + + Args: + edge_type: The EdgeType object to encode. Either `EdgeType.BODY` or + `EdgeType.SYMMETRY` enum with values 1 and 2 respectively. + + Returns: + The encoded EdgeType object as a dictionary. + """ + # Check if object has been encoded before + first_encoding = self._is_first_encoding(edge_type) + py_id = self._get_or_assign_id(edge_type, first_encoding) + if first_encoding: + # Full encoding + return { + "py/reduce": [ + {"py/type": "sleap.skeleton.EdgeType"}, + {"py/tuple": [edge_type.value]}, + ] + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int: + """Gets or assigns a py/id for the object. + + Args: + The object to get or assign a py/id for. + + Returns: + The py/id assigned to the object. + """ + # Object id is unique for each object in the current session + obj_id = id(obj) + # Assign a py/id to the object if it hasn't been assigned one yet + if first_encoding: + py_id = len(self._encoded_objects) + 1 # py/id starts at 1 + # Assign the py/id to the object and store it in _encoded_objects + self._encoded_objects[obj_id] = py_id + return self._encoded_objects[obj_id] + + def _is_first_encoding(self, obj: Any) -> bool: + """Checks if the object is being encoded for the first time. + + Args: + obj: The object to check. + + Returns: + True if this is the first encoding of the object, False otherwise. + """ + obj_id = id(obj) + first_time = obj_id not in self._encoded_objects + return first_time + + class Skeleton: """The main object for representing animal skeletons. @@ -1228,7 +1415,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. - return json.loads(obj.to_json(node_to_idx)) + return json.loads(obj.to_json(node_to_idx=node_to_idx)) @classmethod def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": @@ -1292,10 +1479,10 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: - indexed_node_graph = nx.relabel_nodes( - G=self._graph, mapping=node_to_idx - ) # map nodes to int + # Map Nodes to int + indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx) else: + # Keep graph nodes as Node objects indexed_node_graph = self._graph # Encode to JSON @@ -1314,7 +1501,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: else: data = graph - json_str = jsonpickle.encode(data) + json_str = SkeletonEncoder.encode(data) return json_str diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 2fc32628a..7c5216316 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,9 +1,62 @@ import copy import os - import pytest +import json +from networkx.readwrite import json_graph from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.skeleton import SkeletonEncoder + + +def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Get the skeleton from the fixture + skeleton = Skeleton.load_json(fly_legs_skeleton_json) + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + +@pytest.mark.parametrize( + "skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"] +) +def test_decoded_encoded_Skeleton(skeleton_fixture_name, request): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Use request.getfixturevalue to get the actual fixture value by name + skeleton = request.getfixturevalue(skeleton_fixture_name) + + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + # Now make everything into a JSON string + skeleton_json_str = skeleton.to_json() + decoded_skeleton_json_str = decoded_skeleton.to_json() + + # Check that the JSON strings are the same + assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str) def test_add_dupe_node(skeleton):