Skip to content

Commit

Permalink
Move Query to pure Python (#2106)
Browse files Browse the repository at this point in the history
* Move Query to pure Python
* Add type hints, add docstring, improve code clarity
  • Loading branch information
kounelisagis committed Dec 20, 2024
1 parent 66c401d commit 35ca257
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 292 deletions.
2 changes: 1 addition & 1 deletion doc/source/python-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Sparse Array
Query
---------------

.. autoclass:: tiledb.libtiledb.Query
.. autoclass:: tiledb.Query
:members:

Query Condition
Expand Down
2 changes: 1 addition & 1 deletion tiledb/cc/query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void init_query(py::module &m) {
q.set_validity_buffer(name, (uint8_t *)(a.data()), a.size());
})

.def("submit", &Query::submit, py::call_guard<py::gil_scoped_release>())
.def("_submit", &Query::submit, py::call_guard<py::gil_scoped_release>())

/** hackery from another branch... */
//.def("set_fragment_uri", &Query::set_fragment_uri)
Expand Down
6 changes: 3 additions & 3 deletions tiledb/dense_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
replace_ellipsis,
replace_scalars_slice,
)
from .libtiledb import Query
from .query import Query
from .subarray import Subarray


Expand Down Expand Up @@ -173,7 +173,7 @@ def query(
attrs=attrs,
cond=cond,
dims=dims,
coords=coords,
has_coords=coords,
order=order,
use_arrow=use_arrow,
return_arrow=return_arrow,
Expand Down Expand Up @@ -666,7 +666,7 @@ def read_subarray(self, subarray):
if has_labels:
label_query = Query(self, self.ctx)
label_query.set_subarray(subarray)
label_query.submit()
label_query._submit()
if not label_query.is_complete():
raise tiledb.TileDBError("Failed to get dimension ranges from labels")
result_subarray = Subarray(self, self.ctx)
Expand Down
2 changes: 1 addition & 1 deletion tiledb/domain_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __getitem__(self, idx):
self.query.attrs if self.query.attrs else attr_names
) # query.attrs might be None -> all
attr_cond = self.query.attr_cond
coords = self.query.coords
coords = self.query.has_coords

if coords:
attr_names = [
Expand Down
1 change: 0 additions & 1 deletion tiledb/libtiledb.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,6 @@ cdef extern from "tiledb/tiledb_experimental.h":
)

# Free helper functions
cpdef unicode ustring(object s)
cdef _raise_tiledb_error(tiledb_error_t* err_ptr)
cdef _raise_ctx_err(tiledb_ctx_t* ctx_ptr, int rc)

Expand Down
269 changes: 0 additions & 269 deletions tiledb/libtiledb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,9 @@ from cpython.version cimport PY_MAJOR_VERSION

include "common.pxi"
include "indexing.pyx"
import collections.abc
from json import loads as json_loads

from .cc import TileDBError
from .ctx import Config, Ctx, default_ctx
from .domain_indexer import DomainIndexer
from .vfs import VFS
from .sparse_array import SparseArrayImpl
from .dense_array import DenseArrayImpl
from .array import Array

###############################################################################
Expand Down Expand Up @@ -349,269 +343,6 @@ cdef _raise_ctx_err(tiledb_ctx_t* ctx_ptr, int rc):
_raise_tiledb_error(err_ptr)


cpdef unicode ustring(object s):
"""Coerce a python object to a unicode string"""

if type(s) is unicode:
return <unicode> s
elif PY_MAJOR_VERSION < 3 and isinstance(s, bytes):
return (<bytes> s).decode('ascii')
elif isinstance(s, unicode):
return unicode(s)
raise TypeError(
"ustring() must be a string or a bytes-like object"
", not {0!r}".format(type(s)))


cdef bytes unicode_path(object path):
"""Returns a UTF-8 encoded byte representation of a given URI path string"""
return ustring(path).encode('UTF-8')


###############################################################################
# #
# CLASS DEFINITIONS #
# #
###############################################################################


cdef class Query(object):
"""
Proxy object returned by query() to index into original array
on a subselection of attribute in a defined layout order
See documentation of Array.query
"""

def __init__(self, array, attrs=None, cond=None, dims=None,
coords=False, index_col=True, order=None,
use_arrow=None, return_arrow=False, return_incomplete=False):
if array.mode not in ('r', 'd'):
raise ValueError("array mode must be read or delete mode")

if dims is not None and coords == True:
raise ValueError("Cannot pass both dims and coords=True to Query")

cdef list dims_to_set = list()

if dims is False:
self.dims = False
elif dims != None and dims != True:
domain = array.schema.domain
for dname in dims:
if not domain.has_dim(dname):
raise TileDBError(f"Selected dimension does not exist: '{dname}'")
self.dims = [unicode(dname) for dname in dims]
elif coords == True or dims == True:
domain = array.schema.domain
self.dims = [domain.dim(i).name for i in range(domain.ndim)]

if attrs is not None:
for name in attrs:
if not array.schema.has_attr(name):
raise TileDBError(f"Selected attribute does not exist: '{name}'")
self.attrs = attrs
self.cond = cond

if order == None:
if array.schema.sparse:
self.order = 'U' # unordered
else:
self.order = 'C' # row-major
else:
self.order = order

# reference to the array we are querying
self.array = array
self.coords = coords
self.index_col = index_col
self.return_arrow = return_arrow
if return_arrow:
if use_arrow is None:
use_arrow = True
if not use_arrow:
raise TileDBError("Cannot initialize return_arrow with use_arrow=False")
self.use_arrow = use_arrow

if return_incomplete and not array.schema.sparse:
raise TileDBError("Incomplete queries are only supported for sparse arrays at this time")

self.return_incomplete = return_incomplete

self.domain_index = DomainIndexer(array, query=self)

def __getitem__(self, object selection):
if self.return_arrow:
raise TileDBError("`return_arrow=True` requires .df indexer`")

return self.array.subarray(selection,
attrs=self.attrs,
cond=self.cond,
coords=self.coords if self.coords else self.dims,
order=self.order)

def agg(self, aggs):
"""
Calculate an aggregate operation for a given attribute. Available
operations are sum, min, max, mean, count, and null_count (for nullable
attributes only). Aggregates may be combined with other query operations
such as query conditions and slicing.
The input may be a single operation, a list of operations, or a
dictionary with attribute mapping to a single operation or list of
operations.
For undefined operations on max and min, which can occur when a nullable
attribute contains only nulled data at the given coordinates or when
there is no data read for the given query (e.g. query conditions that do
not match any values or coordinates that contain no data)), invalid
results are represented as np.nan for attributes of floating point types
and None for integer types.
>>> import tiledb, tempfile, numpy as np
>>> path = tempfile.mkdtemp()
>>> with tiledb.from_numpy(path, np.arange(1, 10)) as A:
... pass
>>> # Note that tiledb.from_numpy creates anonymous attributes, so the
>>> # name of the attribute is represented as an empty string
>>> with tiledb.open(path, 'r') as A:
... A.query().agg("sum")[:]
45
>>> with tiledb.open(path, 'r') as A:
... A.query(cond="attr('') < 5").agg(["count", "mean"])[:]
{'count': 9, 'mean': 2.5}
>>> with tiledb.open(path, 'r') as A:
... A.query().agg({"": ["max", "min"]})[2:7]
{'max': 7, 'min': 3}
:param agg: The input attributes and operations to apply aggregations on
:returns: single value for single operation on one attribute, a dictionary
of attribute keys associated with a single value for a single operation
across multiple attributes, or a dictionary of attribute keys that maps
to a dictionary of operation labels with the associated value
"""
schema = self.array.schema
attr_to_aggs_map = {}
if isinstance(aggs, dict):
attr_to_aggs_map = {
a: (
tuple([aggs[a]])
if isinstance(aggs[a], str)
else tuple(aggs[a])
)
for a in aggs
}
elif isinstance(aggs, str):
attrs = tuple(schema.attr(i).name for i in range(schema.nattr))
attr_to_aggs_map = {a: (aggs,) for a in attrs}
elif isinstance(aggs, collections.abc.Sequence):
attrs = tuple(schema.attr(i).name for i in range(schema.nattr))
attr_to_aggs_map = {a: tuple(aggs) for a in attrs}

from .aggregation import Aggregation
return Aggregation(self, attr_to_aggs_map)

@property
def array(self):
return self.array

@property
def attrs(self):
"""List of attributes to include in Query."""
return self.attrs

@property
def cond(self):
"""QueryCondition used to filter attributes or dimensions in Query."""
return self.cond

@property
def dims(self):
"""List of dimensions to include in Query."""
return self.dims

@property
def coords(self):
"""
True if query should include (return) coordinate values.
:rtype: bool
"""
return self.coords

@property
def order(self):
"""Return underlying Array order."""
return self.order

@property
def index_col(self):
"""List of columns to set as index for dataframe queries, or None."""
return self.index_col

@property
def use_arrow(self):
return self.use_arrow

@property
def return_arrow(self):
return self.return_arrow

@property
def return_incomplete(self):
return self.return_incomplete

@property
def domain_index(self):
"""Apply Array.domain_index with query parameters."""
return self.domain_index

def label_index(self, labels):
"""Apply Array.label_index with query parameters."""
from .multirange_indexing import LabelIndexer
return LabelIndexer(self.array, tuple(labels), query=self)

@property
def multi_index(self):
"""Apply Array.multi_index with query parameters."""
# Delayed to avoid circular import
from .multirange_indexing import MultiRangeIndexer
return MultiRangeIndexer(self.array, query=self)

@property
def df(self):
"""Apply Array.multi_index with query parameters and return result
as a Pandas dataframe."""
# Delayed to avoid circular import
from .multirange_indexing import DataFrameIndexer
return DataFrameIndexer(self.array, query=self, use_arrow=self.use_arrow)

def get_stats(self, print_out=True, json=False):
"""Retrieves the stats from a TileDB query.
:param print_out: Print string to console (default True), or return as string
:param json: Return stats JSON object (default: False)
"""
pyquery = self.array.pyquery
if pyquery is None:
return ""
stats = self.array.pyquery.get_stats()
if json:
stats = json_loads(stats)
if print_out:
print(stats)
else:
return stats

def submit(self):
"""An alias for calling the regular indexer [:]"""
return self[:]

def write_direct_dense(self: Array, np.ndarray array not None, **kw):
"""
Write directly to given array attribute with minimal checks,
Expand Down
8 changes: 4 additions & 4 deletions tiledb/multirange_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from .array_schema import ArraySchema
from .cc import TileDBError
from .dataframe_ import check_dataframe_deps
from .libtiledb import Query as QueryProxy
from .main import PyAgg, PyQuery, increment_stat, use_stats
from .metadata import Metadata
from .query import Query
from .query import Query as QueryProxy
from .query_condition import QueryCondition
from .subarray import Subarray

Expand Down Expand Up @@ -422,7 +422,7 @@ def __init__(
check_dataframe_deps()
# we need to use a Query in order to get coords for a dense array
if not query:
query = QueryProxy(array, coords=True)
query = QueryProxy(array, has_coords=True)
use_arrow = (
bool(importlib.util.find_spec("pyarrow"))
if use_arrow is None
Expand Down Expand Up @@ -586,7 +586,7 @@ def _run_query(self) -> Dict[str, np.ndarray]:
# If querying by label and the label query is not yet complete, run the label
# query and update the pyquery with the actual dimensions.
if self.label_query is not None and not self.label_query.is_complete():
self.label_query.submit()
self.label_query._submit()

if not self.label_query.is_complete():
raise TileDBError("failed to get dimension ranges from labels")
Expand Down Expand Up @@ -687,7 +687,7 @@ def _iter_dim_names(
if query is not None:
if query.dims is not None:
return iter(query.dims or ())
if query.coords is False:
if query.has_coords is False:
return iter(())
if not schema.sparse:
return iter(())
Expand Down
Loading

0 comments on commit 35ca257

Please sign in to comment.