Skip to content

Commit

Permalink
feat: Allow sorting of lists and arrays (#20169)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Dec 6, 2024
1 parent 4a18809 commit 947bf89
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 16 deletions.
59 changes: 43 additions & 16 deletions crates/polars-core/src/chunked_array/ops/row_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,29 @@ use crate::utils::_split_offsets;
use crate::POOL;

pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult<Series> {
use DataType::*;
use DataType as D;
let out = match s.dtype() {
D::Null
| D::Boolean
| D::UInt8
| D::UInt16
| D::UInt32
| D::UInt64
| D::Int8
| D::Int16
| D::Int32
| D::Int64
| D::Float32
| D::Float64
| D::String
| D::Binary
| D::BinaryOffset => s.clone(),

#[cfg(feature = "dtype-categorical")]
Categorical(_, _) | Enum(_, _) => s.rechunk(),
Binary | Boolean => s.clone(),
BinaryOffset => s.clone(),
String => s.clone(),
D::Categorical(_, _) | D::Enum(_, _) => s.rechunk(),

#[cfg(feature = "dtype-struct")]
Struct(_) => {
D::Struct(_) => {
let ca = s.struct_().unwrap();
let new_fields = ca
.fields_as_series()
Expand All @@ -29,16 +43,29 @@ pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult<Series
},
// we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here
#[cfg(feature = "dtype-decimal")]
Decimal(_, _) => s.clone(),
List(inner) if !inner.is_nested() => s.clone(),
Null => s.clone(),
_ => {
let phys = s.to_physical_repr().into_owned();
polars_ensure!(
phys.dtype().is_numeric(),
InvalidOperation: "cannot sort column of dtype `{}`", s.dtype()
);
phys
D::Decimal(_, _) => s.clone(),
#[cfg(feature = "dtype-array")]
D::Array(_, _) => s
.array()
.unwrap()
.apply_to_inner(&|s| convert_series_for_row_encoding(&s))
.unwrap()
.into_series(),
D::List(_) => s
.list()
.unwrap()
.apply_to_inner(&|s| convert_series_for_row_encoding(&s))
.unwrap()
.into_series(),

D::Date | D::Datetime(_, _) | D::Duration(_) | D::Time => s.to_physical_repr().into_owned(),

#[cfg(feature = "object")]
D::Object(_, _) => {
polars_bail!( InvalidOperation: "cannot sort column of dtype `{}`", s.dtype())
},
D::Unknown(_) => {
polars_bail!( InvalidOperation: "cannot sort column of dtype `{}`", s.dtype())
},
};
Ok(out)
Expand Down
22 changes: 22 additions & 0 deletions crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use polars_utils::itertools::Itertools;

use self::row_encode::_get_rows_encoded;
use super::*;

// Reduce monomorphisation.
Expand Down Expand Up @@ -149,3 +152,22 @@ where

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
}

pub(crate) fn arg_sort_row_fmt(
by: &[Column],
descending: bool,
nulls_last: bool,
parallel: bool,
) -> PolarsResult<IdxCa> {
let rows_encoded = _get_rows_encoded(by, &[descending], &[nulls_last])?;
let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();

if parallel {
POOL.install(|| items.par_sort_by(|a, b| a.1.cmp(b.1)));
} else {
items.sort_by(|a, b| a.1.cmp(b.1));
}

let ca: NoNull<IdxCa> = items.into_iter().map(|tpl| tpl.0).collect();
Ok(ca.into_inner())
}
1 change: 1 addition & 0 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod categorical;

use std::cmp::Ordering;

pub(crate) use arg_sort::arg_sort_row_fmt;
pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::buffer::Buffer;
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::any::Any;
use std::borrow::Cow;

use self::sort::arg_sort_row_fmt;
use super::{private, MetadataFlags};
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::comparison::*;
Expand Down Expand Up @@ -89,6 +90,23 @@ impl SeriesTrait for SeriesWrap<ArrayChunked> {
self.0.shrink_to_fit()
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
let slf = (*self).clone();
let slf = slf.into_column();
arg_sort_row_fmt(
&[slf],
options.descending,
options.nulls_last,
options.multithreaded,
)
.unwrap()
}

fn sort_with(&self, options: SortOptions) -> PolarsResult<Series> {
let idxs = self.arg_sort(options);
Ok(unsafe { self.take_unchecked(&idxs) })
}

fn slice(&self, offset: i64, length: usize) -> Series {
self.0.slice(offset, length).into_series()
}
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use self::sort::arg_sort_row_fmt;
use super::*;
use crate::chunked_array::comparison::*;
#[cfg(feature = "algorithm_group_by")]
Expand Down Expand Up @@ -93,6 +94,23 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
);
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
let slf = (*self).clone();
let slf = slf.into_column();
arg_sort_row_fmt(
&[slf],
options.descending,
options.nulls_last,
options.multithreaded,
)
.unwrap()
}

fn sort_with(&self, options: SortOptions) -> PolarsResult<Series> {
let idxs = self.arg_sort(options);
Ok(unsafe { self.take_unchecked(&idxs) })
}

fn slice(&self, offset: i64, length: usize) -> Series {
self.0.slice(offset, length).into_series()
}
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/datatypes/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,17 @@ def test_zero_width_array(fn: str) -> None:

df = pl.concat([a.to_frame(), b.to_frame()], how="horizontal")
df.select(c=expr_f(pl.col.a, pl.col.b))


def test_sort() -> None:
def tc(a: list[Any], b: list[Any], w: int) -> None:
a_s = pl.Series("l", a, pl.Array(pl.Int64, w))
b_s = pl.Series("l", b, pl.Array(pl.Int64, w))

assert_series_equal(a_s.sort(), b_s)

tc([], [], 1)
tc([[1]], [[1]], 1)
tc([[2], [1]], [[1], [2]], 1)
tc([[2, 1]], [[2, 1]], 2)
tc([[2, 1], [1, 2]], [[1, 2], [2, 1]], 2)
14 changes: 14 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,17 @@ def test_null_list_categorical_16405() -> None:

expected = pl.DataFrame([None], schema={"result": pl.List(pl.Categorical)})
assert_frame_equal(df, expected)


def test_sort() -> None:
def tc(a: list[Any], b: list[Any]) -> None:
a_s = pl.Series("l", a, pl.List(pl.Int64))
b_s = pl.Series("l", b, pl.List(pl.Int64))

assert_series_equal(a_s.sort(), b_s)

tc([], [])
tc([[1]], [[1]])
tc([[1], []], [[], [1]])
tc([[2, 1]], [[2, 1]])
tc([[2, 1], [1, 2]], [[1, 2], [2, 1]])

0 comments on commit 947bf89

Please sign in to comment.