Skip to content

Commit

Permalink
Add missing super().__init__() involving types wrapped in xla/pytho…
Browse files Browse the repository at this point in the history
…n/sharding.cc

This change is to unblock google/pybind11clif#30095.

Leaving wrapped C++ types uninitialized creates a potential for triggering undefined behavior from Python.

PiperOrigin-RevId: 602787434
  • Loading branch information
Ralf W. Grosse-Kunstleve authored and copybara-github committed Jan 30, 2024
1 parent 114b25a commit 9bef51f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions xla/python/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,14 @@ void RegisterSharding(py::module& m) {
py::object abc_init = abc_module.attr("_abc_init");

// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<Sharding>(m, "Sharding", py::metaclass(abc_meta));
py::class_<Sharding>(m, "Sharding", py::metaclass(abc_meta))
.def(py::init<>());
abc_init(py::type::of<Sharding>());

// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<XLACompatibleSharding, Sharding>(m, "XLACompatibleSharding",
py::metaclass(abc_meta));
py::metaclass(abc_meta))
.def(py::init<>());
abc_init(py::type::of<XLACompatibleSharding>());

py::class_<NamedSharding, XLACompatibleSharding>(m, "NamedSharding",
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 234
_version = 235

# Version number for MLIR:Python components.
mlir_api_version = 55
Expand Down

0 comments on commit 9bef51f

Please sign in to comment.