Skip to content

Commit

Permalink
feat: add LazyFrame.explode for duckdb (#1891)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Feb 3, 2025
1 parent 22d6df3 commit 6df6afe
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 37 deletions.
49 changes: 49 additions & 0 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@

import duckdb
from duckdb import ColumnExpression
from duckdb import ConstantExpression
from duckdb import FunctionExpression

from narwhals._duckdb.utils import ExprKind
from narwhals._duckdb.utils import native_to_narwhals_dtype
from narwhals._duckdb.utils import parse_exprs_and_named_exprs
from narwhals.dependencies import get_duckdb
from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantDataFrame
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import validate_backend_version
Expand Down Expand Up @@ -427,6 +431,51 @@ def drop_nulls(self: Self, subset: list[str] | None) -> Self:
query = f"select * from rel where {keep_condition}" # noqa: S608
return self._from_native_frame(duckdb.sql(query))

def explode(self: Self, columns: list[str]) -> Self:
dtypes = import_dtypes_module(self._version)
schema = self.collect_schema()
for col in columns:
dtype = schema[col]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)

if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with DuckDB backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)

col_to_explode = ColumnExpression(columns[0])
rel = self._native_frame
original_columns = self.columns

not_null_condition = (
col_to_explode.isnotnull() & FunctionExpression("len", col_to_explode) > 0
)
non_null_rel = rel.filter(not_null_condition).select(
*(
FunctionExpression("unnest", col_to_explode).alias(col)
if col in columns
else col
for col in original_columns
)
)

null_rel = rel.filter(~not_null_condition).select(
*(
ConstantExpression(None).alias(col) if col in columns else col
for col in original_columns
)
)

return self._from_native_frame(non_null_rel.union(null_rel))

def unpivot(
self: Self,
on: str | list[str] | None,
Expand Down
28 changes: 11 additions & 17 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals._pandas_like.utils import rename
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -1071,18 +1072,11 @@ def unpivot(
)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
from narwhals.exceptions import InvalidOperationError

def explode(self: Self, columns: list[str]) -> Self:
dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
for col_to_explode in columns:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
Expand All @@ -1092,29 +1086,29 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel
)
raise InvalidOperationError(msg)

if len(to_explode) == 1:
return self._from_native_frame(self._native_frame.explode(to_explode[0]))
if len(columns) == 1:
return self._from_native_frame(self._native_frame.explode(columns[0]))
else:
native_frame = self._native_frame
anchor_series = native_frame[to_explode[0]].list.len()
anchor_series = native_frame[columns[0]].list.len()

if not all(
(native_frame[col_name].list.len() == anchor_series).all()
for col_name in to_explode[1:]
for col_name in columns[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]
other_columns = [c for c in original_columns if c not in columns]

exploded_frame = native_frame[[*other_columns, to_explode[0]]].explode(
to_explode[0]
exploded_frame = native_frame[[*other_columns, columns[0]]].explode(
columns[0]
)
exploded_series = [
native_frame[col_name].explode().to_frame() for col_name in to_explode[1:]
native_frame[col_name].explode().to_frame() for col_name in columns[1:]
]

plx = self.__native_namespace__()
Expand Down
13 changes: 4 additions & 9 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,16 +370,11 @@ def join(
self_native.join(other, on=left_on, how=how).select(col_order)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
def explode(self: Self, columns: list[str]) -> Self:
dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
for col_to_explode in columns:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
Expand All @@ -392,7 +387,7 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel
native_frame = self._native_frame
column_names = self.columns

if len(to_explode) != 1:
if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with SparkLike backend since "
"we cannot guarantee that the exploded columns have matching element counts."
Expand All @@ -403,7 +398,7 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel
native_frame.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != to_explode[0]
if col_name != columns[0]
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
Expand Down
11 changes: 7 additions & 4 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,14 @@ def __eq__(self: Self, other: object) -> NoReturn:
raise NotImplementedError(msg)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)

return self._from_compliant_dataframe(
self._compliant_frame.explode(
columns,
*more_columns,
)
self._compliant_frame.explode(columns=to_explode)
)


Expand Down
13 changes: 7 additions & 6 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
@pytest.mark.parametrize(
("column", "expected_values"),
[
("l2", [3, None, None, 42, None]),
("l3", [1, 2, 3, None, 1]), # fast path for arrow
("l2", [None, 3, None, None, 42]),
("l3", [1, 1, 2, 3, None]), # fast path for arrow
],
)
def test_explode_single_col(
Expand All @@ -40,7 +40,7 @@ def test_explode_single_col(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb")
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -52,8 +52,9 @@ def test_explode_single_col(
.with_columns(nw.col(column).cast(nw.List(nw.Int32())))
.explode(column)
.select("a", column)
.sort("a")
)
expected = {"a": ["x", "x", "y", "z", "w"], column: expected_values}
expected = {"a": ["w", "x", "x", "y", "z"], column: expected_values}
assert_equal_data(result, expected)


Expand Down Expand Up @@ -110,7 +111,7 @@ def test_explode_shape_error(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb")
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -133,7 +134,7 @@ def test_explode_shape_error(
def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")):
if any(x in str(constructor) for x in ("pyarrow_table", "dask")):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None:
for idx, (col, key) in enumerate(zip(result.columns, expected.keys())):
assert col == key, f"Expected column name {key} at index {idx}, found {col}"
result = {key: _to_comparable_list(result[key]) for key in expected}
if is_pyspark and expected: # pragma: no cover
if (is_pyspark or is_duckdb) and expected: # pragma: no cover
sort_key = next(iter(expected.keys()))
expected = _sort_dict_by_key(expected, sort_key)
result = _sort_dict_by_key(result, sort_key)
Expand Down

0 comments on commit 6df6afe

Please sign in to comment.