Skip to content

Commit

Permalink
feat: Pass filter to inner readers in multiscan new streaming (#21436)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Feb 24, 2025
1 parent 4c08d28 commit c37e482
Show file tree
Hide file tree
Showing 12 changed files with 298 additions and 123 deletions.
12 changes: 8 additions & 4 deletions crates/polars-io/src/parquet/read/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ pub fn read_this_row_group(

let mut should_read = true;

if let Some(pred) = predicate {
if let Some(pred) = &pred.skip_batch_predicate {
if let Some(predicate) = predicate {
if let Some(pred) = &predicate.skip_batch_predicate {
if let Some(stats) = collect_statistics(md, schema)? {
let stats = PlIndexMap::from_iter(stats.column_stats().iter().map(|col| {
(
Expand All @@ -86,7 +86,11 @@ pub fn read_this_row_group(
},
)
}));
let pred_result = pred.can_skip_batch(md.num_rows() as IdxSize, stats);
let pred_result = pred.can_skip_batch(
md.num_rows() as IdxSize,
predicate.live_columns.as_ref(),
stats,
);

// a parquet file may not have statistics of all columns
match pred_result {
Expand All @@ -97,7 +101,7 @@ pub fn read_this_row_group(
_ => {},
}
}
} else if let Some(pred) = pred.predicate.as_stats_evaluator() {
} else if let Some(pred) = predicate.predicate.as_stats_evaluator() {
if let Some(stats) = collect_statistics(md, schema)? {
let pred_result = pred.should_read(&stats);

Expand Down
124 changes: 122 additions & 2 deletions crates/polars-io/src/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow::bitmap::{Bitmap, MutableBitmap};
use polars_core::prelude::*;
#[cfg(feature = "parquet")]
use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange};
use polars_utils::format_pl_smallstr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -345,14 +346,56 @@ pub struct ColumnStatistics {
}

pub trait SkipBatchPredicate: Send + Sync {
fn schema(&self) -> &SchemaRef;

fn can_skip_batch(
&self,
batch_size: IdxSize,
statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool>;
live_columns: &PlIndexSet<PlSmallStr>,
mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool> {
let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);

columns.push(Column::new_scalar(
PlSmallStr::from_static("len"),
Scalar::new(IDX_DTYPE, batch_size.into()),
1,
));

for col in live_columns.iter() {
let dtype = self.schema().get(col).unwrap();
let (min, max, nc) = match statistics.swap_remove(col) {
None => (
Scalar::null(dtype.clone()),
Scalar::null(dtype.clone()),
Scalar::null(IDX_DTYPE),
),
Some(stat) => (
Scalar::new(dtype.clone(), stat.min),
Scalar::new(dtype.clone(), stat.max),
Scalar::new(
IDX_DTYPE,
stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
),
),
};
columns.extend([
Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
]);
}

// SAFETY:
// * Each column is length = 1
// * We have an IndexSet, so each column name is unique
let df = unsafe { DataFrame::new_no_checks(1, columns) };
Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
}
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
}

#[derive(Clone)]
pub struct ColumnPredicates {
pub predicates: PlHashMap<
PlSmallStr,
Expand All @@ -375,6 +418,44 @@ impl Default for ColumnPredicates {
}
}

pub struct PhysicalExprWithConstCols<T> {
constants: Vec<(PlSmallStr, Scalar)>,
child: T,
}

impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
fn schema(&self) -> &SchemaRef {
self.child.schema()
}

fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
let mut df = df.clone();
for (name, scalar) in self.constants.iter() {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}
self.child.evaluate_with_stat_df(&df)
}
}

impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
let mut df = df.clone();
for (name, scalar) in self.constants.iter() {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}

self.child.evaluate_io(&df)
}
}

#[derive(Clone)]
pub struct ScanIOPredicate {
pub predicate: Arc<dyn PhysicalIoExpr>,
Expand All @@ -388,6 +469,45 @@ pub struct ScanIOPredicate {
/// A predicate that gets given statistics and evaluates whether a batch can be skipped.
pub column_predicates: Arc<ColumnPredicates>,
}
impl ScanIOPredicate {
pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
let mut live_columns = self.live_columns.as_ref().clone();
for (c, _) in constant_columns.iter() {
live_columns.swap_remove(c);
}
self.live_columns = Arc::new(live_columns);

if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
for (c, v) in constant_columns.iter() {
sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
let nc = if v.is_null() {
AnyValue::Null
} else {
(0 as IdxSize).into()
};
sbp_constant_columns
.push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
}
self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
constants: sbp_constant_columns,
child: skip_batch_predicate,
}));
}

let mut column_predicates = self.column_predicates.as_ref().clone();
for (c, _) in constant_columns.iter() {
column_predicates.predicates.remove(c);
}
self.column_predicates = Arc::new(column_predicates);

self.predicate = Arc::new(PhysicalExprWithConstCols {
constants: constant_columns,
child: self.predicate.clone(),
});
}
}

impl fmt::Debug for ScanIOPredicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
9 changes: 6 additions & 3 deletions crates/polars-mem-engine/src/executors/multi_file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,13 @@ impl MultiScanExec {
let skip_batch_predicate = file_predicate
.as_ref()
.take_if(|_| use_statistics)
.and_then(|p| p.to_dyn_skip_batch_predicate(self.file_info.schema.as_ref()));
.and_then(|p| p.to_dyn_skip_batch_predicate(self.file_info.schema.clone()));
if let Some(skip_batch_predicate) = &skip_batch_predicate {
let can_skip_batch = skip_batch_predicate
.can_skip_batch(exec_source.num_unfiltered_rows()?, PlIndexMap::default())?;
let can_skip_batch = skip_batch_predicate.can_skip_batch(
exec_source.num_unfiltered_rows()?,
file_predicate.as_ref().unwrap().live_columns.as_ref(),
PlIndexMap::default(),
)?;
if can_skip_batch && verbose {
eprintln!(
"File statistics allows skipping of '{}'",
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-mem-engine/src/executors/scan/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ impl ParquetExec {
None
}
};
let predicate = self
.predicate
.as_ref()
.map(|p| p.to_io(self.skip_batch_predicate.as_ref(), &self.file_info.schema));
let predicate = self.predicate.as_ref().map(|p| {
p.to_io(
self.skip_batch_predicate.as_ref(),
self.file_info.schema.clone(),
)
});
let mut base_row_index = self.file_options.row_index.take();

// (offset, end)
Expand Down Expand Up @@ -294,7 +296,7 @@ impl ParquetExec {
skip_batch_predicate: self
.skip_batch_predicate
.clone()
.or_else(|| p.to_dyn_skip_batch_predicate(self.file_info.schema.as_ref())),
.or_else(|| p.to_dyn_skip_batch_predicate(self.file_info.schema.clone())),
column_predicates: Arc::new(Default::default()),
});
let mut base_row_index = self.file_options.row_index.take();
Expand Down
102 changes: 10 additions & 92 deletions crates/polars-mem-engine/src/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@ use std::sync::Arc;

use arrow::bitmap::Bitmap;
use polars_core::frame::DataFrame;
use polars_core::prelude::{
AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexMap, PlIndexSet, IDX_DTYPE,
};
use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
use polars_core::scalar::Scalar;
use polars_core::schema::{Schema, SchemaRef};
use polars_error::PolarsResult;
use polars_expr::prelude::{phys_expr_to_io_expr, AggregationContext, PhysicalExpr};
use polars_expr::state::ExecutionState;
use polars_io::predicates::{
ColumnPredicates, ColumnStatistics, ScanIOPredicate, SkipBatchPredicate,
SpecializedColumnPredicateExpr,
ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr,
};
use polars_utils::pl_str::PlSmallStr;
use polars_utils::{format_pl_smallstr, IdxSize};
Expand Down Expand Up @@ -59,10 +56,7 @@ pub struct PhysicalColumnPredicates {
/// Helper to implement [`SkipBatchPredicate`].
struct SkipBatchPredicateHelper {
skip_batch_predicate: Arc<dyn PhysicalExpr>,
live_columns: Arc<PlIndexSet<PlSmallStr>>,

/// A cached dataframe that gets used to evaluate all the expressions.
df: DataFrame,
schema: SchemaRef,
}

/// Helper for the [`PhysicalExpr`] trait to include constant columns.
Expand Down Expand Up @@ -175,50 +169,19 @@ impl ScanPredicate {
/// Create a predicate to skip batches using statistics.
pub(crate) fn to_dyn_skip_batch_predicate(
&self,
schema: &Schema,
schema: SchemaRef,
) -> Option<Arc<dyn SkipBatchPredicate>> {
let skip_batch_predicate = self.skip_batch_predicate.as_ref()?;

let mut columns = Vec::with_capacity(1 + self.live_columns.len() * 3);

columns.push(Column::new_scalar(
PlSmallStr::from_static("len"),
Scalar::null(IDX_DTYPE),
1,
));
for col in self.live_columns.as_ref() {
let dtype = schema.get(col).unwrap();
columns.extend([
Column::new_scalar(
format_pl_smallstr!("{col}_min"),
Scalar::null(dtype.clone()),
1,
),
Column::new_scalar(
format_pl_smallstr!("{col}_max"),
Scalar::null(dtype.clone()),
1,
),
Column::new_scalar(format_pl_smallstr!("{col}_nc"), Scalar::null(IDX_DTYPE), 1),
]);
}

// SAFETY:
// * Each column is length = 1
// * We have an IndexSet, so each column name is unique
let df = unsafe { DataFrame::new_no_checks(1, columns) };

let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
Some(Arc::new(SkipBatchPredicateHelper {
skip_batch_predicate: skip_batch_predicate.clone(),
live_columns: self.live_columns.clone(),
df,
skip_batch_predicate,
schema,
}))
}

pub fn to_io(
&self,
skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
schema: &SchemaRef,
schema: SchemaRef,
) -> ScanIOPredicate {
ScanIOPredicate {
predicate: phys_expr_to_io_expr(self.predicate.clone()),
Expand All @@ -240,53 +203,8 @@ impl ScanPredicate {
}

impl SkipBatchPredicate for SkipBatchPredicateHelper {
fn can_skip_batch(
&self,
batch_size: IdxSize,
statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool> {
// This is the DF with all nulls.
let mut df = self.df.clone();

// SAFETY: We don't update the dtype, name or length of columns.
let columns = unsafe { df.get_columns_mut() };

// Set `len` statistic.
columns[0]
.as_scalar_column_mut()
.unwrap()
.with_value(batch_size.into());

for (col, stat) in statistics {
// Skip all statistics of columns that are not used in the predicate.
let Some(idx) = self.live_columns.get_index_of(col.as_str()) else {
continue;
};

let nc = stat.null_count.map_or(AnyValue::Null, |nc| nc.into());

// Set `min`, `max` and `null_count` statistics.
let col_idx = (idx * 3) + 1;
columns[col_idx]
.as_scalar_column_mut()
.unwrap()
.with_value(stat.min);
columns[col_idx + 1]
.as_scalar_column_mut()
.unwrap()
.with_value(stat.max);
columns[col_idx + 2]
.as_scalar_column_mut()
.unwrap()
.with_value(nc);
}

Ok(self
.skip_batch_predicate
.evaluate(&df, &Default::default())?
.bool()?
.first()
.unwrap())
fn schema(&self) -> &SchemaRef {
&self.schema
}

fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
Expand Down
Loading

0 comments on commit c37e482

Please sign in to comment.