From ec2e0a937bbcb4056a9b230a311fb64f82daabc3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 16:55:20 +0000 Subject: [PATCH 01/10] wip --- narwhals/_polars/namespace.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 32e53b372..c58d1dfa0 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -14,6 +14,7 @@ from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import narwhals_to_native_dtype from narwhals.utils import Implementation +from narwhals.dtypes import DType if TYPE_CHECKING: from typing_extensions import Self @@ -218,11 +219,15 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: import polars as pl from narwhals._polars.expr import PolarsExpr - + native_dtypes = [] + for dtype in dtypes: + native_dtype_instantiated = narwhals_to_native_dtype(dtype, self._version) + if issubclass(dtype, DType): + native_dtypes.append(native_dtype_instantiated.__class__) + else: + native_dtypes.append(native_dtype_instantiated) return PolarsExpr( - pl.selectors.by_dtype( - [narwhals_to_native_dtype(dtype, self._version) for dtype in dtypes] - ), + pl.selectors.by_dtype(native_dtypes), version=self._version, backend_version=self._backend_version, ) From 61685a6a1701bb510468c3959dccb5273a400c62 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:01:06 +0000 Subject: [PATCH 02/10] fixup selectors for polars --- narwhals/_polars/namespace.py | 8 ++++---- tests/selectors_test.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index c58d1dfa0..205e61fd7 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -13,8 +13,8 @@ from narwhals._expression_parsing import parse_into_exprs from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import narwhals_to_native_dtype -from narwhals.utils import Implementation from narwhals.dtypes import DType +from narwhals.utils import Implementation if TYPE_CHECKING: from typing_extensions import Self @@ -23,7 +23,6 @@ from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr from narwhals._polars.typing import IntoPolarsExpr - from narwhals.dtypes import DType from narwhals.utils import Version @@ -219,10 +218,11 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: import polars as pl from narwhals._polars.expr import PolarsExpr - native_dtypes = [] + + native_dtypes: list[pl.DataType | type[pl.DataType]] = [] for dtype in dtypes: native_dtype_instantiated = narwhals_to_native_dtype(dtype, self._version) - if issubclass(dtype, DType): + if isinstance(dtype, type) and issubclass(dtype, DType): native_dtypes.append(native_dtype_instantiated.__class__) else: native_dtypes.append(native_dtype_instantiated) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 36ea15a4f..628ceaa19 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from datetime import datetime import pytest @@ -117,3 +118,13 @@ def test_set_ops_invalid(constructor: Constructor) -> None: match=re.escape("unsupported operand type(s) for op: ('Selector' + 'Selector')"), ): df.select(boolean() + numeric()) + + +def test_tz_aware(constructor: Constructor) -> None: + data = {"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)], "c": [4, 5]} + df = nw.from_native(constructor(data)).with_columns( + b=nw.col("a").dt.replace_time_zone("Asia/Katmandu") + ) + result = df.select(nw.selectors.by_dtype(nw.Datetime)).collect_schema().names() + expected = ["a", "b"] + assert result == expected From eee3bdd472b826604cb3670ca64b05dbecdf69cf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:04:03 +0000 Subject: [PATCH 03/10] fixup --- tests/selectors_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index f93e2067b..94c6d8c23 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -128,7 +128,11 @@ def test_set_ops_invalid(constructor: Constructor) -> None: df.select(boolean() + numeric()) -def test_tz_aware(constructor: Constructor) -> None: +def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 19): + # bug in old polars + request.applymarker(pytest.mark.xfail) + data = {"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)], "c": [4, 5]} df = nw.from_native(constructor(data)).with_columns( b=nw.col("a").dt.replace_time_zone("Asia/Katmandu") From eb66047d7c95a7e41d953001a591e2ccea00ba14 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:10:01 +0000 Subject: [PATCH 04/10] ci --- tests/selectors_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 94c6d8c23..15b5e63b1 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -16,6 +16,7 @@ from tests.utils import PYARROW_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data +from tests.utils import is_windows data = { "a": [1, 1, 2], @@ -128,10 +129,14 @@ def test_set_ops_invalid(constructor: Constructor) -> None: df.select(boolean() + numeric()) +@pytest.mark.skipif(is_windows()) def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 19): # bug in old polars request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) or "pyspark" in str(constructor): + # replace_time_zone not implemented + request.applymarker(pytest.mark.xfail) data = {"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)], "c": [4, 5]} df = nw.from_native(constructor(data)).with_columns( From 0d3f4fd74242dca79334745b4508602d64e8e8ed Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:11:06 +0000 Subject: [PATCH 05/10] win --- tests/selectors_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 15b5e63b1..fce236f38 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -129,7 +129,7 @@ def test_set_ops_invalid(constructor: Constructor) -> None: df.select(boolean() + numeric()) -@pytest.mark.skipif(is_windows()) +@pytest.mark.skipif(is_windows(), reason="windows is what it is") def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 19): # bug in old polars From 7247bb3ebfae33a52407d8cf754f8c08153aa5b2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:13:58 +0000 Subject: [PATCH 06/10] win --- tests/selectors_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index fce236f38..e337d234e 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -134,6 +134,9 @@ def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> N if "polars" in str(constructor) and POLARS_VERSION < (1, 19): # bug in old polars request.applymarker(pytest.mark.xfail) + if "pyarrow_table" in str(constructor) and POLARS_VERSION < (1, 12): + # bug in old pyarrow + request.applymarker(pytest.mark.xfail) if "duckdb" in str(constructor) or "pyspark" in str(constructor): # replace_time_zone not implemented request.applymarker(pytest.mark.xfail) From 55edfe7b82ee19fc80893c621fb9401f35781230 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:14:09 +0000 Subject: [PATCH 07/10] win --- tests/selectors_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index e337d234e..3f2c3a1a3 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -134,7 +134,7 @@ def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> N if "polars" in str(constructor) and POLARS_VERSION < (1, 19): # bug in old polars request.applymarker(pytest.mark.xfail) - if "pyarrow_table" in str(constructor) and POLARS_VERSION < (1, 12): + if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (1, 12): # bug in old pyarrow request.applymarker(pytest.mark.xfail) if "duckdb" in str(constructor) or "pyspark" in str(constructor): From ba2dd7f0387df8334305a0a781a30730030ead3b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:24:18 +0000 Subject: [PATCH 08/10] fix --- tests/selectors_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 3f2c3a1a3..4f39b4c39 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -134,7 +134,7 @@ def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> N if "polars" in str(constructor) and POLARS_VERSION < (1, 19): # bug in old polars request.applymarker(pytest.mark.xfail) - if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (1, 12): + if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,): # bug in old pyarrow request.applymarker(pytest.mark.xfail) if "duckdb" in str(constructor) or "pyspark" in str(constructor): From b780de509d9f5bc18bcbde2ebe358a624dc12142 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:34:50 +0000 Subject: [PATCH 09/10] fix --- tests/selectors_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 4f39b4c39..9e16da5a1 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -148,3 +148,6 @@ def test_tz_aware(constructor: Constructor, request: pytest.FixtureRequest) -> N result = df.select(nw.selectors.by_dtype(nw.Datetime)).collect_schema().names() expected = ["a", "b"] assert result == expected + result = df.select(nw.selectors.by_dtype(nw.Int64())).collect_schema().names() + expected = ["c"] + assert result == expected From 5ee60090adca0ea7177c0ccadadcce47248c1421 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Jan 2025 19:16:01 +0000 Subject: [PATCH 10/10] list-comprehensionify --- narwhals/_polars/namespace.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 205e61fd7..005af3604 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -35,8 +35,6 @@ def __init__( self._version = version def __getattr__(self: Self, attr: str) -> Any: - import polars as pl - from narwhals._polars.expr import PolarsExpr def func(*args: Any, **kwargs: Any) -> Any: @@ -50,8 +48,6 @@ def func(*args: Any, **kwargs: Any) -> Any: return func def nth(self: Self, *indices: int) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr if self._backend_version < (1, 0, 0): @@ -62,8 +58,6 @@ def nth(self: Self, *indices: int) -> PolarsExpr: ) def len(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr if self._backend_version < (0, 20, 5): @@ -114,8 +108,6 @@ def concat( ) def lit(self: Self, value: Any, dtype: DType | None) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr if dtype is not None: @@ -129,8 +121,6 @@ def lit(self: Self, value: Any, dtype: DType | None) -> PolarsExpr: ) def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr polars_exprs = cast("list[PolarsExpr]", parse_into_exprs(*exprs, namespace=self)) @@ -155,8 +145,6 @@ def concat_str( separator: str, ignore_nulls: bool, ) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr pl_exprs: list[pl.Expr] = [ @@ -215,17 +203,14 @@ def __init__(self: Self, version: Version, backend_version: tuple[int, ...]) -> self._backend_version = backend_version def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr - native_dtypes: list[pl.DataType | type[pl.DataType]] = [] - for dtype in dtypes: - native_dtype_instantiated = narwhals_to_native_dtype(dtype, self._version) - if isinstance(dtype, type) and issubclass(dtype, DType): - native_dtypes.append(native_dtype_instantiated.__class__) - else: - native_dtypes.append(native_dtype_instantiated) + native_dtypes = [ + narwhals_to_native_dtype(dtype, self._version).__class__ + if isinstance(dtype, type) and issubclass(dtype, DType) + else narwhals_to_native_dtype(dtype, self._version) + for dtype in dtypes + ] return PolarsExpr( pl.selectors.by_dtype(native_dtypes), version=self._version,