Skip to content

Commit

Permalink
feat(typing): utilize nw.LazyFrame working TypeVar
Browse files Browse the repository at this point in the history
Possible since narwhals-dev/narwhals#1930

@MarcoGorelli if you're interested what that PR did (besides fix warnings 😉)
  • Loading branch information
dangotbanned committed Feb 5, 2025
1 parent 790ff10 commit 8e53848
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion altair/datasets/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class _SupportsScanMetadata(Protocol):

def _scan_metadata(
self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata]
) -> nw.LazyFrame: ...
) -> nw.LazyFrame[Any]: ...


class DatasetCache:
Expand Down
5 changes: 2 additions & 3 deletions altair/datasets/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pandas as pd
import polars as pl
import pyarrow as pa
from narwhals.stable import v1 as nw

from altair.datasets._cache import DatasetCache
from altair.datasets._reader import Reader
Expand Down Expand Up @@ -58,13 +57,13 @@ def from_backend(
@classmethod
def from_backend(
cls, backend_name: Literal["pandas", "pandas[pyarrow]"], /
) -> Loader[pd.DataFrame, nw.LazyFrame]: ...
) -> Loader[pd.DataFrame, pd.DataFrame]: ...

@overload
@classmethod
def from_backend(
cls, backend_name: Literal["pyarrow"], /
) -> Loader[pa.Table, nw.LazyFrame]: ...
) -> Loader[pa.Table, pa.Table]: ...

@classmethod
def from_backend(
Expand Down
19 changes: 11 additions & 8 deletions altair/datasets/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ def _merge_kwds(self, meta: Metadata, kwds: dict[str, Any], /) -> Mapping[str, A
return kwds

@property
def _metadata_frame(self) -> nw.LazyFrame:
def _metadata_frame(self) -> nw.LazyFrame[IntoFrameT]:
fp = self._metadata_path
return nw.from_native(self.scan_fn(fp)(fp)).lazy()

def _scan_metadata(
self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata]
) -> nw.LazyFrame:
) -> nw.LazyFrame[IntoFrameT]:
if predicates or constraints:
return self._metadata_frame.filter(*predicates, **constraints)
return self._metadata_frame
Expand Down Expand Up @@ -360,7 +360,7 @@ def csv_cache(self) -> CsvCache:
return self._csv_cache

@property
def _metadata_frame(self) -> nw.LazyFrame:
def _metadata_frame(self) -> nw.LazyFrame[IntoFrameT]:
data = cast("dict[str, Any]", self.csv_cache.rotated)
impl = self._implementation
return nw.maybe_convert_dtypes(nw.from_dict(data, backend=impl)).lazy()
Expand All @@ -373,7 +373,7 @@ def reader(
*,
name: str | None = ...,
implementation: nw.Implementation = ...,
) -> Reader[IntoDataFrameT, nw.LazyFrame]: ...
) -> Reader[IntoDataFrameT, nw.LazyFrame[IntoDataFrameT]]: ...


@overload
Expand All @@ -392,7 +392,10 @@ def reader(
*,
name: str | None = None,
implementation: nw.Implementation = nw.Implementation.UNKNOWN,
) -> Reader[IntoDataFrameT, IntoFrameT] | Reader[IntoDataFrameT, nw.LazyFrame]:
) -> (
Reader[IntoDataFrameT, IntoFrameT]
| Reader[IntoDataFrameT, nw.LazyFrame[IntoDataFrameT]]
):
name = name or Counter(el._inferred_package for el in read_fns).most_common(1)[0][0]
if implementation is nw.Implementation.UNKNOWN:
implementation = _into_implementation(Requirement(name))
Expand Down Expand Up @@ -429,9 +432,9 @@ def infer_backend(
@overload
def _from_backend(name: _Polars, /) -> Reader[pl.DataFrame, pl.LazyFrame]: ...
@overload
def _from_backend(name: _PandasAny, /) -> Reader[pd.DataFrame, nw.LazyFrame]: ...
def _from_backend(name: _PandasAny, /) -> Reader[pd.DataFrame, pd.DataFrame]: ...
@overload
def _from_backend(name: _PyArrow, /) -> Reader[pa.Table, nw.LazyFrame]: ...
def _from_backend(name: _PyArrow, /) -> Reader[pa.Table, pa.Table]: ...


# FIXME: The order this is defined in makes splitting the module complicated
Expand Down Expand Up @@ -512,7 +515,7 @@ def _into_suffix(obj: Path | str, /) -> Any:

def _steal_eager_parquet(
read_fns: Sequence[Read[IntoDataFrameT]], /
) -> Sequence[Scan[nw.LazyFrame]] | None:
) -> Sequence[Scan[nw.LazyFrame[IntoDataFrameT]]] | None:
if convertable := next((rd for rd in read_fns if rd.include <= is_parquet), None):
return (_readimpl.into_scan(convertable),)
return None
Expand Down
14 changes: 8 additions & 6 deletions altair/datasets/_readimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
R = TypeVar("R", bound="nwt.IntoFrame")
IntoFrameT = TypeVar(
"IntoFrameT",
bound="nwt.NativeFrame | nw.DataFrame[Any] | nw.LazyFrame | nwt.DataFrameLike",
default=nw.LazyFrame,
bound="nwt.NativeFrame | nw.DataFrame[Any] | nw.LazyFrame[Any] | nwt.DataFrameLike",
default=nw.LazyFrame[Any],
)
Read = TypeAliasType("Read", "BaseImpl[IntoDataFrameT]", type_params=(IntoDataFrameT,))
"""An *eager* file read function."""
Expand Down Expand Up @@ -214,15 +214,17 @@ def scan(
return BaseImpl(fn, include, exclude, kwds)


def into_scan(impl: Read[IntoDataFrameT], /) -> Scan[nw.LazyFrame]:
def scan_fn(fn: Callable[..., IntoDataFrameT], /) -> Callable[..., nw.LazyFrame]:
def into_scan(impl: Read[IntoDataFrameT], /) -> Scan[nw.LazyFrame[IntoDataFrameT]]:
def scan_fn(
fn: Callable[..., IntoDataFrameT], /
) -> Callable[..., nw.LazyFrame[IntoDataFrameT]]:
@wraps(_unwrap_partial(fn))
def wrapper(*args: Any, **kwds: Any) -> nw.LazyFrame:
def wrapper(*args: Any, **kwds: Any) -> nw.LazyFrame[IntoDataFrameT]:
return nw.from_native(fn(*args, **kwds)).lazy()

return wrapper

return BaseImpl(scan_fn(impl.fn), impl.include, impl.exclude, {})
return scan(scan_fn(impl.fn), impl.include, impl.exclude)


def is_available(
Expand Down

0 comments on commit 8e53848

Please sign in to comment.