Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: KeysView/ValuesView/ItemsView using Python types. Fix #4529 #4983

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,10 @@ struct handle_type_name {
static constexpr auto name = const_name<T>();
};
template <>
struct handle_type_name<object> {
static constexpr auto name = const_name("object");
};
template <>
struct handle_type_name<bool_> {
static constexpr auto name = const_name("bool");
};
Expand Down
174 changes: 150 additions & 24 deletions include/pybind11/stl_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "operators.h"

#include <algorithm>
#include <complex>
#include <functional>
#include <sstream>
#include <type_traits>

Expand Down Expand Up @@ -483,6 +485,137 @@ void vector_buffer(Class_ &cl) {
cl, detail::any_of<std::is_same<Args, buffer_protocol>...>{});
}

// Issue #3986 and #4529: map C++ types to Python types with typing strings
template <typename T, typename SFINAE = void>
struct type_mapper {
using py_type = T;
static std::string py_name() { return detail::type_info_description(typeid(T)); }
};

template <>
struct type_mapper<std::nullptr_t> {
using py_type = pybind11::none;
static std::string py_name() {
constexpr auto descr = const_name("None");
return descr.text;
}
};

template <>
struct type_mapper<bool> {
using py_type = pybind11::bool_;
static std::string py_name() {
constexpr auto descr = const_name("bool");
return descr.text;
}
};

template <typename T>
struct type_mapper<T, enable_if_t<std::is_arithmetic<T>::value && !is_std_char_type<T>::value>> {
using py_type
= conditional_t<std::is_floating_point<T>::value, pybind11::float_, pybind11::int_>;
static std::string py_name() {
constexpr auto descr = const_name<std::is_integral<T>::value>("int", "float");
return descr.text;
}
};

template <typename T>
struct type_mapper<std::complex<T>> {
using py_type = std::complex<typename type_mapper<T>::py_type>;
static std::string py_name() {
constexpr auto descr = const_name("complex");
return descr.text;
}
};

template <typename T>
struct type_mapper<T, enable_if_t<is_std_char_type<T>::value>> {
using py_type = pybind11::str;
static std::string py_name() {
constexpr auto descr = const_name(PYBIND11_STRING_NAME);
return descr.text;
}
};

template <typename T>
struct type_mapper<T, enable_if_t<is_pyobject<T>::value>> {
using py_type = T;
static std::string py_name() {
constexpr auto descr = handle_type_name<T>::name;
return descr.text;
}
};

template <typename T>
struct type_mapper<std::shared_ptr<T>> : public type_mapper<T> {};

template <typename T, typename Deleter>
struct type_mapper<std::unique_ptr<T, Deleter>> : public type_mapper<T> {};

template <typename CharT, typename Traits, typename Allocator>
struct type_mapper<std::basic_string<CharT, Traits, Allocator>,
enable_if_t<is_std_char_type<CharT>::value>> {
using py_type = pybind11::str;
static std::string py_name() {
constexpr auto descr = const_name(PYBIND11_STRING_NAME);
return descr.text;
}
};

#ifdef PYBIND11_HAS_STRING_VIEW
template <typename CharT, typename Traits>
struct type_mapper<std::basic_string_view<CharT, Traits>,
enable_if_t<is_std_char_type<CharT>::value>> {
using py_type = pybind11::str;
static std::string py_name() {
constexpr auto descr = const_name(PYBIND11_STRING_NAME);
return descr.text;
}
};
#endif

template <typename T1, typename T2>
struct type_mapper<std::pair<T1, T2>> {
using py_type
= std::tuple<typename type_mapper<T1>::py_type, typename type_mapper<T1>::py_type>;
static std::string py_name() {
return "tuple[" + type_mapper<T1>::py_name() + ", " + type_mapper<T2>::py_name() + "]";
}
};

template <typename... Ts>
struct type_mapper<std::tuple<Ts...>> {
using py_type = std::tuple<typename type_mapper<Ts>::py_type...>;
static std::string py_name() {
std::vector<std::string> names = {type_mapper<Ts>::py_name()...};
std::ostringstream s;
s << "tuple[";
for (size_t i = 0; i < names.size(); ++i) {
s << (i != 0 ? ", " : "") << names[i];
}
s << "]";
return s.str();
}
};

template <typename Return, typename... Args>
struct type_mapper<std::function<Return(Args...)>> {
using retval_type = conditional_t<std::is_same<Return, void>::value, std::nullptr_t, Return>;
using py_type = std::function<typename type_mapper<retval_type>::py_type(
typename type_mapper<Args>::py_type...)>;
static std::string py_name() {
std::vector<std::string> names = {type_mapper<Args>::py_name()...};
std::ostringstream s;
s << "Callable[[";
for (size_t i = 0; i < names.size(); ++i) {
s << (i != 0 ? ", " : "") << names[i];
}
s << "], " << type_mapper<retval_type>::py_name() << "]";
return s.str();
}
};

PYBIND11_NAMESPACE_END(detail)

//
Expand Down Expand Up @@ -649,8 +782,7 @@ template <typename KeyType>
struct keys_view {
virtual size_t len() = 0;
virtual iterator iter() = 0;
virtual bool contains(const KeyType &k) = 0;
virtual bool contains(const object &k) = 0;
virtual bool contains(const handle &k) = 0;
virtual ~keys_view() = default;
};

Expand All @@ -673,8 +805,13 @@ struct KeysViewImpl : public KeysView {
explicit KeysViewImpl(Map &map) : map(map) {}
size_t len() override { return map.size(); }
iterator iter() override { return make_key_iterator(map.begin(), map.end()); }
bool contains(const typename Map::key_type &k) override { return map.find(k) != map.end(); }
bool contains(const object &) override { return false; }
bool contains(const handle &k) override {
try {
return map.find(k.template cast<typename Map::key_type>()) != map.end();
} catch (const cast_error &) {
return false;
}
}
Map &map;
};

Expand Down Expand Up @@ -702,9 +839,11 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
using MappedType = typename Map::mapped_type;
using StrippedKeyType = detail::remove_cvref_t<KeyType>;
using StrippedMappedType = detail::remove_cvref_t<MappedType>;
using KeysView = detail::keys_view<StrippedKeyType>;
using ValuesView = detail::values_view<StrippedMappedType>;
using ItemsView = detail::items_view<StrippedKeyType, StrippedMappedType>;
using PyKeyType = typename detail::type_mapper<StrippedKeyType>::py_type;
using PyMappedType = typename detail::type_mapper<StrippedMappedType>::py_type;
using KeysView = detail::keys_view<PyKeyType>;
using ValuesView = detail::values_view<PyMappedType>;
using ItemsView = detail::items_view<PyKeyType, PyMappedType>;
using Class_ = class_<Map, holder_type>;

// If either type is a non-module-local bound type then make the map binding non-local as well;
Expand All @@ -718,20 +857,10 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
}

Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
static constexpr auto key_type_descr = detail::make_caster<KeyType>::name;
static constexpr auto mapped_type_descr = detail::make_caster<MappedType>::name;
std::string key_type_name(key_type_descr.text), mapped_type_name(mapped_type_descr.text);

// If key type isn't properly wrapped, fall back to C++ names
if (key_type_name == "%") {
key_type_name = detail::type_info_description(typeid(KeyType));
}
// Similarly for value type:
if (mapped_type_name == "%") {
mapped_type_name = detail::type_info_description(typeid(MappedType));
}
std::string key_type_name = detail::type_mapper<StrippedKeyType>::py_name();
std::string mapped_type_name = detail::type_mapper<StrippedMappedType>::py_name();

// Wrap KeysView[KeyType] if it wasn't already wrapped
// Wrap KeysView[PyKeyType] if it wasn't already wrapped
if (!detail::get_type_info(typeid(KeysView))) {
class_<KeysView> keys_view(
scope, ("KeysView[" + key_type_name + "]").c_str(), pybind11::module_local(local));
Expand All @@ -741,10 +870,7 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
);
keys_view.def("__contains__",
static_cast<bool (KeysView::*)(const KeyType &)>(&KeysView::contains));
// Fallback for when the object is not of the key type
keys_view.def("__contains__",
static_cast<bool (KeysView::*)(const object &)>(&KeysView::contains));
static_cast<bool (KeysView::*)(const handle &)>(&KeysView::contains));
}
// Similarly for ValuesView:
if (!detail::get_type_info(typeid(ValuesView))) {
Expand Down
29 changes: 29 additions & 0 deletions tests/test_stl_binders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,35 @@ TEST_SUBMODULE(stl_binders, m) {
py::bind_map<std::map<std::string, double>>(m, "MapStringDouble");
py::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble");

// test_map_view_types
py::bind_map<std::map<std::string, float>>(m, "MapStringFloat");
py::bind_map<std::unordered_map<std::string, float>>(m, "UnorderedMapStringFloat");
py::bind_map<std::map<int16_t, double>>(m, "MapInt16Double");
py::bind_map<std::map<int32_t, double>>(m, "MapInt32Double");
py::bind_map<std::map<int64_t, double>>(m, "MapInt64Double");
py::bind_map<std::map<uint64_t, double>>(m, "MapUInt64Double");
py::bind_map<std::map<std::pair<short, short>, double>>(m, "MapPairShortShortDouble");
py::bind_map<std::map<std::pair<short, long>, std::complex<float>>>(
m, "MapPairShortLongComplexFloat");
py::bind_map<std::map<std::pair<long, short>, std::complex<double>>>(
m, "MapPairLongShortComplexDouble");
py::bind_map<std::map<std::tuple<long, long>, std::complex<double>>>(
m, "MapTupleLongLongComplexDouble");
py::bind_map<std::map<char, std::function<float(int, float)>>>(m,
"MapCharFunctionFloatIntFloat");
py::bind_map<std::map<std::string, std::function<double(long, double)>>>(
m, "MapStringFunctionDoubleLongDouble");
py::bind_map<std::map<std::string, std::function<void(long, double)>>>(
m, "MapStringFunctionVoidLongDouble");
py::bind_map<std::map<std::string, std::nullptr_t>>(m, "MapStringNone");

py::bind_map<std::map<int, std::pair<std::map<int, int>, int>>>(m, "MapIntMapIntIntInt");
py::bind_map<std::map<int, std::pair<std::map<int, int>, long>>>(m, "MapIntMapIntIntLong");
py::bind_map<std::map<int, std::pair<std::map<long, int>, long>>>(m, "MapIntMapLongIntLong");

py::bind_map<std::map<pybind11::int_, int>>(m, "MapPyIntInt");
py::bind_map<std::map<pybind11::int_, pybind11::int_>>(m, "MapPyIntPyInt");

// test_map_string_double_const
py::bind_map<std::map<std::string, double const>>(m, "MapStringDoubleConst");
py::bind_map<std::unordered_map<std::string, double const>>(m,
Expand Down
Loading