Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pickling of objects #243

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions pycolmap/geometry/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ using namespace pybind11::literals;
void BindGeometry(py::module& m) {
BindHomographyGeometry(m);

py::class_<Eigen::Quaterniond>(m, "Rotation3d")
.def(py::init([]() { return Eigen::Quaterniond::Identity(); }))
py::class_<Eigen::Quaterniond> PyRotation3d(m, "Rotation3d");
PyRotation3d.def(py::init([]() { return Eigen::Quaterniond::Identity(); }))
.def(py::init<const Eigen::Vector4d&>(), "xyzw"_a)
.def(py::init<const Eigen::Matrix3d&>(), "rotmat"_a)
.def(py::self * Eigen::Quaterniond())
.def(py::self * Eigen::Vector3d())
.def_property("quat",
py::overload_cast<>(&Eigen::Quaterniond::coeffs),
[](Eigen::Quaterniond& self, const Eigen::Vector4d& quat) {
self.coeffs() = quat;
})
.def("normalize", &Eigen::Quaterniond::normalize)
.def("matrix", &Eigen::Quaterniond::toRotationMatrix)
.def("quat", py::overload_cast<>(&Eigen::Quaterniond::coeffs))
.def("norm", &Eigen::Quaterniond::norm)
.def("inverse", &Eigen::Quaterniond::inverse)
.def("__repr__", [](const Eigen::Quaterniond& self) {
Expand All @@ -36,16 +40,17 @@ void BindGeometry(py::module& m) {
return ss.str();
});
py::implicitly_convertible<py::array, Eigen::Quaterniond>();
MakeDataclass(PyRotation3d);

py::class_<Rigid3d>(m, "Rigid3d")
.def(py::init<>())
py::class_<Rigid3d> PyRigid3d(m, "Rigid3d");
PyRigid3d.def(py::init<>())
.def(py::init<const Eigen::Quaterniond&, const Eigen::Vector3d&>())
.def(py::init([](const Eigen::Matrix3x4d& matrix) {
return Rigid3d(Eigen::Quaterniond(matrix.leftCols<3>()), matrix.col(3));
}))
.def_readwrite("rotation", &Rigid3d::rotation)
.def_readwrite("translation", &Rigid3d::translation)
.def_property_readonly("matrix", &Rigid3d::ToMatrix)
.def("matrix", &Rigid3d::ToMatrix)
.def(py::self * Eigen::Vector3d())
.def(py::self * Rigid3d())
.def("inverse", static_cast<Rigid3d (*)(const Rigid3d&)>(&Inverse))
Expand All @@ -58,16 +63,17 @@ void BindGeometry(py::module& m) {
return ss.str();
});
py::implicitly_convertible<py::array, Rigid3d>();
MakeDataclass(PyRigid3d);

py::class_<Sim3d>(m, "Sim3d")
.def(py::init<>())
py::class_<Sim3d> PySim3d(m, "Sim3d");
PySim3d.def(py::init<>())
.def(
py::init<double, const Eigen::Quaterniond&, const Eigen::Vector3d&>())
.def(py::init(&Sim3d::FromMatrix))
.def_readwrite("scale", &Sim3d::scale)
.def_readwrite("rotation", &Sim3d::rotation)
.def_readwrite("translation", &Sim3d::translation)
.def_property_readonly("matrix", &Sim3d::ToMatrix)
.def("matrix", &Sim3d::ToMatrix)
.def(py::self * Eigen::Vector3d())
.def(py::self * Sim3d())
.def("transform_camera_world", &TransformCameraWorld)
Expand All @@ -81,4 +87,5 @@ void BindGeometry(py::module& m) {
return ss.str();
});
py::implicitly_convertible<py::array, Sim3d>();
MakeDataclass(PySim3d);
}
83 changes: 58 additions & 25 deletions pycolmap/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const Eigen::IOFormat vec_fmt(Eigen::StreamPrecision,
", ");

template <typename T>
inline T pyStringToEnum(const py::enum_<T>& enm, const std::string& value) {
T pyStringToEnum(const py::enum_<T>& enm, const std::string& value) {
const auto values = enm.attr("__members__").template cast<py::dict>();
const auto str_val = py::str(value);
if (values.contains(str_val)) {
Expand All @@ -45,14 +45,14 @@ inline T pyStringToEnum(const py::enum_<T>& enm, const std::string& value) {
}

template <typename T>
inline void AddStringToEnumConstructor(py::enum_<T>& enm) {
void AddStringToEnumConstructor(py::enum_<T>& enm) {
enm.def(py::init([enm](const std::string& value) {
return pyStringToEnum(enm, py::str(value)); // str constructor
}));
py::implicitly_convertible<std::string, T>();
}

inline void UpdateFromDict(py::object& self, const py::dict& dict) {
void UpdateFromDict(py::object& self, const py::dict& dict) {
for (const auto& it : dict) {
if (!py::isinstance<py::str>(it.first)) {
const std::string msg = "Dictionary key is not a string: " +
Expand Down Expand Up @@ -125,34 +125,47 @@ inline void UpdateFromDict(py::object& self, const py::dict& dict) {
}
}

bool AttributeIsFunction(const std::string& name, const py::object& attribute) {
bool AttributeIsFunction(const std::string& name, const py::object& value) {
return (name.find("__") == 0 || name.rfind("__") != std::string::npos ||
py::hasattr(attribute, "__func__") ||
py::hasattr(attribute, "__call__"));
py::hasattr(value, "__func__") || py::hasattr(value, "__call__"));
}

template <typename T, typename... options>
inline py::dict ConvertToDict(const T& self) {
const auto pyself = py::cast(self);
py::dict dict;
std::vector<std::string> ListObjectAttributes(const py::object& pyself) {
std::vector<std::string> attributes;
for (const auto& handle : pyself.attr("__dir__")()) {
const py::str name = py::reinterpret_borrow<py::str>(handle);
const auto attribute = pyself.attr(name);
if (AttributeIsFunction(name, attribute)) {
const py::str attribute = py::reinterpret_borrow<py::str>(handle);
const auto value = pyself.attr(attribute);
if (AttributeIsFunction(attribute, value)) {
continue;
}
if (py::hasattr(attribute, "todict")) {
dict[name] =
attribute.attr("todict").attr("__call__")().template cast<py::dict>();
attributes.push_back(attribute);
}
return attributes;
}

template <typename T, typename... options>
py::dict ConvertToDict(const T& self,
std::vector<std::string> attributes,
const bool recursive) {
const py::object pyself = py::cast(self);
if (attributes.empty()) {
attributes = ListObjectAttributes(pyself);
}
py::dict dict;
for (const auto& attr : attributes) {
const auto value = pyself.attr(attr.c_str());
if (recursive && py::hasattr(value, "todict")) {
dict[attr.c_str()] =
value.attr("todict").attr("__call__")().template cast<py::dict>();
} else {
dict[name] = attribute;
dict[attr.c_str()] = value;
}
}
return dict;
}

template <typename T, typename... options>
inline std::string CreateSummary(const T& self, bool write_type) {
std::string CreateSummary(const T& self, bool write_type) {
std::stringstream ss;
auto pyself = py::cast(self);
const std::string prefix = " ";
Expand All @@ -175,7 +188,7 @@ inline std::string CreateSummary(const T& self, bool write_type) {
if (!after_subsummary) {
ss << prefix;
}
ss << attribute.template cast<std::string>();
ss << name.template cast<std::string>();
if (py::hasattr(attribute, "summary")) {
std::string summ = attribute.attr("summary")
.attr("__call__")(write_type)
Expand Down Expand Up @@ -230,25 +243,45 @@ void AddDefaultsToDocstrings(py::class_<T, options...> cls) {
}

template <typename T, typename... options>
inline void MakeDataclass(py::class_<T, options...> cls) {
void MakeDataclass(py::class_<T, options...> cls,
const std::vector<std::string>& attributes = {}) {
AddDefaultsToDocstrings(cls);
cls.def("mergedict", &UpdateFromDict);
if (!py::hasattr(cls, "summary")) {
cls.def("summary", &CreateSummary<T>, "write_type"_a = false);
}
cls.def("todict", &ConvertToDict<T>);
cls.def("mergedict", &UpdateFromDict);
cls.def(
"todict",
[attributes](const T& self, const bool recursive) {
return ConvertToDict(self, attributes, recursive);
},
"recursive"_a = true);

cls.def(py::init([cls](const py::dict& dict) {
auto self = py::object(cls());
py::object self = cls();
self.attr("mergedict").attr("__call__")(dict);
return self.cast<T>();
}));
cls.def(py::init([cls](const py::kwargs& kwargs) {
py::dict dict = kwargs.cast<py::dict>();
auto self = py::object(cls(dict));
return self.cast<T>();
return cls(dict).template cast<T>();
}));
py::implicitly_convertible<py::dict, T>();
py::implicitly_convertible<py::kwargs, T>();

cls.def("__copy__", [](const T& self) { return T(self); });
cls.def("__deepcopy__",
[](const T& self, const py::dict&) { return T(self); });

cls.def(py::pickle(
[attributes](const T& self) {
return ConvertToDict(self, attributes, /*recursive=*/false);
},
[cls](const py::dict& dict) {
py::object self = cls();
self.attr("mergedict").attr("__call__")(dict);
return self.cast<T>();
}));
}

// Catch python keyboard interrupts
Expand Down
11 changes: 7 additions & 4 deletions pycolmap/scene/camera.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ void BindCamera(py::module& m) {
"Rescale camera dimensions by given factor and accordingly the "
"focal length and\n"
"and the principal point.")
.def("__copy__", [](const Camera& self) { return Camera(self); })
.def("__deepcopy__",
[](const Camera& self, const py::dict&) { return Camera(self); })
.def("__repr__", &PrintCamera);
MakeDataclass(PyCamera);
MakeDataclass(PyCamera,
{"camera_id",
"model",
"width",
"height",
"params",
"has_prior_focal_length"});
}
9 changes: 3 additions & 6 deletions pycolmap/scene/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ void BindImage(py::module& m) {
&Image::IsRegistered,
&Image::SetRegistered,
"Whether image is registered in the reconstruction.")
.def_property_readonly("num_points2D",
&Image::NumPoints2D,
"Get the number of image points (keypoints).")
.def("num_points2D",
&Image::NumPoints2D,
"Get the number of image points (keypoints).")
.def_property_readonly(
"num_points3D",
&Image::NumPoints3D,
Expand Down Expand Up @@ -289,9 +289,6 @@ void BindImage(py::module& m) {
},
"Project list of image points (with depth) to world coordinate "
"frame.")
.def("__copy__", [](const Image& self) { return Image(self); })
.def("__deepcopy__",
[](const Image& self, const py::dict&) { return Image(self); })
.def("__repr__", &PrintImage);
MakeDataclass(PyImage);
}
3 changes: 0 additions & 3 deletions pycolmap/scene/point2D.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ void BindPoint2D(py::module& m) {
.def_readwrite("xy", &Point2D::xy)
.def_readwrite("point3D_id", &Point2D::point3D_id)
.def("has_point3D", &Point2D::HasPoint3D)
.def("__copy__", [](const Point2D& self) { return Point2D(self); })
.def("__deepcopy__",
[](const Point2D& self, const py::dict&) { return Point2D(self); })
.def("__repr__", &PrintPoint2D);
MakeDataclass(PyPoint2D);
}
3 changes: 0 additions & 3 deletions pycolmap/scene/point3D.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ void BindPoint3D(py::module& m) {
.def_readwrite("color", &Point3D::color)
.def_readwrite("error", &Point3D::error)
.def_readwrite("track", &Point3D::track)
.def("__copy__", [](const Point3D& self) { return Point3D(self); })
.def("__deepcopy__",
[](const Point3D& self, const py::dict&) { return Point3D(self); })
.def("__repr__", [](const Point3D& self) {
std::stringstream ss;
ss << "Point3D(xyz=[" << self.xyz.format(vec_fmt) << "], color=["
Expand Down
21 changes: 8 additions & 13 deletions pycolmap/scene/track.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "colmap/util/misc.h"
#include "colmap/util/types.h"

#include "pycolmap/helpers.h"
#include "pycolmap/log_exceptions.h"

#include <memory>
Expand All @@ -18,24 +19,20 @@ using namespace pybind11::literals;
namespace py = pybind11;

void BindTrack(py::module& m) {
py::class_<TrackElement, std::shared_ptr<TrackElement>>(m, "TrackElement")
.def(py::init<>())
py::class_<TrackElement, std::shared_ptr<TrackElement>> PyTrackElement(
m, "TrackElement");
PyTrackElement.def(py::init<>())
.def(py::init<image_t, point2D_t>())
.def_readwrite("image_id", &TrackElement::image_id)
.def_readwrite("point2D_idx", &TrackElement::point2D_idx)
.def("__copy__",
[](const TrackElement& self) { return TrackElement(self); })
.def("__deepcopy__",
[](const TrackElement& self, const py::dict&) {
return TrackElement(self);
})
.def("__repr__", [](const TrackElement& self) {
return "TrackElement(image_id=" + std::to_string(self.image_id) +
", point2D_idx=" + std::to_string(self.point2D_idx) + ")";
});
MakeDataclass(PyTrackElement);

py::class_<Track, std::shared_ptr<Track>>(m, "Track")
.def(py::init<>())
py::class_<Track, std::shared_ptr<Track>> PyTrack(m, "Track");
PyTrack.def(py::init<>())
.def(py::init([](const std::vector<TrackElement>& elements) {
auto track = std::make_shared<Track>();
track->AddElements(elements);
Expand Down Expand Up @@ -67,10 +64,8 @@ void BindTrack(py::module& m) {
py::overload_cast<const image_t, const point2D_t>(
&Track::DeleteElement),
"Remove TrackElement with (image_id,point2D_idx).")
.def("__copy__", [](const Track& self) { return Track(self); })
.def("__deepcopy__",
[](const Track& self, const py::dict&) { return Track(self); })
.def("__repr__", [](const Track& self) {
return "Track(length=" + std::to_string(self.Length()) + ")";
});
MakeDataclass(PyTrack);
}
Loading