Skip to content

Commit

Permalink
Enable pickling of objects (#243)
Browse files Browse the repository at this point in the history
* Cleanup helpers
* Optionally pass attribute list for dict export
* Geometry and track objects become dataclasses
* Change properties to/from functions
* Pickling
* Move copy bindings to helper
  • Loading branch information
sarlinpe authored Jan 23, 2024
1 parent 0c064ad commit 6f0096d
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 62 deletions.
25 changes: 16 additions & 9 deletions pycolmap/geometry/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ using namespace pybind11::literals;
void BindGeometry(py::module& m) {
BindHomographyGeometry(m);

py::class_<Eigen::Quaterniond>(m, "Rotation3d")
.def(py::init([]() { return Eigen::Quaterniond::Identity(); }))
py::class_<Eigen::Quaterniond> PyRotation3d(m, "Rotation3d");
PyRotation3d.def(py::init([]() { return Eigen::Quaterniond::Identity(); }))
.def(py::init<const Eigen::Vector4d&>(), "xyzw"_a)
.def(py::init<const Eigen::Matrix3d&>(), "rotmat"_a)
.def(py::self * Eigen::Quaterniond())
.def(py::self * Eigen::Vector3d())
.def_property("quat",
py::overload_cast<>(&Eigen::Quaterniond::coeffs),
[](Eigen::Quaterniond& self, const Eigen::Vector4d& quat) {
self.coeffs() = quat;
})
.def("normalize", &Eigen::Quaterniond::normalize)
.def("matrix", &Eigen::Quaterniond::toRotationMatrix)
.def("quat", py::overload_cast<>(&Eigen::Quaterniond::coeffs))
.def("norm", &Eigen::Quaterniond::norm)
.def("inverse", &Eigen::Quaterniond::inverse)
.def("__repr__", [](const Eigen::Quaterniond& self) {
Expand All @@ -36,16 +40,17 @@ void BindGeometry(py::module& m) {
return ss.str();
});
py::implicitly_convertible<py::array, Eigen::Quaterniond>();
MakeDataclass(PyRotation3d);

py::class_<Rigid3d>(m, "Rigid3d")
.def(py::init<>())
py::class_<Rigid3d> PyRigid3d(m, "Rigid3d");
PyRigid3d.def(py::init<>())
.def(py::init<const Eigen::Quaterniond&, const Eigen::Vector3d&>())
.def(py::init([](const Eigen::Matrix3x4d& matrix) {
return Rigid3d(Eigen::Quaterniond(matrix.leftCols<3>()), matrix.col(3));
}))
.def_readwrite("rotation", &Rigid3d::rotation)
.def_readwrite("translation", &Rigid3d::translation)
.def_property_readonly("matrix", &Rigid3d::ToMatrix)
.def("matrix", &Rigid3d::ToMatrix)
.def(py::self * Eigen::Vector3d())
.def(py::self * Rigid3d())
.def("inverse", static_cast<Rigid3d (*)(const Rigid3d&)>(&Inverse))
Expand All @@ -58,16 +63,17 @@ void BindGeometry(py::module& m) {
return ss.str();
});
py::implicitly_convertible<py::array, Rigid3d>();
MakeDataclass(PyRigid3d);

py::class_<Sim3d>(m, "Sim3d")
.def(py::init<>())
py::class_<Sim3d> PySim3d(m, "Sim3d");
PySim3d.def(py::init<>())
.def(
py::init<double, const Eigen::Quaterniond&, const Eigen::Vector3d&>())
.def(py::init(&Sim3d::FromMatrix))
.def_readwrite("scale", &Sim3d::scale)
.def_readwrite("rotation", &Sim3d::rotation)
.def_readwrite("translation", &Sim3d::translation)
.def_property_readonly("matrix", &Sim3d::ToMatrix)
.def("matrix", &Sim3d::ToMatrix)
.def(py::self * Eigen::Vector3d())
.def(py::self * Sim3d())
.def("transform_camera_world", &TransformCameraWorld)
Expand All @@ -81,4 +87,5 @@ void BindGeometry(py::module& m) {
return ss.str();
});
py::implicitly_convertible<py::array, Sim3d>();
MakeDataclass(PySim3d);
}
81 changes: 57 additions & 24 deletions pycolmap/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const Eigen::IOFormat vec_fmt(Eigen::StreamPrecision,
", ");

template <typename T>
inline T pyStringToEnum(const py::enum_<T>& enm, const std::string& value) {
T pyStringToEnum(const py::enum_<T>& enm, const std::string& value) {
const auto values = enm.attr("__members__").template cast<py::dict>();
const auto str_val = py::str(value);
if (values.contains(str_val)) {
Expand All @@ -45,14 +45,14 @@ inline T pyStringToEnum(const py::enum_<T>& enm, const std::string& value) {
}

template <typename T>
inline void AddStringToEnumConstructor(py::enum_<T>& enm) {
void AddStringToEnumConstructor(py::enum_<T>& enm) {
enm.def(py::init([enm](const std::string& value) {
return pyStringToEnum(enm, py::str(value)); // str constructor
}));
py::implicitly_convertible<std::string, T>();
}

inline void UpdateFromDict(py::object& self, const py::dict& dict) {
void UpdateFromDict(py::object& self, const py::dict& dict) {
for (const auto& it : dict) {
if (!py::isinstance<py::str>(it.first)) {
const std::string msg = "Dictionary key is not a string: " +
Expand Down Expand Up @@ -125,34 +125,47 @@ inline void UpdateFromDict(py::object& self, const py::dict& dict) {
}
}

bool AttributeIsFunction(const std::string& name, const py::object& attribute) {
bool AttributeIsFunction(const std::string& name, const py::object& value) {
return (name.find("__") == 0 || name.rfind("__") != std::string::npos ||
py::hasattr(attribute, "__func__") ||
py::hasattr(attribute, "__call__"));
py::hasattr(value, "__func__") || py::hasattr(value, "__call__"));
}

template <typename T, typename... options>
inline py::dict ConvertToDict(const T& self) {
const auto pyself = py::cast(self);
py::dict dict;
std::vector<std::string> ListObjectAttributes(const py::object& pyself) {
std::vector<std::string> attributes;
for (const auto& handle : pyself.attr("__dir__")()) {
const py::str name = py::reinterpret_borrow<py::str>(handle);
const auto attribute = pyself.attr(name);
if (AttributeIsFunction(name, attribute)) {
const py::str attribute = py::reinterpret_borrow<py::str>(handle);
const auto value = pyself.attr(attribute);
if (AttributeIsFunction(attribute, value)) {
continue;
}
if (py::hasattr(attribute, "todict")) {
dict[name] =
attribute.attr("todict").attr("__call__")().template cast<py::dict>();
attributes.push_back(attribute);
}
return attributes;
}

template <typename T, typename... options>
py::dict ConvertToDict(const T& self,
std::vector<std::string> attributes,
const bool recursive) {
const py::object pyself = py::cast(self);
if (attributes.empty()) {
attributes = ListObjectAttributes(pyself);
}
py::dict dict;
for (const auto& attr : attributes) {
const auto value = pyself.attr(attr.c_str());
if (recursive && py::hasattr(value, "todict")) {
dict[attr.c_str()] =
value.attr("todict").attr("__call__")().template cast<py::dict>();
} else {
dict[name] = attribute;
dict[attr.c_str()] = value;
}
}
return dict;
}

template <typename T, typename... options>
inline std::string CreateSummary(const T& self, bool write_type) {
std::string CreateSummary(const T& self, bool write_type) {
std::stringstream ss;
auto pyself = py::cast(self);
const std::string prefix = " ";
Expand Down Expand Up @@ -230,25 +243,45 @@ void AddDefaultsToDocstrings(py::class_<T, options...> cls) {
}

template <typename T, typename... options>
inline void MakeDataclass(py::class_<T, options...> cls) {
void MakeDataclass(py::class_<T, options...> cls,
const std::vector<std::string>& attributes = {}) {
AddDefaultsToDocstrings(cls);
cls.def("mergedict", &UpdateFromDict);
if (!py::hasattr(cls, "summary")) {
cls.def("summary", &CreateSummary<T>, "write_type"_a = false);
}
cls.def("todict", &ConvertToDict<T>);
cls.def("mergedict", &UpdateFromDict);
cls.def(
"todict",
[attributes](const T& self, const bool recursive) {
return ConvertToDict(self, attributes, recursive);
},
"recursive"_a = true);

cls.def(py::init([cls](const py::dict& dict) {
auto self = py::object(cls());
py::object self = cls();
self.attr("mergedict").attr("__call__")(dict);
return self.cast<T>();
}));
cls.def(py::init([cls](const py::kwargs& kwargs) {
py::dict dict = kwargs.cast<py::dict>();
auto self = py::object(cls(dict));
return self.cast<T>();
return cls(dict).template cast<T>();
}));
py::implicitly_convertible<py::dict, T>();
py::implicitly_convertible<py::kwargs, T>();

cls.def("__copy__", [](const T& self) { return T(self); });
cls.def("__deepcopy__",
[](const T& self, const py::dict&) { return T(self); });

cls.def(py::pickle(
[attributes](const T& self) {
return ConvertToDict(self, attributes, /*recursive=*/false);
},
[cls](const py::dict& dict) {
py::object self = cls();
self.attr("mergedict").attr("__call__")(dict);
return self.cast<T>();
}));
}

// Catch python keyboard interrupts
Expand Down
11 changes: 7 additions & 4 deletions pycolmap/scene/camera.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ void BindCamera(py::module& m) {
"Rescale camera dimensions by given factor and accordingly the "
"focal length and\n"
"and the principal point.")
.def("__copy__", [](const Camera& self) { return Camera(self); })
.def("__deepcopy__",
[](const Camera& self, const py::dict&) { return Camera(self); })
.def("__repr__", &PrintCamera);
MakeDataclass(PyCamera);
MakeDataclass(PyCamera,
{"camera_id",
"model",
"width",
"height",
"params",
"has_prior_focal_length"});
}
9 changes: 3 additions & 6 deletions pycolmap/scene/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ void BindImage(py::module& m) {
&Image::IsRegistered,
&Image::SetRegistered,
"Whether image is registered in the reconstruction.")
.def_property_readonly("num_points2D",
&Image::NumPoints2D,
"Get the number of image points (keypoints).")
.def("num_points2D",
&Image::NumPoints2D,
"Get the number of image points (keypoints).")
.def_property_readonly(
"num_points3D",
&Image::NumPoints3D,
Expand Down Expand Up @@ -289,9 +289,6 @@ void BindImage(py::module& m) {
},
"Project list of image points (with depth) to world coordinate "
"frame.")
.def("__copy__", [](const Image& self) { return Image(self); })
.def("__deepcopy__",
[](const Image& self, const py::dict&) { return Image(self); })
.def("__repr__", &PrintImage);
MakeDataclass(PyImage);
}
3 changes: 0 additions & 3 deletions pycolmap/scene/point2D.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ void BindPoint2D(py::module& m) {
.def_readwrite("xy", &Point2D::xy)
.def_readwrite("point3D_id", &Point2D::point3D_id)
.def("has_point3D", &Point2D::HasPoint3D)
.def("__copy__", [](const Point2D& self) { return Point2D(self); })
.def("__deepcopy__",
[](const Point2D& self, const py::dict&) { return Point2D(self); })
.def("__repr__", &PrintPoint2D);
MakeDataclass(PyPoint2D);
}
3 changes: 0 additions & 3 deletions pycolmap/scene/point3D.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ void BindPoint3D(py::module& m) {
.def_readwrite("color", &Point3D::color)
.def_readwrite("error", &Point3D::error)
.def_readwrite("track", &Point3D::track)
.def("__copy__", [](const Point3D& self) { return Point3D(self); })
.def("__deepcopy__",
[](const Point3D& self, const py::dict&) { return Point3D(self); })
.def("__repr__", [](const Point3D& self) {
std::stringstream ss;
ss << "Point3D(xyz=[" << self.xyz.format(vec_fmt) << "], color=["
Expand Down
21 changes: 8 additions & 13 deletions pycolmap/scene/track.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "colmap/util/misc.h"
#include "colmap/util/types.h"

#include "pycolmap/helpers.h"
#include "pycolmap/log_exceptions.h"

#include <memory>
Expand All @@ -18,24 +19,20 @@ using namespace pybind11::literals;
namespace py = pybind11;

void BindTrack(py::module& m) {
py::class_<TrackElement, std::shared_ptr<TrackElement>>(m, "TrackElement")
.def(py::init<>())
py::class_<TrackElement, std::shared_ptr<TrackElement>> PyTrackElement(
m, "TrackElement");
PyTrackElement.def(py::init<>())
.def(py::init<image_t, point2D_t>())
.def_readwrite("image_id", &TrackElement::image_id)
.def_readwrite("point2D_idx", &TrackElement::point2D_idx)
.def("__copy__",
[](const TrackElement& self) { return TrackElement(self); })
.def("__deepcopy__",
[](const TrackElement& self, const py::dict&) {
return TrackElement(self);
})
.def("__repr__", [](const TrackElement& self) {
return "TrackElement(image_id=" + std::to_string(self.image_id) +
", point2D_idx=" + std::to_string(self.point2D_idx) + ")";
});
MakeDataclass(PyTrackElement);

py::class_<Track, std::shared_ptr<Track>>(m, "Track")
.def(py::init<>())
py::class_<Track, std::shared_ptr<Track>> PyTrack(m, "Track");
PyTrack.def(py::init<>())
.def(py::init([](const std::vector<TrackElement>& elements) {
auto track = std::make_shared<Track>();
track->AddElements(elements);
Expand Down Expand Up @@ -67,10 +64,8 @@ void BindTrack(py::module& m) {
py::overload_cast<const image_t, const point2D_t>(
&Track::DeleteElement),
"Remove TrackElement with (image_id,point2D_idx).")
.def("__copy__", [](const Track& self) { return Track(self); })
.def("__deepcopy__",
[](const Track& self, const py::dict&) { return Track(self); })
.def("__repr__", [](const Track& self) {
return "Track(length=" + std::to_string(self.Length()) + ")";
});
MakeDataclass(PyTrack);
}

0 comments on commit 6f0096d

Please sign in to comment.