From 3061fe926d7e7cfe79c2ff7218a9365f75095276 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 8 Dec 2024 22:46:13 +0100 Subject: [PATCH 01/10] feat: DataFrame and LazyFrame explode --- docs/api-reference/dataframe.md | 1 + docs/api-reference/lazyframe.md | 1 + narwhals/_arrow/dataframe.py | 69 ++++++++++++++++++++++ narwhals/_pandas_like/dataframe.py | 38 ++++++++++++ narwhals/dataframe.py | 14 +++++ narwhals/exceptions.py | 4 ++ tests/frame/explode_test.py | 95 ++++++++++++++++++++++++++++++ 7 files changed, 222 insertions(+) create mode 100644 tests/frame/explode_test.py diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 883fb7897..3b70b0bd8 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -11,6 +11,7 @@ - columns - drop - drop_nulls + - explode - filter - gather_every - get_column diff --git a/docs/api-reference/lazyframe.md b/docs/api-reference/lazyframe.md index 515069d1c..07667ab04 100644 --- a/docs/api-reference/lazyframe.md +++ b/docs/api-reference/lazyframe.md @@ -10,6 +10,7 @@ - columns - drop - drop_nulls + - explode - filter - gather_every - group_by diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 68f7ab534..8daa96d68 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -743,3 +743,72 @@ def unpivot( ) # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not # upcast numeric to non-numeric (e.g. string) datatypes + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + import pyarrow as pa + import pyarrow.compute as pc + + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + native_frame = self._native_frame + counts = pc.list_value_length(native_frame[to_explode[0]]) + + if not all( + pc.all(pc.equal(pc.list_value_length(native_frame[col_name]), counts)).as_py() + for col_name in to_explode[1:] + ): + from narwhals.exceptions import ShapeError + + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + + original_columns = self.columns + other_columns = [c for c in original_columns if c not in to_explode] + fast_path = pc.all(pc.greater_equal(counts, 1)).as_py() + + if fast_path: + indices = pc.list_parent_indices(native_frame[to_explode[0]]) + + exploded_frame = native_frame.take(indices=indices) + exploded_series = [ + pc.list_flatten(native_frame[col_name]) for col_name in to_explode + ] + return self._from_native_frame( + pa.Table.from_arrays( + [*[exploded_frame[c] for c in other_columns], *exploded_series], + names=[*other_columns, *to_explode], + ) + ).select(*original_columns) + + else: + + def explode_null_array(array: pa.ChunkedArray) -> pa.ChunkedArray: + exploded_values = [] # type: ignore[var-annotated] + for lst_element in array.to_pylist(): + if lst_element is None or len(lst_element) == 0: + exploded_values.append(None) + else: # Non-empty list) + exploded_values.extend(lst_element) + return pa.chunked_array([exploded_values]) + + indices = pa.array( + [ + i + for i, count in enumerate(counts.to_pylist()) + for _ in range(max(count or 1, 1)) + ] + ) + exploded_frame = native_frame.take(indices=indices) + exploded_series = [ + explode_null_array(native_frame[col_name]) for col_name in to_explode + ] + + return self._from_native_frame( + pa.Table.from_arrays( + [*[exploded_frame[c] for c in other_columns], *exploded_series], + names=[*other_columns, *to_explode], + ) + ).select(*original_columns) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a897548bf..8428916ef 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -937,3 +937,41 @@ def unpivot( value_name=value_name if value_name is not None else "value", ) ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + + if len(to_explode) == 1: + return self._from_native_frame(self._native_frame.explode(to_explode[0])) + else: + native_frame = self._native_frame + anchor_series = native_frame[to_explode[0]].list.len() + + if not all( + (native_frame[col_name].list.len() == anchor_series).all() + for col_name in to_explode[1:] + ): + from narwhals.exceptions import ShapeError + + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + + original_columns = self.columns + other_columns = [c for c in original_columns if c not in to_explode] + + exploded_frame = native_frame[[*other_columns, to_explode[0]]].explode( + to_explode[0] + ) + exploded_series = [ + native_frame[col_name].explode().to_frame() for col_name in to_explode[1:] + ] + + plx = self.__native_namespace__() + + return self._from_native_frame( + plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns] + ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index c057b7227..87c5ec34b 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -334,6 +334,14 @@ def __eq__(self, other: object) -> NoReturn: ) raise NotImplementedError(msg) + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + return self._from_compliant_dataframe( + self._compliant_frame.explode( + columns, + *more_columns, + ) + ) + class DataFrame(BaseFrame[DataFrameT]): """Narwhals DataFrame, backed by a native dataframe. @@ -2925,6 +2933,9 @@ def unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + return super().explode(columns, *more_columns) + class LazyFrame(BaseFrame[FrameT]): """Narwhals DataFrame, backed by a native dataframe. @@ -4643,3 +4654,6 @@ def unpivot( return super().unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + return super().explode(columns, *more_columns) diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 12f85d1ad..ee4b79b6a 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -35,6 +35,10 @@ def from_missing_and_available_column_names( return ColumnNotFoundError(message) +class ShapeError(Exception): + """Exception raised when trying to perform operations on data structures with incompatible shapes.""" + + class InvalidOperationError(Exception): """Exception raised during invalid operations.""" diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py new file mode 100644 index 000000000..c4e17bd08 --- /dev/null +++ b/tests/frame/explode_test.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest +from polars.exceptions import ShapeError as PlShapeError + +import narwhals.stable.v1 as nw +from narwhals.exceptions import ShapeError +from tests.utils import Constructor +from tests.utils import assert_equal_data + +# For context, polars allows to explode multiple columns only if the columns +# have matching element counts, therefore, l1 and l2 but not l1 and l3 together. +data = { + "a": ["x", "y", "z", "w"], + "l1": [[1, 2], None, [None], []], + "l2": [[3, None], None, [42], []], + "l3": [[1, 2], [3], [None], [1]], +} + + +@pytest.mark.parametrize( + ("columns", "expected_values"), + [ + ("l2", [3, None, None, 42, None]), + ("l3", [1, 2, 3, None, 1]), # fast path for arrow + ], +) +def test_explode_single_col( + request: pytest.FixtureRequest, + constructor: Constructor, + columns: str, + expected_values: list[int | None], +) -> None: + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + request.applymarker(pytest.mark.xfail) + + result = ( + nw.from_native(constructor(data)) + .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) + .explode(columns) + .select("a", columns) + ) + expected = {"a": ["x", "x", "y", "z", "w"], columns: expected_values} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("columns", "more_columns"), + [ + ("l1", ["l2"]), + (["l1", "l2"], []), + ], +) +def test_explode_multiple_cols( + request: pytest.FixtureRequest, + constructor: Constructor, + columns: str | Sequence[str], + more_columns: Sequence[str], +) -> None: + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + request.applymarker(pytest.mark.xfail) + + result = ( + nw.from_native(constructor(data)) + .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) + .explode(columns, *more_columns) + .select("a", "l1", "l2") + ) + expected = { + "a": ["x", "x", "y", "z", "w"], + "l1": [1, 2, None, None, None], + "l2": [3, None, None, 42, None], + } + assert_equal_data(result, expected) + + +def test_explode_exception( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + request.applymarker(pytest.mark.xfail) + + with pytest.raises( + (ShapeError, PlShapeError), + match="exploded columns must have matching element counts", + ): + _ = ( + nw.from_native(constructor(data)) + .lazy() + .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) + .explode("l1", "l3") + .collect() + ) From 2326b08d12ec8512fe4a94891ef6a9a20b32b313 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 9 Dec 2024 10:03:56 +0100 Subject: [PATCH 02/10] arrow refactor --- narwhals/_arrow/dataframe.py | 51 ++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 8daa96d68..8a37da4c4 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -771,19 +771,15 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel if fast_path: indices = pc.list_parent_indices(native_frame[to_explode[0]]) - - exploded_frame = native_frame.take(indices=indices) - exploded_series = [ - pc.list_flatten(native_frame[col_name]) for col_name in to_explode - ] - return self._from_native_frame( - pa.Table.from_arrays( - [*[exploded_frame[c] for c in other_columns], *exploded_series], - names=[*other_columns, *to_explode], - ) - ).select(*original_columns) - + flatten_func = pc.list_flatten else: + indices = pa.array( + [ + i + for i, count in enumerate(counts.to_pylist()) + for _ in range(max(count or 1, 1)) + ] + ) def explode_null_array(array: pa.ChunkedArray) -> pa.ChunkedArray: exploded_values = [] # type: ignore[var-annotated] @@ -794,21 +790,18 @@ def explode_null_array(array: pa.ChunkedArray) -> pa.ChunkedArray: exploded_values.extend(lst_element) return pa.chunked_array([exploded_values]) - indices = pa.array( - [ - i - for i, count in enumerate(counts.to_pylist()) - for _ in range(max(count or 1, 1)) - ] - ) - exploded_frame = native_frame.take(indices=indices) - exploded_series = [ - explode_null_array(native_frame[col_name]) for col_name in to_explode - ] + flatten_func = explode_null_array - return self._from_native_frame( - pa.Table.from_arrays( - [*[exploded_frame[c] for c in other_columns], *exploded_series], - names=[*other_columns, *to_explode], - ) - ).select(*original_columns) + arrays = [ + native_frame[col_name].take(indices=indices) + if col_name in other_columns + else flatten_func(native_frame[col_name]) + for col_name in original_columns + ] + + return self._from_native_frame( + pa.Table.from_arrays( + arrays=arrays, + names=original_columns, + ) + ) From 32af22e2deb1978f6088b8374d36e9c65fc0a965 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 9 Dec 2024 10:37:24 +0100 Subject: [PATCH 03/10] raise for invalid type and docstrings --- narwhals/_arrow/dataframe.py | 14 ++++ narwhals/_pandas_like/dataframe.py | 11 +++ narwhals/dataframe.py | 120 ++++++++++++++++++++++++++++- tests/frame/explode_test.py | 17 +++- 4 files changed, 159 insertions(+), 3 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 8a37da4c4..ea1e31261 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -19,6 +19,7 @@ from narwhals.utils import Implementation from narwhals.utils import flatten from narwhals.utils import generate_temporary_column_name +from narwhals.utils import import_dtypes_module from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop @@ -748,11 +749,24 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel import pyarrow as pa import pyarrow.compute as pc + from narwhals.exceptions import InvalidOperationError + + dtypes = import_dtypes_module(self._version) + to_explode = ( [columns, *more_columns] if isinstance(columns, str) else [*columns, *more_columns] ) + + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = f"`explode` operation not supported for dtype `{dtype}`" + raise InvalidOperationError(msg) + native_frame = self._native_frame counts = pc.list_value_length(native_frame[to_explode[0]]) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 8428916ef..816617e8d 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -939,11 +939,22 @@ def unpivot( ) def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + from narwhals.exceptions import InvalidOperationError + + dtypes = import_dtypes_module(self._version) + to_explode = ( [columns, *more_columns] if isinstance(columns, str) else [*columns, *more_columns] ) + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = f"`explode` operation not supported for dtype `{dtype}`" + raise InvalidOperationError(msg) if len(to_explode) == 1: return self._from_native_frame(self._native_frame.explode(to_explode[0])) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 87c5ec34b..664b5ee44 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -580,8 +580,6 @@ def to_pandas(self) -> pd.DataFrame: 0 1 6.0 a 1 2 7.0 b 2 3 8.0 c - - """ return self._compliant_frame.to_pandas() @@ -2934,6 +2932,74 @@ def unpivot( ) def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + """Explode the dataframe to long format by exploding the given columns. + + Notes: + It is possible to explode multiple columns only if these columns must have + matching element counts. + + Arguments: + columns: Column names. The underlying columns being exploded must be of the `List` data type. + *more_columns: Additional names of columns to explode, specified as positional arguments. + + Returns: + New DataFrame + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoDataFrameT + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = { + ... "a": ["x", "y", "z", "w"], + ... "lst1": [[1, 2], None, [None], []], + ... "lst2": [[3, None], None, [42], []], + ... } + + We define a library agnostic function: + + >>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT: + ... return ( + ... nw.from_native(df_native) + ... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32()))) + ... .explode("lst1", "lst2") + ... .to_native() + ... ) + + We can then pass any supported library such as pandas, Polars (eager), + or PyArrow to `agnostic_explode`: + + >>> agnostic_explode(pd.DataFrame(data)) + a lst1 lst2 + 0 x 1 3 + 0 x 2 + 1 y + 2 z 42 + 3 w + >>> agnostic_explode(pl.DataFrame(data)) + shape: (5, 3) + ┌─────┬──────┬──────┐ + │ a ┆ lst1 ┆ lst2 │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i32 ┆ i32 │ + ╞═════╪══════╪══════╡ + │ x ┆ 1 ┆ 3 │ + │ x ┆ 2 ┆ null │ + │ y ┆ null ┆ null │ + │ z ┆ null ┆ 42 │ + │ w ┆ null ┆ null │ + └─────┴──────┴──────┘ + >>> agnostic_explode(pa.table(data)) + pyarrow.Table + a: string + lst1: int64 + lst2: int64 + ---- + a: [["x","x","y","z","w"]] + lst1: [[1,2,null,null,null]] + lst2: [[3,null,null,42,null]] + """ return super().explode(columns, *more_columns) @@ -4656,4 +4722,54 @@ def unpivot( ) def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + """Explode the dataframe to long format by exploding the given columns. + + Notes: + It is possible to explode multiple columns only if these columns must have + matching element counts. + + Arguments: + columns: Column names. The underlying columns being exploded must be of the `List` data type. + *more_columns: Additional names of columns to explode, specified as positional arguments. + + Returns: + New DataFrame + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoFrameT + >>> import polars as pl + >>> data = { + ... "a": ["x", "y", "z", "w"], + ... "lst1": [[1, 2], None, [None], []], + ... "lst2": [[3, None], None, [42], []], + ... } + + We define a library agnostic function: + + >>> def agnostic_explode(df_native: IntoFrameT) -> IntoFrameT: + ... return ( + ... nw.from_native(df_native) + ... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32()))) + ... .explode("lst1", "lst2") + ... .to_native() + ... ) + + We can then pass any supported library such as pandas, Polars (eager), + or PyArrow to `agnostic_explode`: + + >>> agnostic_explode(pl.LazyFrame(data)).collect() + shape: (5, 3) + ┌─────┬──────┬──────┐ + │ a ┆ lst1 ┆ lst2 │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i32 ┆ i32 │ + ╞═════╪══════╪══════╡ + │ x ┆ 1 ┆ 3 │ + │ x ┆ 2 ┆ null │ + │ y ┆ null ┆ null │ + │ z ┆ null ┆ 42 │ + │ w ┆ null ┆ null │ + └─────┴──────┴──────┘ + """ return super().explode(columns, *more_columns) diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index c4e17bd08..87dfca5c8 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -3,9 +3,11 @@ from typing import Sequence import pytest +from polars.exceptions import InvalidOperationError as PlInvalidOperationError from polars.exceptions import ShapeError as PlShapeError import narwhals.stable.v1 as nw +from narwhals.exceptions import InvalidOperationError from narwhals.exceptions import ShapeError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -76,7 +78,7 @@ def test_explode_multiple_cols( assert_equal_data(result, expected) -def test_explode_exception( +def test_explode_shape_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): @@ -93,3 +95,16 @@ def test_explode_exception( .explode("l1", "l3") .collect() ) + + +def test_explode_invalid_operation_error( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + + with pytest.raises( + (InvalidOperationError, PlInvalidOperationError), + match="`explode` operation not supported for dtype", + ): + _ = nw.from_native(constructor(data)).lazy().explode("a").collect() From 3b52ab583c13e10d17b16f3337695a1452fbc185 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:45:22 +0100 Subject: [PATCH 04/10] Update narwhals/dataframe.py --- narwhals/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 664b5ee44..ecfec64a0 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -4733,7 +4733,7 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel *more_columns: Additional names of columns to explode, specified as positional arguments. Returns: - New DataFrame + New LazyFrame Examples: >>> import narwhals as nw From c3bf0096a8a5eec5acfda94c69390f3e347af08c Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 9 Dec 2024 11:12:43 +0100 Subject: [PATCH 05/10] old versions --- narwhals/_arrow/dataframe.py | 2 +- tests/frame/explode_test.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index ea1e31261..f3a28acbe 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -807,7 +807,7 @@ def explode_null_array(array: pa.ChunkedArray) -> pa.ChunkedArray: flatten_func = explode_null_array arrays = [ - native_frame[col_name].take(indices=indices) + native_frame[col_name].take(indices) if col_name in other_columns else flatten_func(native_frame[col_name]) for col_name in original_columns diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index 87dfca5c8..24e458c5e 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -9,6 +9,8 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import InvalidOperationError from narwhals.exceptions import ShapeError +from tests.utils import PANDAS_VERSION +from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -38,6 +40,9 @@ def test_explode_single_col( if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): + request.applymarker(pytest.mark.xfail) + result = ( nw.from_native(constructor(data)) .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) @@ -64,6 +69,9 @@ def test_explode_multiple_cols( if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): + request.applymarker(pytest.mark.xfail) + result = ( nw.from_native(constructor(data)) .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) @@ -84,6 +92,9 @@ def test_explode_shape_error( if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): + request.applymarker(pytest.mark.xfail) + with pytest.raises( (ShapeError, PlShapeError), match="exploded columns must have matching element counts", @@ -103,6 +114,9 @@ def test_explode_invalid_operation_error( if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): + request.applymarker(pytest.mark.xfail) + with pytest.raises( (InvalidOperationError, PlInvalidOperationError), match="`explode` operation not supported for dtype", From 72314a21676a25e6ca615fc5537a6296404def3f Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 17 Dec 2024 14:11:06 +0100 Subject: [PATCH 06/10] almost all native --- narwhals/_arrow/dataframe.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index bc3a0a820..92ce984bb 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -795,6 +795,7 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel if fast_path: indices = pc.list_parent_indices(native_frame[to_explode[0]]) flatten_func = pc.list_flatten + else: indices = pa.array( [ @@ -803,17 +804,18 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel for _ in range(max(count or 1, 1)) ] ) + parent_indices = pc.list_parent_indices(native_frame[to_explode[0]]) + is_valid_index = pc.is_in(indices, value_set=parent_indices) + exploded_size = len(is_valid_index) - def explode_null_array(array: pa.ChunkedArray) -> pa.ChunkedArray: - exploded_values = [] # type: ignore[var-annotated] - for lst_element in array.to_pylist(): - if lst_element is None or len(lst_element) == 0: - exploded_values.append(None) - else: # Non-empty list) - exploded_values.extend(lst_element) - return pa.chunked_array([exploded_values]) + def flatten_func(array: pa.ChunkedArray) -> pa.ChunkedArray: + dtype = array.type.value_type - flatten_func = explode_null_array + return pc.replace_with_mask( + pa.array([None] * exploded_size, type=dtype), + is_valid_index, + pc.list_flatten(array).combine_chunks(), + ) arrays = [ native_frame[col_name].take(indices) From 7f04579bc79f33965400006229354f0346317483 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 17 Dec 2024 14:16:31 +0100 Subject: [PATCH 07/10] doctest --- narwhals/dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index f7e5063e9..4295c27c2 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3070,8 +3070,8 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> agnostic_explode(pa.table(data)) pyarrow.Table a: string - lst1: int64 - lst2: int64 + lst1: int32 + lst2: int32 ---- a: [["x","x","y","z","w"]] lst1: [[1,2,null,null,null]] From 864e9328caf6d33f25b4f5550b0fd7c73fd0a2e7 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 21 Dec 2024 13:01:07 +0100 Subject: [PATCH 08/10] better error message, fail for arrow with nulls --- narwhals/_arrow/dataframe.py | 28 ++++++---------- narwhals/_pandas_like/dataframe.py | 5 ++- tests/frame/explode_test.py | 51 ++++++++++++++++++++---------- 3 files changed, 48 insertions(+), 36 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 92ce984bb..68237eb12 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -773,7 +773,11 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel dtype = schema[col_to_explode] if dtype != dtypes.List: - msg = f"`explode` operation not supported for dtype `{dtype}`" + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) native_frame = self._native_frame @@ -797,25 +801,11 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel flatten_func = pc.list_flatten else: - indices = pa.array( - [ - i - for i, count in enumerate(counts.to_pylist()) - for _ in range(max(count or 1, 1)) - ] + msg = ( + "`DataFrame.explode` is not supported for pyarrow backend and column" + "containing null's or empty list elements" ) - parent_indices = pc.list_parent_indices(native_frame[to_explode[0]]) - is_valid_index = pc.is_in(indices, value_set=parent_indices) - exploded_size = len(is_valid_index) - - def flatten_func(array: pa.ChunkedArray) -> pa.ChunkedArray: - dtype = array.type.value_type - - return pc.replace_with_mask( - pa.array([None] * exploded_size, type=dtype), - is_valid_index, - pc.list_flatten(array).combine_chunks(), - ) + raise NotImplementedError(msg) arrays = [ native_frame[col_name].take(indices) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 53661104b..cdbfd034e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -965,7 +965,10 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel dtype = schema[col_to_explode] if dtype != dtypes.List: - msg = f"`explode` operation not supported for dtype `{dtype}`" + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) raise InvalidOperationError(msg) if len(to_explode) == 1: diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index 24e458c5e..42f2716db 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -21,11 +21,12 @@ "l1": [[1, 2], None, [None], []], "l2": [[3, None], None, [42], []], "l3": [[1, 2], [3], [None], [1]], + "l4": [[1, 2], [3], [123], [456]], } @pytest.mark.parametrize( - ("columns", "expected_values"), + ("column", "expected_values"), [ ("l2", [3, None, None, 42, None]), ("l3", [1, 2, 3, None, 1]), # fast path for arrow @@ -34,7 +35,7 @@ def test_explode_single_col( request: pytest.FixtureRequest, constructor: Constructor, - columns: str, + column: str, expected_values: list[int | None], ) -> None: if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): @@ -43,21 +44,40 @@ def test_explode_single_col( if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): request.applymarker(pytest.mark.xfail) + if "pyarrow_table" in str(constructor) and column == "l2": + request.applymarker(pytest.mark.xfail) + result = ( nw.from_native(constructor(data)) - .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) - .explode(columns) - .select("a", columns) + .with_columns(nw.col(column).cast(nw.List(nw.Int32()))) + .explode(column) + .select("a", column) ) - expected = {"a": ["x", "x", "y", "z", "w"], columns: expected_values} + expected = {"a": ["x", "x", "y", "z", "w"], column: expected_values} assert_equal_data(result, expected) @pytest.mark.parametrize( - ("columns", "more_columns"), + ("columns", "more_columns", "expected"), [ - ("l1", ["l2"]), - (["l1", "l2"], []), + ( + "l1", + ["l2"], + { + "a": ["x", "x", "y", "z", "w"], + "l1": [1, 2, None, None, None], + "l2": [3, None, None, 42, None], + }, + ), + ( + "l3", + ["l4"], + { + "a": ["x", "x", "y", "z", "w"], + "l3": [1, 2, 3, None, 1], + "l4": [1, 2, 3, 123, 456], + }, + ), ], ) def test_explode_multiple_cols( @@ -65,6 +85,7 @@ def test_explode_multiple_cols( constructor: Constructor, columns: str | Sequence[str], more_columns: Sequence[str], + expected: dict[str, list[str | int | None]], ) -> None: if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) @@ -72,17 +93,15 @@ def test_explode_multiple_cols( if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): request.applymarker(pytest.mark.xfail) + if "pyarrow_table" in str(constructor) and columns == "l1": + request.applymarker(pytest.mark.xfail) + result = ( nw.from_native(constructor(data)) - .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) + .with_columns(nw.col(columns, *more_columns).cast(nw.List(nw.Int32()))) .explode(columns, *more_columns) - .select("a", "l1", "l2") + .select("a", columns, *more_columns) ) - expected = { - "a": ["x", "x", "y", "z", "w"], - "l1": [1, 2, None, None, None], - "l2": [3, None, None, 42, None], - } assert_equal_data(result, expected) From cc72f6b94a827b4edb45224ff7df696357de26fa Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 21 Dec 2024 13:13:52 +0100 Subject: [PATCH 09/10] doctest-modules --- narwhals/dataframe.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 8d4f0816c..540a8f34f 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3152,9 +3152,9 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> import polars as pl >>> import pyarrow as pa >>> data = { - ... "a": ["x", "y", "z", "w"], - ... "lst1": [[1, 2], None, [None], []], - ... "lst2": [[3, None], None, [42], []], + ... "a": ["x", "y", "z"], + ... "lst1": [[1, 2], [None, 3], [None]], + ... "lst2": [["foo", None], ["bar", None], ["baz"]], ... } We define a library agnostic function: @@ -3162,7 +3162,10 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT: ... return ( ... nw.from_native(df_native) - ... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32()))) + ... .with_columns( + ... nw.col("lst1").cast(nw.List(nw.Int32())), + ... nw.col("lst2").cast(nw.List(nw.String())), + ... ) ... .explode("lst1", "lst2") ... .to_native() ... ) @@ -3172,33 +3175,33 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> agnostic_explode(pd.DataFrame(data)) a lst1 lst2 - 0 x 1 3 + 0 x 1 foo 0 x 2 - 1 y - 2 z 42 - 3 w + 1 y bar + 1 y 3 + 2 z baz >>> agnostic_explode(pl.DataFrame(data)) shape: (5, 3) ┌─────┬──────┬──────┐ │ a ┆ lst1 ┆ lst2 │ │ --- ┆ --- ┆ --- │ - │ str ┆ i32 ┆ i32 │ + │ str ┆ i32 ┆ str │ ╞═════╪══════╪══════╡ - │ x ┆ 1 ┆ 3 │ + │ x ┆ 1 ┆ foo │ │ x ┆ 2 ┆ null │ - │ y ┆ null ┆ null │ - │ z ┆ null ┆ 42 │ - │ w ┆ null ┆ null │ + │ y ┆ null ┆ bar │ + │ y ┆ 3 ┆ null │ + │ z ┆ null ┆ baz │ └─────┴──────┴──────┘ >>> agnostic_explode(pa.table(data)) pyarrow.Table a: string lst1: int32 - lst2: int32 + lst2: string ---- - a: [["x","x","y","z","w"]] - lst1: [[1,2,null,null,null]] - lst2: [[3,null,null,42,null]] + a: [["x","x","y","y","z"]] + lst1: [[1,2,null,3,null]] + lst2: [["foo",null,"bar",null,"baz"]] """ return super().explode(columns, *more_columns) From 1156beb9c3f2e4abc3b413ade9832128feff45ed Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 21 Dec 2024 13:23:17 +0100 Subject: [PATCH 10/10] completely remove pyarrow implementation --- narwhals/_arrow/dataframe.py | 68 ------------------------------------ narwhals/dataframe.py | 38 +++++++------------- tests/frame/explode_test.py | 23 ++++++------ 3 files changed, 26 insertions(+), 103 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 68237eb12..34758bd82 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -19,7 +19,6 @@ from narwhals.utils import Implementation from narwhals.utils import flatten from narwhals.utils import generate_temporary_column_name -from narwhals.utils import import_dtypes_module from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop from narwhals.utils import scale_bytes @@ -753,70 +752,3 @@ def unpivot( ) # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not # upcast numeric to non-numeric (e.g. string) datatypes - - def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: - import pyarrow as pa - import pyarrow.compute as pc - - from narwhals.exceptions import InvalidOperationError - - dtypes = import_dtypes_module(self._version) - - to_explode = ( - [columns, *more_columns] - if isinstance(columns, str) - else [*columns, *more_columns] - ) - - schema = self.collect_schema() - for col_to_explode in to_explode: - dtype = schema[col_to_explode] - - if dtype != dtypes.List: - msg = ( - f"`explode` operation not supported for dtype `{dtype}`, " - "expected List type" - ) - - raise InvalidOperationError(msg) - - native_frame = self._native_frame - counts = pc.list_value_length(native_frame[to_explode[0]]) - - if not all( - pc.all(pc.equal(pc.list_value_length(native_frame[col_name]), counts)).as_py() - for col_name in to_explode[1:] - ): - from narwhals.exceptions import ShapeError - - msg = "exploded columns must have matching element counts" - raise ShapeError(msg) - - original_columns = self.columns - other_columns = [c for c in original_columns if c not in to_explode] - fast_path = pc.all(pc.greater_equal(counts, 1)).as_py() - - if fast_path: - indices = pc.list_parent_indices(native_frame[to_explode[0]]) - flatten_func = pc.list_flatten - - else: - msg = ( - "`DataFrame.explode` is not supported for pyarrow backend and column" - "containing null's or empty list elements" - ) - raise NotImplementedError(msg) - - arrays = [ - native_frame[col_name].take(indices) - if col_name in other_columns - else flatten_func(native_frame[col_name]) - for col_name in original_columns - ] - - return self._from_native_frame( - pa.Table.from_arrays( - arrays=arrays, - names=original_columns, - ) - ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 540a8f34f..8f120532c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3152,9 +3152,9 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> import polars as pl >>> import pyarrow as pa >>> data = { - ... "a": ["x", "y", "z"], - ... "lst1": [[1, 2], [None, 3], [None]], - ... "lst2": [["foo", None], ["bar", None], ["baz"]], + ... "a": ["x", "y", "z", "w"], + ... "lst1": [[1, 2], None, [None], []], + ... "lst2": [[3, None], None, [42], []], ... } We define a library agnostic function: @@ -3162,10 +3162,7 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT: ... return ( ... nw.from_native(df_native) - ... .with_columns( - ... nw.col("lst1").cast(nw.List(nw.Int32())), - ... nw.col("lst2").cast(nw.List(nw.String())), - ... ) + ... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32()))) ... .explode("lst1", "lst2") ... .to_native() ... ) @@ -3175,33 +3172,24 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel >>> agnostic_explode(pd.DataFrame(data)) a lst1 lst2 - 0 x 1 foo + 0 x 1 3 0 x 2 - 1 y bar - 1 y 3 - 2 z baz + 1 y + 2 z 42 + 3 w >>> agnostic_explode(pl.DataFrame(data)) shape: (5, 3) ┌─────┬──────┬──────┐ │ a ┆ lst1 ┆ lst2 │ │ --- ┆ --- ┆ --- │ - │ str ┆ i32 ┆ str │ + │ str ┆ i32 ┆ i32 │ ╞═════╪══════╪══════╡ - │ x ┆ 1 ┆ foo │ + │ x ┆ 1 ┆ 3 │ │ x ┆ 2 ┆ null │ - │ y ┆ null ┆ bar │ - │ y ┆ 3 ┆ null │ - │ z ┆ null ┆ baz │ + │ y ┆ null ┆ null │ + │ z ┆ null ┆ 42 │ + │ w ┆ null ┆ null │ └─────┴──────┴──────┘ - >>> agnostic_explode(pa.table(data)) - pyarrow.Table - a: string - lst1: int32 - lst2: string - ---- - a: [["x","x","y","y","z"]] - lst1: [[1,2,null,3,null]] - lst2: [["foo",null,"bar",null,"baz"]] """ return super().explode(columns, *more_columns) diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index 42f2716db..631da0255 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -38,15 +38,15 @@ def test_explode_single_col( column: str, expected_values: list[int | None], ) -> None: - if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + if any( + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "pyarrow_table") + ): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): request.applymarker(pytest.mark.xfail) - if "pyarrow_table" in str(constructor) and column == "l2": - request.applymarker(pytest.mark.xfail) - result = ( nw.from_native(constructor(data)) .with_columns(nw.col(column).cast(nw.List(nw.Int32()))) @@ -87,15 +87,15 @@ def test_explode_multiple_cols( more_columns: Sequence[str], expected: dict[str, list[str | int | None]], ) -> None: - if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + if any( + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "pyarrow_table") + ): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): request.applymarker(pytest.mark.xfail) - if "pyarrow_table" in str(constructor) and columns == "l1": - request.applymarker(pytest.mark.xfail) - result = ( nw.from_native(constructor(data)) .with_columns(nw.col(columns, *more_columns).cast(nw.List(nw.Int32()))) @@ -108,7 +108,10 @@ def test_explode_multiple_cols( def test_explode_shape_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + if any( + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "pyarrow_table") + ): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): @@ -130,7 +133,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):