Skip to content

Commit

Permalink
MOD: Improve performance of DBNStore.to_ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
nmacholl committed Nov 10, 2023
1 parent 4048e4c commit e53d3ca
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 23 deletions.
137 changes: 118 additions & 19 deletions databento/common/dbnstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, BinaryIO, Callable, Literal, overload
from typing import (
IO,
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Literal,
Protocol,
overload,
)

import databento_dbn
import numpy as np
Expand Down Expand Up @@ -1072,20 +1081,43 @@ def to_ndarray(
"""
schema = validate_maybe_enum(schema, Schema, "schema")
if schema is None:
if self.schema is None:
ndarray_iter: NDArrayIterator

if self.schema is None:
# If schema is None, we're handling heterogeneous data from the live client.
# This is less performant because the records of a given schema are not contiguous in memory.
if schema is None:
raise ValueError("a schema must be specified for mixed DBN data")
schema = self.schema

dtype = SCHEMA_DTYPES_MAP[schema]
ndarray_iter = NDArrayIterator(
filter(lambda r: isinstance(r, SCHEMA_STRUCT_MAP[schema]), self),
dtype,
count,
)
schema_struct = SCHEMA_STRUCT_MAP[schema]
schema_dtype = SCHEMA_DTYPES_MAP[schema]
schema_filter = filter(lambda r: isinstance(r, schema_struct), self)

ndarray_iter = NDArrayBytesIterator(
records=map(bytes, schema_filter),
dtype=schema_dtype,
count=count,
)
else:
# If schema is set, we're handling homogeneous historical data.
schema_dtype = SCHEMA_DTYPES_MAP[self.schema]

if self._metadata.ts_out:
schema_dtype.append(("ts_out", "u8"))

if schema is not None and schema != self.schema:
# This is to maintain identical behavior with NDArrayBytesIterator
ndarray_iter = iter([np.empty([0, 1], dtype=schema_dtype)])
else:
ndarray_iter = NDArrayStreamIterator(
reader=self.reader,
dtype=schema_dtype,
offset=self._metadata_length,
count=count,
)

if count is None:
return next(ndarray_iter, np.empty([0, 1], dtype=dtype))
return next(ndarray_iter, np.empty([0, 1], dtype=schema_dtype))

return ndarray_iter

Expand Down Expand Up @@ -1124,10 +1156,66 @@ def _transcode(
transcoder.flush()


class NDArrayIterator:
class NDArrayIterator(Protocol):
@abc.abstractmethod
def __iter__(self) -> NDArrayIterator:
...

@abc.abstractmethod
def __next__(self) -> np.ndarray[Any, Any]:
...


class NDArrayStreamIterator(NDArrayIterator):
"""
Iterator for homogeneous byte streams of DBN records.
"""

def __init__(
self,
reader: IO[bytes],
dtype: list[tuple[str, str]],
offset: int = 0,
count: int | None = None,
) -> None:
self._reader = reader
self._dtype = np.dtype(dtype)
self._offset = offset
self._count = count

self._reader.seek(offset)

def __iter__(self) -> NDArrayStreamIterator:
return self

def __next__(self) -> np.ndarray[Any, Any]:
if self._count is None:
read_size = -1
else:
read_size = self._dtype.itemsize * max(self._count, 1)

if buffer := self._reader.read(read_size):
try:
return np.frombuffer(
buffer=buffer,
dtype=self._dtype,
)
except ValueError:
raise BentoError(
"DBN file is truncated or contains an incomplete record",
)

raise StopIteration


class NDArrayBytesIterator(NDArrayIterator):
"""
Iterator for heterogeneous streams of DBN records.
"""

def __init__(
self,
records: Iterator[DBNRecord],
records: Iterator[bytes],
dtype: list[tuple[str, str]],
count: int | None,
):
Expand All @@ -1144,22 +1232,33 @@ def __next__(self) -> np.ndarray[Any, Any]:
num_records = 0
for record in itertools.islice(self._records, self._count):
num_records += 1
record_bytes.write(bytes(record))
record_bytes.write(record)

if num_records == 0:
if self._first_next:
return np.empty([0, 1], dtype=self._dtype)
raise StopIteration

self._first_next = False
return np.frombuffer(
record_bytes.getvalue(),
dtype=self._dtype,
count=num_records,
)

try:
return np.frombuffer(
record_bytes.getbuffer(),
dtype=self._dtype,
count=num_records,
)
except ValueError:
raise BentoError(
"DBN file is truncated or contains an incomplete record",
)


class DataFrameIterator:
"""
Iterator for DataFrames that supports batching and column formatting for
DBN records.
"""

def __init__(
self,
records: Iterator[np.ndarray[Any, Any]],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_historical_bento.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,8 +905,8 @@ def test_dbnstore_to_ndarray_with_count(
# Act
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)

nd_iter = dbnstore.to_ndarray(count=count)
expected = dbnstore.to_ndarray()
nd_iter = dbnstore.to_ndarray(count=count)

# Assert
aggregator: list[np.ndarray[Any, Any]] = []
Expand Down Expand Up @@ -935,8 +935,8 @@ def test_dbnstore_to_ndarray_with_schema(
# Act
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)

actual = dbnstore.to_ndarray(schema=schema)
expected = dbnstore.to_ndarray()
actual = dbnstore.to_ndarray(schema=schema)

# Assert
for i, row in enumerate(actual):
Expand Down Expand Up @@ -1014,8 +1014,8 @@ def test_dbnstore_to_df_with_count(
# Act
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)

df_iter = dbnstore.to_df(count=count)
expected = dbnstore.to_df()
df_iter = dbnstore.to_df(count=count)

# Assert
aggregator: list[pd.DataFrame] = []
Expand Down Expand Up @@ -1048,8 +1048,8 @@ def test_dbnstore_to_df_with_schema(
# Act
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)

actual = dbnstore.to_df(schema=schema)
expected = dbnstore.to_df()
actual = dbnstore.to_df(schema=schema)

# Assert
assert actual.equals(expected)
Expand Down

0 comments on commit e53d3ca

Please sign in to comment.