Skip to content

Commit

Permalink
FIX: Fix handling of None in symbol list parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
nmacholl committed Aug 7, 2024
1 parent f61adb2 commit f4f8c3f
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 44 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 0.39.1 - TBD

#### Bug fixes
- Fixed an issue where a symbol list which contained a `None` would produce a convoluted exception

## 0.39.0 - 2024-07-30

#### Enhancements
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The library is fully compatible with the latest distribution of Anaconda 3.8 and
The minimum dependencies as found in the `pyproject.toml` are also listed below:
- python = "^3.8"
- aiohttp = "^3.8.3"
- databento-dbn = "0.19.1"
- databento-dbn = "0.20.0"
- numpy= ">=1.23.5"
- pandas = ">=1.5.3"
- pip-system-certs = ">=4.0" (Windows only)
Expand Down
62 changes: 36 additions & 26 deletions databento/common/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,60 +58,70 @@ def optional_values_list_to_string(
return values_list_to_string(values)


@singledispatch
def optional_symbols_list_to_list(
symbols: Iterable[str | int | Integral] | str | int | Integral | None,
stype_in: SType,
) -> list[str]:
"""
Create a list from a symbols string or iterable of symbol strings (if not
None).
Create a list from an optional symbols string or iterable of symbol
strings. If symbols is `None`, this function returns `[ALL_SYMBOLS]`.
Parameters
----------
symbols : Iterable of str or int or Number, or str or int or Number, optional
The symbols to concatenate.
The symbols to concatenate; or `None`.
stype_in : SType
The input symbology type for the request.
Returns
-------
list[str]
Notes
-----
If None is given, [ALL_SYMBOLS] is returned.
See Also
--------
symbols_list_to_list
"""
raise TypeError(
f"`{symbols}` is not a valid type for symbol input; "
"allowed types are Iterable[str | int], str, int, and None.",
)
if symbols is None:
return [ALL_SYMBOLS]
return symbols_list_to_list(symbols, stype_in)


@optional_symbols_list_to_list.register(cls=type(None))
def _(_: None, __: SType) -> list[str]:
@singledispatch
def symbols_list_to_list(
symbols: Iterable[str | int | Integral] | str | int | Integral,
stype_in: SType,
) -> list[str]:
"""
Dispatch method for optional_symbols_list_to_list. Handles None which
defaults to [ALL_SYMBOLS].
Create a list from a symbols string or iterable of symbol strings.
See Also
--------
optional_symbols_list_to_list
Parameters
----------
symbols : Iterable of str or int or Number, or str or int or Number
The symbols to concatenate.
stype_in : SType
The input symbology type for the request.
Returns
-------
list[str]
"""
return [ALL_SYMBOLS]
raise TypeError(
f"`{symbols}` is not a valid type for symbol input; "
"allowed types are Iterable[str | int], str, and int.",
)


@optional_symbols_list_to_list.register(cls=Integral)
@symbols_list_to_list.register(cls=Integral)
def _(symbols: Integral, stype_in: SType) -> list[str]:
"""
Dispatch method for optional_symbols_list_to_list. Handles integral types,
alerting when an integer is given for STypes that expect strings.
See Also
--------
optional_symbols_list_to_list
symbols_list_to_list
"""
if stype_in == SType.INSTRUMENT_ID:
Expand All @@ -122,15 +132,15 @@ def _(symbols: Integral, stype_in: SType) -> list[str]:
)


@optional_symbols_list_to_list.register(cls=str)
@symbols_list_to_list.register(cls=str)
def _(symbols: str, stype_in: SType) -> list[str]:
"""
Dispatch method for optional_symbols_list_to_list. Handles str, splitting
on commas and validating smart symbology.
See Also
--------
optional_symbols_list_to_list
symbols_list_to_list
"""
if not symbols:
Expand All @@ -147,19 +157,19 @@ def _(symbols: str, stype_in: SType) -> list[str]:
return list(map(str.upper, map(str.strip, symbol_list)))


@optional_symbols_list_to_list.register(cls=Iterable)
@symbols_list_to_list.register(cls=Iterable)
def _(symbols: Iterable[Any], stype_in: SType) -> list[str]:
"""
Dispatch method for optional_symbols_list_to_list. Handles Iterables by
dispatching the individual members.
See Also
--------
optional_symbols_list_to_list
symbols_list_to_list
"""
symbol_to_list = partial(
optional_symbols_list_to_list,
symbols_list_to_list,
stype_in=stype_in,
)
aggregated: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions databento/historical/api/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
from databento.common.http import check_http_error
from databento.common.parsing import datetime_to_string
from databento.common.parsing import optional_datetime_to_string
from databento.common.parsing import optional_symbols_list_to_list
from databento.common.parsing import optional_values_list_to_string
from databento.common.parsing import symbols_list_to_list
from databento.common.publishers import Dataset
from databento.common.types import Default
from databento.common.validation import validate_enum
Expand Down Expand Up @@ -147,7 +147,7 @@ def submit_job(
"""
stype_in_valid = validate_enum(stype_in, SType, "stype_in")
symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid)
symbols_list = symbols_list_to_list(symbols, stype_in_valid)
data: dict[str, object | None] = {
"dataset": validate_semantic_string(dataset, "dataset"),
"start": datetime_to_string(start),
Expand Down
4 changes: 2 additions & 2 deletions databento/live/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from databento.common.error import BentoError
from databento.common.iterator import chunk
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
from databento.common.parsing import optional_symbols_list_to_list
from databento.common.parsing import symbols_list_to_list
from databento.common.publishers import Dataset
from databento.common.types import DBNRecord
from databento.common.validation import validate_enum
Expand Down Expand Up @@ -310,7 +310,7 @@ def subscribe(
)

stype_in_valid = validate_enum(stype_in, SType, "stype_in")
symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid)
symbols_list = symbols_list_to_list(symbols, stype_in_valid)

subscriptions: list[SubscriptionRequest] = []
for batch in chunk(symbols_list, SYMBOL_LIST_BATCH_SIZE):
Expand Down
63 changes: 52 additions & 11 deletions tests/test_common_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import numpy as np
import pandas as pd
import pytest
from databento.common.constants import ALL_SYMBOLS
from databento.common.parsing import optional_date_to_string
from databento.common.parsing import optional_datetime_to_string
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
from databento.common.parsing import optional_symbols_list_to_list
from databento.common.parsing import optional_values_list_to_string
from databento.common.parsing import symbols_list_to_list
from databento_dbn import SType


Expand Down Expand Up @@ -50,7 +52,9 @@ def test_maybe_values_list_to_string_given_valid_inputs_returns_expected(
@pytest.mark.parametrize(
"stype, symbols, expected",
[
pytest.param(SType.RAW_SYMBOL, None, ["ALL_SYMBOLS"]),
pytest.param(SType.RAW_SYMBOL, None, TypeError),
pytest.param(SType.PARENT, ["ES", None], TypeError),
pytest.param(SType.PARENT, ["ES", [None]], TypeError),
pytest.param(SType.PARENT, "ES.fut", ["ES.FUT"]),
pytest.param(SType.PARENT, "ES,CL", ["ES", "CL"]),
pytest.param(SType.PARENT, "ES,CL,", ["ES", "CL"]),
Expand All @@ -67,22 +71,24 @@ def test_maybe_values_list_to_string_given_valid_inputs_returns_expected(
pytest.param(SType.PARENT, 123458, ValueError),
],
)
def test_optional_symbols_list_to_list_given_valid_inputs_returns_expected(
def test_symbols_list_to_list_given_valid_inputs_returns_expected(
stype: SType,
symbols: list[str] | None,
expected: list[object] | type[Exception],
) -> None:
# Arrange, Act, Assert
if isinstance(expected, list):
assert optional_symbols_list_to_list(symbols, stype) == expected
assert symbols_list_to_list(symbols, stype) == expected
else:
with pytest.raises(expected):
optional_symbols_list_to_list(symbols, stype)
symbols_list_to_list(symbols, stype)


@pytest.mark.parametrize(
"symbols, stype, expected",
[
pytest.param([12345, None], SType.INSTRUMENT_ID, TypeError),
pytest.param([12345, [None]], SType.INSTRUMENT_ID, TypeError),
pytest.param(12345, SType.INSTRUMENT_ID, ["12345"]),
pytest.param("67890", SType.INSTRUMENT_ID, ["67890"]),
pytest.param([12345, " 67890"], SType.INSTRUMENT_ID, ["12345", "67890"]),
Expand All @@ -104,7 +110,7 @@ def test_optional_symbols_list_to_list_given_valid_inputs_returns_expected(
pytest.param(12345, SType.CONTINUOUS, ValueError),
],
)
def test_optional_symbols_list_to_list_int(
def test_symbols_list_to_list_int(
symbols: list[int] | int | None,
stype: SType,
expected: list[object] | type[Exception],
Expand All @@ -117,15 +123,18 @@ def test_optional_symbols_list_to_list_int(
"""
# Arrange, Act, Assert
if isinstance(expected, list):
assert optional_symbols_list_to_list(symbols, stype) == expected
assert symbols_list_to_list(symbols, stype) == expected
else:
with pytest.raises(expected):
optional_symbols_list_to_list(symbols, stype)
symbols_list_to_list(symbols, stype)


@pytest.mark.parametrize(
"symbols, stype, expected",
[
pytest.param(None, SType.INSTRUMENT_ID, TypeError),
pytest.param([np.byte(120), None], SType.INSTRUMENT_ID, TypeError),
pytest.param([np.byte(120), [None]], SType.INSTRUMENT_ID, TypeError),
pytest.param(np.byte(120), SType.INSTRUMENT_ID, ["120"]),
pytest.param(np.short(32_000), SType.INSTRUMENT_ID, ["32000"]),
pytest.param(
Expand All @@ -140,7 +149,7 @@ def test_optional_symbols_list_to_list_int(
),
],
)
def test_optional_symbols_list_to_list_numpy(
def test_symbols_list_to_list_numpy(
symbols: list[int] | int | None,
stype: SType,
expected: list[object] | type[Exception],
Expand All @@ -153,15 +162,18 @@ def test_optional_symbols_list_to_list_numpy(
"""
# Arrange, Act, Assert
if isinstance(expected, list):
assert optional_symbols_list_to_list(symbols, stype) == expected
assert symbols_list_to_list(symbols, stype) == expected
else:
with pytest.raises(expected):
optional_symbols_list_to_list(symbols, stype)
symbols_list_to_list(symbols, stype)


@pytest.mark.parametrize(
"symbols, stype, expected",
[
pytest.param(None, SType.RAW_SYMBOL, TypeError),
pytest.param(["NVDA", None], SType.RAW_SYMBOL, TypeError),
pytest.param(["NVDA", [None]], SType.RAW_SYMBOL, TypeError),
pytest.param("NVDA", SType.RAW_SYMBOL, ["NVDA"]),
pytest.param(" nvda ", SType.RAW_SYMBOL, ["NVDA"]),
pytest.param("NVDA,amd", SType.RAW_SYMBOL, ["NVDA", "AMD"]),
Expand All @@ -179,7 +191,7 @@ def test_optional_symbols_list_to_list_numpy(
pytest.param(["NVDA", [""]], SType.RAW_SYMBOL, ValueError),
],
)
def test_optional_symbols_list_to_list_raw_symbol(
def test_symbols_list_to_list_raw_symbol(
symbols: list[int] | int | None,
stype: SType,
expected: list[object] | type[Exception],
Expand All @@ -188,6 +200,35 @@ def test_optional_symbols_list_to_list_raw_symbol(
Test that str are allowed for SType.RAW_SYMBOL.
"""
# Arrange, Act, Assert
if isinstance(expected, list):
assert symbols_list_to_list(symbols, stype) == expected
else:
with pytest.raises(expected):
symbols_list_to_list(symbols, stype)


@pytest.mark.parametrize(
"symbols, stype, expected",
[
pytest.param(None, SType.RAW_SYMBOL, [ALL_SYMBOLS]),
pytest.param(["NVDA", None], SType.RAW_SYMBOL, TypeError),
pytest.param([12345, None], SType.INSTRUMENT_ID, TypeError),
pytest.param("NVDA", SType.RAW_SYMBOL, ["NVDA"]),
pytest.param(["NVDA", "TSLA"], SType.RAW_SYMBOL, ["NVDA", "TSLA"]),
pytest.param(12345, SType.INSTRUMENT_ID, ["12345"]),
pytest.param([12345, "67890"], SType.INSTRUMENT_ID, ["12345", "67890"]),
],
)
def test_optional_symbols_list_to_list_raw_symbol(
symbols: list[int | str] | int | str | None,
stype: SType,
expected: list[object] | type[Exception],
) -> None:
"""
Test an optional symbols list converts a value of `None` to `[ALL_SYMBOLS]`
and handles other symbols lists.
"""
# Arrange, Act, Assert
if isinstance(expected, list):
assert optional_symbols_list_to_list(symbols, stype) == expected
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_live_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def test_live_async_iteration_after_start(
[
pytest.param("NVDA", id="str"),
pytest.param("ES,CL", id="str-list"),
pytest.param(None, id="all-symbols"),
pytest.param(ALL_SYMBOLS, id="all-symbols"),
],
)
@pytest.mark.parametrize(
Expand All @@ -459,7 +459,7 @@ async def test_live_subscribe(
mock_live_server: MockLiveServerInterface,
schema: Schema,
stype_in: SType,
symbols: str | None,
symbols: str,
start: str,
) -> None:
"""
Expand Down

0 comments on commit f4f8c3f

Please sign in to comment.