Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change implementation of the __init__() must be called when overriding __init__ safety feature to work for any metaclass. #30095

Merged
merged 16 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions include/pybind11/detail/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "../attr.h"
#include "../options.h"

#include <cassert>
#include <unordered_map>

PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)

Expand Down Expand Up @@ -179,6 +182,36 @@ extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name
return PyType_Type.tp_getattro(obj, name);
}

// Ensure that the base __init__ function(s) were called.
// Set TypeError and return false if not.
// CALLER IS RESPONSIBLE for managing the self refcount appropriately.
inline bool ensure_base_init_functions_were_called(PyObject *self) {
values_and_holders vhs(self);
for (const auto &vh : vhs) {
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
PyErr_Format(PyExc_TypeError,
"%.200s.__init__() must be called when overriding __init__",
get_fully_qualified_tp_name(vh.type->type).c_str());
return false;
}
}
return true;
}

// See google/pywrapcc#30095 for background.
#if !defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT) \
&& !defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
# if !defined(PYPY_VERSION)
// With CPython the safety checks work for any metaclass.
// However, with PyPy this implementation does not work at all.
# define PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
# else
// With this the safety checks work only for the default `py::metaclass()`.
# define PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS
# endif
#endif

#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
/// metaclass `__call__` function that is used to create all pybind11 objects.
extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, PyObject *kwargs) {

Expand All @@ -188,20 +221,14 @@ extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, P
return nullptr;
}

// Ensure that the base __init__ function(s) were called
values_and_holders vhs(self);
for (const auto &vh : vhs) {
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
PyErr_Format(PyExc_TypeError,
"%.200s.__init__() must be called when overriding __init__",
get_fully_qualified_tp_name(vh.type->type).c_str());
Py_DECREF(self);
return nullptr;
}
if (!ensure_base_init_functions_were_called(self)) {
Py_DECREF(self);
return nullptr;
}

return self;
}
#endif

/// Cleanup the type-info for a pybind11-registered type.
extern "C" inline void pybind11_meta_dealloc(PyObject *obj) {
Expand Down Expand Up @@ -268,7 +295,9 @@ inline PyTypeObject *make_default_metaclass() {
type->tp_base = type_incref(&PyType_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;

#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
type->tp_call = pybind11_meta_call;
#endif

type->tp_setattro = pybind11_meta_setattro;
type->tp_getattro = pybind11_meta_getattro;
Expand Down Expand Up @@ -340,6 +369,33 @@ inline bool deregister_instance(instance *self, void *valptr, const type_info *t
return ret;
}

#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT)

using derived_tp_init_registry_type = std::unordered_map<PyTypeObject *, initproc>;

inline derived_tp_init_registry_type *derived_tp_init_registry() {
// Intentionally leak the unordered_map:
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
static auto *singleton = new derived_tp_init_registry_type();
return singleton;
}

extern "C" inline int tp_init_with_safety_checks(PyObject *self, PyObject *args, PyObject *kw) {
assert(PyType_Check(self) == 0);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this condition guaranteed? Looks like we don't ensure it when intercepting tp_init, based on my initial read of the pybind11object_new method.

PS: if it was not for consisitency with pyclif naming, I would recommend calling this safe_tp_init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this condition guaranteed?

It's not. I put the assert() here only to be sure my code isn't doing something unexpected. This PR is already TGP tested, which makes me believe it is very unlikely that the assert() will ever fire, but if it does, I'll have a concrete situation to fix (as opposed to anticipating/guessing).

PS: if it was not for consisitency with pyclif naming,

Name changed to tp_init_with_safety_checks.

(After this PR is merged I'll go back and change the PyCLIF code accordingly. I also want to backport the weakref-based cleanup.)

const auto derived_tp_init = derived_tp_init_registry()->find(Py_TYPE(self));
if (derived_tp_init == derived_tp_init_registry()->end()) {
pybind11_fail("FATAL: Internal consistency check failed at " __FILE__
":" PYBIND11_TOSTRING(__LINE__));
}
int status = (*derived_tp_init->second)(self, args, kw);
if (status == 0 && !ensure_base_init_functions_were_called(self)) {
return -1; // No Py_DECREF here.
}
return status;
}

#endif // PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT

/// Instance creation function for all pybind11 types. It allocates the internal instance layout
/// for holding C++ objects and holders. Allocation is done lazily (the first time the instance is
/// cast to a reference or pointer), and initialization is done by an `__init__` function.
Expand All @@ -360,11 +416,7 @@ inline PyObject *make_new_instance(PyTypeObject *type) {
return self;
}

/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
return make_new_instance(type);
}
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *);

/// An `__init__` function constructs the C++ object. Users should provide at least one
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
Expand Down
27 changes: 27 additions & 0 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,33 @@ class cpp_function : public function {
}
};

PYBIND11_NAMESPACE_BEGIN(detail)

/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the ifdef for pybind11_meta_call should also be in the body like here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided on chat: leaving as is.

if (type->tp_init != pybind11_object_init && type->tp_init != tp_init_with_safety_checks
&& derived_tp_init_registry()->count(type) == 0) {
weakref((PyObject *) type, cpp_function([type](handle wr) {
auto num_erased = derived_tp_init_registry()->erase(type);
if (num_erased != 1) {
pybind11_fail("FATAL: Internal consistency check failed at " __FILE__
":" PYBIND11_TOSTRING(__LINE__) ": num_erased="
+ std::to_string(num_erased));
}
wr.dec_ref();
}))
.release();
(*derived_tp_init_registry())[type] = type->tp_init;
type->tp_init = tp_init_with_safety_checks;
}
#endif // PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
return make_new_instance(type);
}

PYBIND11_NAMESPACE_END(detail)

/// Wrapper for Python extension modules
class module_ : public object {
public:
Expand Down
46 changes: 30 additions & 16 deletions tests/test_python_multiple_inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace test_python_multiple_inheritance {
// Copied from:
// https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python_multiple_inheritance.h

template <int> // Using int as a trick to easily generate a series of types.
struct CppBase {
explicit CppBase(int value) : base_value(value) {}
int get_base_value() const { return base_value; }
Expand All @@ -14,32 +15,45 @@ struct CppBase {
int base_value;
};

struct CppDrvd : CppBase {
explicit CppDrvd(int value) : CppBase(value), drvd_value(value * 3) {}
template <int SerNo>
struct CppDrvd : CppBase<SerNo> {
explicit CppDrvd(int value) : CppBase<SerNo>(value), drvd_value(value * 3) {}
int get_drvd_value() const { return drvd_value; }
void reset_drvd_value(int new_value) { drvd_value = new_value; }

int get_base_value_from_drvd() const { return get_base_value(); }
void reset_base_value_from_drvd(int new_value) { reset_base_value(new_value); }
int get_base_value_from_drvd() const { return CppBase<SerNo>::get_base_value(); }
void reset_base_value_from_drvd(int new_value) { CppBase<SerNo>::reset_base_value(new_value); }

private:
int drvd_value;
};

template <int SerNo, typename... Extra>
void wrap_classes(py::module_ &m, const char *name_base, const char *name_drvd, Extra... extra) {
py::class_<CppBase<SerNo>>(m, name_base, std::forward<Extra>(extra)...)
.def(py::init<int>())
.def("get_base_value", &CppBase<SerNo>::get_base_value)
.def("reset_base_value", &CppBase<SerNo>::reset_base_value);

py::class_<CppDrvd<SerNo>, CppBase<SerNo>>(m, name_drvd, std::forward<Extra>(extra)...)
.def(py::init<int>())
.def("get_drvd_value", &CppDrvd<SerNo>::get_drvd_value)
.def("reset_drvd_value", &CppDrvd<SerNo>::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd<SerNo>::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd<SerNo>::reset_base_value_from_drvd);
}

} // namespace test_python_multiple_inheritance

TEST_SUBMODULE(python_multiple_inheritance, m) {
using namespace test_python_multiple_inheritance;

py::class_<CppBase>(m, "CppBase")
.def(py::init<int>())
.def("get_base_value", &CppBase::get_base_value)
.def("reset_base_value", &CppBase::reset_base_value);

py::class_<CppDrvd, CppBase>(m, "CppDrvd")
.def(py::init<int>())
.def("get_drvd_value", &CppDrvd::get_drvd_value)
.def("reset_drvd_value", &CppDrvd::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd);
wrap_classes<0>(m, "CppBase0", "CppDrvd0");
wrap_classes<1>(m, "CppBase1", "CppDrvd1", py::metaclass((PyObject *) &PyType_Type));

m.attr("if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS") =
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
true;
#else
false;
#endif
}
117 changes: 110 additions & 7 deletions tests/test_python_multiple_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,78 @@
# Adapted from:
# https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py
# https://github.com/google/clif/blob/7d388e1de7db5beeb3d7429c18a2776d8188f44f/clif/testing/python/python_multiple_inheritance_test.py

import pytest

from pybind11_tests import python_multiple_inheritance as m

#
# Using default py::metaclass() (used with py::class_<> for CppBase0, CppDrvd0):
#


class PC0(m.CppBase0):
pass


class PPCC0(PC0, m.CppDrvd0):
pass


class PCExplicitInitWithSuper0(m.CppBase0):
def __init__(self, value):
super().__init__(value + 1)


class PCExplicitInitMissingSuper0(m.CppBase0):
def __init__(self, value):
del value


class PCExplicitInitMissingSuperB0(m.CppBase0):
def __init__(self, value):
del value


#
# Using py::metaclass((PyObject *) &PyType_Type) (used with py::class_<> for CppBase1, CppDrvd1):
# COPY-PASTE block from above, replace 0 with 1:
#

class PC(m.CppBase):

class PC1(m.CppBase1):
pass


class PPCC(PC, m.CppDrvd):
class PPCC1(PC1, m.CppDrvd1):
pass


def test_PC():
d = PC(11)
class PCExplicitInitWithSuper1(m.CppBase1):
def __init__(self, value):
super().__init__(value + 1)


class PCExplicitInitMissingSuper1(m.CppBase1):
def __init__(self, value):
del value


class PCExplicitInitMissingSuperB1(m.CppBase1):
def __init__(self, value):
del value


@pytest.mark.parametrize(("pc_type"), [PC0, PC1])
def test_PC(pc_type):
d = pc_type(11)
assert d.get_base_value() == 11
d.reset_base_value(13)
assert d.get_base_value() == 13


def test_PPCC():
d = PPCC(11)
@pytest.mark.parametrize(("ppcc_type"), [PPCC0, PPCC1])
def test_PPCC(ppcc_type):
d = ppcc_type(11)
assert d.get_drvd_value() == 33
d.reset_drvd_value(55)
assert d.get_drvd_value() == 55
Expand All @@ -33,3 +85,54 @@ def test_PPCC():
d.reset_base_value_from_drvd(30)
assert d.get_base_value() == 30
assert d.get_base_value_from_drvd() == 30


@pytest.mark.parametrize(
("pc_type"), [PCExplicitInitWithSuper0, PCExplicitInitWithSuper1]
)
def testPCExplicitInitWithSuper(pc_type):
d = pc_type(14)
assert d.get_base_value() == 15


@pytest.mark.parametrize(
("derived_type"),
[
PCExplicitInitMissingSuper0,
PCExplicitInitMissingSuperB0,
PCExplicitInitMissingSuper1,
PCExplicitInitMissingSuperB1,
],
)
def testPCExplicitInitMissingSuper(derived_type):
if (
m.if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS
and derived_type
in (
PCExplicitInitMissingSuper1,
PCExplicitInitMissingSuperB1,
)
):
pytest.skip(
"PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS is defined"
)
with pytest.raises(TypeError) as excinfo:
derived_type(0)
assert str(excinfo.value).endswith(
".__init__() must be called when overriding __init__"
)


def test_derived_tp_init_registry_weakref_based_cleanup():
def nested_function(i):
class NestedClass(m.CppBase0):
def __init__(self, value):
super().__init__(value + 3)

d1 = NestedClass(i + 7)
d2 = NestedClass(i + 8)
return (d1.get_base_value(), d2.get_base_value())

for _ in range(100):
assert nested_function(0) == (10, 11)
assert nested_function(3) == (13, 14)
Loading