Skip to content

Commit

Permalink
feat(python): support DataFrame init from pydantic model data (#8178
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexander-beedie authored Apr 12, 2023
1 parent a6a2149 commit 4110656
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 34 deletions.
15 changes: 12 additions & 3 deletions py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from types import ModuleType
from typing import TYPE_CHECKING, Any, Hashable, cast

_DELTALAKE_AVAILABLE = True
_FSSPEC_AVAILABLE = True
_HYPOTHESIS_AVAILABLE = True
_NUMPY_AVAILABLE = True
_PANDAS_AVAILABLE = True
_PYARROW_AVAILABLE = True
_PYDANTIC_AVAILABLE = True
_ZONEINFO_AVAILABLE = True
_HYPOTHESIS_AVAILABLE = True
_DELTALAKE_AVAILABLE = True


class _LazyModule(ModuleType):
Expand Down Expand Up @@ -154,6 +155,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
import numpy
import pandas
import pyarrow
import pydantic

if sys.version_info >= (3, 9):
import zoneinfo
Expand All @@ -167,13 +169,14 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
pickle, _ = _lazy_import("pickle")
subprocess, _ = _lazy_import("subprocess")

# heavy third party libs
# heavy/optional third party libs
deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake")
fsspec, _FSSPEC_AVAILABLE = _lazy_import("fsspec")
hypothesis, _HYPOTHESIS_AVAILABLE = _lazy_import("hypothesis")
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
pyarrow, _PYARROW_AVAILABLE = _lazy_import("pyarrow")
pydantic, _PYDANTIC_AVAILABLE = _lazy_import("pydantic")
zoneinfo, _ZONEINFO_AVAILABLE = (
_lazy_import("zoneinfo")
if sys.version_info >= (3, 9)
Expand Down Expand Up @@ -203,6 +206,10 @@ def _check_for_pyarrow(obj: Any) -> bool:
return _PYARROW_AVAILABLE and _might_be(cast(Hashable, type(obj)), "pyarrow")


def _check_for_pydantic(obj: Any) -> bool:
return _PYDANTIC_AVAILABLE and _might_be(cast(Hashable, type(obj)), "pydantic")


__all__ = [
# lazy-load rarely-used/heavy builtins (for fast startup)
"dataclasses",
Expand All @@ -215,12 +222,14 @@ def _check_for_pyarrow(obj: Any) -> bool:
"fsspec",
"numpy",
"pandas",
"pydantic",
"pyarrow",
"zoneinfo",
# lazy utilities
"_check_for_numpy",
"_check_for_pandas",
"_check_for_pyarrow",
"_check_for_pydantic",
"_LazyModule",
# exported flags/guards
"_DELTALAKE_AVAILABLE",
Expand Down
66 changes: 53 additions & 13 deletions py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
_NUMPY_AVAILABLE,
_check_for_numpy,
_check_for_pandas,
_check_for_pydantic,
dataclasses,
pydantic,
)
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
Expand All @@ -74,22 +76,27 @@

if version_info >= (3, 10):

def dataclass_type_hints(obj: type) -> dict[str, Any]:
def type_hints(obj: type) -> dict[str, Any]:
return get_type_hints(obj)

else:

def dataclass_type_hints(obj: type) -> dict[str, Any]:
def type_hints(obj: type) -> dict[str, Any]:
return getattr(obj, "__annotations__", {})


def is_namedtuple(value: Any, annotated: bool = False) -> bool:
"""Infer whether value is a NamedTuple."""
"""Check whether value is a NamedTuple."""
if all(hasattr(value, attr) for attr in ("_fields", "_field_defaults", "_replace")):
return len(value.__annotations__) == len(value._fields) if annotated else True
return False


def is_pydantic_model(value: Any) -> bool:
"""Check whether value is a pydantic.BaseModel."""
return _check_for_pydantic(value) and isinstance(value, pydantic.BaseModel)


def include_unknowns(
schema: SchemaDict, cols: Sequence[str]
) -> MutableMapping[str, PolarsDataType]:
Expand Down Expand Up @@ -316,7 +323,11 @@ def sequence_to_pyseries(

value = _get_first_non_none(values)
if value is not None:
if dataclasses.is_dataclass(value) or is_namedtuple(value, annotated=True):
if (
dataclasses.is_dataclass(value)
or is_namedtuple(value, annotated=True)
or is_pydantic_model(value)
):
return pli.DataFrame(values).to_struct(name)._s
elif isinstance(value, range):
values = [range_to_series("", v) for v in values]
Expand Down Expand Up @@ -809,6 +820,9 @@ def _sequence_to_pydf_dispatcher(

elif dataclasses.is_dataclass(first_element):
to_pydf = _sequence_of_dataclasses_to_pydf

elif is_pydantic_model(first_element):
to_pydf = _sequence_of_models_to_pydf
else:
to_pydf = _sequence_of_elements_to_pydf

Expand Down Expand Up @@ -893,10 +907,11 @@ def _sequence_of_tuple_to_pydf(
orient: Orientation | None,
infer_schema_length: int | None,
) -> PyDataFrame:
# infer additional meta information if NAMED tuple...
if is_namedtuple(first_element):
# infer additional meta information if named tuple or pydantic model...
named_tuple = is_namedtuple(first_element)
if named_tuple or is_pydantic_model(first_element):
if schema is None:
schema = first_element._fields # type: ignore[attr-defined]
schema = first_element._fields if named_tuple else list(data.__fields__) # type: ignore[attr-defined]
if len(first_element.__annotations__) == len(schema):
schema = [
(name, py_type_to_dtype(tp, raise_unmatched=False))
Expand Down Expand Up @@ -1003,6 +1018,25 @@ def _sequence_of_pandas_to_pydf(
return PyDataFrame(data_series)


def _sequence_of_models_to_pydf(
first_element: Any,
data: Sequence[Any],
schema: SchemaDefinition | None,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
**kwargs: Any,
) -> PyDataFrame:
kwargs["pydantic_model"] = True
return _sequence_of_dataclasses_to_pydf(
first_element=first_element,
data=data,
schema=schema,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
**kwargs,
)


def _sequence_of_dataclasses_to_pydf(
first_element: Any,
data: Sequence[Any],
Expand All @@ -1013,6 +1047,7 @@ def _sequence_of_dataclasses_to_pydf(
) -> PyDataFrame:
from dataclasses import astuple

from_model = kwargs.get("pydantic_model")
if schema:
column_names, schema_overrides = _unpack_schema(schema, schema_overrides)
schema_override = {
Expand All @@ -1022,19 +1057,24 @@ def _sequence_of_dataclasses_to_pydf(
column_names = []
schema_override = {
col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown)
for col, tp in dataclass_type_hints(first_element.__class__).items()
for col, tp in type_hints(first_element.__class__).items()
}
if from_model:
schema_override.pop("__slots__", None)
schema_override.update(schema_overrides or {})

for col, tp in schema_override.items():
if tp == Categorical:
schema_override[col] = Utf8

pydf = PyDataFrame.read_rows(
[astuple(dc) for dc in data],
infer_schema_length,
schema_override or None,
)
if from_model:
pydf = PyDataFrame.read_dicts([md.dict() for md in data], infer_schema_length)
else:
pydf = PyDataFrame.read_rows(
[astuple(dc) for dc in data],
infer_schema_length,
schema_override or None,
)
if schema_override:
structs = {c: tp for c, tp in schema_override.items() if isinstance(tp, Struct)}
pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides)
Expand Down
18 changes: 8 additions & 10 deletions py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,19 @@ disable_error_code = [

[[tool.mypy.overrides]]
module = [
"IPython.*",
"backports",
"pyarrow.*",
"polars.polars",
"matplotlib.*",
"fsspec.*",
"connectorx",
"deltalake",
"IPython.*",
"fsspec.*",
"matplotlib.*",
"polars.polars",
"pyarrow.*",
"pydantic",
"sqlalchemy",
"xlsx2csv",
"xlsxwriter",
"xlsxwriter.format",
"xlsxwriter.utility",
"xlsxwriter.worksheet",
"xlsxwriter.*",
"zoneinfo",
"sqlalchemy",
]
ignore_missing_imports = true

Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ deltalake >= 0.8.0
numpy
pandas
pyarrow
pydantic
backports.zoneinfo; python_version < '3.9'
tzdata; platform_system == 'Windows'
xlsx2csv
Expand Down
18 changes: 13 additions & 5 deletions py-polars/tests/unit/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ def test_init_dataclasses_and_namedtuple(monkeypatch: Any) -> None:
from dataclasses import dataclass
from typing import NamedTuple

monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")
from polars.dependencies import pydantic
from polars.utils._construction import type_hints

from polars.utils._construction import dataclass_type_hints
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")

@dataclass
class TradeDC:
Expand All @@ -128,6 +129,12 @@ class TradeDC:
price: Decimal
size: int | None = None

class TradePD(pydantic.BaseModel):
timestamp: datetime
ticker: str
price: Decimal
size: int

class TradeNT(NamedTuple):
timestamp: datetime
ticker: str
Expand All @@ -139,9 +146,10 @@ class TradeNT(NamedTuple):
(datetime(2022, 9, 9, 10, 15, 12), "FLSY", Decimal("10.0"), 1500),
(datetime(2022, 9, 7, 15, 30), "MU", Decimal("55.5"), 400),
]
columns = ["timestamp", "ticker", "price", "size"]

for TradeClass in (TradeDC, TradeNT):
trades = [TradeClass(*values) for values in raw_data]
for TradeClass in (TradeDC, TradeNT, TradePD):
trades = [TradeClass(**dict(zip(columns, values))) for values in raw_data]

for DF in (pl.DataFrame, pl.from_records):
df = DF(data=trades) # type: ignore[operator]
Expand Down Expand Up @@ -184,7 +192,7 @@ class TradeNT(NamedTuple):
assert df.rows() == raw_data

# cover a miscellaneous edge-case when detecting the annotations
assert dataclass_type_hints(obj=type(None)) == {}
assert type_hints(obj=type(None)) == {}


def test_init_ndarray(monkeypatch: Any) -> None:
Expand Down
21 changes: 18 additions & 3 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def test_init_dataclass_namedtuple() -> None:
from dataclasses import dataclass
from typing import NamedTuple

from polars.dependencies import pydantic

@dataclass
class TeaShipmentDC:
exporter: str
Expand All @@ -163,11 +165,18 @@ class TeaShipmentNT(NamedTuple):
product: str
tonnes: None | int

for Tea in (TeaShipmentDC, TeaShipmentNT):
class TeaShipmentPD(pydantic.BaseModel):
exporter: str
importer: str
product: str
tonnes: int

for Tea in (TeaShipmentDC, TeaShipmentNT, TeaShipmentPD):
t0 = Tea(exporter="Sri Lanka", importer="USA", product="Ceylon", tonnes=10)
t1 = Tea(exporter="India", importer="UK", product="Darjeeling", tonnes=25)
t2 = Tea(exporter="China", importer="UK", product="Keemum", tonnes=40)

s = pl.Series("t", [t0, t1])
s = pl.Series("t", [t0, t1, t2])

assert isinstance(s, pl.Series)
assert s.dtype.fields == [ # type: ignore[union-attr]
Expand All @@ -189,8 +198,14 @@ class TeaShipmentNT(NamedTuple):
"product": "Darjeeling",
"tonnes": 25,
},
{
"exporter": "China",
"importer": "UK",
"product": "Keemum",
"tonnes": 40,
},
]
assert_frame_equal(s.to_frame(), pl.DataFrame({"t": [t0, t1]}))
assert_frame_equal(s.to_frame(), pl.DataFrame({"t": [t0, t1, t2]}))


def test_concat() -> None:
Expand Down

0 comments on commit 4110656

Please sign in to comment.