Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle skeleton encoding internally #1970

Merged
merged 40 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
034cffb
start class `SkeletonEncoder`
eberrigan Sep 20, 2024
dd4865d
start class `SkeletonEncoder`
eberrigan Sep 21, 2024
f66a12d
_encoded_objects need to be a dict to add to
eberrigan Sep 21, 2024
7c48903
add notebook for testing
eberrigan Sep 23, 2024
4a75924
format
eberrigan Sep 23, 2024
57ec9d3
fix type in docstring
eberrigan Sep 23, 2024
58b2398
finish classmethod for encoding Skeleton as a json string
eberrigan Sep 23, 2024
8aedebd
test encoded Skeleton as json string by decoding it
eberrigan Sep 23, 2024
b9973c2
add test for decoded encoded skeleton
eberrigan Sep 24, 2024
4d77a68
update jupyter notebook for easy testing
eberrigan Sep 24, 2024
6c96a0e
constraining attrs in dev environment to make sure decode format is a…
eberrigan Sep 24, 2024
c78e9a9
encode links first then encode source then target then type
eberrigan Sep 24, 2024
5935b6a
save first enconding statically as an input to _get_or_assign_id so t…
eberrigan Sep 24, 2024
cab40c2
save first encoding statically
eberrigan Sep 24, 2024
2068fdd
first encoding is passed to _get_or_assign_id
eberrigan Sep 24, 2024
61c7cc0
use first_encoding variable to determine if we should assign a py/id
eberrigan Sep 24, 2024
c4f5be9
add print statements for debugging
eberrigan Sep 24, 2024
839d67c
update notebook for easy testing
eberrigan Sep 24, 2024
5bcea83
black
eberrigan Sep 24, 2024
167ef77
remove comment
eberrigan Sep 24, 2024
28f0c61
adding attrs constraint to show this passes for certain attrs version…
eberrigan Sep 24, 2024
fd14ad0
add import
eberrigan Sep 24, 2024
7fa9517
switch out jsonpickle.encode
eberrigan Sep 24, 2024
1d98177
oops remove import
eberrigan Sep 24, 2024
949fbe6
can attrs be unconstrained?
eberrigan Sep 24, 2024
6490d1f
forgot comma
eberrigan Sep 24, 2024
c199061
pin attrs for testing
eberrigan Sep 24, 2024
345dbc0
test Skeleton from json, template, with symmetries, and template
eberrigan Sep 24, 2024
4a8c326
use SkeletonEncoder.encode
eberrigan Sep 24, 2024
c57e64d
black
eberrigan Sep 24, 2024
4c8bdd6
try removing None values in EdgeType reduced
eberrigan Sep 24, 2024
6444378
Handle case when nodes are replaced by integer indices from caller
eberrigan Sep 25, 2024
c232ae6
Remove prototyping notebook
talmo Sep 25, 2024
2a56e88
Merge branch 'develop' into elizabeth/handle-skeleton-encoding-intern…
talmo Sep 25, 2024
5319e5a
Remove attrs pins
talmo Sep 25, 2024
b1d757b
Remove sort keys (which flips the neccessary ordering of our py/ids)
roomrys Sep 25, 2024
e7fb00a
Do not add extra indents to encoded file
roomrys Sep 25, 2024
743e406
Only append links after fully encoded (fat-finger)
roomrys Sep 25, 2024
02865f8
Remove outdated comment
roomrys Sep 25, 2024
83a2704
Lint
roomrys Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ channels:

dependencies:
# Packages SLEAP uses directly
- conda-forge::attrs >=21.2.0 #,<=21.4.0
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg
- conda-forge::jsmin
Expand Down
2 changes: 1 addition & 1 deletion environment_no_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ channels:

dependencies:
# Packages SLEAP uses directly
- conda-forge::attrs >=21.2.0 #,<=21.4.0
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg
- conda-forge::jsmin
Expand Down
198 changes: 193 additions & 5 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,194 @@ def decode(cls, json_str: str) -> Dict:
return decoder._decode(json_str)


class SkeletonEncoder:
"""Replace jsonpickle.encode with our own encoder.

The input is a dictionary containing python objects that need to be encoded as
JSON strings. The output is a JSON string that represents the input dictionary.

`Node(name='neck', weight=1.0)` =>
{
"py/object": "sleap.Skeleton.Node",
"py/state": {"py/tuple" ["neck", 1.0]}
}

`<EdgeType.BODY: 1>` =>
{"py/reduce": [
{"py/type": "sleap.Skeleton.EdgeType"},
{"py/tuple": [1] }
]
}`

Where `name` and `weight` are the attributes of the `Node` class; weight is always 1.0.
`EdgeType` is an enum with values `BODY = 1` and `SYMMETRY = 2`.

See sleap.skeleton.Node and sleap.skeleton.EdgeType.

If the object has been "seen" before, it will not be encoded as the full JSON string
but referenced by its `py/id`, which starts at 1 and indexes the objects in the
order they are seen so that the second time the first object is used, it will be
referenced as `{"py/id": 1}`.
"""

def __init__(self):
"""Initializes a SkeletonEncoder instance."""
# Maps object id to py/id
self._encoded_objects: Dict[int, int] = {}

@classmethod
def encode(cls, data: Dict[str, Any]) -> str:
"""Encodes the input dictionary as a JSON string.

Args:
data: The data to encode.

Returns:
json_str: The JSON string representation of the data.
"""
encoder = cls()
encoded_data = encoder._encode(data)
json_str = json.dumps(encoded_data)
return json_str

def _encode(self, obj: Any) -> Any:
"""Recursively encodes the input object.

Args:
obj: The object to encode. Can be a dictionary, list, Node, EdgeType or
primitive data type.

Returns:
The encoded object as a dictionary.
"""
if isinstance(obj, dict):
encoded_obj = {}
for key, value in obj.items():
if key == "links":
encoded_obj[key] = self._encode_links(value)
else:
encoded_obj[key] = self._encode(value)
return encoded_obj
elif isinstance(obj, list):
return [self._encode(v) for v in obj]
elif isinstance(obj, EdgeType):
return self._encode_edge_type(obj)
elif isinstance(obj, Node):
return self._encode_node(obj)
else:
return obj # Primitive data types

def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Encodes the list of links (edges) in the skeleton graph.

Args:
links: A list of dictionaries, each representing an edge in the graph.

Returns:
A list of encoded edge dictionaries with keys ordered as specified.
"""
encoded_links = []
for link in links:
# Use a regular dict (insertion order preserved in Python 3.7+)
encoded_link = {}

# Encode in specific order of appearance
roomrys marked this conversation as resolved.
Show resolved Hide resolved
for key, value in link.items():
if key in ("source", "target"):
encoded_link[key] = self._encode_node(value)
elif key == "type":
encoded_link[key] = self._encode_edge_type(value)
else:
encoded_link[key] = self._encode(value)
encoded_links.append(encoded_link)

return encoded_links

def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]:
"""Encodes a Node object.

Args:
node: The Node object to encode or integer index. The latter requires that
the class has the `idx_to_node` attribute set.

Returns:
The encoded `Node` object as a dictionary.
"""
if isinstance(node, int):
# We sometimes have the node object already replaced by its index (when
# `node_to_idx` is provided). In this case, the node is already encoded.
return node

# Check if object has been encoded before
first_encoding = self._is_first_encoding(node)
py_id = self._get_or_assign_id(node, first_encoding)
if first_encoding:
# Full encoding
return {
"py/object": "sleap.skeleton.Node",
"py/state": {"py/tuple": [node.name, node.weight]},
}
else:
# Reference by py/id
return {"py/id": py_id}

def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]:
"""Encodes an EdgeType object.

Args:
edge_type: The EdgeType object to encode. Either `EdgeType.BODY` or
`EdgeType.SYMMETRY` enum with values 1 and 2 respectively.

Returns:
The encoded EdgeType object as a dictionary.
"""
# Check if object has been encoded before
first_encoding = self._is_first_encoding(edge_type)
py_id = self._get_or_assign_id(edge_type, first_encoding)
if first_encoding:
# Full encoding
return {
"py/reduce": [
{"py/type": "sleap.skeleton.EdgeType"},
{"py/tuple": [edge_type.value]},
]
}
else:
# Reference by py/id
return {"py/id": py_id}

def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int:
"""Gets or assigns a py/id for the object.

Args:
The object to get or assign a py/id for.

Returns:
The py/id assigned to the object.
"""
# Object id is unique for each object in the current session
obj_id = id(obj)
# Assign a py/id to the object if it hasn't been assigned one yet
if first_encoding:
py_id = len(self._encoded_objects) + 1 # py/id starts at 1
# Assign the py/id to the object and store it in _encoded_objects
self._encoded_objects[obj_id] = py_id
return self._encoded_objects[obj_id]

def _is_first_encoding(self, obj: Any) -> bool:
"""Checks if the object is being encoded for the first time.

Args:
obj: The object to check.

Returns:
True if this is the first encoding of the object, False otherwise.
"""
obj_id = id(obj)
first_time = obj_id not in self._encoded_objects
return first_time


class Skeleton:
"""The main object for representing animal skeletons.

Expand Down Expand Up @@ -1228,7 +1416,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D

# This is a weird hack to serialize the whole _graph into a dict.
# I use the underlying to_json and parse it.
return json.loads(obj.to_json(node_to_idx))
return json.loads(obj.to_json(node_to_idx=node_to_idx))

@classmethod
def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton":
Expand Down Expand Up @@ -1292,10 +1480,10 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
"""
jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4)
if node_to_idx is not None:
indexed_node_graph = nx.relabel_nodes(
G=self._graph, mapping=node_to_idx
) # map nodes to int
# Map Nodes to int
indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx)
else:
# Keep graph nodes as Node objects
indexed_node_graph = self._graph

# Encode to JSON
Expand All @@ -1314,7 +1502,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
else:
data = graph

json_str = jsonpickle.encode(data)
json_str = SkeletonEncoder.encode(data)

return json_str

Expand Down
55 changes: 54 additions & 1 deletion tests/test_skeleton.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,62 @@
import copy
import os

import pytest
import json

from networkx.readwrite import json_graph
from sleap.skeleton import Skeleton, SkeletonDecoder
from sleap.skeleton import SkeletonEncoder


def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json):
"""
Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton.
"""
# Get the skeleton from the fixture
skeleton = Skeleton.load_json(fly_legs_skeleton_json)
# Get the graph from the skeleton
indexed_node_graph = skeleton._graph
graph = json_graph.node_link_data(indexed_node_graph)

# Encode the graph as a json string to test .encode method
encoded_json_str = SkeletonEncoder.encode(graph)

# Get the skeleton from the encoded json string
decoded_skeleton = Skeleton.from_json(encoded_json_str)

# Check that the decoded skeleton is the same as the original skeleton
assert skeleton.matches(decoded_skeleton)
Comment on lines +11 to +28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance test coverage with additional assertions

The test function effectively verifies that the encoded and decoded Skeleton matches the original. However, we can improve it by adding more specific assertions:

  1. Assert that the encoded JSON string is not empty.
  2. Compare the number of nodes and edges in the original and decoded skeletons.
  3. Verify that the node names and edge connections are preserved.

Consider adding these assertions to strengthen the test:

def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json):
    skeleton = Skeleton.load_json(fly_legs_skeleton_json)
    indexed_node_graph = skeleton._graph
    graph = json_graph.node_link_data(indexed_node_graph)

    encoded_json_str = SkeletonEncoder.encode(graph)
    assert encoded_json_str, "Encoded JSON string should not be empty"

    decoded_skeleton = Skeleton.from_json(encoded_json_str)

    assert skeleton.matches(decoded_skeleton)
    assert len(skeleton.nodes) == len(decoded_skeleton.nodes), "Number of nodes should match"
    assert len(skeleton.edges) == len(decoded_skeleton.edges), "Number of edges should match"
    assert set(n.name for n in skeleton.nodes) == set(n.name for n in decoded_skeleton.nodes), "Node names should match"
    assert set(skeleton.edge_names) == set(decoded_skeleton.edge_names), "Edge connections should match"



@pytest.mark.parametrize(
"skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"]
)
def test_decoded_encoded_Skeleton(skeleton_fixture_name, request):
"""
Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton.
"""
# Use request.getfixturevalue to get the actual fixture value by name
skeleton = request.getfixturevalue(skeleton_fixture_name)

# Get the graph from the skeleton
indexed_node_graph = skeleton._graph
graph = json_graph.node_link_data(indexed_node_graph)

# Encode the graph as a json string to test .encode method
encoded_json_str = SkeletonEncoder.encode(graph)

# Get the skeleton from the encoded json string
decoded_skeleton = Skeleton.from_json(encoded_json_str)

# Check that the decoded skeleton is the same as the original skeleton
assert skeleton.matches(decoded_skeleton)

# Now make everything into a JSON string
skeleton_json_str = skeleton.to_json()
decoded_skeleton_json_str = decoded_skeleton.to_json()

# Check that the JSON strings are the same
assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str)
Comment on lines +31 to +59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance parameterized test with additional assertions and error messages

The test function effectively verifies that the encoded and decoded Skeleton matches the original across multiple fixtures. To further improve its robustness:

  1. Add assertions for the number of nodes and edges.
  2. Verify that node names and edge connections are preserved.
  3. Include more descriptive error messages in assertions.

Consider enhancing the test function as follows:

@pytest.mark.parametrize(
    "skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"]
)
def test_decoded_encoded_Skeleton(skeleton_fixture_name, request):
    skeleton = request.getfixturevalue(skeleton_fixture_name)
    indexed_node_graph = skeleton._graph
    graph = json_graph.node_link_data(indexed_node_graph)

    encoded_json_str = SkeletonEncoder.encode(graph)
    assert encoded_json_str, f"Encoded JSON string for {skeleton_fixture_name} should not be empty"

    decoded_skeleton = Skeleton.from_json(encoded_json_str)

    assert skeleton.matches(decoded_skeleton), f"Decoded {skeleton_fixture_name} should match the original"
    assert len(skeleton.nodes) == len(decoded_skeleton.nodes), f"Number of nodes in {skeleton_fixture_name} should match"
    assert len(skeleton.edges) == len(decoded_skeleton.edges), f"Number of edges in {skeleton_fixture_name} should match"
    assert set(n.name for n in skeleton.nodes) == set(n.name for n in decoded_skeleton.nodes), f"Node names in {skeleton_fixture_name} should match"
    assert set(skeleton.edge_names) == set(decoded_skeleton.edge_names), f"Edge connections in {skeleton_fixture_name} should match"

    skeleton_json_str = skeleton.to_json()
    decoded_skeleton_json_str = decoded_skeleton.to_json()

    assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str), f"JSON representations of {skeleton_fixture_name} should match"

These changes will provide more detailed information if a test fails, making it easier to identify and fix issues.



def test_add_dupe_node(skeleton):
Expand Down
Loading