Skip to content

Commit

Permalink
MOD: Upgrade Python client to databento_dbn 0.14.2
Browse files Browse the repository at this point in the history
  • Loading branch information
nmacholl committed Nov 20, 2023
1 parent 17f1a02 commit f361be3
Show file tree
Hide file tree
Showing 69 changed files with 799 additions and 466 deletions.
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
# Changelog

## 0.24.0 - TBD

This release adds support for DBN v2.

#### Enhancements
- Improved the performance for stream writes in the `Live` client
- Upgraded `databento-dbn` to 0.14.2
- Added `databento.common.types` module to hold common type annotations

#### Breaking Changes
- `DBNStore` iteration and `DBNStore.replay` will upgrade DBN version 1 messages to version 2
- `Live` client iteration and callbacks upgrade DBN version 1 messages to version 2
- Moved `DBNRecord`, `RecordCallback`, and `ExceptionCallback` types to them `databento.common.types` module
- Moved `AUTH_TIMEOUT_SECONDS` and `CONNECT_TIMEOUT_SECONDS` constants from the `databento.live` module to `databento.live.session`
- Moved `INT64_NULL` from the `databento.common.dbnstore` module to `databento.common.constants`
- Moved `SCHEMA_STRUCT_MAP` from the `databento.common.data` module to `databento.common.constants`
- Removed `schema` parameter from `DataFrameIterator` constructor, `struct_type` is to be used instead
- Removed `NON_SCHEMA_RECORD_TYPES` constant as it is no longer used
- Removed `DERIV_SCHEMAS` constant as it is no longer used
- Removed `SCHEMA_COLUMNS` constant as it is no longer used
- Removed `SCHEMA_DTYPES_MAP` constant as it is no longer used
- Removed empty `databento.common.data` module

## 0.23.1 - 2023-11-10

#### Enhancements
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The library is fully compatible with the latest distribution of Anaconda 3.8 and
The minimum dependencies as found in the `pyproject.toml` are also listed below:
- python = "^3.8"
- aiohttp = "^3.8.3"
- databento-dbn = "0.13.0"
- databento-dbn = "0.14.2"
- numpy= ">=1.23.5"
- pandas = ">=1.5.3"
- requests = ">=2.24.0"
Expand Down
2 changes: 1 addition & 1 deletion databento/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from databento.common.publishers import Publisher
from databento.common.publishers import Venue
from databento.common.symbology import InstrumentMap
from databento.common.types import DBNRecord
from databento.historical.api import API_VERSION
from databento.historical.client import Historical
from databento.live import DBNRecord
from databento.live.client import Live
from databento.version import __version__ # noqa

Expand Down
60 changes: 59 additions & 1 deletion databento/common/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,59 @@
ALL_SYMBOLS = "ALL_SYMBOLS"
from __future__ import annotations

from typing import Final

import numpy as np
from databento_dbn import ImbalanceMsg
from databento_dbn import InstrumentDefMsg
from databento_dbn import InstrumentDefMsgV1
from databento_dbn import MBOMsg
from databento_dbn import MBP1Msg
from databento_dbn import MBP10Msg
from databento_dbn import OHLCVMsg
from databento_dbn import Schema
from databento_dbn import StatMsg
from databento_dbn import TradeMsg

from databento.common.types import DBNRecord


ALL_SYMBOLS: Final = "ALL_SYMBOLS"


DEFINITION_TYPE_MAX_MAP: Final = {
x[0]: np.iinfo(x[1]).max
for x in InstrumentDefMsg._dtypes
if not isinstance(x[1], str)
}

INT64_NULL: Final = 9223372036854775807

SCHEMA_STRUCT_MAP: Final[dict[Schema, type[DBNRecord]]] = {
Schema.DEFINITION: InstrumentDefMsg,
Schema.IMBALANCE: ImbalanceMsg,
Schema.MBO: MBOMsg,
Schema.MBP_1: MBP1Msg,
Schema.MBP_10: MBP10Msg,
Schema.OHLCV_1S: OHLCVMsg,
Schema.OHLCV_1M: OHLCVMsg,
Schema.OHLCV_1H: OHLCVMsg,
Schema.OHLCV_1D: OHLCVMsg,
Schema.STATISTICS: StatMsg,
Schema.TBBO: MBP1Msg,
Schema.TRADES: TradeMsg,
}

SCHEMA_STRUCT_MAP_V1: Final[dict[Schema, type[DBNRecord]]] = {
Schema.DEFINITION: InstrumentDefMsgV1,
Schema.IMBALANCE: ImbalanceMsg,
Schema.MBO: MBOMsg,
Schema.MBP_1: MBP1Msg,
Schema.MBP_10: MBP10Msg,
Schema.OHLCV_1S: OHLCVMsg,
Schema.OHLCV_1M: OHLCVMsg,
Schema.OHLCV_1H: OHLCVMsg,
Schema.OHLCV_1D: OHLCVMsg,
Schema.STATISTICS: StatMsg,
Schema.TBBO: MBP1Msg,
Schema.TRADES: TradeMsg,
}
73 changes: 0 additions & 73 deletions databento/common/data.py

This file was deleted.

99 changes: 55 additions & 44 deletions databento/common/dbnstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,26 @@
from databento_dbn import Compression
from databento_dbn import DBNDecoder
from databento_dbn import Encoding
from databento_dbn import ErrorMsg
from databento_dbn import InstrumentDefMsg
from databento_dbn import InstrumentDefMsgV1
from databento_dbn import Metadata
from databento_dbn import Schema
from databento_dbn import SType
from databento_dbn import SymbolMappingMsg
from databento_dbn import SystemMsg
from databento_dbn import Transcoder
from databento_dbn import VersionUpgradePolicy

from databento.common.data import DEFINITION_TYPE_MAX_MAP
from databento.common.data import SCHEMA_COLUMNS
from databento.common.data import SCHEMA_DTYPES_MAP
from databento.common.data import SCHEMA_STRUCT_MAP
from databento.common.constants import DEFINITION_TYPE_MAX_MAP
from databento.common.constants import INT64_NULL
from databento.common.constants import SCHEMA_STRUCT_MAP
from databento.common.constants import SCHEMA_STRUCT_MAP_V1
from databento.common.error import BentoError
from databento.common.iterator import chunk
from databento.common.symbology import InstrumentMap
from databento.common.types import DBNRecord
from databento.common.validation import validate_enum
from databento.common.validation import validate_file_write_path
from databento.common.validation import validate_maybe_enum
from databento.live import DBNRecord


NON_SCHEMA_RECORD_TYPES = [
ErrorMsg,
SymbolMappingMsg,
SystemMsg,
Metadata,
]

INT64_NULL = 9223372036854775807

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
Expand Down Expand Up @@ -380,7 +370,9 @@ def __init__(self, data_source: DataSource) -> None:

def __iter__(self) -> Generator[DBNRecord, None, None]:
reader = self.reader
decoder = DBNDecoder()
decoder = DBNDecoder(
upgrade_policy=VersionUpgradePolicy.UPGRADE,
)
while True:
raw = reader.read(DBNStore.DBN_READ_SIZE)
if raw:
Expand Down Expand Up @@ -936,8 +928,8 @@ def to_df(

df_iter = DataFrameIterator(
records=records,
schema=schema,
count=count,
struct_type=self._schema_struct_map[schema],
instrument_map=self._instrument_map,
price_type=price_type,
pretty_ts=pretty_ts,
Expand Down Expand Up @@ -1084,13 +1076,13 @@ def to_ndarray(
ndarray_iter: NDArrayIterator

if self.schema is None:
# If schema is None, we're handling heterogeneous data from the live client.
# This is less performant because the records of a given schema are not contiguous in memory.
# If schema is None, we're handling heterogeneous data from the live client
# This is less performant because the records of a given schema are not contiguous in memory
if schema is None:
raise ValueError("a schema must be specified for mixed DBN data")

schema_struct = SCHEMA_STRUCT_MAP[schema]
schema_dtype = SCHEMA_DTYPES_MAP[schema]
schema_struct = self._schema_struct_map[schema]
schema_dtype = schema_struct._dtypes
schema_filter = filter(lambda r: isinstance(r, schema_struct), self)

ndarray_iter = NDArrayBytesIterator(
Expand All @@ -1099,8 +1091,9 @@ def to_ndarray(
count=count,
)
else:
# If schema is set, we're handling homogeneous historical data.
schema_dtype = SCHEMA_DTYPES_MAP[self.schema]
# If schema is set, we're handling homogeneous historical data
schema_struct = self._schema_struct_map[self.schema]
schema_dtype = schema_struct._dtypes

if self._metadata.ts_out:
schema_dtype.append(("ts_out", "u8"))
Expand Down Expand Up @@ -1145,15 +1138,36 @@ def _transcode(
pretty_ts=pretty_ts,
has_metadata=True,
map_symbols=map_symbols,
symbol_map=symbol_map, # type: ignore [arg-type]
symbol_interval_map=symbol_map, # type: ignore [arg-type]
schema=schema,
)

transcoder.write(bytes(self.metadata))
for records in chunk(self, 2**16):
for record in records:
transcoder.write(bytes(record))
transcoder.flush()
reader = self.reader
transcoder.write(reader.read(self._metadata_length))
while byte_chunk := reader.read(2**16):
transcoder.write(byte_chunk)

if transcoder.buffer():
raise BentoError(
"DBN file is truncated or contains an incomplete record",
)

transcoder.flush()

@property
def _schema_struct_map(self) -> dict[Schema, type[DBNRecord]]:
"""
Return a mapping of Schema variants to DBNRecord types based on the DBN
metadata version.
Returns
-------
dict[Schema, type[DBNRecord]]
"""
if self.metadata.version == 1:
return SCHEMA_STRUCT_MAP_V1
return SCHEMA_STRUCT_MAP


class NDArrayIterator(Protocol):
Expand Down Expand Up @@ -1263,31 +1277,30 @@ def __init__(
self,
records: Iterator[np.ndarray[Any, Any]],
count: int | None,
schema: Schema,
struct_type: type[DBNRecord],
instrument_map: InstrumentMap,
price_type: Literal["fixed", "float", "decimal"] = "float",
pretty_ts: bool = True,
map_symbols: bool = True,
):
self._records = records
self._schema = schema
self._count = count
self._struct_type = struct_type
self._price_type = price_type
self._pretty_ts = pretty_ts
self._map_symbols = map_symbols
self._instrument_map = instrument_map
self._struct = SCHEMA_STRUCT_MAP[schema]

def __iter__(self) -> DataFrameIterator:
return self

def __next__(self) -> pd.DataFrame:
df = pd.DataFrame(
next(self._records),
columns=SCHEMA_COLUMNS[self._schema],
columns=self._struct_type._ordered_fields,
)

if self._schema == Schema.DEFINITION:
if self._struct_type in (InstrumentDefMsg, InstrumentDefMsgV1):
self._format_definition_fields(df)

self._format_hidden_fields(df)
Expand All @@ -1310,8 +1323,8 @@ def _format_definition_fields(self, df: pd.DataFrame) -> None:
df[column] = df[column].where(df[column] != type_max, np.nan)

def _format_hidden_fields(self, df: pd.DataFrame) -> None:
for column, dtype in SCHEMA_DTYPES_MAP[self._schema]:
hidden_fields = self._struct._hidden_fields
for column, dtype in self._struct_type._dtypes:
hidden_fields = self._struct_type._hidden_fields
if dtype.startswith("S") and column not in hidden_fields:
df[column] = df[column].str.decode("utf-8")

Expand All @@ -1328,7 +1341,7 @@ def _format_px(
df: pd.DataFrame,
price_type: Literal["fixed", "float", "decimal"],
) -> None:
px_fields = self._struct._price_fields
px_fields = self._struct_type._price_fields

if price_type == "decimal":
for field in px_fields:
Expand All @@ -1343,11 +1356,9 @@ def _format_px(
return # do nothing

def _format_pretty_ts(self, df: pd.DataFrame) -> None:
for field in self._struct._timestamp_fields:
for field in self._struct_type._timestamp_fields:
df[field] = pd.to_datetime(df[field], utc=True, errors="coerce")

def _format_set_index(self, df: pd.DataFrame) -> None:
index_column = (
"ts_event" if self._schema.value.startswith("ohlcv") else "ts_recv"
)
index_column = self._struct_type._ordered_fields[0]
df.set_index(index_column, inplace=True)
Loading

0 comments on commit f361be3

Please sign in to comment.