Skip to content

Commit

Permalink
chore: Move rolling to polars-compute (#21503)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 27, 2025
1 parent f2fd6f8 commit b35bc7b
Show file tree
Hide file tree
Showing 48 changed files with 122 additions and 67 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/polars-arrow/src/legacy/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::iter::Enumerate;
use crate::array::BooleanArray;
use crate::bitmap::utils::BitChunks;
pub mod ewm;
pub mod rolling;
pub mod set;
pub mod sort_partition;
#[cfg(feature = "performant")]
Expand Down
7 changes: 0 additions & 7 deletions crates/polars-arrow/src/legacy/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@ use crate::array::{BinaryArray, ListArray, Utf8Array};
pub use crate::legacy::array::default_arrays::*;
pub use crate::legacy::array::*;
pub use crate::legacy::index::*;
pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod;
pub use crate::legacy::kernels::rolling::{
RollingFnParams, RollingQuantileParams, RollingVarParams,
};
pub use crate::legacy::kernels::{Ambiguous, NonExistent};

pub type LargeStringArray = Utf8Array<i64>;
pub type LargeBinaryArray = BinaryArray<i64>;
pub type LargeListArray = ListArray<i64>;

#[allow(deprecated)]
pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions;
2 changes: 2 additions & 0 deletions crates/polars-compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ num-traits = { workspace = true }
polars-error = { workspace = true }
polars-utils = { workspace = true }
ryu = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
strength_reduce = { workspace = true }
strum_macros = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions crates/polars-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub mod hyperloglogplus;
pub mod if_then_else;
pub mod min_max;
pub mod propagate_dictionary;
pub mod rolling;
pub mod size;
pub mod sum;
pub mod unique;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,39 @@ mod window;

use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};

use arrow::array::{ArrayRef, PrimitiveArray};
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::types::NativeType;
use num_traits::{Bounded, Float, NumCast, One, Zero};
use polars_utils::float::IsFloat;
use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
use window::*;

use crate::array::{ArrayRef, PrimitiveArray};
use crate::bitmap::{Bitmap, MutableBitmap};
use crate::legacy::prelude::*;
use crate::legacy::utils::CustomIterTools;
use crate::types::NativeType;

type Start = usize;
type End = usize;
type Idx = usize;
type WindowSize = usize;
type Len = usize;

#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[strum(serialize_all = "snake_case")]
pub enum QuantileMethod {
#[default]
Nearest,
Lower,
Higher,
Midpoint,
Linear,
Equiprobable,
}

#[deprecated(note = "use QuantileMethod instead")]
pub type QuantileInterpolOptions = QuantileMethod;

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum RollingFnParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,19 @@ mod sum;
mod variance;
use std::fmt::Debug;

use arrow::array::PrimitiveArray;
use arrow::datatypes::ArrowDataType;
use arrow::legacy::error::PolarsResult;
use arrow::legacy::utils::CustomIterTools;
use arrow::types::NativeType;
pub use mean::*;
pub use min_max::*;
use num_traits::{Float, Num, NumCast};
pub use quantile::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
pub use sum::*;
pub use variance::*;

use super::*;
use crate::array::PrimitiveArray;
use crate::datatypes::ArrowDataType;
use crate::legacy::error::PolarsResult;
use crate::types::NativeType;

pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self;
Expand Down Expand Up @@ -70,22 +68,6 @@ where
Ok(Box::new(arr))
}

#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[strum(serialize_all = "snake_case")]
pub enum QuantileMethod {
#[default]
Nearest,
Lower,
Higher,
Midpoint,
Linear,
Equiprobable,
}

#[deprecated(note = "use QuantileMethod instead")]
pub type QuantileInterpolOptions = QuantileMethod;

pub(super) fn rolling_apply_weights<T, Fo, Fa>(
values: &[T],
window_size: usize,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use arrow::legacy::utils::CustomIterTools;
use num_traits::ToPrimitive;
use polars_error::polars_ensure;

use super::QuantileMethod::*;
use super::*;
use crate::rolling::quantile_filter::SealedRolling;

pub struct QuantileWindow<'a, T: NativeType> {
sorted: SortedBuf<'a, T>,
Expand All @@ -21,6 +23,7 @@ impl<
+ NumCast
+ One
+ Zero
+ SealedRolling
+ Sub<Output = T>,
> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
{
Expand Down Expand Up @@ -115,6 +118,7 @@ where
+ NumCast
+ One
+ Zero
+ SealedRolling
+ PartialOrd
+ Sub<Output = T>,
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use arrow::array::iterator::NonNullValuesIter;
use arrow::bitmap::utils::count_zeros;

use super::*;
use crate::array::iterator::NonNullValuesIter;
use crate::bitmap::utils::count_zeros;

pub fn is_reverse_sorted_max_nulls<T: NativeType>(values: &[T], validity: &Bitmap) -> bool {
let mut it = NonNullValuesIter::new(values, Some(validity));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod quantile;
mod sum;
mod variance;

use arrow::legacy::utils::CustomIterTools;
pub use mean::*;
pub use min_max::*;
pub use quantile::*;
Expand Down Expand Up @@ -90,10 +91,11 @@ where

#[cfg(test)]
mod test {
use arrow::array::{Array, Int32Array};
use arrow::buffer::Buffer;
use arrow::datatypes::ArrowDataType;

use super::*;
use crate::array::{Array, Int32Array};
use crate::buffer::Buffer;
use crate::datatypes::ArrowDataType;

fn get_null_arr() -> PrimitiveArray<f64> {
// 1, None, -1, 4
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use arrow::array::MutablePrimitiveArray;

use super::*;
use crate::array::MutablePrimitiveArray;
use crate::rolling::quantile_filter::SealedRolling;

pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
sorted: SortedBufNulls<'a, T>,
Expand All @@ -19,6 +21,7 @@ impl<
+ NumCast
+ One
+ Zero
+ SealedRolling
+ PartialOrd
+ Sub<Output = T>,
> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
Expand Down Expand Up @@ -117,6 +120,7 @@ where
+ NumCast
+ One
+ Zero
+ SealedRolling
+ PartialOrd
+ Sub<Output = T>,
{
Expand Down Expand Up @@ -155,9 +159,10 @@ where

#[cfg(test)]
mod test {
use arrow::buffer::Buffer;
use arrow::datatypes::ArrowDataType;

use super::*;
use crate::buffer::Buffer;
use crate::datatypes::ArrowDataType;

#[test]
fn test_rolling_median_nulls() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::ops::{Add, Div, Mul, Sub};

use arrow::pushable::Pushable;
use arrow::types::NativeType;
use num_traits::NumCast;
use polars_utils::index::{Bounded, Indexable, NullCount};
use polars_utils::nulls::IsNull;
use polars_utils::slice::SliceAble;
use polars_utils::sort::arg_sort_ascending;
use polars_utils::total_ord::TotalOrd;

use crate::legacy::prelude::QuantileMethod;
use crate::pushable::Pushable;
use crate::types::NativeType;
use super::QuantileMethod;

struct Block<'a, A> {
k: usize,
Expand Down Expand Up @@ -527,13 +527,28 @@ pub(super) trait FinishLinear {
fn finish_midpoint(lower: Self, upper: Self) -> Self;
}

pub trait SealedRolling {}

impl SealedRolling for i8 {}
impl SealedRolling for i16 {}
impl SealedRolling for i32 {}
impl SealedRolling for i64 {}
impl SealedRolling for u8 {}
impl SealedRolling for u16 {}
impl SealedRolling for u32 {}
impl SealedRolling for u64 {}
impl SealedRolling for i128 {}
impl SealedRolling for f32 {}
impl SealedRolling for f64 {}

impl<
T: NativeType
+ NumCast
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>
+ SealedRolling
+ Debug,
> FinishLinear for T
{
Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use arrow::types::NativeType;
use num_traits::{Float, One, ToPrimitive, Zero};
use polars_compute::float_sum;
use polars_compute::min_max::MinMaxKernel;
use polars_compute::rolling::QuantileMethod;
use polars_compute::sum::{wrapping_sum_arr, WrappingSum};
use polars_utils::min_max::MinMax;
use polars_utils::sync::SyncPtr;
Expand Down Expand Up @@ -657,6 +658,8 @@ impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}

#[cfg(test)]
mod test {
use polars_compute::rolling::QuantileMethod;

use crate::prelude::*;

#[test]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use polars_compute::rolling::QuantileMethod;

use super::*;

pub trait QuantileAggSeries {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Traits for miscellaneous operations on ChunkedArray
use arrow::offset::OffsetsBuffer;
use polars_compute::rolling::QuantileMethod;

use crate::prelude::*;

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::legacy::prelude::RollingFnParams;
use polars_compute::rolling::RollingFnParams;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down
2 changes: 2 additions & 0 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::borrow::Cow;
use arrow::bitmap::BitmapBuilder;
use arrow::trusted_len::TrustMyLength;
use num_traits::{Num, NumCast};
use polars_compute::rolling::QuantileMethod;
use polars_error::PolarsResult;
use polars_utils::index::check_bounds;
use polars_utils::pl_str::PlSmallStr;
Expand Down Expand Up @@ -753,6 +754,7 @@ impl Column {
method: QuantileMethod,
) -> Self {
// @scalar-opt

unsafe {
self.as_materialized_series()
.agg_quantile(groups, quantile, method)
Expand Down
18 changes: 10 additions & 8 deletions crates/polars-core/src/frame/group_by/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ use std::cmp::Ordering;

pub use agg_list::*;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::legacy::kernels::rolling;
use arrow::legacy::kernels::rolling::no_nulls::{
MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow,
};
use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls;
use arrow::legacy::kernels::take_agg::*;
use arrow::legacy::prelude::QuantileMethod;
use arrow::legacy::trusted_len::TrustedLenPush;
use arrow::types::NativeType;
use num_traits::pow::Pow;
use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
use polars_compute::rolling::no_nulls::{
MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow,
};
use polars_compute::rolling::nulls::RollingAggWindowNulls;
use polars_compute::rolling::quantile_filter::SealedRolling;
use polars_compute::rolling::{
self, quantile_filter, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams,
};
use polars_utils::float::IsFloat;
use polars_utils::idx_vec::IdxVec;
use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
Expand Down Expand Up @@ -343,7 +345,7 @@ where
ChunkedArray<T>: QuantileDispatcher<K::Native>,
ChunkedArray<K>: IntoSeries,
K: PolarsNumericType,
<K as datatypes::PolarsNumericType>::Native: num_traits::Float,
<K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
{
let invalid_quantile = !(0.0..=1.0).contains(&quantile);
if invalid_quantile {
Expand Down Expand Up @@ -423,7 +425,7 @@ where
ChunkedArray<T>: QuantileDispatcher<K::Native>,
ChunkedArray<K>: IntoSeries,
K: PolarsNumericType,
<K as datatypes::PolarsNumericType>::Native: num_traits::Float,
<K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
{
match groups {
GroupsType::Idx(groups) => {
Expand Down
Loading

0 comments on commit b35bc7b

Please sign in to comment.