Skip to content

Commit

Permalink
add node and edge metadata bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanhhughes committed Nov 15, 2024
1 parent a9ee4c2 commit ee4fb45
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 17 deletions.
21 changes: 20 additions & 1 deletion python/bindings/src/spark_dsg_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,16 @@ PYBIND11_MODULE(_dsg_bindings, module) {
.def_readwrite("last_update_time_ns", &NodeAttributes::last_update_time_ns)
.def_readwrite("is_active", &NodeAttributes::is_active)
.def_readwrite("is_predicted", &NodeAttributes::is_predicted)
.def("_get_metadata",
[](const NodeAttributes& node) {
std::stringstream ss;
ss << node.metadata;
return ss.str();
})
.def("_set_metadata",
[](NodeAttributes& node, const std::string& data) {
node.metadata = nlohmann::json::parse(data);
})
.def("__repr__", [](const NodeAttributes& attrs) {
std::stringstream ss;
ss << attrs;
Expand Down Expand Up @@ -386,7 +396,16 @@ PYBIND11_MODULE(_dsg_bindings, module) {
py::class_<EdgeAttributes>(module, "EdgeAttributes")
.def(py::init<>())
.def_readwrite("weighted", &EdgeAttributes::weighted)
.def_readwrite("weight", &EdgeAttributes::weight);
.def_readwrite("weight", &EdgeAttributes::weight)
.def("_get_metadata",
[](const EdgeAttributes& edge) {
std::stringstream ss;
ss << edge.metadata;
return ss.str();
})
.def("_set_metadata", [](EdgeAttributes& edge, const std::string& data) {
edge.metadata = nlohmann::json::parse(data);
});

/**************************************************************************************
* Scene graph node
Expand Down
31 changes: 19 additions & 12 deletions python/src/spark_dsg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

from spark_dsg._dsg_bindings import *
from spark_dsg._dsg_bindings import (BoundingBoxType, DsgLayers,
DynamicSceneGraph, LayerView,
DynamicSceneGraph, EdgeAttributes,
LayerView, NodeAttributes,
SceneGraphLayer,
compute_ancestor_bounding_box)
from spark_dsg.open3d_visualization import render_to_open3d
Expand Down Expand Up @@ -68,17 +69,17 @@ def add_bounding_boxes_to_layer(
node.attributes.bounding_box = bbox


def _get_metadata(G):
def _get_metadata(obj):
"""Get graph metadata."""
data_str = G._get_metadata()
data_str = obj._get_metadata()
metadata = json.loads(data_str)
metadata = dict() if metadata is None else metadata
return types.MappingProxyType(metadata)


def _set_metadata(G, obj):
def _set_metadata(obj, data):
"""Serialize and set graph metadata."""
G._set_metadata(json.dumps(obj))
obj._set_metadata(json.dumps(data))


def _update_nested(contents, other):
Expand All @@ -92,18 +93,24 @@ def _update_nested(contents, other):
contents[key] = value


def _add_metadata(G, obj):
def _add_metadata(obj, data):
"""Serialize and update metadata from passed object."""
data_str = G._get_metadata()
data_str = obj._get_metadata()
metadata = json.loads(data_str)
metadata = dict() if metadata is None else metadata
_update_nested(metadata, obj)
G._set_metadata(json.dumps(metadata))
_update_nested(metadata, data)
obj._set_metadata(json.dumps(metadata))


DynamicSceneGraph.metadata = property(_get_metadata)
DynamicSceneGraph.set_metadata = _set_metadata
DynamicSceneGraph.add_metadata = _add_metadata
def _add_metadata_interface(obj):
obj.metadata = property(_get_metadata)
obj.set_metadata = _set_metadata
obj.add_metadata = _add_metadata


_add_metadata_interface(DynamicSceneGraph)
_add_metadata_interface(NodeAttributes)
_add_metadata_interface(EdgeAttributes)

DynamicSceneGraph.to_torch = scene_graph_to_torch
SceneGraphLayer.to_torch = scene_graph_layer_to_torch
Expand Down
21 changes: 17 additions & 4 deletions python/tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
#
#
"""Test that the bindings are working appropriately."""
import json

import numpy as np
import spark_dsg as dsg

Expand Down Expand Up @@ -307,7 +305,7 @@ def test_node_counts(resource_dir):


def test_graph_metadata(tmp_path):
"""Test that graph metadat works as expected."""
"""Test that graph metadata works as expected."""
G = dsg.DynamicSceneGraph()
G.add_metadata({"foo": 5})
G.add_metadata({"bar": [1, 2, 3, 4, 5]})
Expand All @@ -323,9 +321,24 @@ def test_graph_metadata(tmp_path):
}

G.add_metadata({"something": {"b": 643.0, "other": "foo"}})
print(G.metadata)
assert G.metadata == {
"foo": 5,
"bar": [1, 2, 3, 4, 5],
"something": {"a": 13, "b": 643.0, "c": "world", "other": "foo"},
}


def test_attribute_metadata(tmp_path):
"""Test that attribute metadata works as expected."""
G = dsg.DynamicSceneGraph()

attrs = dsg.ObjectNodeAttributes()
attrs.add_metadata({"test": {"a": 5, "c": "hello"}})
attrs.add_metadata({"test": {"a": 6, "b": 42.0}})
G.add_node(dsg.DsgLayers.OBJECTS, dsg.NodeSymbol("O", 1), attrs)

graph_path = tmp_path / "graph.json"
G.save(graph_path)
G_new = dsg.DynamicSceneGraph.load(graph_path)
new_attrs = G_new.get_node(dsg.NodeSymbol("O", 1)).attributes
assert new_attrs.metadata == {"test": {"a": 6, "b": 42.0, "c": "hello"}}

0 comments on commit ee4fb45

Please sign in to comment.