Skip to content

Commit

Permalink
Add Enumeration Extension Features
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv authored Oct 25, 2023
1 parent 56e21b3 commit bd86eb9
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# - this is for builds-from-source
# - release builds are controlled by `misc/azure-release.yml`
# - this should be set to the current core release, not `dev`
TILEDB_VERSION = "2.17.1"
TILEDB_VERSION = "2.17.3"

# allow overriding w/ environment variable
TILEDB_VERSION = os.environ.get("TILEDB_VERSION") or TILEDB_VERSION
Expand Down
45 changes: 44 additions & 1 deletion tiledb/cc/enumeration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ void init_enumeration(py::module &m) {
py::class_<Enumeration>(m, "Enumeration")
.def(py::init<Enumeration>())

.def(py::init([](const Context &ctx, const std::string &name,
py::dtype type, bool ordered) {
tiledb_datatype_t data_type;
try {
data_type = np_to_tdb_dtype(type);
} catch (const TileDBPyError &e) {
throw py::type_error(e.what());
}
py::size_t cell_val_num = get_ncells(type);

return Enumeration::create_empty(ctx, name, data_type, cell_val_num,
ordered);
}))

.def(py::init([](const Context &ctx, const std::string &name,
std::vector<std::string> &values, bool ordered,
tiledb_datatype_t type) {
Expand Down Expand Up @@ -71,7 +85,36 @@ void init_enumeration(py::module &m) {
})

.def("str_values",
[](Enumeration &enmr) { return enmr.as_vector<std::string>(); });
[](Enumeration &enmr) { return enmr.as_vector<std::string>(); })

.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<int64_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<uint64_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<int32_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<uint32_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<int16_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<uint16_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<int8_t>)>(
&Enumeration::extend))
.def("extend",
static_cast<Enumeration (Enumeration::*)(std::vector<uint8_t>)>(
&Enumeration::extend))
.def(
"extend",
static_cast<Enumeration (Enumeration::*)(std::vector<std::string> &)>(
&Enumeration::extend));
}

} // namespace libtiledbcpp
25 changes: 20 additions & 5 deletions tiledb/enumeration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import io
from typing import Any, Optional, Sequence

Expand All @@ -16,7 +18,12 @@ class Enumeration(CtxMixin, lt.Enumeration):
"""

def __init__(
self, name: str, ordered: bool, values: Sequence[Any], ctx: Optional[Ctx] = None
self,
name: str,
ordered: bool,
values: Optional[Sequence[Any]] = None,
dtype: Optional[np.dtype] = None,
ctx: Optional[Ctx] = None,
):
"""Class representing the TileDB Enumeration.
Expand All @@ -29,6 +36,11 @@ def __init__(
:param ctx: A TileDB context
:type ctx: tiledb.Ctx
"""
if values is None or len(values) == 0:
if dtype is None:
raise ValueError("dtype must be provied for empty enumeration")
super().__init__(ctx, name, np.dtype(dtype), ordered)

values = np.array(values)
if np.dtype(values.dtype).kind in "US":
dtype = (
Expand Down Expand Up @@ -84,17 +96,20 @@ def values(self) -> NDArray:
else:
return super().values()

def extend(self, values: Sequence[Any]) -> Enumeration:
return Enumeration.from_pybind11(self._ctx, super().extend(values))

def __eq__(self, other):
if not isinstance(other, Enumeration):
return False

return any(
return all(
[
self.name == other.name,
self.dtype == other.dtype,
self.dtype == other.dtype,
self.dtype == other.dtype,
self.values() == other.values(),
self.cell_val_num == other.cell_val_num,
self.ordered == other.ordered,
np.array_equal(self.values(), other.values()),
]
)

Expand Down
12 changes: 12 additions & 0 deletions tiledb/schema_evolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ void init_schema_evolution(py::module &m) {
if (rc != TILEDB_OK) {
TPY_ERROR_LOC(get_last_ctx_err_str(inst.ctx_, rc));
}
})
.def("extend_enumeration",
[](ArraySchemaEvolution &inst, py::object enum_py) {
tiledb_enumeration_t *enum_c =
(py::capsule)enum_py.attr("__capsule__")();
if (enum_c == nullptr)
TPY_ERROR_LOC("Invalid Enumeration!");
int rc = tiledb_array_schema_evolution_extend_enumeration(
inst.ctx_, inst.evol_, enum_c);
if (rc != TILEDB_OK) {
TPY_ERROR_LOC(get_last_ctx_err_str(inst.ctx_, rc));
}
});
}

Expand Down
7 changes: 7 additions & 0 deletions tiledb/schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def drop_enumeration(self, enmr_name: str):

self.ase.drop_enumeration(enmr_name)

def extend_enumeration(self, enmr: Enumeration):
"""Extend the existing enumeration (by name) in the schema evolution.
Note: this function does not apply any changes; the changes are
only applied when `ArraySchemaEvolution.array_evolve` is called."""

self.ase.extend_enumeration(enmr)

def array_evolve(self, uri: str):
"""Apply ArraySchemaEvolution actions to Array at given URI."""

Expand Down
26 changes: 26 additions & 0 deletions tiledb/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,29 @@ def test_schema_evolution_with_enmr(tmp_path):

with tiledb.open(uri) as A:
assert not A.schema.has_attr("a3")


def test_schema_evolution_extend_enmr(tmp_path):
uri = str(tmp_path)
enmr = tiledb.Enumeration("e", True, dtype=str)
attrs = [tiledb.Attr(name="a", dtype=int, enum_label="e")]
domain = tiledb.Domain(tiledb.Dim(domain=(0, 3), dtype=np.uint64))
schema = tiledb.ArraySchema(domain=domain, attrs=attrs, enums=[enmr])
tiledb.Array.create(uri, schema)

with tiledb.open(uri) as A:
assert A.schema.has_attr("a")
assert A.attr("a").enum_label == "e"
assert A.enum("e") == enmr

se = tiledb.ArraySchemaEvolution()
with pytest.raises(tiledb.TileDBError):
enmr.extend([1, 2, 3])
updated_enmr = enmr.extend(["a", "b", "c"])
se.extend_enumeration(updated_enmr)
se.array_evolve(uri)

with tiledb.open(uri) as A:
assert A.schema.has_attr("a")
assert A.attr("a").enum_label == "e"
assert A.enum("e") == updated_enmr

0 comments on commit bd86eb9

Please sign in to comment.