Skip to content

Commit

Permalink
add nulls_last kw in dataframe sort (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Oct 3, 2024
1 parent 3f8ba38 commit bfd42e5
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 23 deletions.
8 changes: 6 additions & 2 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ def sort(
self,
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
descending: bool | Sequence[bool],
nulls_last: bool,
) -> Self:
flat_keys = flatten([*flatten([by]), *more_by])
df = self._native_frame
Expand All @@ -408,7 +409,10 @@ def sort(
(key, "descending" if is_descending else "ascending")
for key, is_descending in zip(flat_keys, descending)
]
return self._from_native_frame(df.sort_by(sorting=sorting))

null_placement = "at_end" if nulls_last else "at_start"

return self._from_native_frame(df.sort_by(sorting, null_placement=null_placement))

def to_pandas(self) -> Any:
return self._native_frame.to_pandas()
Expand Down
8 changes: 6 additions & 2 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,19 @@ def sort(
self: Self,
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
descending: bool | Sequence[bool],
nulls_last: bool,
) -> Self:
flat_keys = flatten([*flatten([by]), *more_by])
df = self._native_frame
if isinstance(descending, bool):
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
return self._from_native_frame(df.sort_values(flat_keys, ascending=ascending))
na_position = "last" if nulls_last else "first"
return self._from_native_frame(
df.sort_values(flat_keys, ascending=ascending, na_position=na_position)
)

def join(
self: Self,
Expand Down
8 changes: 6 additions & 2 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,19 @@ def sort(
self,
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
descending: bool | Sequence[bool],
nulls_last: bool,
) -> Self:
flat_keys = flatten([*flatten([by]), *more_by])
df = self._native_frame
if isinstance(descending, bool):
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
return self._from_native_frame(df.sort_values(flat_keys, ascending=ascending))
na_position = "last" if nulls_last else "first"
return self._from_native_frame(
df.sort_values(flat_keys, ascending=ascending, na_position=na_position)
)

# --- convert ---
def collect(self) -> PandasLikeDataFrame:
Expand Down
43 changes: 26 additions & 17 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ def sort(
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
) -> Self:
return self._from_compliant_dataframe(
self._compliant_frame.sort(by, *more_by, descending=descending)
self._compliant_frame.sort(
by, *more_by, descending=descending, nulls_last=nulls_last
)
)

def join(
Expand Down Expand Up @@ -1944,19 +1947,22 @@ def sort(
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
) -> Self:
r"""
Sort the dataframe by the given columns.
Arguments:
by: Column(s) names to sort by.
*more_by: Additional columns to sort by, specified as positional arguments.
descending: Sort in descending order. When sorting by multiple columns, can be
specified per column by passing a sequence of booleans.
nulls_last: Place null values last.
*more_by: Additional columns to sort by, specified as positional
arguments.
descending: Sort in descending order. When sorting by multiple
columns, can be specified per column by passing a
sequence of booleans.
Warning:
Unlike Polars, it is not possible to specify a sequence of booleans for
`nulls_last` in order to control per-column behaviour. Instead a single
boolean is applied for all `by` columns.
Examples:
>>> import narwhals as nw
Expand Down Expand Up @@ -1996,7 +2002,7 @@ def sort(
│ 2 ┆ 5.0 ┆ c │
└──────┴─────┴─────┘
"""
return super().sort(by, *more_by, descending=descending)
return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last)

def join(
self,
Expand Down Expand Up @@ -3858,20 +3864,23 @@ def sort(
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
) -> Self:
r"""
Sort the LazyFrame by the given columns.
Arguments:
by: Column(s) to sort by. Accepts expression input. Strings are
parsed as column names.
*more_by: Additional columns to sort by, specified as positional
arguments.
by: Column(s) names to sort by.
*more_by: Additional columns to sort by, specified as positional arguments.
descending: Sort in descending order. When sorting by multiple columns, can be
specified per column by passing a sequence of booleans.
nulls_last: Place null values last; can specify a single boolean applying to
all columns or a sequence of booleans for per-column control.
descending: Sort in descending order. When sorting by multiple
columns, can be specified per column by passing a
sequence of booleans.
Warning:
Unlike Polars, it is not possible to specify a sequence of booleans for
`nulls_last` in order to control per-column behaviour. Instead a single
boolean is applied for all `by` columns.
Examples:
>>> import narwhals as nw
Expand Down Expand Up @@ -3911,7 +3920,7 @@ def sort(
│ 2 ┆ 5.0 ┆ c │
└──────┴─────┴─────┘
"""
return super().sort(by, *more_by, descending=descending)
return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last)

def join(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/frame/sort_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import compare_dicts
Expand All @@ -20,3 +24,19 @@ def test_sort(constructor: Constructor) -> None:
"z": [8.0, 9.0, 7.0],
}
compare_dicts(result, expected)


@pytest.mark.parametrize(
("nulls_last", "expected"),
[
(True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, float("nan")]}),
(False, {"a": [-1, 0, 2, 0], "b": [float("nan"), 3, 2, 1]}),
],
)
def test_sort_nulls(
constructor: Constructor, *, nulls_last: bool, expected: dict[str, float]
) -> None:
data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]}
df = nw.from_native(constructor(data))
result = df.sort("b", descending=True, nulls_last=nulls_last)
compare_dicts(result, expected)

0 comments on commit bfd42e5

Please sign in to comment.