Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pass filter to inner readers in multiscan new streaming #21436

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions crates/polars-io/src/parquet/read/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ pub fn read_this_row_group(

let mut should_read = true;

if let Some(pred) = predicate {
if let Some(pred) = &pred.skip_batch_predicate {
if let Some(predicate) = predicate {
if let Some(pred) = &predicate.skip_batch_predicate {
if let Some(stats) = collect_statistics(md, schema)? {
let stats = PlIndexMap::from_iter(stats.column_stats().iter().map(|col| {
(
Expand All @@ -86,7 +86,11 @@ pub fn read_this_row_group(
},
)
}));
let pred_result = pred.can_skip_batch(md.num_rows() as IdxSize, stats);
let pred_result = pred.can_skip_batch(
md.num_rows() as IdxSize,
predicate.live_columns.as_ref(),
stats,
);

// a parquet file may not have statistics of all columns
match pred_result {
Expand All @@ -97,7 +101,7 @@ pub fn read_this_row_group(
_ => {},
}
}
} else if let Some(pred) = pred.predicate.as_stats_evaluator() {
} else if let Some(pred) = predicate.predicate.as_stats_evaluator() {
if let Some(stats) = collect_statistics(md, schema)? {
let pred_result = pred.should_read(&stats);

Expand Down
124 changes: 122 additions & 2 deletions crates/polars-io/src/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow::bitmap::{Bitmap, MutableBitmap};
use polars_core::prelude::*;
#[cfg(feature = "parquet")]
use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange};
use polars_utils::format_pl_smallstr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -345,14 +346,56 @@ pub struct ColumnStatistics {
}

pub trait SkipBatchPredicate: Send + Sync {
fn schema(&self) -> &SchemaRef;

fn can_skip_batch(
&self,
batch_size: IdxSize,
statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool>;
live_columns: &PlIndexSet<PlSmallStr>,
mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool> {
let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);

columns.push(Column::new_scalar(
PlSmallStr::from_static("len"),
Scalar::new(IDX_DTYPE, batch_size.into()),
1,
));

for col in live_columns.iter() {
let dtype = self.schema().get(col).unwrap();
let (min, max, nc) = match statistics.swap_remove(col) {
None => (
Scalar::null(dtype.clone()),
Scalar::null(dtype.clone()),
Scalar::null(IDX_DTYPE),
),
Some(stat) => (
Scalar::new(dtype.clone(), stat.min),
Scalar::new(dtype.clone(), stat.max),
Scalar::new(
IDX_DTYPE,
stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
),
),
};
columns.extend([
Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
]);
}

// SAFETY:
// * Each column is length = 1
// * We have an IndexSet, so each column name is unique
let df = unsafe { DataFrame::new_no_checks(1, columns) };
Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
}
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
}

#[derive(Clone)]
pub struct ColumnPredicates {
pub predicates: PlHashMap<
PlSmallStr,
Expand All @@ -375,6 +418,44 @@ impl Default for ColumnPredicates {
}
}

pub struct PhysicalExprWithConstCols<T> {
constants: Vec<(PlSmallStr, Scalar)>,
child: T,
}

impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
fn schema(&self) -> &SchemaRef {
self.child.schema()
}

fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
let mut df = df.clone();
for (name, scalar) in self.constants.iter() {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}
self.child.evaluate_with_stat_df(&df)
}
}

impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
let mut df = df.clone();
for (name, scalar) in self.constants.iter() {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}

self.child.evaluate_io(&df)
}
}

#[derive(Clone)]
pub struct ScanIOPredicate {
pub predicate: Arc<dyn PhysicalIoExpr>,
Expand All @@ -388,6 +469,45 @@ pub struct ScanIOPredicate {
/// A predicate that gets given statistics and evaluates whether a batch can be skipped.
pub column_predicates: Arc<ColumnPredicates>,
}
impl ScanIOPredicate {
pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
let mut live_columns = self.live_columns.as_ref().clone();
for (c, _) in constant_columns.iter() {
live_columns.swap_remove(c);
}
self.live_columns = Arc::new(live_columns);

if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
for (c, v) in constant_columns.iter() {
sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
let nc = if v.is_null() {
AnyValue::Null
} else {
(0 as IdxSize).into()
};
sbp_constant_columns
.push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
}
self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
constants: sbp_constant_columns,
child: skip_batch_predicate,
}));
}

let mut column_predicates = self.column_predicates.as_ref().clone();
for (c, _) in constant_columns.iter() {
column_predicates.predicates.remove(c);
}
self.column_predicates = Arc::new(column_predicates);

self.predicate = Arc::new(PhysicalExprWithConstCols {
constants: constant_columns,
child: self.predicate.clone(),
});
}
}

impl fmt::Debug for ScanIOPredicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
9 changes: 6 additions & 3 deletions crates/polars-mem-engine/src/executors/multi_file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,13 @@ impl MultiScanExec {
let skip_batch_predicate = file_predicate
.as_ref()
.take_if(|_| use_statistics)
.and_then(|p| p.to_dyn_skip_batch_predicate(self.file_info.schema.as_ref()));
.and_then(|p| p.to_dyn_skip_batch_predicate(self.file_info.schema.clone()));
if let Some(skip_batch_predicate) = &skip_batch_predicate {
let can_skip_batch = skip_batch_predicate
.can_skip_batch(exec_source.num_unfiltered_rows()?, PlIndexMap::default())?;
let can_skip_batch = skip_batch_predicate.can_skip_batch(
exec_source.num_unfiltered_rows()?,
file_predicate.as_ref().unwrap().live_columns.as_ref(),
PlIndexMap::default(),
)?;
if can_skip_batch && verbose {
eprintln!(
"File statistics allows skipping of '{}'",
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-mem-engine/src/executors/scan/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ impl ParquetExec {
None
}
};
let predicate = self
.predicate
.as_ref()
.map(|p| p.to_io(self.skip_batch_predicate.as_ref(), &self.file_info.schema));
let predicate = self.predicate.as_ref().map(|p| {
p.to_io(
self.skip_batch_predicate.as_ref(),
self.file_info.schema.clone(),
)
});
let mut base_row_index = self.file_options.row_index.take();

// (offset, end)
Expand Down Expand Up @@ -294,7 +296,7 @@ impl ParquetExec {
skip_batch_predicate: self
.skip_batch_predicate
.clone()
.or_else(|| p.to_dyn_skip_batch_predicate(self.file_info.schema.as_ref())),
.or_else(|| p.to_dyn_skip_batch_predicate(self.file_info.schema.clone())),
column_predicates: Arc::new(Default::default()),
});
let mut base_row_index = self.file_options.row_index.take();
Expand Down
102 changes: 10 additions & 92 deletions crates/polars-mem-engine/src/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@ use std::sync::Arc;

use arrow::bitmap::Bitmap;
use polars_core::frame::DataFrame;
use polars_core::prelude::{
AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexMap, PlIndexSet, IDX_DTYPE,
};
use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
use polars_core::scalar::Scalar;
use polars_core::schema::{Schema, SchemaRef};
use polars_error::PolarsResult;
use polars_expr::prelude::{phys_expr_to_io_expr, AggregationContext, PhysicalExpr};
use polars_expr::state::ExecutionState;
use polars_io::predicates::{
ColumnPredicates, ColumnStatistics, ScanIOPredicate, SkipBatchPredicate,
SpecializedColumnPredicateExpr,
ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr,
};
use polars_utils::pl_str::PlSmallStr;
use polars_utils::{format_pl_smallstr, IdxSize};
Expand Down Expand Up @@ -59,10 +56,7 @@ pub struct PhysicalColumnPredicates {
/// Helper to implement [`SkipBatchPredicate`].
struct SkipBatchPredicateHelper {
skip_batch_predicate: Arc<dyn PhysicalExpr>,
live_columns: Arc<PlIndexSet<PlSmallStr>>,

/// A cached dataframe that gets used to evaluate all the expressions.
df: DataFrame,
schema: SchemaRef,
}

/// Helper for the [`PhysicalExpr`] trait to include constant columns.
Expand Down Expand Up @@ -175,50 +169,19 @@ impl ScanPredicate {
/// Create a predicate to skip batches using statistics.
pub(crate) fn to_dyn_skip_batch_predicate(
&self,
schema: &Schema,
schema: SchemaRef,
) -> Option<Arc<dyn SkipBatchPredicate>> {
let skip_batch_predicate = self.skip_batch_predicate.as_ref()?;

let mut columns = Vec::with_capacity(1 + self.live_columns.len() * 3);

columns.push(Column::new_scalar(
PlSmallStr::from_static("len"),
Scalar::null(IDX_DTYPE),
1,
));
for col in self.live_columns.as_ref() {
let dtype = schema.get(col).unwrap();
columns.extend([
Column::new_scalar(
format_pl_smallstr!("{col}_min"),
Scalar::null(dtype.clone()),
1,
),
Column::new_scalar(
format_pl_smallstr!("{col}_max"),
Scalar::null(dtype.clone()),
1,
),
Column::new_scalar(format_pl_smallstr!("{col}_nc"), Scalar::null(IDX_DTYPE), 1),
]);
}

// SAFETY:
// * Each column is length = 1
// * We have an IndexSet, so each column name is unique
let df = unsafe { DataFrame::new_no_checks(1, columns) };

let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
Some(Arc::new(SkipBatchPredicateHelper {
skip_batch_predicate: skip_batch_predicate.clone(),
live_columns: self.live_columns.clone(),
df,
skip_batch_predicate,
schema,
}))
}

pub fn to_io(
&self,
skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
schema: &SchemaRef,
schema: SchemaRef,
) -> ScanIOPredicate {
ScanIOPredicate {
predicate: phys_expr_to_io_expr(self.predicate.clone()),
Expand All @@ -240,53 +203,8 @@ impl ScanPredicate {
}

impl SkipBatchPredicate for SkipBatchPredicateHelper {
fn can_skip_batch(
&self,
batch_size: IdxSize,
statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool> {
// This is the DF with all nulls.
let mut df = self.df.clone();

// SAFETY: We don't update the dtype, name or length of columns.
let columns = unsafe { df.get_columns_mut() };

// Set `len` statistic.
columns[0]
.as_scalar_column_mut()
.unwrap()
.with_value(batch_size.into());

for (col, stat) in statistics {
// Skip all statistics of columns that are not used in the predicate.
let Some(idx) = self.live_columns.get_index_of(col.as_str()) else {
continue;
};

let nc = stat.null_count.map_or(AnyValue::Null, |nc| nc.into());

// Set `min`, `max` and `null_count` statistics.
let col_idx = (idx * 3) + 1;
columns[col_idx]
.as_scalar_column_mut()
.unwrap()
.with_value(stat.min);
columns[col_idx + 1]
.as_scalar_column_mut()
.unwrap()
.with_value(stat.max);
columns[col_idx + 2]
.as_scalar_column_mut()
.unwrap()
.with_value(nc);
}

Ok(self
.skip_batch_predicate
.evaluate(&df, &Default::default())?
.bool()?
.first()
.unwrap())
fn schema(&self) -> &SchemaRef {
&self.schema
}

fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
Expand Down
Loading
Loading