diff --git a/pycolmap/geometry/bindings.h b/pycolmap/geometry/bindings.h index 4ad3b6a..6db1746 100644 --- a/pycolmap/geometry/bindings.h +++ b/pycolmap/geometry/bindings.h @@ -19,15 +19,19 @@ using namespace pybind11::literals; void BindGeometry(py::module& m) { BindHomographyGeometry(m); - py::class_(m, "Rotation3d") - .def(py::init([]() { return Eigen::Quaterniond::Identity(); })) + py::class_ PyRotation3d(m, "Rotation3d"); + PyRotation3d.def(py::init([]() { return Eigen::Quaterniond::Identity(); })) .def(py::init(), "xyzw"_a) .def(py::init(), "rotmat"_a) .def(py::self * Eigen::Quaterniond()) .def(py::self * Eigen::Vector3d()) + .def_property("quat", + py::overload_cast<>(&Eigen::Quaterniond::coeffs), + [](Eigen::Quaterniond& self, const Eigen::Vector4d& quat) { + self.coeffs() = quat; + }) .def("normalize", &Eigen::Quaterniond::normalize) .def("matrix", &Eigen::Quaterniond::toRotationMatrix) - .def("quat", py::overload_cast<>(&Eigen::Quaterniond::coeffs)) .def("norm", &Eigen::Quaterniond::norm) .def("inverse", &Eigen::Quaterniond::inverse) .def("__repr__", [](const Eigen::Quaterniond& self) { @@ -36,16 +40,17 @@ void BindGeometry(py::module& m) { return ss.str(); }); py::implicitly_convertible(); + MakeDataclass(PyRotation3d); - py::class_(m, "Rigid3d") - .def(py::init<>()) + py::class_ PyRigid3d(m, "Rigid3d"); + PyRigid3d.def(py::init<>()) .def(py::init()) .def(py::init([](const Eigen::Matrix3x4d& matrix) { return Rigid3d(Eigen::Quaterniond(matrix.leftCols<3>()), matrix.col(3)); })) .def_readwrite("rotation", &Rigid3d::rotation) .def_readwrite("translation", &Rigid3d::translation) - .def_property_readonly("matrix", &Rigid3d::ToMatrix) + .def("matrix", &Rigid3d::ToMatrix) .def(py::self * Eigen::Vector3d()) .def(py::self * Rigid3d()) .def("inverse", static_cast(&Inverse)) @@ -58,16 +63,17 @@ void BindGeometry(py::module& m) { return ss.str(); }); py::implicitly_convertible(); + MakeDataclass(PyRigid3d); - py::class_(m, "Sim3d") - .def(py::init<>()) + py::class_ PySim3d(m, "Sim3d"); + PySim3d.def(py::init<>()) .def( py::init()) .def(py::init(&Sim3d::FromMatrix)) .def_readwrite("scale", &Sim3d::scale) .def_readwrite("rotation", &Sim3d::rotation) .def_readwrite("translation", &Sim3d::translation) - .def_property_readonly("matrix", &Sim3d::ToMatrix) + .def("matrix", &Sim3d::ToMatrix) .def(py::self * Eigen::Vector3d()) .def(py::self * Sim3d()) .def("transform_camera_world", &TransformCameraWorld) @@ -81,4 +87,5 @@ void BindGeometry(py::module& m) { return ss.str(); }); py::implicitly_convertible(); + MakeDataclass(PySim3d); } diff --git a/pycolmap/helpers.h b/pycolmap/helpers.h index c4709db..c726a95 100644 --- a/pycolmap/helpers.h +++ b/pycolmap/helpers.h @@ -31,7 +31,7 @@ const Eigen::IOFormat vec_fmt(Eigen::StreamPrecision, ", "); template -inline T pyStringToEnum(const py::enum_& enm, const std::string& value) { +T pyStringToEnum(const py::enum_& enm, const std::string& value) { const auto values = enm.attr("__members__").template cast(); const auto str_val = py::str(value); if (values.contains(str_val)) { @@ -45,14 +45,14 @@ inline T pyStringToEnum(const py::enum_& enm, const std::string& value) { } template -inline void AddStringToEnumConstructor(py::enum_& enm) { +void AddStringToEnumConstructor(py::enum_& enm) { enm.def(py::init([enm](const std::string& value) { return pyStringToEnum(enm, py::str(value)); // str constructor })); py::implicitly_convertible(); } -inline void UpdateFromDict(py::object& self, const py::dict& dict) { +void UpdateFromDict(py::object& self, const py::dict& dict) { for (const auto& it : dict) { if (!py::isinstance(it.first)) { const std::string msg = "Dictionary key is not a string: " + @@ -125,34 +125,47 @@ inline void UpdateFromDict(py::object& self, const py::dict& dict) { } } -bool AttributeIsFunction(const std::string& name, const py::object& attribute) { +bool AttributeIsFunction(const std::string& name, const py::object& value) { return (name.find("__") == 0 || name.rfind("__") != std::string::npos || - py::hasattr(attribute, "__func__") || - py::hasattr(attribute, "__call__")); + py::hasattr(value, "__func__") || py::hasattr(value, "__call__")); } -template -inline py::dict ConvertToDict(const T& self) { - const auto pyself = py::cast(self); - py::dict dict; +std::vector ListObjectAttributes(const py::object& pyself) { + std::vector attributes; for (const auto& handle : pyself.attr("__dir__")()) { - const py::str name = py::reinterpret_borrow(handle); - const auto attribute = pyself.attr(name); - if (AttributeIsFunction(name, attribute)) { + const py::str attribute = py::reinterpret_borrow(handle); + const auto value = pyself.attr(attribute); + if (AttributeIsFunction(attribute, value)) { continue; } - if (py::hasattr(attribute, "todict")) { - dict[name] = - attribute.attr("todict").attr("__call__")().template cast(); + attributes.push_back(attribute); + } + return attributes; +} + +template +py::dict ConvertToDict(const T& self, + std::vector attributes, + const bool recursive) { + const py::object pyself = py::cast(self); + if (attributes.empty()) { + attributes = ListObjectAttributes(pyself); + } + py::dict dict; + for (const auto& attr : attributes) { + const auto value = pyself.attr(attr.c_str()); + if (recursive && py::hasattr(value, "todict")) { + dict[attr.c_str()] = + value.attr("todict").attr("__call__")().template cast(); } else { - dict[name] = attribute; + dict[attr.c_str()] = value; } } return dict; } template -inline std::string CreateSummary(const T& self, bool write_type) { +std::string CreateSummary(const T& self, bool write_type) { std::stringstream ss; auto pyself = py::cast(self); const std::string prefix = " "; @@ -230,25 +243,45 @@ void AddDefaultsToDocstrings(py::class_ cls) { } template -inline void MakeDataclass(py::class_ cls) { +void MakeDataclass(py::class_ cls, + const std::vector& attributes = {}) { AddDefaultsToDocstrings(cls); - cls.def("mergedict", &UpdateFromDict); if (!py::hasattr(cls, "summary")) { cls.def("summary", &CreateSummary, "write_type"_a = false); } - cls.def("todict", &ConvertToDict); + cls.def("mergedict", &UpdateFromDict); + cls.def( + "todict", + [attributes](const T& self, const bool recursive) { + return ConvertToDict(self, attributes, recursive); + }, + "recursive"_a = true); + cls.def(py::init([cls](const py::dict& dict) { - auto self = py::object(cls()); + py::object self = cls(); self.attr("mergedict").attr("__call__")(dict); return self.cast(); })); cls.def(py::init([cls](const py::kwargs& kwargs) { py::dict dict = kwargs.cast(); - auto self = py::object(cls(dict)); - return self.cast(); + return cls(dict).template cast(); })); py::implicitly_convertible(); py::implicitly_convertible(); + + cls.def("__copy__", [](const T& self) { return T(self); }); + cls.def("__deepcopy__", + [](const T& self, const py::dict&) { return T(self); }); + + cls.def(py::pickle( + [attributes](const T& self) { + return ConvertToDict(self, attributes, /*recursive=*/false); + }, + [cls](const py::dict& dict) { + py::object self = cls(); + self.attr("mergedict").attr("__call__")(dict); + return self.cast(); + })); } // Catch python keyboard interrupts diff --git a/pycolmap/scene/camera.h b/pycolmap/scene/camera.h index cbe8c2b..882dedd 100644 --- a/pycolmap/scene/camera.h +++ b/pycolmap/scene/camera.h @@ -195,9 +195,12 @@ void BindCamera(py::module& m) { "Rescale camera dimensions by given factor and accordingly the " "focal length and\n" "and the principal point.") - .def("__copy__", [](const Camera& self) { return Camera(self); }) - .def("__deepcopy__", - [](const Camera& self, const py::dict&) { return Camera(self); }) .def("__repr__", &PrintCamera); - MakeDataclass(PyCamera); + MakeDataclass(PyCamera, + {"camera_id", + "model", + "width", + "height", + "params", + "has_prior_focal_length"}); } diff --git a/pycolmap/scene/image.h b/pycolmap/scene/image.h index 0d4296c..93a3d69 100644 --- a/pycolmap/scene/image.h +++ b/pycolmap/scene/image.h @@ -182,9 +182,9 @@ void BindImage(py::module& m) { &Image::IsRegistered, &Image::SetRegistered, "Whether image is registered in the reconstruction.") - .def_property_readonly("num_points2D", - &Image::NumPoints2D, - "Get the number of image points (keypoints).") + .def("num_points2D", + &Image::NumPoints2D, + "Get the number of image points (keypoints).") .def_property_readonly( "num_points3D", &Image::NumPoints3D, @@ -289,9 +289,6 @@ void BindImage(py::module& m) { }, "Project list of image points (with depth) to world coordinate " "frame.") - .def("__copy__", [](const Image& self) { return Image(self); }) - .def("__deepcopy__", - [](const Image& self, const py::dict&) { return Image(self); }) .def("__repr__", &PrintImage); MakeDataclass(PyImage); } diff --git a/pycolmap/scene/point2D.h b/pycolmap/scene/point2D.h index 3760b30..5d4595f 100644 --- a/pycolmap/scene/point2D.h +++ b/pycolmap/scene/point2D.h @@ -56,9 +56,6 @@ void BindPoint2D(py::module& m) { .def_readwrite("xy", &Point2D::xy) .def_readwrite("point3D_id", &Point2D::point3D_id) .def("has_point3D", &Point2D::HasPoint3D) - .def("__copy__", [](const Point2D& self) { return Point2D(self); }) - .def("__deepcopy__", - [](const Point2D& self, const py::dict&) { return Point2D(self); }) .def("__repr__", &PrintPoint2D); MakeDataclass(PyPoint2D); } diff --git a/pycolmap/scene/point3D.h b/pycolmap/scene/point3D.h index d008613..97df736 100644 --- a/pycolmap/scene/point3D.h +++ b/pycolmap/scene/point3D.h @@ -34,9 +34,6 @@ void BindPoint3D(py::module& m) { .def_readwrite("color", &Point3D::color) .def_readwrite("error", &Point3D::error) .def_readwrite("track", &Point3D::track) - .def("__copy__", [](const Point3D& self) { return Point3D(self); }) - .def("__deepcopy__", - [](const Point3D& self, const py::dict&) { return Point3D(self); }) .def("__repr__", [](const Point3D& self) { std::stringstream ss; ss << "Point3D(xyz=[" << self.xyz.format(vec_fmt) << "], color=[" diff --git a/pycolmap/scene/track.h b/pycolmap/scene/track.h index 1849c14..04b57b7 100644 --- a/pycolmap/scene/track.h +++ b/pycolmap/scene/track.h @@ -4,6 +4,7 @@ #include "colmap/util/misc.h" #include "colmap/util/types.h" +#include "pycolmap/helpers.h" #include "pycolmap/log_exceptions.h" #include @@ -18,24 +19,20 @@ using namespace pybind11::literals; namespace py = pybind11; void BindTrack(py::module& m) { - py::class_>(m, "TrackElement") - .def(py::init<>()) + py::class_> PyTrackElement( + m, "TrackElement"); + PyTrackElement.def(py::init<>()) .def(py::init()) .def_readwrite("image_id", &TrackElement::image_id) .def_readwrite("point2D_idx", &TrackElement::point2D_idx) - .def("__copy__", - [](const TrackElement& self) { return TrackElement(self); }) - .def("__deepcopy__", - [](const TrackElement& self, const py::dict&) { - return TrackElement(self); - }) .def("__repr__", [](const TrackElement& self) { return "TrackElement(image_id=" + std::to_string(self.image_id) + ", point2D_idx=" + std::to_string(self.point2D_idx) + ")"; }); + MakeDataclass(PyTrackElement); - py::class_>(m, "Track") - .def(py::init<>()) + py::class_> PyTrack(m, "Track"); + PyTrack.def(py::init<>()) .def(py::init([](const std::vector& elements) { auto track = std::make_shared(); track->AddElements(elements); @@ -67,10 +64,8 @@ void BindTrack(py::module& m) { py::overload_cast( &Track::DeleteElement), "Remove TrackElement with (image_id,point2D_idx).") - .def("__copy__", [](const Track& self) { return Track(self); }) - .def("__deepcopy__", - [](const Track& self, const py::dict&) { return Track(self); }) .def("__repr__", [](const Track& self) { return "Track(length=" + std::to_string(self.Length()) + ")"; }); + MakeDataclass(PyTrack); }