From 77352b2411b5d9340374c30e21b861b0d0d46f82 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 29 May 2024 16:46:18 -0400 Subject: [PATCH 01/35] Add `ParquetExec::builder()`, deprecate `ParquetExec::new` (#10636) * Add `ParquetExec::builder()`, deprecate `ParquetExec::new` * Add a #[must_use] --- .../src/datasource/file_format/parquet.rs | 22 +- .../datasource/physical_plan/parquet/mod.rs | 239 ++++++++++++++---- .../core/src/datasource/schema_adapter.rs | 6 +- .../combine_partial_final_agg.rs | 8 +- .../enforce_distribution.rs | 16 +- .../core/src/physical_optimizer/test_utils.rs | 16 +- datafusion/core/src/test_util/parquet.rs | 22 +- .../core/tests/parquet/custom_reader.rs | 6 +- datafusion/core/tests/parquet/page_pruning.rs | 11 +- .../core/tests/parquet/schema_coercion.rs | 16 +- datafusion/proto/src/physical_plan/mod.rs | 11 +- .../tests/cases/roundtrip_physical_plan.rs | 18 +- .../substrait/src/physical_plan/consumer.rs | 8 +- .../tests/cases/roundtrip_physical_plan.rs | 8 +- 14 files changed, 265 insertions(+), 142 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index e102cfc372dd..39e6900ed53a 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -30,7 +30,7 @@ use crate::arrow::array::{ }; use crate::arrow::datatypes::{DataType, Fields, Schema, SchemaRef}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::{FileGroupDisplay, FileSinkConfig, ParquetExec}; +use crate::datasource::physical_plan::{FileGroupDisplay, FileSinkConfig}; use crate::datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, }; @@ -75,6 +75,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::task::JoinSet; +use crate::datasource::physical_plan::parquet::ParquetExecBuilder; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; use object_store::path::Path; @@ -253,17 +254,22 @@ impl FileFormat for ParquetFormat { conf: FileScanConfig, filters: Option<&Arc>, ) -> Result> { + let mut builder = + ParquetExecBuilder::new_with_options(conf, self.options.clone()); + // If enable pruning then combine the filters to build the predicate. // If disable pruning then set the predicate to None, thus readers // will not prune data based on the statistics. - let predicate = self.enable_pruning().then(|| filters.cloned()).flatten(); + if self.enable_pruning() { + if let Some(predicate) = filters.cloned() { + builder = builder.with_predicate(predicate); + } + } + if let Some(metadata_size_hint) = self.metadata_size_hint() { + builder = builder.with_metadata_size_hint(metadata_size_hint); + } - Ok(Arc::new(ParquetExec::new( - conf, - predicate, - self.metadata_size_hint(), - self.options.clone(), - ))) + Ok(builder.build_arc()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 6655125ea876..ac7c39bbdb94 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -104,6 +104,27 @@ pub use statistics::{RequestedStatistics, StatisticsConverter}; /// `───────────────────' /// /// ``` +/// +/// # Example: Create a `ParquetExec` +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # let file_schema = Arc::new(Schema::empty()); +/// # let object_store_url = ObjectStoreUrl::local_filesystem(); +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # use datafusion_physical_expr::expressions::lit; +/// # let predicate = lit(true); +/// // Create a ParquetExec for reading `file1.parquet` with a file size of 100MB +/// let file_scan_config = FileScanConfig::new(object_store_url, file_schema) +/// .with_file(PartitionedFile::new("file1.parquet", 100*1024*1024)); +/// let exec = ParquetExec::builder(file_scan_config) +/// // Provide a predicate for filtering row groups/pages +/// .with_predicate(predicate) +/// .build(); +/// ``` +/// /// # Features /// /// Supports the following optimizations: @@ -131,7 +152,7 @@ pub use statistics::{RequestedStatistics, StatisticsConverter}; /// * metadata_size_hint: controls the number of bytes read from the end of the /// file in the initial I/O when the default [`ParquetFileReaderFactory`]. If a /// custom reader is used, it supplies the metadata directly and this parameter -/// is ignored. See [`Self::with_parquet_file_reader_factory`] for more details. +/// is ignored. [`ParquetExecBuilder::with_metadata_size_hint`] for more details. /// /// # Execution Overview /// @@ -141,9 +162,9 @@ pub use statistics::{RequestedStatistics, StatisticsConverter}; /// * Step 2: When the stream is polled, the [`ParquetOpener`] is called to open /// the file. /// -/// * Step 3: The `ParquetOpener` gets the file metadata via -/// [`ParquetFileReaderFactory`] and applies any predicates -/// and projections to determine what pages must be read. +/// * Step 3: The `ParquetOpener` gets the [`ParquetMetaData`] (file metadata) +/// via [`ParquetFileReaderFactory`] and applies any predicates and projections +/// to determine what pages must be read. /// /// * Step 4: The stream begins reading data, fetching the required pages /// and incrementally decoding them. @@ -154,6 +175,7 @@ pub use statistics::{RequestedStatistics, StatisticsConverter}; /// /// [`RecordBatch`]: arrow::record_batch::RecordBatch /// [`SchemaAdapter`]: crate::datasource::schema_adapter::SchemaAdapter +/// [`ParquetMetadata`]: parquet::file::metadata::ParquetMetaData #[derive(Debug, Clone)] pub struct ParquetExec { /// Base configuration for this scan @@ -179,14 +201,125 @@ pub struct ParquetExec { schema_adapter_factory: Option>, } -impl ParquetExec { - /// Create a new Parquet reader execution plan provided file list and schema. - pub fn new( - base_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, +/// [`ParquetExecBuilder`]`, builder for [`ParquetExec`]. +/// +/// See example on [`ParquetExec`]. +pub struct ParquetExecBuilder { + file_scan_config: FileScanConfig, + predicate: Option>, + metadata_size_hint: Option, + table_parquet_options: TableParquetOptions, + parquet_file_reader_factory: Option>, + schema_adapter_factory: Option>, +} + +impl ParquetExecBuilder { + /// Create a new builder to read the provided file scan configuration + pub fn new(file_scan_config: FileScanConfig) -> Self { + Self::new_with_options(file_scan_config, TableParquetOptions::default()) + } + + /// Create a new builder to read the data specified in the file scan + /// configuration with the provided `TableParquetOptions`. + pub fn new_with_options( + file_scan_config: FileScanConfig, table_parquet_options: TableParquetOptions, ) -> Self { + Self { + file_scan_config, + predicate: None, + metadata_size_hint: None, + table_parquet_options, + parquet_file_reader_factory: None, + schema_adapter_factory: None, + } + } + + /// Set the predicate for the scan. + /// + /// The ParquetExec uses this predicate to filter row groups and data pages + /// using the Parquet statistics and bloom filters. + /// + /// If the predicate can not be used to prune the scan, it is ignored (no + /// error is raised). + pub fn with_predicate(mut self, predicate: Arc) -> Self { + self.predicate = Some(predicate); + self + } + + /// Set the metadata size hint + /// + /// This value determines how many bytes at the end of the file the default + /// [`ParquetFileReaderFactory`] will request in the initial IO. If this is + /// too small, the ParquetExec will need to make additional IO requests to + /// read the footer. + pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { + self.metadata_size_hint = Some(metadata_size_hint); + self + } + + /// Set the table parquet options that control how the ParquetExec reads. + /// + /// See also [`Self::new_with_options`] + pub fn with_table_parquet_options( + mut self, + table_parquet_options: TableParquetOptions, + ) -> Self { + self.table_parquet_options = table_parquet_options; + self + } + + /// Set optional user defined parquet file reader factory. + /// + /// You can use [`ParquetFileReaderFactory`] to more precisely control how + /// data is read from parquet files (e.g. skip re-reading metadata, coalesce + /// I/O operations, etc). + /// + /// The default reader factory reads directly from an [`ObjectStore`] + /// instance using individual I/O operations for the footer and each page. + /// + /// If a custom `ParquetFileReaderFactory` is provided, then data access + /// operations will be routed to this factory instead of `ObjectStore`. + pub fn with_parquet_file_reader_factory( + mut self, + parquet_file_reader_factory: Arc, + ) -> Self { + self.parquet_file_reader_factory = Some(parquet_file_reader_factory); + self + } + + /// Set optional schema adapter factory. + /// + /// [`SchemaAdapterFactory`] allows user to specify how fields from the + /// parquet file get mapped to that of the table schema. The default schema + /// adapter uses arrow's cast library to map the parquet fields to the table + /// schema. + pub fn with_schema_adapter_factory( + mut self, + schema_adapter_factory: Arc, + ) -> Self { + self.schema_adapter_factory = Some(schema_adapter_factory); + self + } + + /// Convenience: build an `Arc`d `ParquetExec` from this builder + pub fn build_arc(self) -> Arc { + Arc::new(self.build()) + } + + /// Build a [`ParquetExec`] + #[must_use] + pub fn build(self) -> ParquetExec { + let Self { + file_scan_config, + predicate, + metadata_size_hint, + table_parquet_options, + parquet_file_reader_factory, + schema_adapter_factory, + } = self; + + let base_config = file_scan_config; debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", base_config.file_groups, base_config.projection, predicate, base_config.limit); @@ -225,12 +358,12 @@ impl ParquetExec { let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); - let cache = Self::compute_properties( + let cache = ParquetExec::compute_properties( projected_schema, &projected_output_ordering, &base_config, ); - Self { + ParquetExec { base_config, projected_statistics, metrics, @@ -238,12 +371,44 @@ impl ParquetExec { pruning_predicate, page_pruning_predicate, metadata_size_hint, - parquet_file_reader_factory: None, + parquet_file_reader_factory, cache, table_parquet_options, - schema_adapter_factory: None, + schema_adapter_factory, } } +} + +impl ParquetExec { + /// Create a new Parquet reader execution plan provided file list and schema. + #[deprecated( + since = "39.0.0", + note = "use `ParquetExec::builder` or `ParquetExecBuilder`" + )] + pub fn new( + base_config: FileScanConfig, + predicate: Option>, + metadata_size_hint: Option, + table_parquet_options: TableParquetOptions, + ) -> Self { + let mut builder = + ParquetExecBuilder::new_with_options(base_config, table_parquet_options); + if let Some(predicate) = predicate { + builder = builder.with_predicate(predicate); + } + if let Some(metadata_size_hint) = metadata_size_hint { + builder = builder.with_metadata_size_hint(metadata_size_hint); + } + builder.build() + } + + /// Return a [`ParquetExecBuilder`]. + /// + /// See example on [`ParquetExec`] and [`ParquetExecBuilder`] for specifying + /// parquet table options. + pub fn builder(file_scan_config: FileScanConfig) -> ParquetExecBuilder { + ParquetExecBuilder::new(file_scan_config) + } /// [`FileScanConfig`] that controls this scan (such as which files to read) pub fn base_config(&self) -> &FileScanConfig { @@ -267,13 +432,7 @@ impl ParquetExec { /// Optional user defined parquet file reader factory. /// - /// You can use [`ParquetFileReaderFactory`] to more precisely control how - /// data is read from parquet files (e.g. skip re-reading metadata, coalesce - /// I/O operations, etc). - /// - /// The default reader factory reads directly from an [`ObjectStore`] - /// instance using individual I/O operations for the footer and then for - /// each page. + /// See documentation on [`ParquetExecBuilder::with_parquet_file_reader_factory`] pub fn with_parquet_file_reader_factory( mut self, parquet_file_reader_factory: Arc, @@ -284,9 +443,7 @@ impl ParquetExec { /// Optional schema adapter factory. /// - /// `SchemaAdapterFactory` allows user to specify how fields from the parquet file get mapped to - /// that of the table schema. The default schema adapter uses arrow's cast library to map - /// the parquet fields to the table schema. + /// See documentation on [`ParquetExecBuilder::with_schema_adapter_factory`] pub fn with_schema_adapter_factory( mut self, schema_adapter_factory: Arc, @@ -1033,15 +1190,17 @@ mod tests { let predicate = predicate.map(|p| logical2physical(&p, &file_schema)); // prepare the scan - let mut parquet_exec = ParquetExec::new( + let mut builder = ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) .with_file_group(file_group) .with_projection(projection), - predicate, - None, - Default::default(), ); + if let Some(predicate) = predicate { + builder = builder.with_predicate(predicate); + } + let mut parquet_exec = builder.build(); + if pushdown_predicate { parquet_exec = parquet_exec .with_pushdown_filters(true) @@ -1684,13 +1843,11 @@ mod tests { expected_row_num: Option, file_schema: SchemaRef, ) -> Result<()> { - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) .with_file_groups(file_groups), - None, - None, - Default::default(), - ); + ) + .build(); assert_eq!( parquet_exec .properties() @@ -1786,7 +1943,7 @@ mod tests { ), ]); - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new(object_store_url, schema.clone()) .with_file(partitioned_file) // file has 10 cols so index 12 should be month and 13 should be day @@ -1803,10 +1960,8 @@ mod tests { false, ), ]), - None, - None, - Default::default(), - ); + ) + .build(); assert_eq!( parquet_exec.cache.output_partitioning().partition_count(), 1 @@ -1861,13 +2016,11 @@ mod tests { }; let file_schema = Arc::new(Schema::empty()); - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) .with_file(partitioned_file), - None, - None, - Default::default(), - ); + ) + .build(); let mut results = parquet_exec.execute(0, state.task_ctx())?; let batch = results.next().await.unwrap(); diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index 1838a3354b9c..77fde608fd05 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -258,13 +258,11 @@ mod tests { let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); // prepare the scan - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema) .with_file(partitioned_file), - None, - None, - Default::default(), ) + .build() .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})); let session_ctx = SessionContext::new(); diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index b93f4012b093..909c8acdb816 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -245,16 +245,14 @@ mod tests { } fn parquet_exec(schema: &SchemaRef) -> Arc { - Arc::new(ParquetExec::new( + ParquetExec::builder( FileScanConfig::new( ObjectStoreUrl::parse("test:///").unwrap(), schema.clone(), ) .with_file(PartitionedFile::new("x".to_string(), 100)), - None, - None, - Default::default(), - )) + ) + .build_arc() } fn partial_aggregate_exec( diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 9eb5aafd81a2..88fa3a978af7 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1431,14 +1431,12 @@ pub(crate) mod tests { pub(crate) fn parquet_exec_with_sort( output_ordering: Vec>, ) -> Arc { - Arc::new(ParquetExec::new( + ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering), - None, - None, - Default::default(), - )) + ) + .build_arc() } fn parquet_exec_multiple() -> Arc { @@ -1449,17 +1447,15 @@ pub(crate) mod tests { fn parquet_exec_multiple_sorted( output_ordering: Vec>, ) -> Arc { - Arc::new(ParquetExec::new( + ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) .with_file_groups(vec![ vec![PartitionedFile::new("x".to_string(), 100)], vec![PartitionedFile::new("y".to_string(), 100)], ]) .with_output_ordering(output_ordering), - None, - None, - Default::default(), - )) + ) + .build_arc() } fn csv_exec() -> Arc { diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 4d926847e465..cfd0312f813d 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -274,13 +274,11 @@ pub fn sort_preserving_merge_exec( /// Create a non sorted parquet exec pub fn parquet_exec(schema: &SchemaRef) -> Arc { - Arc::new(ParquetExec::new( + ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema.clone()) .with_file(PartitionedFile::new("x".to_string(), 100)), - None, - None, - Default::default(), - )) + ) + .build_arc() } // Created a sorted parquet exec @@ -290,14 +288,12 @@ pub fn parquet_exec_sorted( ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(ParquetExec::new( + ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema.clone()) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(vec![sort_exprs]), - None, - None, - Default::default(), - )) + ) + .build_arc() } pub fn union_exec(input: Vec>) -> Arc { diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index ed539d29bd26..9f06ad9308ab 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -37,6 +37,7 @@ use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; +use crate::datasource::physical_plan::parquet::ParquetExecBuilder; use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; @@ -163,22 +164,19 @@ impl TestParquetFile { let filter = simplifier.coerce(filter, &df_schema).unwrap(); let physical_filter_expr = create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; - let parquet_exec = Arc::new(ParquetExec::new( - scan_config, - Some(physical_filter_expr.clone()), - None, - parquet_options, - )); + + let parquet_exec = + ParquetExecBuilder::new_with_options(scan_config, parquet_options) + .with_predicate(physical_filter_expr.clone()) + .build_arc(); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?); Ok(exec) } else { - Ok(Arc::new(ParquetExec::new( - scan_config, - None, - None, - parquet_options, - ))) + Ok( + ParquetExecBuilder::new_with_options(scan_config, parquet_options) + .build_arc(), + ) } } diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 4f50c55c627c..0e515fd4647b 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -75,17 +75,15 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { .collect(); // prepare the scan - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new( // just any url that doesn't point to in memory object store ObjectStoreUrl::local_filesystem(), file_schema, ) .with_file_group(file_group), - None, - None, - Default::default(), ) + .build() .with_parquet_file_reader_factory(Arc::new(InMemoryParquetFileReaderFactory( Arc::clone(&in_memory_object_store), ))); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 2e9cda40c330..15efd4bcd9dd 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -70,13 +70,12 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { let execution_props = ExecutionProps::new(); let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); - let parquet_exec = ParquetExec::new( + ParquetExec::builder( FileScanConfig::new(object_store_url, schema).with_file(partitioned_file), - Some(predicate), - None, - Default::default(), - ); - parquet_exec.with_enable_page_index(true) + ) + .with_predicate(predicate) + .build() + .with_enable_page_index(true) } #[tokio::test] diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index ac51b4f71201..af9411f40ecb 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -59,13 +59,11 @@ async fn multi_parquet_coercion() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) .with_file_group(file_group), - None, - None, - Default::default(), - ); + ) + .build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -115,14 +113,12 @@ async fn multi_parquet_coercion_projection() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let parquet_exec = ParquetExec::new( + let parquet_exec = ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) .with_file_group(file_group) .with_projection(Some(vec![1, 0, 2])), - None, - None, - Default::default(), - ); + ) + .build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a9965e1c8151..550176a42e66 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -224,12 +224,11 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ) }) .transpose()?; - Ok(Arc::new(ParquetExec::new( - base_config, - predicate, - None, - Default::default(), - ))) + let mut builder = ParquetExec::builder(base_config); + if let Some(predicate) = predicate { + builder = builder.with_predicate(predicate) + } + Ok(builder.build_arc()) } PhysicalPlanType::AvroScan(scan) => { Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 55b346a482d3..df1995f46533 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -582,12 +582,11 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { Operator::Eq, lit("1"), )); - roundtrip_test(Arc::new(ParquetExec::new( - scan_config, - Some(predicate), - None, - Default::default(), - ))) + roundtrip_test( + ParquetExec::builder(scan_config) + .with_predicate(predicate) + .build_arc(), + ) } #[tokio::test] @@ -613,12 +612,7 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { output_ordering: vec![], }; - roundtrip_test(Arc::new(ParquetExec::new( - scan_config, - None, - None, - Default::default(), - ))) + roundtrip_test(ParquetExec::builder(scan_config).build_arc()) } #[test] diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 68f8b02b0f09..39b38c94ec18 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -121,12 +121,8 @@ pub async fn from_substrait_rel( } } - Ok(Arc::new(ParquetExec::new( - base_config, - None, - None, - Default::default(), - )) as Arc) + Ok(ParquetExec::builder(base_config).build_arc() + as Arc) } _ => not_impl_err!( "Only LocalFile reads are supported when parsing physical" diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index aca044319406..4014670a7cbc 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -45,12 +45,8 @@ async fn parquet_exec() -> Result<()> { 123, )], ]); - let parquet_exec: Arc = Arc::new(ParquetExec::new( - scan_config, - None, - None, - Default::default(), - )); + let parquet_exec: Arc = + ParquetExec::builder(scan_config).build_arc(); let mut extension_info: ( Vec, From 3d007608535cb138ae4473ce6305bd4ec8481627 Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Wed, 29 May 2024 20:18:02 -0400 Subject: [PATCH 02/35] feature: Add a WindowUDFImpl::simplify() API (#9906) * feature: Add a WindowUDFImpl::simplfy() API Signed-off-by: guojidan <1948535941@qq.com> * fix doc Signed-off-by: guojidan <1948535941@qq.com> * fix fmt Signed-off-by: guojidan <1948535941@qq.com> --------- Signed-off-by: guojidan <1948535941@qq.com> --- .../examples/simplify_udwf_expression.rs | 142 ++++++++++++++++++ datafusion/expr/src/function.rs | 13 ++ datafusion/expr/src/udwf.rs | 34 ++++- .../simplify_expressions/expr_simplifier.rs | 103 ++++++++++++- 4 files changed, 288 insertions(+), 4 deletions(-) create mode 100644 datafusion-examples/examples/simplify_udwf_expression.rs diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs new file mode 100644 index 000000000000..2824d03761ab --- /dev/null +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow_schema::DataType; +use datafusion::execution::context::SessionContext; +use datafusion::{error::Result, execution::options::CsvReadOptions}; +use datafusion_expr::function::WindowFunctionSimplification; +use datafusion_expr::{ + expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, +}; + +/// This UDWF will show how to use the WindowUDFImpl::simplify() API +#[derive(Debug, Clone)] +struct SimplifySmoothItUdf { + signature: Signature, +} + +impl SimplifySmoothItUdf { + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} +impl WindowUDFImpl for SimplifySmoothItUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "simplify_smooth_it" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn partition_evaluator(&self) -> Result> { + todo!() + } + + /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. + fn simplify(&self) -> Option { + // Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction( + // WindowFunction { + // fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction( + // AggregateFunction::Avg, + // ), + // args, + // partition_by: partition_by.to_vec(), + // order_by: order_by.to_vec(), + // window_frame: window_frame.clone(), + // null_treatment: *null_treatment, + // }, + // ))) + let simplify = |window_function: datafusion_expr::expr::WindowFunction, + _: &dyn SimplifyInfo| { + Ok(Expr::WindowFunction(WindowFunction { + fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction( + AggregateFunction::Avg, + ), + args: window_function.args, + partition_by: window_function.partition_by, + order_by: window_function.order_by, + window_frame: window_function.window_frame, + null_treatment: window_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new()); + ctx.register_udwf(simplify_smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index eb748ed2711a..7f49b03bb2ce 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -134,3 +134,16 @@ pub type AggregateFunctionSimplification = Box< &dyn crate::simplify::SimplifyInfo, ) -> Result, >; + +/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +/// closure returns simplified [Expr] or an error. +pub type WindowFunctionSimplification = Box< + dyn Fn( + crate::expr::WindowFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 5a8373509a40..ce28b444adbc 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -18,8 +18,8 @@ //! [`WindowUDF`]: User Defined Window Functions use crate::{ - Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, - WindowFrame, + function::WindowFunctionSimplification, Expr, PartitionEvaluator, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, }; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -170,6 +170,13 @@ impl WindowUDF { self.inner.return_type(args) } + /// Do the function rewrite + /// + /// See [`WindowUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() + } + /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory(&self) -> Result> { self.inner.partition_evaluator() @@ -266,6 +273,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDWF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization. The default implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// Example: + /// [`simplify_udwf_expression.rs`]: + /// + /// # Returns + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + fn simplify(&self) -> Option { + None + } } /// WindowUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 25504e5c78e7..c87654292a01 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,10 +32,13 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery}; +use datafusion_expr::expr::{ + AggregateFunctionDefinition, InList, InSubquery, WindowFunction, +}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, + WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -1391,6 +1394,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { (_, expr) => Transformed::no(expr), }, + Expr::WindowFunction(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(ref udwf), + .. + }) => match (udwf.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf)) => { + Transformed::yes(simplify_function(wf, info)?) + } + (_, expr) => Transformed::no(expr), + }, + // // Rules for Between // @@ -1758,7 +1771,10 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::{AccumulatorArgs, AggregateFunctionSimplification}, + function::{ + AccumulatorArgs, AggregateFunctionSimplification, + WindowFunctionSimplification, + }, interval_arithmetic::Interval, *, }; @@ -3800,4 +3816,87 @@ mod tests { } } } + + #[test] + fn test_simplify_udwf() { + let udwf = WindowFunctionDefinition::WindowUDF( + WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), + ); + let window_function_expr = + Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( + udwf, + vec![], + vec![], + vec![], + WindowFrame::new(None), + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(window_function_expr), expected); + + let udwf = WindowFunctionDefinition::WindowUDF( + WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), + ); + let window_function_expr = + Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( + udwf, + vec![], + vec![], + vec![], + WindowFrame::new(None), + None, + )); + + let expected = window_function_expr.clone(); + assert_eq!(simplify(window_function_expr), expected); + } + + /// A Mock UDWF which defines `simplify` to be used in tests + /// related to UDWF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdwf { + simplify: bool, + } + + impl SimplifyMockUdwf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl WindowUDFImpl for SimplifyMockUdwf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + + fn partition_evaluator(&self) -> Result> { + unimplemented!("not needed for tests") + } + } } From 088ad010a6ceaa6a2e810d418a2370e45acf3d54 Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Thu, 30 May 2024 01:08:14 -0400 Subject: [PATCH 03/35] Chore: clean up udwf example && remove edundant import (#10718) Signed-off-by: guojidan <1948535941@qq.com> --- .../examples/simplify_udwf_expression.rs | 12 ------------ datafusion/proto-common/src/to_proto/mod.rs | 2 -- 2 files changed, 14 deletions(-) diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index 2824d03761ab..4e8d03c38e00 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -68,18 +68,6 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. fn simplify(&self) -> Option { - // Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction( - // WindowFunction { - // fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction( - // AggregateFunction::Avg, - // ), - // args, - // partition_by: partition_by.to_vec(), - // order_by: order_by.to_vec(), - // window_frame: window_frame.clone(), - // null_treatment: *null_treatment, - // }, - // ))) let simplify = |window_function: datafusion_expr::expr::WindowFunction, _: &dyn SimplifyInfo| { Ok(Expr::WindowFunction(WindowFunction { diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index e53604fc748c..f160bc40af39 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -289,8 +289,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { type Error = Error; fn try_from(val: &ScalarValue) -> Result { - use protobuf::scalar_value::Value; - let data_type = val.data_type(); match val { ScalarValue::Boolean(val) => { From c775e4d6ea6dfe9c26a772b676552b9711004a3d Mon Sep 17 00:00:00 2001 From: QP Hou Date: Thu, 30 May 2024 04:34:18 -0700 Subject: [PATCH 04/35] push down filter to partition listing (#10693) --- .../core/src/datasource/listing/helpers.rs | 205 +++++++++++++++++- 1 file changed, 202 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 5b8709009665..b531cf8369cf 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -17,11 +17,13 @@ //! Helper functions for the table implementation +use std::collections::HashMap; use std::sync::Arc; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; +use crate::logical_expr::{BinaryExpr, Operator}; use crate::{error::Result, scalar::ScalarValue}; use arrow::{ @@ -169,9 +171,17 @@ async fn list_partitions( store: &dyn ObjectStore, table_path: &ListingTableUrl, max_depth: usize, + partition_prefix: Option, ) -> Result> { let partition = Partition { - path: table_path.prefix().clone(), + path: match partition_prefix { + Some(prefix) => Path::from_iter( + Path::from(table_path.prefix().as_ref()) + .parts() + .chain(Path::from(prefix.as_ref()).parts()), + ), + None => table_path.prefix().clone(), + }, depth: 0, files: None, }; @@ -305,6 +315,80 @@ async fn prune_partitions( Ok(filtered) } +#[derive(Debug)] +enum PartitionValue { + Single(String), + Multi, +} + +fn populate_partition_values<'a>( + partition_values: &mut HashMap<&'a str, PartitionValue>, + filter: &'a Expr, +) { + if let Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) = filter + { + match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) + | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + if partition_values + .insert(name, PartitionValue::Single(val.to_string())) + .is_some() + { + partition_values.insert(name, PartitionValue::Multi); + } + } + _ => {} + }, + Operator::And => { + populate_partition_values(partition_values, left); + populate_partition_values(partition_values, right); + } + _ => {} + } + } +} + +fn evaluate_partition_prefix<'a>( + partition_cols: &'a [(String, DataType)], + filters: &'a [Expr], +) -> Option { + let mut partition_values = HashMap::new(); + for filter in filters { + populate_partition_values(&mut partition_values, filter); + } + + if partition_values.is_empty() { + return None; + } + + let mut parts = vec![]; + for (p, _) in partition_cols { + match partition_values.get(p.as_str()) { + Some(PartitionValue::Single(val)) => { + // if a partition only has a single literal value, then it can be added to the + // prefix + parts.push(format!("{p}={val}")); + } + _ => { + // break on the first unconstrainted partition to create a common prefix + // for all covered partitions. + break; + } + } + } + + if parts.is_empty() { + None + } else { + Some(Path::from_iter(parts)) + } +} + /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. /// `filters` might contain expressions that can be resolved only at the @@ -327,7 +411,10 @@ pub async fn pruned_partition_list<'a>( )); } - let partitions = list_partitions(store, table_path, partition_cols.len()).await?; + let partition_prefix = evaluate_partition_prefix(partition_cols, filters); + let partitions = + list_partitions(store, table_path, partition_cols.len(), partition_prefix) + .await?; debug!("Listed {} partitions", partitions.len()); let pruned = @@ -416,7 +503,9 @@ where mod tests { use std::ops::Not; - use crate::logical_expr::{case, col, lit}; + use futures::StreamExt; + + use crate::logical_expr::{case, col, lit, Expr}; use crate::test::object_store::make_test_store_and_state; use super::*; @@ -675,4 +764,114 @@ mod tests { // this helper function assert!(expr_applicable_for_cols(&[], &lit(true))); } + + #[test] + fn test_evaluate_partition_prefix() { + let partitions = &[ + ("a".to_string(), DataType::Utf8), + ("b".to_string(), DataType::Int16), + ("c".to_string(), DataType::Boolean), + ]; + + assert_eq!( + evaluate_partition_prefix(partitions, &[col("a").eq(lit("foo"))]), + Some(Path::from("a=foo")), + ); + + assert_eq!( + evaluate_partition_prefix(partitions, &[lit("foo").eq(col("a"))]), + Some(Path::from("a=foo")), + ); + + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(lit("foo")).and((col("b").eq(lit("bar"))))], + ), + Some(Path::from("a=foo/b=bar")), + ); + + assert_eq!( + evaluate_partition_prefix( + partitions, + // list of filters should be evaluated as AND + &[col("a").eq(lit("foo")), col("b").eq(lit("bar")),], + ), + Some(Path::from("a=foo/b=bar")), + ); + + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a") + .eq(lit("foo")) + .and(col("b").eq(lit("1"))) + .and(col("c").eq(lit("true")))], + ), + Some(Path::from("a=foo/b=1/c=true")), + ); + + // no prefix when filter is empty + assert_eq!(evaluate_partition_prefix(partitions, &[]), None); + + // b=foo results in no prefix because a is not restricted + assert_eq!( + evaluate_partition_prefix(partitions, &[Expr::eq(col("b"), lit("foo"))]), + None, + ); + + // a=foo and c=baz only results in preifx a=foo because b is not restricted + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(lit("foo")).and(col("c").eq(lit("baz")))], + ), + Some(Path::from("a=foo")), + ); + + // partition with multiple values results in no prefix + assert_eq!( + evaluate_partition_prefix( + partitions, + &[Expr::and(col("a").eq(lit("foo")), col("a").eq(lit("bar")))], + ), + None, + ); + + // no prefix because partition a is not restricted to a single literal + assert_eq!( + evaluate_partition_prefix( + partitions, + &[Expr::or(col("a").eq(lit("foo")), col("a").eq(lit("bar")))], + ), + None, + ); + assert_eq!( + evaluate_partition_prefix(partitions, &[col("b").lt(lit(5))],), + None, + ); + } + + #[test] + fn test_evaluate_date_partition_prefix() { + let partitions = &[("a".to_string(), DataType::Date32)]; + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + ), + Some(Path::from("a=1970-01-04")), + ); + + let partitions = &[("a".to_string(), DataType::Date64)]; + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( + 4 * 24 * 60 * 60 * 1000 + )))),], + ), + Some(Path::from("a=1970-01-05")), + ); + } } From ad2b1dcac8168906e4444527320d3139a1a2ea5b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 May 2024 07:48:33 -0700 Subject: [PATCH 05/35] Make swap_hash_join public API (#10702) --- datafusion/core/src/physical_optimizer/join_selection.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 135a59aa0353..1613e5089860 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -157,7 +157,9 @@ fn swap_join_projection( } /// This function swaps the inputs of the given join operator. -fn swap_hash_join( +/// This function is public so other downstream projects can use it +/// to construct `HashJoinExec` with right side as the build side. +pub fn swap_hash_join( hash_join: &HashJoinExec, partition_mode: PartitionMode, ) -> Result> { From 313f47f30c96d60cd187bbdcd26999ca8d48c609 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 30 May 2024 23:56:29 +0800 Subject: [PATCH 06/35] ci: fix clippy error (#10723) --- datafusion/core/src/datasource/listing/helpers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index b531cf8369cf..822a66783819 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -786,7 +786,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(lit("foo")).and((col("b").eq(lit("bar"))))], + &[col("a").eq(lit("foo")).and(col("b").eq(lit("bar")))], ), Some(Path::from("a=foo/b=bar")), ); From 904f0db73cf2c0049822c7045e09824a7453a02d Mon Sep 17 00:00:00 2001 From: Oleks V Date: Thu, 30 May 2024 09:30:30 -0700 Subject: [PATCH 07/35] Fix clippy (#10725) From 79481839da993623e899f8835a145ddd8bfc210e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 30 May 2024 23:41:43 +0300 Subject: [PATCH 08/35] Remove Eager Trait for Joins (#10721) * Remove eager trait * Update helpers.rs --- .../src/joins/stream_join_utils.rs | 354 +----------------- .../src/joins/symmetric_hash_join.rs | 268 ++++++++++++- datafusion/physical-plan/src/joins/utils.rs | 28 +- 3 files changed, 263 insertions(+), 387 deletions(-) diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index f82eb31f9699..0a01d84141e7 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -20,12 +20,11 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use std::task::{Context, Poll}; use std::usize; -use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; +use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, handle_state, metrics, ExecutionPlan}; +use crate::{metrics, ExecutionPlan}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -36,15 +35,12 @@ use datafusion_common::{ arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, }; -use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use async_trait::async_trait; -use futures::{ready, FutureExt, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; @@ -629,352 +625,6 @@ pub fn record_visited_indices( } } -/// Represents the various states of an eager join stream operation. -/// -/// This enum is used to track the current state of streaming during a join -/// operation. It provides indicators as to which side of the join needs to be -/// pulled next or if one (or both) sides have been exhausted. This allows -/// for efficient management of resources and optimal performance during the -/// join process. -#[derive(Clone, Debug)] -pub enum EagerJoinStreamState { - /// Indicates that the next step should pull from the right side of the join. - PullRight, - - /// Indicates that the next step should pull from the left side of the join. - PullLeft, - - /// State representing that the right side of the join has been fully processed. - RightExhausted, - - /// State representing that the left side of the join has been fully processed. - LeftExhausted, - - /// Represents a state where both sides of the join are exhausted. - /// - /// The `final_result` field indicates whether the join operation has - /// produced a final result or not. - BothExhausted { final_result: bool }, -} - -/// `EagerJoinStream` is an asynchronous trait designed for managing incremental -/// join operations between two streams, such as those used in `SymmetricHashJoinExec` -/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to scan -/// one side of the join fully before proceeding, `EagerJoinStream` facilitates -/// more dynamic join operations by working with streams as they emit data. This -/// approach allows for more efficient processing, particularly in scenarios -/// where waiting for complete data materialization is not feasible or optimal. -/// The trait provides a framework for handling various states of such a join -/// process, ensuring that join logic is efficiently executed as data becomes -/// available from either stream. -/// -/// Implementors of this trait can perform eager joins of data from two different -/// asynchronous streams, typically referred to as left and right streams. The -/// trait provides a comprehensive set of methods to control and execute the join -/// process, leveraging the states defined in `EagerJoinStreamState`. Methods are -/// primarily focused on asynchronously fetching data batches from each stream, -/// processing them, and managing transitions between various states of the join. -/// -/// This trait's default implementations use a state machine approach to navigate -/// different stages of the join operation, handling data from both streams and -/// determining when the join completes. -/// -/// State Transitions: -/// - From `PullLeft` to `PullRight` or `LeftExhausted`: -/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: -/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for -/// processing the batch. -/// - On error (`Some(Err(e))`), the error is returned, and the state remains -/// unchanged. -/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` -/// to proceed with the join process. -/// - From `PullRight` to `PullLeft` or `RightExhausted`: -/// - In `fetch_next_from_right_stream`, when fetching from the right stream: -/// - If a batch is available, state changes to `PullLeft` for processing. -/// - On error, the error is returned without changing the state. -/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, -/// with a `Continue` result. -/// - Handling `RightExhausted` and `LeftExhausted`: -/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios -/// when streams are exhausted: -/// - They attempt to continue processing with the other stream. -/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. -/// - Transition to `BothExhausted { final_result: true }`: -/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are -/// exhausted, indicating completion of processing and availability of final results. -#[async_trait] -pub trait EagerJoinStream { - /// Implements the main polling logic for the join stream. - /// - /// This method continuously checks the state of the join stream and - /// acts accordingly by delegating the handling to appropriate sub-methods - /// depending on the current state. - /// - /// # Arguments - /// - /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. - /// - /// # Returns - /// - /// * `Poll>>` - A polled result, either a `RecordBatch` or None. - fn poll_next_impl( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> - where - Self: Send, - { - loop { - return match self.state() { - EagerJoinStreamState::PullRight => { - handle_async_state!(self.fetch_next_from_right_stream(), cx) - } - EagerJoinStreamState::PullLeft => { - handle_async_state!(self.fetch_next_from_left_stream(), cx) - } - EagerJoinStreamState::RightExhausted => { - handle_async_state!(self.handle_right_stream_end(), cx) - } - EagerJoinStreamState::LeftExhausted => { - handle_async_state!(self.handle_left_stream_end(), cx) - } - EagerJoinStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) - } - EagerJoinStreamState::BothExhausted { final_result: true } => { - Poll::Ready(None) - } - }; - } - } - /// Asynchronously pulls the next batch from the right stream. - /// - /// This default implementation checks for the next value in the right stream. - /// If a batch is found, the state is switched to `PullLeft`, and the batch handling - /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. - /// - /// # Returns - /// - /// * `Result>>` - The state result after pulling the batch. - async fn fetch_next_from_right_stream( - &mut self, - ) -> Result>> { - match self.right_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.set_state(EagerJoinStreamState::PullLeft); - self.process_batch_from_right(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::RightExhausted); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Asynchronously pulls the next batch from the left stream. - /// - /// This default implementation checks for the next value in the left stream. - /// If a batch is found, the state is switched to `PullRight`, and the batch handling - /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. - /// - /// # Returns - /// - /// * `Result>>` - The state result after pulling the batch. - async fn fetch_next_from_left_stream( - &mut self, - ) -> Result>> { - match self.left_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.set_state(EagerJoinStreamState::PullRight); - self.process_batch_from_left(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::LeftExhausted); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Asynchronously handles the scenario when the right stream is exhausted. - /// - /// In this default implementation, when the right stream is exhausted, it attempts - /// to pull from the left stream. If a batch is found in the left stream, it delegates - /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set - /// to indicate both streams are exhausted without final results yet. - /// - /// # Returns - /// - /// * `Result>>` - The state result after checking the exhaustion state. - async fn handle_right_stream_end( - &mut self, - ) -> Result>> { - match self.left_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.process_batch_after_right_end(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::BothExhausted { - final_result: false, - }); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Asynchronously handles the scenario when the left stream is exhausted. - /// - /// When the left stream is exhausted, this default - /// implementation tries to pull from the right stream and delegates the batch - /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state - /// is updated to indicate so. - /// - /// # Returns - /// - /// * `Result>>` - The state result after checking the exhaustion state. - async fn handle_left_stream_end( - &mut self, - ) -> Result>> { - match self.right_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.process_batch_after_left_end(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::BothExhausted { - final_result: false, - }); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Handles the state when both streams are exhausted and final results are yet to be produced. - /// - /// This default implementation switches the state to indicate both streams are - /// exhausted with final results and then invokes the handling for this specific - /// scenario via `process_batches_before_finalization`. - /// - /// # Returns - /// - /// * `Result>>` - The state result after both streams are exhausted. - fn prepare_for_final_results_after_exhaustion( - &mut self, - ) -> Result>> { - self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); - self.process_batches_before_finalization() - } - - /// Handles a pulled batch from the right stream. - /// - /// # Arguments - /// - /// * `batch` - The pulled `RecordBatch` from the right stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after processing the batch. - fn process_batch_from_right( - &mut self, - batch: RecordBatch, - ) -> Result>>; - - /// Handles a pulled batch from the left stream. - /// - /// # Arguments - /// - /// * `batch` - The pulled `RecordBatch` from the left stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after processing the batch. - fn process_batch_from_left( - &mut self, - batch: RecordBatch, - ) -> Result>>; - - /// Handles the situation when only the left stream is exhausted. - /// - /// # Arguments - /// - /// * `right_batch` - The `RecordBatch` from the right stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after the left stream is exhausted. - fn process_batch_after_left_end( - &mut self, - right_batch: RecordBatch, - ) -> Result>>; - - /// Handles the situation when only the right stream is exhausted. - /// - /// # Arguments - /// - /// * `left_batch` - The `RecordBatch` from the left stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after the right stream is exhausted. - fn process_batch_after_right_end( - &mut self, - left_batch: RecordBatch, - ) -> Result>>; - - /// Handles the final state after both streams are exhausted. - /// - /// # Returns - /// - /// * `Result>>` - The final state result after processing. - fn process_batches_before_finalization( - &mut self, - ) -> Result>>; - - /// Provides mutable access to the right stream. - /// - /// # Returns - /// - /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the right stream. - fn right_stream(&mut self) -> &mut SendableRecordBatchStream; - - /// Provides mutable access to the left stream. - /// - /// # Returns - /// - /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the left stream. - fn left_stream(&mut self) -> &mut SendableRecordBatchStream; - - /// Sets the current state of the join stream. - /// - /// # Arguments - /// - /// * `state` - The new state to be set. - fn set_state(&mut self, state: EagerJoinStreamState); - - /// Fetches the current state of the join stream. - /// - /// # Returns - /// - /// * `EagerJoinStreamState` - The current state of the join stream. - fn state(&mut self) -> EagerJoinStreamState; -} - #[derive(Debug)] pub struct StreamJoinSideMetrics { /// Number of batches consumed by this operator diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 0d902af9c6cc..449c42d69797 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -28,17 +28,17 @@ use std::any::Any; use std::fmt::{self, Debug}; use std::sync::Arc; -use std::task::Poll; +use std::task::{Context, Poll}; use std::{usize, vec}; use crate::common::SharedMemoryReservation; +use crate::handle_state; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, - EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, - StreamJoinMetrics, + PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, @@ -72,7 +72,7 @@ use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use ahash::RandomState; -use futures::Stream; +use futures::{ready, Stream, StreamExt}; use hashbrown::HashSet; use parking_lot::Mutex; @@ -522,7 +522,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { left_sorted_filter_expr, right_sorted_filter_expr, null_equals_null: self.null_equals_null, - state: EagerJoinStreamState::PullRight, + state: SHJStreamState::PullRight, reservation, })) } @@ -560,7 +560,7 @@ struct SymmetricHashJoinStream { /// Memory reservation reservation: SharedMemoryReservation, /// State machine for input execution - state: EagerJoinStreamState, + state: SHJStreamState, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -1103,7 +1103,227 @@ impl OneSideHashJoiner { } } -impl EagerJoinStream for SymmetricHashJoinStream { +/// `SymmetricHashJoinStream` manages incremental join operations between two +/// streams. Unlike traditional join approaches that need to scan one side of +/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates +/// more dynamic join operations by working with streams as they emit data. This +/// approach allows for more efficient processing, particularly in scenarios +/// where waiting for complete data materialization is not feasible or optimal. +/// The trait provides a framework for handling various states of such a join +/// process, ensuring that join logic is efficiently executed as data becomes +/// available from either stream. +/// +/// This implementation performs eager joins of data from two different asynchronous +/// streams, typically referred to as left and right streams. The implementation +/// provides a comprehensive set of methods to control and execute the join +/// process, leveraging the states defined in `SHJStreamState`. Methods are +/// primarily focused on asynchronously fetching data batches from each stream, +/// processing them, and managing transitions between various states of the join. +/// +/// This implementations use a state machine approach to navigate different +/// stages of the join operation, handling data from both streams and determining +/// when the join completes. +/// +/// State Transitions: +/// - From `PullLeft` to `PullRight` or `LeftExhausted`: +/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: +/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for +/// processing the batch. +/// - On error (`Some(Err(e))`), the error is returned, and the state remains +/// unchanged. +/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` +/// to proceed with the join process. +/// - From `PullRight` to `PullLeft` or `RightExhausted`: +/// - In `fetch_next_from_right_stream`, when fetching from the right stream: +/// - If a batch is available, state changes to `PullLeft` for processing. +/// - On error, the error is returned without changing the state. +/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, +/// with a `Continue` result. +/// - Handling `RightExhausted` and `LeftExhausted`: +/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios +/// when streams are exhausted: +/// - They attempt to continue processing with the other stream. +/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. +/// - Transition to `BothExhausted { final_result: true }`: +/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are +/// exhausted, indicating completion of processing and availability of final results. +impl SymmetricHashJoinStream { + /// Implements the main polling logic for the join stream. + /// + /// This method continuously checks the state of the join stream and + /// acts accordingly by delegating the handling to appropriate sub-methods + /// depending on the current state. + /// + /// # Arguments + /// + /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. + /// + /// # Returns + /// + /// * `Poll>>` - A polled result, either a `RecordBatch` or None. + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + return match self.state() { + SHJStreamState::PullRight => { + handle_state!(ready!(self.fetch_next_from_right_stream(cx))) + } + SHJStreamState::PullLeft => { + handle_state!(ready!(self.fetch_next_from_left_stream(cx))) + } + SHJStreamState::RightExhausted => { + handle_state!(ready!(self.handle_right_stream_end(cx))) + } + SHJStreamState::LeftExhausted => { + handle_state!(ready!(self.handle_left_stream_end(cx))) + } + SHJStreamState::BothExhausted { + final_result: false, + } => { + handle_state!(self.prepare_for_final_results_after_exhaustion()) + } + SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None), + }; + } + } + /// Asynchronously pulls the next batch from the right stream. + /// + /// This default implementation checks for the next value in the right stream. + /// If a batch is found, the state is switched to `PullLeft`, and the batch handling + /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + fn fetch_next_from_right_stream( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.right_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + self.set_state(SHJStreamState::PullLeft); + Poll::Ready(self.process_batch_from_right(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::RightExhausted); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Asynchronously pulls the next batch from the left stream. + /// + /// This default implementation checks for the next value in the left stream. + /// If a batch is found, the state is switched to `PullRight`, and the batch handling + /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + fn fetch_next_from_left_stream( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.left_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + self.set_state(SHJStreamState::PullRight); + Poll::Ready(self.process_batch_from_left(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::LeftExhausted); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Asynchronously handles the scenario when the right stream is exhausted. + /// + /// In this default implementation, when the right stream is exhausted, it attempts + /// to pull from the left stream. If a batch is found in the left stream, it delegates + /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set + /// to indicate both streams are exhausted without final results yet. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + fn handle_right_stream_end( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.left_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Poll::Ready(self.process_batch_after_right_end(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::BothExhausted { + final_result: false, + }); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Asynchronously handles the scenario when the left stream is exhausted. + /// + /// When the left stream is exhausted, this default + /// implementation tries to pull from the right stream and delegates the batch + /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state + /// is updated to indicate so. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + fn handle_left_stream_end( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.right_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Poll::Ready(self.process_batch_after_left_end(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::BothExhausted { + final_result: false, + }); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Handles the state when both streams are exhausted and final results are yet to be produced. + /// + /// This default implementation switches the state to indicate both streams are + /// exhausted with final results and then invokes the handling for this specific + /// scenario via `process_batches_before_finalization`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after both streams are exhausted. + fn prepare_for_final_results_after_exhaustion( + &mut self, + ) -> Result>> { + self.set_state(SHJStreamState::BothExhausted { final_result: true }); + self.process_batches_before_finalization() + } + fn process_batch_from_right( &mut self, batch: RecordBatch, @@ -1189,16 +1409,14 @@ impl EagerJoinStream for SymmetricHashJoinStream { &mut self.left_stream } - fn set_state(&mut self, state: EagerJoinStreamState) { + fn set_state(&mut self, state: SHJStreamState) { self.state = state; } - fn state(&mut self) -> EagerJoinStreamState { + fn state(&mut self) -> SHJStreamState { self.state.clone() } -} -impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; size += std::mem::size_of_val(&self.schema); @@ -1321,6 +1539,34 @@ impl SymmetricHashJoinStream { } } +/// Represents the various states of an symmetric hash join stream operation. +/// +/// This enum is used to track the current state of streaming during a join +/// operation. It provides indicators as to which side of the join needs to be +/// pulled next or if one (or both) sides have been exhausted. This allows +/// for efficient management of resources and optimal performance during the +/// join process. +#[derive(Clone, Debug)] +pub enum SHJStreamState { + /// Indicates that the next step should pull from the right side of the join. + PullRight, + + /// Indicates that the next step should pull from the left side of the join. + PullLeft, + + /// State representing that the right side of the join has been fully processed. + RightExhausted, + + /// State representing that the left side of the join has been fully processed. + LeftExhausted, + + /// Represents a state where both sides of the join are exhausted. + /// + /// The `final_result` field indicates whether the join operation has + /// produced a final result or not. + BothExhausted { final_result: bool }, +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index acf9ed4d7ec8..0d99d7a16356 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1494,10 +1494,9 @@ impl BuildProbeJoinMetrics { } /// The `handle_state` macro is designed to process the result of a state-changing -/// operation, encountered e.g. in implementations of `EagerJoinStream`. It -/// operates on a `StatefulStreamResult` by matching its variants and executing -/// corresponding actions. This macro is used to streamline code that deals with -/// state transitions, reducing boilerplate and improving readability. +/// operation. It operates on a `StatefulStreamResult` by matching its variants and +/// executing corresponding actions. This macro is used to streamline code that deals +/// with state transitions, reducing boilerplate and improving readability. /// /// # Cases /// @@ -1525,26 +1524,7 @@ macro_rules! handle_state { }; } -/// The `handle_async_state` macro adapts the `handle_state` macro for use in -/// asynchronous operations, particularly when dealing with `Poll` results within -/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing -/// function using `poll_unpin` and then passes the result to `handle_state` for -/// further processing. -/// -/// # Arguments -/// -/// * `$state_func`: An async function or future that returns a -/// `Result>`. -/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. -/// -#[macro_export] -macro_rules! handle_async_state { - ($state_func:expr, $cx:expr) => { - $crate::handle_state!(ready!($state_func.poll_unpin($cx))) - }; -} - -/// Represents the result of an operation on stateful join stream. +/// Represents the result of a stateful operation. /// /// This enumueration indicates whether the state produced a result that is /// ready for use (`Ready`) or if the operation requires continuation (`Continue`). From 100b30e13583badc5aa9e88861d63feb80876c5e Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Fri, 31 May 2024 00:57:33 +0200 Subject: [PATCH 09/35] Minor: fix signature `fn octect_length()` (#10726) * fix: signature fn octect_length * chore: add test --- datafusion/core/tests/expr_api/mod.rs | 17 +++++++++++++++++ datafusion/functions/src/string/mod.rs | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index a69f7bd48437..1db5aa9f235a 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -28,6 +28,23 @@ use std::sync::{Arc, OnceLock}; mod simplification; +#[test] +fn test_octet_length() { + #[rustfmt::skip] + evaluate_expr_test( + octet_length(col("list")), + vec![ + "+------+", + "| expr |", + "+------+", + "| 5 |", + "| 18 |", + "| 6 |", + "+------+", + ], + ); +} + #[test] fn test_eq() { // id = '2' diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 52411142cb8d..e931c4998115 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -128,8 +128,8 @@ pub mod expr_fn { } #[doc = "returns the number of bytes of a string"] - pub fn octet_length(args: Vec) -> Expr { - super::octet_length().call(args) + pub fn octet_length(args: Expr) -> Expr { + super::octet_length().call(vec![args]) } #[doc = "replace the substring of string that starts at the start'th character and extends for count characters with new substring"] From 76f50b0ce913bdbea07bd4ca07c1bafc93c49e61 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 31 May 2024 17:46:21 +0800 Subject: [PATCH 10/35] docs: add documents to substrait type variation consts (#10719) * docs: add documents to substrait type variation consts Signed-off-by: Ruihang Xia * rename and add todo Signed-off-by: Ruihang Xia * fix link style Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 162 +++++++------- .../substrait/src/logical_plan/producer.rs | 208 ++++++++++-------- datafusion/substrait/src/variation_const.rs | 39 ++-- 3 files changed, 229 insertions(+), 180 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index eb819e2c87df..597f34e89a02 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -73,11 +73,14 @@ use std::str::FromStr; use std::sync::Arc; use crate::variation_const::{ - DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, - DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, - LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, - TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF, + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; enum ScalarFunctionType { @@ -1130,29 +1133,29 @@ fn from_substrait_type( Some(s_kind) => match s_kind { r#type::Kind::Bool(_) => Ok(DataType::Boolean), r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int8), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt8), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int16), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt16), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int32), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt32), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int64), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt64), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1160,16 +1163,16 @@ fn from_substrait_type( r#type::Kind::Fp32(_) => Ok(DataType::Float32), r#type::Kind::Fp64(_) => Ok(DataType::Float64), r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_REF => { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { Ok(DataType::Timestamp(TimeUnit::Second, None)) } - TIMESTAMP_MILLI_TYPE_REF => { + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) } - TIMESTAMP_MICRO_TYPE_REF => { + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) } - TIMESTAMP_NANO_TYPE_REF => { + TIMESTAMP_NANO_TYPE_VARIATION_REF => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } v => not_impl_err!( @@ -1177,15 +1180,15 @@ fn from_substrait_type( ), }, r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_REF => Ok(DataType::Date32), - DATE_64_TYPE_REF => Ok(DataType::Date64), + DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), + DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Binary), - LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeBinary), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1194,8 +1197,8 @@ fn from_substrait_type( Ok(DataType::FixedSizeBinary(fixed.length)) } r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Utf8), - LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeUtf8), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1209,18 +1212,18 @@ fn from_substrait_type( is_substrait_type_nullable(inner_type)?, )); match list.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), - LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" )?, } } r#type::Kind::Decimal(d) => match d.type_variation_reference { - DECIMAL_128_TYPE_REF => { + DECIMAL_128_TYPE_VARIATION_REF => { Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) } - DECIMAL_256_TYPE_REF => { + DECIMAL_256_TYPE_VARIATION_REF => { Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) } v => not_impl_err!( @@ -1397,29 +1400,29 @@ fn from_substrait_literal( let scalar_value = match &lit.literal_type { Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), Some(LiteralType::I8(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I16(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I32(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I64(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), others => { return substrait_err!("Unknown type variation reference {others}"); } @@ -1427,25 +1430,35 @@ fn from_substrait_literal( Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference { - TIMESTAMP_SECOND_TYPE_REF => ScalarValue::TimestampSecond(Some(*t), None), - TIMESTAMP_MILLI_TYPE_REF => ScalarValue::TimestampMillisecond(Some(*t), None), - TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), - TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + ScalarValue::TimestampSecond(Some(*t), None) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + ScalarValue::TimestampMillisecond(Some(*t), None) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + ScalarValue::TimestampMicrosecond(Some(*t), None) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + ScalarValue::TimestampNanosecond(Some(*t), None) + } others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), Some(LiteralType::String(s)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::LargeBinary(Some(b.clone())) + } others => { return substrait_err!("Unknown type variation reference {others}"); } @@ -1484,11 +1497,10 @@ fn from_substrait_literal( } let element_type = elements[0].data_type(); match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => ScalarValue::List(ScalarValue::new_list( - elements.as_slice(), - &element_type, - )), - LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList( + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( + ScalarValue::new_list(elements.as_slice(), &element_type), + ), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( ScalarValue::new_large_list(elements.as_slice(), &element_type), ), others => { @@ -1503,10 +1515,10 @@ fn from_substrait_literal( name_idx, )?; match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { ScalarValue::List(ScalarValue::new_list(&[], &element_type)) } - LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList( + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( ScalarValue::new_large_list(&[], &element_type), ), others => { @@ -1590,29 +1602,29 @@ fn from_substrait_null( match kind { r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)), r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int8(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt8(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int8(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt8(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int16(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt16(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int16(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt16(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int32(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt32(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int32(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt32(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int64(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt64(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int64(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt64(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), @@ -1620,14 +1632,16 @@ fn from_substrait_null( r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_REF => Ok(ScalarValue::TimestampSecond(None, None)), - TIMESTAMP_MILLI_TYPE_REF => { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampSecond(None, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { Ok(ScalarValue::TimestampMillisecond(None, None)) } - TIMESTAMP_MICRO_TYPE_REF => { + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { Ok(ScalarValue::TimestampMicrosecond(None, None)) } - TIMESTAMP_NANO_TYPE_REF => { + TIMESTAMP_NANO_TYPE_VARIATION_REF => { Ok(ScalarValue::TimestampNanosecond(None, None)) } v => not_impl_err!( @@ -1635,23 +1649,23 @@ fn from_substrait_null( ), }, r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_REF => Ok(ScalarValue::Date32(None)), - DATE_64_TYPE_REF => Ok(ScalarValue::Date64(None)), + DATE_32_TYPE_VARIATION_REF => Ok(ScalarValue::Date32(None)), + DATE_64_TYPE_VARIATION_REF => Ok(ScalarValue::Date64(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Binary(None)), - LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeBinary(None)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Binary(None)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeBinary(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, // FixedBinary is not supported because `None` doesn't have length r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Utf8(None)), - LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeUtf8(None)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Utf8(None)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeUtf8(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), @@ -1671,12 +1685,12 @@ fn from_substrait_null( true, ); match l.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::List(Arc::new( - GenericListArray::new_null(field.into(), 1), - ))), - LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeList(Arc::new( - GenericListArray::new_null(field.into(), 1), - ))), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::List( + Arc::new(GenericListArray::new_null(field.into(), 1)), + )), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeList( + Arc::new(GenericListArray::new_null(field.into(), 1)), + )), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 010386bf97ce..0208b010c856 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -91,13 +91,15 @@ use substrait::{ }; use crate::variation_const::{ - DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, - DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF, - INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_URL, INTERVAL_YEAR_MONTH_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, - TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, - UNSIGNED_INTEGER_TYPE_REF, + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_DAY_TIME_TYPE_URL, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, + INTERVAL_YEAR_MONTH_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_URL, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; /// Convert DataFusion LogicalPlan to Substrait Plan @@ -626,7 +628,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { .iter() .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) .collect::>()?, - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Unspecified as i32, }; @@ -1430,78 +1432,78 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::Int8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::UInt8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), DataType::Int16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::UInt16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), DataType::Int32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::UInt32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), DataType::Int64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::UInt64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, nullability, })), }), // Float16 is not supported in Substrait DataType::Float32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::Float64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), // Timezone is ignored. DataType::Timestamp(unit, _) => { let type_variation_reference = match unit { - TimeUnit::Second => TIMESTAMP_SECOND_TYPE_REF, - TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_REF, - TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_REF, - TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_REF, + TimeUnit::Second => TIMESTAMP_SECOND_TYPE_VARIATION_REF, + TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_VARIATION_REF, }; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { @@ -1512,13 +1514,13 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_REF, + type_variation_reference: DATE_32_TYPE_VARIATION_REF, nullability, })), }), DataType::Date64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_REF, + type_variation_reference: DATE_64_TYPE_VARIATION_REF, nullability, })), }), @@ -1527,7 +1529,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result Result Result Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { length: *length, - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, })), }), DataType::LargeBinary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), DataType::Utf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), DataType::LargeUtf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), @@ -1601,7 +1603,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result Result Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_128_TYPE_REF, + type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, nullability, scale: *s as i32, precision: *p as i32, @@ -1639,7 +1641,7 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_256_TYPE_REF, + type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, nullability, scale: *s as i32, precision: *p as i32, @@ -1861,7 +1863,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { if value.is_null() { return Ok(Literal { nullable: true, - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, literal_type: Some(LiteralType::Null(to_substrait_type( &value.data_type(), true, @@ -1869,38 +1871,58 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }); } let (literal_type, type_variation_reference) = match value { - ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), - ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), - ScalarValue::UInt8(Some(n)) => { - (LiteralType::I8(*n as i32), UNSIGNED_INTEGER_TYPE_REF) + ScalarValue::Boolean(Some(b)) => { + (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_REF), - ScalarValue::UInt16(Some(n)) => { - (LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF) + ScalarValue::Int8(Some(n)) => { + (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_REF), - ScalarValue::UInt32(Some(n)) => { - (LiteralType::I32(*n as i32), UNSIGNED_INTEGER_TYPE_REF) - } - ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_REF), - ScalarValue::UInt64(Some(n)) => { - (LiteralType::I64(*n as i64), UNSIGNED_INTEGER_TYPE_REF) - } - ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_REF), - ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_REF), - ScalarValue::TimestampSecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_SECOND_TYPE_REF) + ScalarValue::UInt8(Some(n)) => ( + LiteralType::I8(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int16(Some(n)) => { + (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampMillisecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_MILLI_TYPE_REF) + ScalarValue::UInt16(Some(n)) => ( + LiteralType::I16(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt32(Some(n)) => ( + LiteralType::I32(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt64(Some(n)) => ( + LiteralType::I64(*n as i64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Float32(Some(f)) => { + (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampMicrosecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_MICRO_TYPE_REF) + ScalarValue::Float64(Some(f)) => { + (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampNanosecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_NANO_TYPE_REF) + ScalarValue::TimestampSecond(Some(t), _) => ( + LiteralType::Timestamp(*t), + TIMESTAMP_SECOND_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), _) => ( + LiteralType::Timestamp(*t), + TIMESTAMP_MILLI_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), _) => ( + LiteralType::Timestamp(*t), + TIMESTAMP_MICRO_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), _) => ( + LiteralType::Timestamp(*t), + TIMESTAMP_NANO_TYPE_VARIATION_REF, + ), + ScalarValue::Date32(Some(d)) => { + (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) } - ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF), // Date64 literal is not supported in Substrait ScalarValue::IntervalYearMonth(Some(i)) => { let bytes = i.to_le_bytes(); @@ -1911,7 +1933,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { parameter: Some(parameter::Parameter::DataType( substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, })), }, @@ -1931,7 +1953,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { let i64_param = Parameter { parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, })), })), @@ -1957,7 +1979,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { parameter: Some(parameter::Parameter::DataType( substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, })), }, @@ -1971,36 +1993,42 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { INTERVAL_DAY_TIME_TYPE_REF, ) } - ScalarValue::Binary(Some(b)) => { - (LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF) - } - ScalarValue::LargeBinary(Some(b)) => { - (LiteralType::Binary(b.clone()), LARGE_CONTAINER_TYPE_REF) - } - ScalarValue::FixedSizeBinary(_, Some(b)) => { - (LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_REF) - } - ScalarValue::Utf8(Some(s)) => { - (LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_REF) - } - ScalarValue::LargeUtf8(Some(s)) => { - (LiteralType::String(s.clone()), LARGE_CONTAINER_TYPE_REF) - } + ScalarValue::Binary(Some(b)) => ( + LiteralType::Binary(b.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeBinary(Some(b)) => ( + LiteralType::Binary(b.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::FixedSizeBinary(_, Some(b)) => ( + LiteralType::FixedBinary(b.clone()), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8(Some(s)) => ( + LiteralType::String(s.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeUtf8(Some(s)) => ( + LiteralType::String(s.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), ScalarValue::Decimal128(v, p, s) if v.is_some() => ( LiteralType::Decimal(Decimal { value: v.unwrap().to_le_bytes().to_vec(), precision: *p as i32, scale: *s as i32, }), - DECIMAL_128_TYPE_REF, + DECIMAL_128_TYPE_VARIATION_REF, ), ScalarValue::List(l) => ( convert_array_to_literal_list(l)?, - DEFAULT_CONTAINER_TYPE_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeList(l) => ( + convert_array_to_literal_list(l)?, + LARGE_CONTAINER_TYPE_VARIATION_REF, ), - ScalarValue::LargeList(l) => { - (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF) - } ScalarValue::Struct(s) => ( LiteralType::Struct(Struct { fields: s @@ -2011,11 +2039,11 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) .collect::>>()?, }), - DEFAULT_TYPE_REF, + DEFAULT_TYPE_VARIATION_REF, ), _ => ( not_impl_err!("Unsupported literal: {value:?}")?, - DEFAULT_TYPE_REF, + DEFAULT_TYPE_VARIATION_REF, ), }; diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 51c0d3b0211e..27f4b3ea228a 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -18,28 +18,35 @@ //! Type variation constants //! //! To add support for types not in the [core specification](https://substrait.io/types/type_classes/), -//! we make use of the [simple extensions](https://substrait.io/extensions/#simple-extensions) of substrait -//! type. This module contains the constants used to identify the type variation. +//! we make use of the [simple extensions] of substrait type. This module contains the constants used +//! to identify the type variation. //! //! The rules of type variations here are: //! - Default type reference is 0. It is used when the actual type is the same with the original type. //! - Extended variant type references start from 1, and ususlly increase by 1. +//! +//! Definitions here are not the final form. All the non-system-preferred variations will be defined +//! using [simple extensions] as per the [spec of type_variations](https://substrait.io/types/type_variations/) +//! +//! [simple extensions]: (https://substrait.io/extensions/#simple-extensions) -// For type variations -pub const DEFAULT_TYPE_REF: u32 = 0; -pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1; -pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0; -pub const TIMESTAMP_MILLI_TYPE_REF: u32 = 1; -pub const TIMESTAMP_MICRO_TYPE_REF: u32 = 2; -pub const TIMESTAMP_NANO_TYPE_REF: u32 = 3; -pub const DATE_32_TYPE_REF: u32 = 0; -pub const DATE_64_TYPE_REF: u32 = 1; -pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0; -pub const LARGE_CONTAINER_TYPE_REF: u32 = 1; -pub const DECIMAL_128_TYPE_REF: u32 = 0; -pub const DECIMAL_256_TYPE_REF: u32 = 1; +// For [type variations](https://substrait.io/types/type_variations/#type-variations) in substrait. +// Type variations are used to represent different types based on one type class. +/// The "system-preferred" variation (i.e., no variation). +pub const DEFAULT_TYPE_VARIATION_REF: u32 = 0; +pub const UNSIGNED_INTEGER_TYPE_VARIATION_REF: u32 = 1; +pub const TIMESTAMP_SECOND_TYPE_VARIATION_REF: u32 = 0; +pub const TIMESTAMP_MILLI_TYPE_VARIATION_REF: u32 = 1; +pub const TIMESTAMP_MICRO_TYPE_VARIATION_REF: u32 = 2; +pub const TIMESTAMP_NANO_TYPE_VARIATION_REF: u32 = 3; +pub const DATE_32_TYPE_VARIATION_REF: u32 = 0; +pub const DATE_64_TYPE_VARIATION_REF: u32 = 1; +pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0; +pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; +pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; +pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; -// For custom types +// For [user-defined types](https://substrait.io/types/type_classes/#user-defined-types). /// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. /// /// An `i32` for elapsed whole months. See also [`ScalarValue::IntervalYearMonth`] From 075ed331bc7721159b3d9cfd1a05cb6a7dba1a1e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 06:16:13 -0400 Subject: [PATCH 11/35] Update rstest requirement from 0.19.0 to 0.20.0 (#10734) Updates the requirements on [rstest](https://github.com/la10736/rstest) to permit the latest version. - [Release notes](https://github.com/la10736/rstest/releases) - [Changelog](https://github.com/la10736/rstest/blob/master/CHANGELOG.md) - [Commits](https://github.com/la10736/rstest/compare/v0.19.0...v0.19.0) --- updated-dependencies: - dependency-name: rstest dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 6174840e745a..45504be3f1ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,7 +109,7 @@ parking_lot = "0.12" parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" regex = "1.8" -rstest = "0.19.0" +rstest = "0.20.0" serde_json = "1" sqlparser = { version = "0.45.0", features = ["visitor"] } tempfile = "3" From 02a76daa72ec4c98ed63bd4eb4523e011cc7e610 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 06:56:39 -0400 Subject: [PATCH 12/35] Update rstest_reuse requirement from 0.6.0 to 0.7.0 (#10733) * Update rstest_reuse requirement from 0.6.0 to 0.7.0 Updates the requirements on [rstest_reuse](https://github.com/la10736/rstest) to permit the latest version. - [Release notes](https://github.com/la10736/rstest/releases) - [Changelog](https://github.com/la10736/rstest/blob/master/CHANGELOG.md) - [Commits](https://github.com/la10736/rstest/compare/0.6.0...0.7.0) --- updated-dependencies: - dependency-name: rstest_reuse dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Remove unused reuse --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- datafusion/physical-plan/Cargo.toml | 2 +- datafusion/physical-plan/src/lib.rs | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index dac2a24d359d..4292f95fe406 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -65,7 +65,7 @@ tokio = { workspace = true } [dev-dependencies] rstest = { workspace = true } -rstest_reuse = "0.6.0" +rstest_reuse = "0.7.0" termtree = "0.4.1" tokio = { workspace = true, features = [ "rt-multi-thread", diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 739bff2cfa23..bd77814bbbe4 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -798,10 +798,6 @@ pub fn get_plan_string(plan: &Arc) -> Vec { actual.iter().map(|elem| elem.to_string()).collect() } -#[cfg(test)] -#[allow(clippy::single_component_path_imports)] -use rstest_reuse; - #[cfg(test)] mod tests { use std::any::Any; From 09dde27be39ad054f85dfb5c37b7468a3f68d652 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 31 May 2024 06:56:51 -0400 Subject: [PATCH 13/35] Add example for building an external secondary index for parquet files (#10549) * Add example for building an external index for parquet filtes * Use register_object_store api * use FileScanConfig API * Udpate to use new API * Collapose `use` statements * fix typo --- datafusion-examples/README.md | 1 + datafusion-examples/examples/parquet_index.rs | 705 ++++++++++++++++++ 2 files changed, 706 insertions(+) create mode 100644 datafusion-examples/examples/parquet_index.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 778950cbf926..a5395ea7aab3 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -60,6 +60,7 @@ cargo run --example csv_sql - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es +- ['parquet_index.rs'](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files - ['parquet_exec_visitor.rs'](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs new file mode 100644 index 000000000000..625133ae7cbd --- /dev/null +++ b/datafusion-examples/examples/parquet_index.rs @@ -0,0 +1,705 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, Int32Array, RecordBatch, StringArray, + UInt64Array, +}; +use arrow::datatypes::Int32Type; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{ + parquet::{RequestedStatistics, StatisticsConverter}, + {FileScanConfig, ParquetExec}, +}; +use datafusion::datasource::TableProvider; +use datafusion::execution::context::SessionState; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::{ + arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, +}; +use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::*; +use datafusion_common::{ + internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{utils::conjunction, TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr::PhysicalExpr; +use std::any::Any; +use std::collections::HashSet; +use std::fmt::Display; +use std::fs::{self, DirEntry, File}; +use std::ops::Range; +use std::path::{Path, PathBuf}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tempfile::TempDir; +use url::Url; + +/// This example demonstrates building a secondary index over multiple Parquet +/// files and using that index during query to skip ("prune") files that do not +/// contain relevant data. +/// +/// This example rules out relevant data using min/max values of a column +/// extracted from the Parquet metadata. In a real system, the index could be +/// more sophisticated, e.g. using inverted indices, bloom filters or other +/// techniques. +/// +/// Note this is a low level example for people who want to build their own +/// custom indexes. To read a directory of parquet files as a table, you can use +/// a higher level API such as [`SessionContext::read_parquet`] or +/// [`ListingTable`], which also do file pruning based on parquet statistics +/// (using the same underlying APIs) +/// +/// For a more advanced example of using an index to prune row groups within a +/// file, see the (forthcoming) `advanced_parquet_index` example. +/// +/// # Diagram +/// +/// ```text +/// ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ +/// ┃ Index ┃ +/// ┃ ┃ +/// step 1: predicate is ┌ ─ ─ ─ ─▶┃ (sometimes referred to ┃ +/// evaluated against ┃ as a "catalog" or ┃ +/// data in the index │ ┃ "metastore") ┃ +/// (using ┗━━━━━━━━━━━━━━━━━━━━━━━━┛ +/// PruningPredicate) │ │ +/// +/// │ │ +/// ┌──────────────┐ +/// │ value = 150 │─ ─ ─ ─ ┘ │ +/// └──────────────┘ ┌─────────────┐ +/// Predicate from query │ │ │ +/// └─────────────┘ +/// │ ┌─────────────┐ +/// step 2: Index returns only ─ ▶│ │ +/// parquet files that might have └─────────────┘ +/// matching data. ... +/// ┌─────────────┐ +/// Thus some parquet files are │ │ +/// "pruned" and thus are not └─────────────┘ +/// scanned at all Parquet Files +/// +/// ``` +/// +/// [`ListingTable`]: datafusion::datasource::listing::ListingTable +#[tokio::main] +async fn main() -> Result<()> { + // Demo data has three files, each with schema + // * file_name (string) + // * value (int32) + // + // The files are as follows: + // * file1.parquet (value: 0..100) + // * file2.parquet (value: 100..200) + // * file3.parquet (value: 200..3000) + let data = DemoData::try_new()?; + + // Create a table provider with and our special index. + let provider = Arc::new(IndexTableProvider::try_new(data.path())?); + println!("** Table Provider:"); + println!("{provider}\n"); + + // Create a SessionContext for running queries that has the table provider + // registered as "index_table" + let ctx = SessionContext::new(); + ctx.register_table("index_table", Arc::clone(&provider) as _)?; + + // register object store provider for urls like `file://` work + let url = Url::try_from("file://").unwrap(); + let object_store = object_store::local::LocalFileSystem::new(); + ctx.register_object_store(&url, Arc::new(object_store)); + + // Select data from the table without any predicates (and thus no pruning) + println!("** Select data, no predicates:"); + ctx.sql("SELECT file_name, value FROM index_table LIMIT 10") + .await? + .show() + .await?; + println!("Files pruned: {}\n", provider.index().last_num_pruned()); + + // Run a query that uses the index to prune files. + // + // Using the predicate "value = 150", the IndexTable can skip reading file 1 + // (max value 100) and file 3 (min value of 200) + println!("** Select data, predicate `value = 150`"); + ctx.sql("SELECT file_name, value FROM index_table WHERE value = 150") + .await? + .show() + .await?; + println!("Files pruned: {}\n", provider.index().last_num_pruned()); + + // likewise, we can use a more complicated predicate like + // "value < 20 OR value > 500" to read only file 1 and file 3 + println!("** Select data, predicate `value < 20 OR value > 500`"); + ctx.sql( + "SELECT file_name, count(value) FROM index_table \ + WHERE value < 20 OR value > 500 GROUP BY file_name", + ) + .await? + .show() + .await?; + println!("Files pruned: {}\n", provider.index().last_num_pruned()); + + Ok(()) +} + +/// DataFusion `TableProvider` that uses [`IndexTableProvider`], a secondary +/// index to decide which Parquet files to read. +#[derive(Debug)] +pub struct IndexTableProvider { + /// The index of the parquet files in the directory + index: ParquetMetadataIndex, + /// the directory in which the files are stored + dir: PathBuf, +} + +impl Display for IndexTableProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "IndexTableProvider")?; + writeln!(f, "---- Index ----")?; + write!(f, "{}", self.index) + } +} + +impl IndexTableProvider { + /// Create a new IndexTableProvider + pub fn try_new(dir: impl Into) -> Result { + let dir = dir.into(); + + // Create an index of the parquet files in the directory as we see them. + let mut index_builder = ParquetMetadataIndexBuilder::new(); + + let files = read_dir(&dir)?; + for file in &files { + index_builder.add_file(&file.path())?; + } + + let index = index_builder.build()?; + + Ok(Self { index, dir }) + } + + /// return a reference to the underlying index + fn index(&self) -> &ParquetMetadataIndex { + &self.index + } +} + +#[async_trait] +impl TableProvider for IndexTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.index.schema().clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let df_schema = DFSchema::try_from(self.schema())?; + // convert filters like [`a = 1`, `b = 2`] to a single filter like `a = 1 AND b = 2` + let predicate = conjunction(filters.to_vec()); + let predicate = predicate + .map(|predicate| state.create_physical_expr(predicate, &df_schema)) + .transpose()? + // if there are no filters, use a literal true to have a predicate + // that always evaluates to true we can pass to the index + .unwrap_or_else(|| datafusion_physical_expr::expressions::lit(true)); + + // Use the index to find the files that might have data that matches the + // predicate. Any file that can not have data that matches the predicate + // will not be returned. + let files = self.index.get_files(predicate.clone())?; + + let object_store_url = ObjectStoreUrl::parse("file://")?; + let mut file_scan_config = FileScanConfig::new(object_store_url, self.schema()) + .with_projection(projection.cloned()) + .with_limit(limit); + + // Transform to the format needed to pass to ParquetExec + // Create one file group per file (default to scanning them all in parallel) + for (file_name, file_size) in files { + let path = self.dir.join(file_name); + let canonical_path = fs::canonicalize(path)?; + file_scan_config = file_scan_config.with_file(PartitionedFile::new( + canonical_path.display().to_string(), + file_size, + )); + } + let exec = ParquetExec::builder(file_scan_config) + .with_predicate(predicate) + .build_arc(); + + Ok(exec) + } + + /// Tell DataFusion to push filters down to the scan method + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Inexact because the pruning can't handle all expressions and pruning + // is not done at the row level -- there may be rows in returned files + // that do not pass the filter + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } +} + +/// Simple in memory secondary index for a set of parquet files +/// +/// The index is represented as an arrow [`RecordBatch`] that can be passed +/// directly by the DataFusion [`PruningPredicate`] API +/// +/// The `RecordBatch` looks as follows. +/// +/// ```text +/// +---------------+-----------+-----------+------------------+------------------+ +/// | file_name | file_size | row_count | value_column_min | value_column_max | +/// +---------------+-----------+-----------+------------------+------------------+ +/// | file1.parquet | 6062 | 100 | 0 | 99 | +/// | file2.parquet | 6062 | 100 | 100 | 199 | +/// | file3.parquet | 163310 | 2800 | 200 | 2999 | +/// +---------------+-----------+-----------+------------------+------------------+ +/// ``` +/// +/// It must store file_name and file_size to construct `PartitionedFile`. +/// +/// Note a more advanced index might store finer grained information, such as information +/// about each row group within a file +#[derive(Debug)] +struct ParquetMetadataIndex { + file_schema: SchemaRef, + /// The index of the parquet files. See the struct level documentation for + /// the schema of this index. + index: RecordBatch, + /// The number of files that were pruned in the last query + last_num_pruned: AtomicUsize, +} + +impl Display for ParquetMetadataIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "ParquetMetadataIndex(last_num_pruned: {})", + self.last_num_pruned() + )?; + let batches = pretty_format_batches(&[self.index.clone()]).unwrap(); + write!(f, "{batches}",) + } +} + +impl ParquetMetadataIndex { + /// the schema of the *files* in the index (not the index's schema) + fn schema(&self) -> &SchemaRef { + &self.file_schema + } + + /// number of files in the index + fn len(&self) -> usize { + self.index.num_rows() + } + + /// Return a [`PartitionedFile`] for the specified file offset + /// + /// For example, if the index batch contained data like + /// + /// ```text + /// fileA + /// fileB + /// fileC + /// ``` + /// + /// `get_file(1)` would return `(fileB, size)` + fn get_file(&self, file_offset: usize) -> (&str, u64) { + // Filenames and sizes are always non null, so we don't have to check is_valid + let file_name = self.file_names().value(file_offset); + let file_size = self.file_size().value(file_offset); + (file_name, file_size) + } + + /// Return the number of files that were pruned in the last query + pub fn last_num_pruned(&self) -> usize { + self.last_num_pruned.load(Ordering::SeqCst) + } + + /// Set the number of files that were pruned in the last query + fn set_last_num_pruned(&self, num_pruned: usize) { + self.last_num_pruned.store(num_pruned, Ordering::SeqCst); + } + + /// Return all the files matching the predicate + /// + /// Returns a tuple `(file_name, file_size)` + pub fn get_files( + &self, + predicate: Arc, + ) -> Result> { + // Use the PruningPredicate API to determine which files can not + // possibly have any relevant data. + let pruning_predicate = + PruningPredicate::try_new(predicate, self.schema().clone())?; + + // Now evaluate the pruning predicate into a boolean mask, one element per + // file in the index. If the mask is true, the file may have rows that + // match the predicate. If the mask is false, we know the file can not have *any* + // rows that match the predicate and thus can be skipped. + let file_mask = pruning_predicate.prune(self)?; + + let num_left = file_mask.iter().filter(|x| **x).count(); + self.set_last_num_pruned(self.len() - num_left); + + // Return only files that match the predicate from the index + let files_and_sizes: Vec<_> = file_mask + .into_iter() + .enumerate() + .filter_map(|(file, keep)| { + if keep { + Some(self.get_file(file)) + } else { + None + } + }) + .collect(); + Ok(files_and_sizes) + } + + /// Return the file_names column of this index + fn file_names(&self) -> &StringArray { + self.index + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + } + + /// Return the file_size column of this index + fn file_size(&self) -> &UInt64Array { + self.index + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + } + + /// Reference to the row count column + fn row_counts_ref(&self) -> &ArrayRef { + self.index.column(2) + } + + /// Reference to the column minimum values + fn value_column_mins(&self) -> &ArrayRef { + self.index.column(3) + } + + /// Reference to the column maximum values + fn value_column_maxes(&self) -> &ArrayRef { + self.index.column(4) + } +} + +/// In order to use the PruningPredicate API, we need to provide DataFusion +/// the required statistics via the [`PruningStatistics`] trait +impl PruningStatistics for ParquetMetadataIndex { + /// return the minimum values for the value column + fn min_values(&self, column: &Column) -> Option { + if column.name.eq("value") { + Some(self.value_column_mins().clone()) + } else { + None + } + } + + /// return the maximum values for the value column + fn max_values(&self, column: &Column) -> Option { + if column.name.eq("value") { + Some(self.value_column_maxes().clone()) + } else { + None + } + } + + /// return the number of "containers". In this example, each "container" is + /// a file (aka a row in the index) + fn num_containers(&self) -> usize { + self.len() + } + + /// Return `None` to signal we don't have any information about null + /// counts in the index, + fn null_counts(&self, _column: &Column) -> Option { + None + } + + /// return the row counts for each file + fn row_counts(&self, _column: &Column) -> Option { + Some(self.row_counts_ref().clone()) + } + + /// The `contained` API can be used with structures such as Bloom filters, + /// but is not used in this example, so return `None` + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } +} + +/// Builds a [`ParquetMetadataIndex`] from a set of parquet files +#[derive(Debug, Default)] +struct ParquetMetadataIndexBuilder { + file_schema: Option, + filenames: Vec, + file_sizes: Vec, + row_counts: Vec, + /// Holds the min/max value of the value column for each file + value_column_mins: Vec, + value_column_maxs: Vec, +} + +impl ParquetMetadataIndexBuilder { + fn new() -> Self { + Self::default() + } + + /// Add a new file to the index + fn add_file(&mut self, file: &Path) -> Result<()> { + let file_name = file + .file_name() + .ok_or_else(|| internal_datafusion_err!("No filename"))? + .to_str() + .ok_or_else(|| internal_datafusion_err!("Invalid filename"))?; + let file_size = file.metadata()?.len(); + + let file = File::open(file).map_err(|e| { + DataFusionError::from(e).context(format!("Error opening file {file:?}")) + })?; + + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?; + + // Get the schema of the file. A real system might have to handle the + // case where the schema of the file is not the same as the schema of + // the other files e.g. using SchemaAdapter. + if self.file_schema.is_none() { + self.file_schema = Some(reader.schema().clone()); + } + + // extract the parquet statistics from the file's footer + let metadata = reader.metadata(); + + // Extract the min/max values for each row group from the statistics + let row_counts = StatisticsConverter::row_counts(reader.metadata())?; + let value_column_mins = StatisticsConverter::try_new( + "value", + RequestedStatistics::Min, + reader.schema(), + )? + .extract(reader.metadata())?; + let value_column_maxes = StatisticsConverter::try_new( + "value", + RequestedStatistics::Max, + reader.schema(), + )? + .extract(reader.metadata())?; + + // In a real system you would have to handle nulls, which represent + // unknown statistics. All statistics are known in this example + assert_eq!(row_counts.null_count(), 0); + assert_eq!(value_column_mins.null_count(), 0); + assert_eq!(value_column_maxes.null_count(), 0); + + // The statistics gathered above are for each row group. We need to + // aggregate them together to compute the overall file row count, + // min and max. + let row_count = row_counts + .iter() + .flatten() // skip nulls (should be none) + .sum::(); + let value_column_min = value_column_mins + .as_primitive::() + .iter() + .flatten() // skip nulls (i.e. min is unknown) + .min() + .unwrap_or_default(); + let value_column_max = value_column_maxes + .as_primitive::() + .iter() + .flatten() // skip nulls (i.e. max is unknown) + .max() + .unwrap_or_default(); + + // sanity check the statistics + assert_eq!(row_count, metadata.file_metadata().num_rows() as u64); + + self.add_row( + file_name, + file_size, + row_count, + value_column_min, + value_column_max, + ); + Ok(()) + } + + /// Add an entry for a single new file to the in progress index + fn add_row( + &mut self, + file_name: impl Into, + file_size: u64, + row_count: u64, + value_column_min: i32, + value_column_max: i32, + ) { + self.filenames.push(file_name.into()); + self.file_sizes.push(file_size); + self.row_counts.push(row_count); + self.value_column_mins.push(value_column_min); + self.value_column_maxs.push(value_column_max); + } + + /// Build the index from the files added + fn build(self) -> Result { + let Some(file_schema) = self.file_schema else { + return Err(internal_datafusion_err!("No files added to index")); + }; + + let file_name: ArrayRef = Arc::new(StringArray::from(self.filenames)); + let file_size: ArrayRef = Arc::new(UInt64Array::from(self.file_sizes)); + let row_count: ArrayRef = Arc::new(UInt64Array::from(self.row_counts)); + let value_column_min: ArrayRef = + Arc::new(Int32Array::from(self.value_column_mins)); + let value_column_max: ArrayRef = + Arc::new(Int32Array::from(self.value_column_maxs)); + + let index = RecordBatch::try_from_iter(vec![ + ("file_name", file_name), + ("file_size", file_size), + ("row_count", row_count), + ("value_column_min", value_column_min), + ("value_column_max", value_column_max), + ])?; + + Ok(ParquetMetadataIndex { + file_schema, + index, + last_num_pruned: AtomicUsize::new(0), + }) + } +} + +/// Return a list of the directory entries in the given directory, sorted by name +fn read_dir(dir: &Path) -> Result> { + let mut files = dir + .read_dir() + .map_err(|e| { + DataFusionError::from(e).context(format!("Error reading directory {dir:?}")) + })? + .map(|entry| { + entry.map_err(|e| { + DataFusionError::from(e) + .context(format!("Error reading directory entry in {dir:?}")) + }) + }) + .collect::>>()?; + files.sort_by_key(|entry| entry.file_name()); + Ok(files) +} + +/// Demonstration Data +/// +/// Makes a directory with three parquet files +/// +/// The schema of the files is +/// * file_name (string) +/// * value (int32) +/// +/// The files are as follows: +/// * file1.parquet (values 0..100) +/// * file2.parquet (values 100..200) +/// * file3.parquet (values 200..3000) +struct DemoData { + tmpdir: TempDir, +} + +impl DemoData { + fn try_new() -> Result { + let tmpdir = TempDir::new()?; + make_demo_file(tmpdir.path().join("file1.parquet"), 0..100)?; + make_demo_file(tmpdir.path().join("file2.parquet"), 100..200)?; + make_demo_file(tmpdir.path().join("file3.parquet"), 200..3000)?; + + Ok(Self { tmpdir }) + } + + fn path(&self) -> PathBuf { + self.tmpdir.path().into() + } +} + +/// Creates a new parquet file at the specified path. +/// +/// The `value` column increases sequentially from `min_value` to `max_value` +/// with the following schema: +/// +/// * file_name: Utf8 +/// * value: Int32 +fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> { + let path = path.as_ref(); + let file = File::create(path)?; + let filename = path + .file_name() + .ok_or_else(|| internal_datafusion_err!("No filename"))? + .to_str() + .ok_or_else(|| internal_datafusion_err!("Invalid filename"))?; + + let num_values = value_range.len(); + let file_names = + StringArray::from_iter_values(std::iter::repeat(&filename).take(num_values)); + let values = Int32Array::from_iter_values(value_range); + let batch = RecordBatch::try_from_iter(vec![ + ("file_name", Arc::new(file_names) as ArrayRef), + ("value", Arc::new(values) as ArrayRef), + ])?; + + let schema = batch.schema(); + + // write the actual values to the file + let props = None; + let mut writer = ArrowWriter::try_new(file, schema, props)?; + writer.write(&batch)?; + writer.finish()?; + + Ok(()) +} From a803214015ff8705c6d0bc5e4beff507a4c5876c Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Fri, 31 May 2024 23:28:19 +0200 Subject: [PATCH 14/35] chore: move stddev test to slt (#10741) --- .../physical-expr/src/aggregate/stddev.rs | 127 ------------------ .../sqllogictest/test_files/aggregate.slt | 122 +++++++++++++++++ 2 files changed, 122 insertions(+), 127 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index e5ce1b9230db..ec8d8cea67c4 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -245,135 +245,8 @@ mod tests { use super::*; use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::{array::*, datatypes::*}; - #[test] - fn stddev_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); - generic_test_op!(a, DataType::Float64, StddevPop, ScalarValue::from(0.5_f64)) - } - - #[test] - fn stddev_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - StddevPop, - ScalarValue::from(0.7760297817881877_f64) - ) - } - - #[test] - fn stddev_f64_3() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn stddev_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - Stddev, - ScalarValue::from(0.9504384952922168_f64) - ) - } - - #[test] - fn stddev_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn stddev_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn stddev_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn test_stddev_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Stddev::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - - Ok(()) - } - - #[test] - fn stddev_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!( - a, - DataType::Int32, - StddevPop, - ScalarValue::from(1.479019945774904_f64) - ) - } - - #[test] - fn stddev_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Stddev::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - Ok(()) - } - #[test] fn stddev_f64_merge_1() -> Result<()> { let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 256fddd9f254..03e8fad8a7f8 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2338,6 +2338,128 @@ select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; statement ok drop table t; +# aggregate stddev f64_1 +statement ok +create table t (c1 double) as values (1), (2); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +0.5 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_2 +statement ok +create table t (c1 double) as values (1.1), (2), (3); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +0.776029781788 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_3 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_4 +statement ok +create table t (c1 double) as values (1.1), (2), (3); + +query RT +select stddev(c1), arrow_typeof(stddev(c1)) from t; +---- +0.950438495292 Float64 + +statement ok +drop table t; + +# aggregate stddev i32 +statement ok +create table t (c1 int) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev u32 +statement ok +create table t (c1 int unsigned) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev f32 +statement ok +create table t (c1 float) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev single_input +statement ok +create table t (c1 double) as values (1); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +0 Float64 + +statement ok +drop table t; + +# aggregate stddev with_nulls +statement ok +create table t (c1 int) as values (1), (null), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.479019945775 Float64 + +statement ok +drop table t; + +# aggregate stddev all_nulls +statement ok +create table t (c1 int) as values (null), (null); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +NULL Float64 + +statement ok +drop table t; + + + # simple_mean query R select mean(c1) from test From 7fd286b1eac8c38477b48e9e3a155bd0d0f63094 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sat, 1 Jun 2024 05:30:16 +0800 Subject: [PATCH 15/35] fix(CLI): can not create external tables with format options (#10739) --- datafusion-cli/src/exec.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index ffe447e79fd7..855d6a7cbbc9 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -21,6 +21,7 @@ use std::collections::HashMap; use std::fs::File; use std::io::prelude::*; use std::io::BufReader; +use std::str::FromStr; use crate::helper::split_from_semicolon; use crate::print_format::PrintFormat; @@ -300,11 +301,13 @@ async fn create_plan( // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + // To support custom formats, treat error as None + let format = FileType::from_str(&cmd.file_type).ok(); register_object_store_and_config_extensions( ctx, &cmd.location, &cmd.options, - None, + format, ) .await?; } @@ -398,11 +401,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + let format = FileType::from_str(&cmd.file_type).ok(); register_object_store_and_config_extensions( &ctx, &cmd.location, &cmd.options, - None, + format, ) .await?; } else { @@ -601,4 +605,16 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn create_external_table_format_option() -> Result<()> { + let location = "path/to/file.cvs"; + + // Test with format options + let sql = + format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')"); + create_external_table_test(location, &sql).await.unwrap(); + + Ok(()) + } } From 68f84761d2aa8608b34981f5279a7685bf896dba Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Sat, 1 Jun 2024 00:32:31 +0300 Subject: [PATCH 16/35] Add support for `AggregateExpr`, `WindowExpr` rewrite. (#10742) * Initial commit * Minor changes * Minor changes * Update comments --- .../physical-expr-common/src/aggregate/mod.rs | 34 ++++++++++++++++ .../physical-expr/src/aggregate/count.rs | 15 +++++++ datafusion/physical-expr/src/lib.rs | 4 +- .../src/window/sliding_aggregate.rs | 25 ++++++++++++ .../physical-expr/src/window/window_expr.rs | 39 +++++++++++++++++++ 5 files changed, 116 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 503e2d8f9758..78c7d40b87f5 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -185,6 +185,40 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { fn create_sliding_accumulator(&self) -> Result> { not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") } + + /// Returns all expressions used in the [`AggregateExpr`]. + /// These expressions are (1)function arguments, (2) order by expressions. + fn all_expressions(&self) -> AggregatePhysicalExpressions { + let args = self.expressions(); + let order_bys = self.order_bys().unwrap_or(&[]); + let order_by_exprs = order_bys + .iter() + .map(|sort_expr| sort_expr.expr.clone()) + .collect::>(); + AggregatePhysicalExpressions { + args, + order_by_exprs, + } + } + + /// Rewrites [`AggregateExpr`], with new expressions given. The argument should be consistent + /// with the return value of the [`AggregateExpr::all_expressions`] method. + /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. + fn with_new_expressions( + &self, + _args: Vec>, + _order_by_exprs: Vec>, + ) -> Option> { + None + } +} + +/// Stores the physical expressions used inside the `AggregateExpr`. +pub struct AggregatePhysicalExpressions { + /// Aggregate function arguments + pub args: Vec>, + /// Order by expressions + pub order_by_exprs: Vec>, } /// Physical aggregate expression of a UDAF. diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index e3660221e61a..aad18a82ab87 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -260,6 +260,21 @@ impl AggregateExpr for Count { // instantiate specialized accumulator Ok(Box::new(CountGroupsAccumulator::new())) } + + fn with_new_expressions( + &self, + args: Vec>, + order_by_exprs: Vec>, + ) -> Option> { + debug_assert_eq!(self.exprs.len(), args.len()); + debug_assert!(order_by_exprs.is_empty()); + Some(Arc::new(Count { + name: self.name.clone(), + data_type: self.data_type.clone(), + nullable: self.nullable, + exprs: args, + })) + } } impl PartialEq for Count { diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 1bdf082b2eaf..72f5f2d50cb8 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -41,7 +41,9 @@ pub mod execution_props { pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; -pub use datafusion_physical_expr_common::aggregate::AggregateExpr; +pub use datafusion_physical_expr_common::aggregate::{ + AggregateExpr, AggregatePhysicalExpressions, +}; pub use equivalence::EquivalenceProperties; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 1494129cf897..961f0884dd87 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -141,6 +141,31 @@ impl WindowExpr for SlidingAggregateWindowExpr { fn uses_bounded_memory(&self) -> bool { !self.window_frame.end_bound.is_unbounded() } + + fn with_new_expressions( + &self, + args: Vec>, + partition_bys: Vec>, + order_by_exprs: Vec>, + ) -> Option> { + debug_assert_eq!(self.order_by.len(), order_by_exprs.len()); + + let new_order_by = self + .order_by + .iter() + .zip(order_by_exprs) + .map(|(req, new_expr)| PhysicalSortExpr { + expr: new_expr, + options: req.options, + }) + .collect::>(); + Some(Arc::new(SlidingAggregateWindowExpr { + aggregate: self.aggregate.with_new_expressions(args, vec![])?, + partition_by: partition_bys, + order_by: new_order_by, + window_frame: self.window_frame.clone(), + })) + } } impl AggregateWindowExpr for SlidingAggregateWindowExpr { diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index dd9514c69a45..065371d9e43e 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -128,6 +128,45 @@ pub trait WindowExpr: Send + Sync + Debug { /// Get the reverse expression of this [WindowExpr]. fn get_reverse_expr(&self) -> Option>; + + /// Returns all expressions used in the [`WindowExpr`]. + /// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions. + fn all_expressions(&self) -> WindowPhysicalExpressions { + let args = self.expressions(); + let partition_by_exprs = self.partition_by().to_vec(); + let order_by_exprs = self + .order_by() + .iter() + .map(|sort_expr| sort_expr.expr.clone()) + .collect::>(); + WindowPhysicalExpressions { + args, + partition_by_exprs, + order_by_exprs, + } + } + + /// Rewrites [`WindowExpr`], with new expressions given. The argument should be consistent + /// with the return value of the [`WindowExpr::all_expressions`] method. + /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. + fn with_new_expressions( + &self, + _args: Vec>, + _partition_bys: Vec>, + _order_by_exprs: Vec>, + ) -> Option> { + None + } +} + +/// Stores the physical expressions used inside the `WindowExpr`. +pub struct WindowPhysicalExpressions { + /// Window function arguments + pub args: Vec>, + /// PARTITION BY expressions + pub partition_by_exprs: Vec>, + /// ORDER BY expressions + pub order_by_exprs: Vec>, } /// Extension trait that adds common functionality to [`AggregateWindowExpr`]s From d6ddd23795222672055e0b737c20bc1fc19e7dd3 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 31 May 2024 14:32:56 -0700 Subject: [PATCH 17/35] Fix SMJ Left Anti Join when the join filter is set (#10724) * Fix: Sort Merge Join crashes on TPCH Q21 * Fix LeftAnti SMJ join when the join filter is set * rm dbg --- .../src/joins/sort_merge_join.rs | 249 ++++++++++++++---- .../test_files/sort_merge_join.slt | 121 +++++++-- 2 files changed, 306 insertions(+), 64 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index ec83fe3f2af8..143a726d31b1 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -487,7 +487,6 @@ struct StreamedBatch { /// The join key arrays of streamed batch which are used to compare with buffered batches /// and to produce output. They are produced by evaluating `on` expressions. pub join_arrays: Vec, - /// Chunks of indices from buffered side (may be nulls) joined to streamed pub output_indices: Vec, /// Index of currently scanned batch from buffered data @@ -1021,6 +1020,15 @@ impl SMJStream { join_streamed = true; join_buffered = true; }; + + if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; + join_buffered = join_streamed; + } } Ordering::Greater => { if matches!(self.join_type, JoinType::Full) { @@ -1181,7 +1189,10 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!(self.join_type, JoinType::LeftSemi) { + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti + ) { // unwrap is safe here as we check is_some on top of if statement let buffered_columns = get_buffered_columns( &self.buffered_data, @@ -1228,7 +1239,15 @@ impl SMJStream { datafusion_common::cast::as_boolean_array(&filter_result)?; let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask(self.join_type, streamed_indices, mask); + get_filtered_join_mask( + self.join_type, + streamed_indices, + mask, + &self.streamed_batch.join_filter_matched_idxs, + &self.buffered_data.scanning_batch_idx, + &self.buffered_data.batches.len(), + ); + if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { mask = &filtered_join_mask.0; self.streamed_batch @@ -1419,51 +1438,87 @@ fn get_buffered_columns( .collect::, ArrowError>>() } -// Calculate join filter bit mask considering join type specifics -// `streamed_indices` - array of streamed datasource JOINED row indices -// `mask` - array booleans representing computed join filter expression eval result: -// true = the row index matches the join filter -// false = the row index doesn't match the join filter -// `streamed_indices` have the same length as `mask` +/// Calculate join filter bit mask considering join type specifics +/// `streamed_indices` - array of streamed datasource JOINED row indices +/// `mask` - array booleans representing computed join filter expression eval result: +/// true = the row index matches the join filter +/// false = the row index doesn't match the join filter +/// `streamed_indices` have the same length as `mask` +/// `matched_indices` array of streaming indices that already has a join filter match +/// `scanning_batch_idx` current buffered batch +/// `buffered_batches_len` how many batches are in buffered data fn get_filtered_join_mask( join_type: JoinType, streamed_indices: UInt64Array, mask: &BooleanArray, + matched_indices: &HashSet, + scanning_buffered_batch_idx: &usize, + buffered_batches_len: &usize, ) -> Option<(BooleanArray, Vec)> { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - if matches!(join_type, JoinType::LeftSemi) { - // have we seen a filter match for a streaming index before - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - if mask.value(i) && !seen_as_true { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_indices.value(i)); - } else { - corrected_mask.append_value(false); + let mut seen_as_true: bool = false; + let streamed_indices_length = streamed_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(streamed_indices_length); + + let mut filter_matched_indices: Vec = vec![]; + + #[allow(clippy::needless_range_loop)] + match join_type { + // for LeftSemi Join the filter mask should be calculated in its own way: + // if we find at least one matching row for specific streaming index + // we don't need to check any others for the same index + JoinType::LeftSemi => { + // have we seen a filter match for a streaming index before + for i in 0..streamed_indices_length { + // LeftSemi respects only first true values for specific streaming index, + // others true values for the same index must be false + if mask.value(i) && !seen_as_true { + seen_as_true = true; + corrected_mask.append_value(true); + filter_matched_indices.push(streamed_indices.value(i)); + } else { + corrected_mask.append_value(false); + } + + // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if i < streamed_indices_length - 1 + && streamed_indices.value(i) != streamed_indices.value(i + 1) + { + seen_as_true = false; + } } + Some((corrected_mask.finish(), filter_matched_indices)) + } + // LeftAnti semantics: return true if for every x in the collection, p(x) is false. + // the true(if any) flag needs to be set only once per streaming index + // to prevent duplicates in the output + JoinType::LeftAnti => { + // have we seen a filter match for a streaming index before + for i in 0..streamed_indices_length { + if mask.value(i) && !seen_as_true { + seen_as_true = true; + filter_matched_indices.push(streamed_indices.value(i)); + } - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_indices.value(i) != streamed_indices.value(i + 1) - { - seen_as_true = false; + // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if (i < streamed_indices_length - 1 + && streamed_indices.value(i) != streamed_indices.value(i + 1)) + || (i == streamed_indices_length - 1 + && *scanning_buffered_batch_idx == buffered_batches_len - 1) + { + corrected_mask.append_value( + !matched_indices.contains(&streamed_indices.value(i)) + && !seen_as_true, + ); + seen_as_true = false; + } else { + corrected_mask.append_value(false); + } } + + Some((corrected_mask.finish(), filter_matched_indices)) } - Some((corrected_mask.finish(), filter_matched_indices)) - } else { - None + _ => None, } } @@ -1711,8 +1766,9 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::{BooleanArray, UInt64Array}; + use hashbrown::HashSet; - use datafusion_common::JoinType::LeftSemi; + use datafusion_common::JoinType::{LeftAnti, LeftSemi}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -2754,7 +2810,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]) + &BooleanArray::from(vec![true, true, false, false]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) ); @@ -2763,7 +2822,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, true]) + &BooleanArray::from(vec![true, true]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![true, true]), vec![0, 1])) ); @@ -2772,7 +2834,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]) + &BooleanArray::from(vec![false, true]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![false, true]), vec![1])) ); @@ -2781,7 +2846,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]) + &BooleanArray::from(vec![true, false]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![true, false]), vec![0])) ); @@ -2790,7 +2858,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]) + &BooleanArray::from(vec![false, true, true, true, true, true]), + &HashSet::new(), + &0, + &0 ), Some(( BooleanArray::from(vec![false, true, false, true, false, false]), @@ -2802,7 +2873,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]) + &BooleanArray::from(vec![false, false, false, false, false, true]), + &HashSet::new(), + &0, + &0 ), Some(( BooleanArray::from(vec![false, false, false, false, false, true]), @@ -2813,6 +2887,89 @@ mod tests { Ok(()) } + #[tokio::test] + async fn left_anti_join_filtered_mask() -> Result<()> { + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 0, 1, 1]), + &BooleanArray::from(vec![true, true, false, false]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, true]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![false, false]), vec![0, 1])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![false, true]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![true, false]), vec![1])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, false]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![false, true]), vec![0])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, true, true, true, true, true]), + &HashSet::new(), + &0, + &1 + ), + Some(( + BooleanArray::from(vec![false, false, false, false, false, false]), + vec![0, 1] + )) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, false, false, false, false, true]), + &HashSet::new(), + &0, + &1 + ), + Some(( + BooleanArray::from(vec![false, false, true, false, false, false]), + vec![1] + )) + ); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 3a27d9693d00..babb7dc8fd6b 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -378,24 +378,6 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != 11 12 11 13 -#LEFTANTI tests -# returns no rows instead of correct result -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c union all -# select 11 a, 14 b, 4 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - # Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches statement ok set datafusion.execution.batch_size = 1; @@ -431,5 +413,108 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != 11 12 11 13 +#LEFTANTI tests +statement ok +set datafusion.execution.batch_size = 10; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +# Test LEFT ANTI with cross batch data distribution +statement ok +set datafusion.execution.batch_size = 1; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +# return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true; + +statement ok +set datafusion.execution.batch_size = 8192; + From 7638a26979382f98fe9725a75424acf0788ff26a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 1 Jun 2024 09:49:10 +0800 Subject: [PATCH 18/35] Introduce FunctionRegistry dependency to optimize and rewrite rule (#10714) * mv function registry to expr Signed-off-by: jayzhan211 * registry move to config trait Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * rm dependency Signed-off-by: jayzhan211 * fix cli cargo lock Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 160 +++++++++--------- datafusion/core/src/execution/context/mod.rs | 4 + datafusion/execution/src/lib.rs | 7 +- datafusion/expr/src/lib.rs | 1 + .../{execution => expr}/src/registry.rs | 4 +- datafusion/optimizer/Cargo.toml | 1 - datafusion/optimizer/src/optimizer.rs | 5 + .../src/replace_distinct_aggregate.rs | 69 ++------ .../sqllogictest/test_files/distinct_on.slt | 36 ++++ 9 files changed, 147 insertions(+), 140 deletions(-) rename datafusion/{execution => expr}/src/registry.rs (98%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 62b6ad287aa6..6a1ba8aba005 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.21.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" dependencies = [ "gimli", ] @@ -363,9 +363,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.9" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e9eabd7a98fe442131a17c316bd9349c43695e49e730c3c8e12cfb5f4da2693" +checksum = "cd066d0b4ef8ecb03a55319dc13aa6910616d0f44008a045bb1835af830abff5" dependencies = [ "bzip2", "flate2", @@ -387,7 +387,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -708,9 +708,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11" dependencies = [ "addr2line", "cc", @@ -869,9 +869,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" dependencies = [ "jobserver", "libc", @@ -1042,9 +1042,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] @@ -1093,7 +1093,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -1325,7 +1325,6 @@ dependencies = [ "chrono", "datafusion-common", "datafusion-expr", - "datafusion-functions-aggregate", "datafusion-physical-expr", "hashbrown 0.14.5", "indexmap 2.2.6", @@ -1495,9 +1494,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" [[package]] name = "encoding_rs" @@ -1535,9 +1534,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", "windows-sys 0.52.0", @@ -1685,7 +1684,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -1747,9 +1746,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "glob" @@ -1987,9 +1986,9 @@ dependencies = [ [[package]] name = "instant" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" dependencies = [ "cfg-if", "js-sys", @@ -2114,9 +2113,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.154" +version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "libflate" @@ -2150,9 +2149,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" -version = "0.1.37" +version = "0.1.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81eb4061c0582dedea1cbc7aff2240300dd6982e0239d1c99e65c1dbf4a30ba7" +checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" dependencies = [ "cc", "libc", @@ -2170,9 +2169,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "lock_api" @@ -2228,9 +2227,9 @@ checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "mimalloc" -version = "0.1.41" +version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f41a2280ded0da56c8cf898babb86e8f10651a34adcfff190ae9a1159c6908d" +checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" dependencies = [ "libmimalloc-sys", ] @@ -2243,9 +2242,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" dependencies = [ "adler", ] @@ -2289,9 +2288,9 @@ checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" [[package]] name = "num" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" dependencies = [ "num-bigint", "num-complex", @@ -2313,9 +2312,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "num-traits", ] @@ -2348,11 +2347,10 @@ dependencies = [ [[package]] name = "num-rational" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" dependencies = [ - "autocfg", "num-bigint", "num-integer", "num-traits", @@ -2380,9 +2378,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.2" +version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e" dependencies = [ "memchr", ] @@ -2453,9 +2451,9 @@ checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" [[package]] name = "parking_lot" -version = "0.12.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -2532,9 +2530,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", "indexmap 2.2.6", @@ -2595,7 +2593,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -2684,9 +2682,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.82" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" dependencies = [ "unicode-ident", ] @@ -3001,9 +2999,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51f344d206c5e1b010eec27349b815a4805f70a778895959d70b74b9b529b30a" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" @@ -3017,9 +3015,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" [[package]] name = "rustyline" @@ -3121,29 +3119,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.200" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.200" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] name = "serde_json" -version = "1.0.116" +version = "1.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" dependencies = [ "itoa", "ryu", @@ -3271,7 +3269,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3317,7 +3315,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3330,7 +3328,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3352,9 +3350,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.61" +version = "2.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c993ed8ccba56ae856363b1845da7266a7cb78e1d146c8a32d54b45a8b831fc9" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" dependencies = [ "proc-macro2", "quote", @@ -3423,22 +3421,22 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3508,9 +3506,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -3527,13 +3525,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3629,7 +3627,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3674,7 +3672,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] @@ -3828,7 +3826,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", "wasm-bindgen-shared", ] @@ -3862,7 +3860,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4127,14 +4125,14 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.66", ] [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" [[package]] name = "zstd" diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index cb0dfd079169..745eff550fae 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2350,6 +2350,10 @@ impl OptimizerConfig for SessionState { fn options(&self) -> &ConfigOptions { self.config_options() } + + fn function_registry(&self) -> Option<&dyn FunctionRegistry> { + Some(self) + } } /// Create a new task context instance from SessionContext diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index a1a1551c2ca6..2fe0d83b1d1c 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -22,11 +22,16 @@ pub mod config; pub mod disk_manager; pub mod memory_pool; pub mod object_store; -pub mod registry; pub mod runtime_env; mod stream; mod task; +pub mod registry { + pub use datafusion_expr::registry::{ + FunctionRegistry, MemoryFunctionRegistry, SerializerRegistry, + }; +} + pub use disk_manager::DiskManager; pub use registry::FunctionRegistry; pub use stream::{RecordBatchStream, SendableRecordBatchStream}; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d0114a472541..74d6b4149dbe 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -48,6 +48,7 @@ pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod registry; pub mod simplify; pub mod sort_properties; pub mod tree_node; diff --git a/datafusion/execution/src/registry.rs b/datafusion/expr/src/registry.rs similarity index 98% rename from datafusion/execution/src/registry.rs rename to datafusion/expr/src/registry.rs index f3714a11c239..70d0a21a870e 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -17,9 +17,9 @@ //! FunctionRegistry trait +use crate::expr_rewriter::FunctionRewrite; +use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use std::collections::HashMap; use std::{collections::HashSet, sync::Arc}; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 59c0b476c7cf..67d5c9b23b74 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,7 +45,6 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 0501f5b8a40a..998eeb7167ee 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -21,6 +21,7 @@ use std::collections::HashSet; use std::sync::Arc; use chrono::{DateTime, Utc}; +use datafusion_expr::registry::FunctionRegistry; use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; @@ -122,6 +123,10 @@ pub trait OptimizerConfig { fn alias_generator(&self) -> Arc; fn options(&self) -> &ConfigOptions; + + fn function_registry(&self) -> Option<&dyn FunctionRegistry> { + None + } } /// A standalone [`OptimizerConfig`] that can be used independently diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index dcd13c58b919..752e2b200741 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -21,11 +21,11 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{internal_err, Column, Result}; +use datafusion_expr::expr::AggregateFunction; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; -use datafusion_functions_aggregate::first_last::first_value; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -73,7 +73,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { fn rewrite( &self, plan: LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Distinct(Distinct::All(input)) => { @@ -95,9 +95,18 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. - let aggr_expr = select_expr - .into_iter() - .map(|e| first_value(vec![e], false, None, sort_expr.clone(), None)); + let first_value_udaf = + config.function_registry().unwrap().udaf("first_value")?; + let aggr_expr = select_expr.into_iter().map(|e| { + Expr::AggregateFunction(AggregateFunction::new_udf( + first_value_udaf.clone(), + vec![e], + false, + None, + sort_expr.clone(), + None, + )) + }); let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; let group_expr = normalize_cols(on_expr, input.as_ref())?; @@ -163,53 +172,3 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { Some(BottomUp) } } - -#[cfg(test)] -mod tests { - use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; - use crate::test::{assert_optimized_plan_eq, test_table_scan}; - use datafusion_expr::{col, LogicalPlanBuilder}; - use std::sync::Arc; - - #[test] - fn replace_distinct() -> datafusion_common::Result<()> { - let table_scan = test_table_scan().unwrap(); - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? - .distinct()? - .build()?; - - let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan, - expected, - ) - } - - #[test] - fn replace_distinct_on() -> datafusion_common::Result<()> { - let table_scan = test_table_scan().unwrap(); - let plan = LogicalPlanBuilder::from(table_scan) - .distinct_on( - vec![col("a")], - vec![col("b")], - Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), - )? - .build()?; - - let expected = "Projection: first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ - \n Sort: test.a DESC NULLS FIRST\ - \n Aggregate: groupBy=[[test.a]], aggr=[[first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ - \n TableScan: test"; - - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan, - expected, - ) - } -} diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt index beef865bcada..99639d78c309 100644 --- a/datafusion/sqllogictest/test_files/distinct_on.slt +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -143,3 +143,39 @@ LIMIT 3; -25 15295 45 15673 -72 -11122 + +# test distinct on +statement ok +create table t(a int, b int, c int) as values (1, 2, 3); + +statement ok +set datafusion.explain.logical_plan_only = true; + +query TT +explain select distinct on (a) b from t order by a desc, c; +---- +logical_plan +01)Projection: first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST] AS b +02)--Sort: t.a DESC NULLS FIRST +03)----Aggregate: groupBy=[[t.a]], aggr=[[first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST]]] +04)------TableScan: t projection=[a, b, c] + +statement ok +drop table t; + +# test distinct +statement ok +create table t(a int, b int) as values (1, 2); + +statement ok +set datafusion.explain.logical_plan_only = true; + +query TT +explain select distinct a, b from t; +---- +logical_plan +01)Aggregate: groupBy=[[t.a, t.b]], aggr=[[]] +02)--TableScan: t projection=[a, b] + +statement ok +drop table t; From 3777114192de43a5b7b4149f843e802a15f15e13 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Sat, 1 Jun 2024 03:57:30 -0700 Subject: [PATCH 19/35] Minor: Add SMJ to TPCH benchmark usage (#10747) * Fix: Sort Merge Join crashes on TPCH Q21 * Fix LeftAnti SMJ join when the join filter is set * rm dbg * Add SMJ to TPCH benchmark usage --- benchmarks/bench.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 49e65eafac9a..87d0720ccb63 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -66,9 +66,11 @@ compare: Compares results from benchmark runs * Benchmarks ********** all(default): Data/Run/Compare for all benchmarks -tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table +tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join +tpch_smj: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, sort merge join tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory -tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table +tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table, hash join +tpch_smj10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table, sort merge join tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory parquet: Benchmark of parquet reader's filtering speed sort: Benchmark of sorting speed From acd7106fa40fad58f50ae06227971c51073d8f48 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 1 Jun 2024 06:57:40 -0400 Subject: [PATCH 20/35] Minor: Split physical_plan/parquet/mod.rs into smaller modules (#10727) * Minor: Split physical_plan/parquet/mod.rs into smaller modules * doc tweaks * Add object store docs * Apply suggestions from code review Co-authored-by: Ruihang Xia --------- Co-authored-by: Ruihang Xia --- .../datasource/physical_plan/parquet/mod.rs | 376 +----------------- .../physical_plan/parquet/opener.rs | 204 ++++++++++ .../physical_plan/parquet/reader.rs | 140 +++++++ .../physical_plan/parquet/row_groups.rs | 2 +- .../physical_plan/parquet/writer.rs | 80 ++++ 5 files changed, 445 insertions(+), 357 deletions(-) create mode 100644 datafusion/core/src/datasource/physical_plan/parquet/opener.rs create mode 100644 datafusion/core/src/datasource/physical_plan/parquet/reader.rs create mode 100644 datafusion/core/src/datasource/physical_plan/parquet/writer.rs diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ac7c39bbdb94..f0328098b406 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -15,66 +15,55 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading Parquet files +//! [`ParquetExec`] Execution plan for reading Parquet files use std::any::Any; use std::fmt::Debug; -use std::ops::Range; use std::sync::Arc; use crate::datasource::listing::PartitionedFile; -use crate::datasource::physical_plan::file_stream::{ - FileOpenFuture, FileOpener, FileStream, -}; +use crate::datasource::physical_plan::file_stream::FileStream; use crate::datasource::physical_plan::{ parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, - FileMeta, FileScanConfig, + FileScanConfig, }; use crate::{ config::{ConfigOptions, TableParquetOptions}, - datasource::listing::ListingTableUrl, - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_optimizer::pruning::PruningPredicate, physical_plan::{ metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, - Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, }, }; use arrow::datatypes::{DataType, SchemaRef}; -use arrow::error::ArrowError; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalExpr}; -use bytes::Bytes; -use futures::future::BoxFuture; -use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use log::debug; -use object_store::buffered::BufWriter; -use object_store::path::Path; -use object_store::ObjectStore; -use parquet::arrow::arrow_reader::ArrowReaderOptions; -use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use parquet::arrow::{AsyncArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMask}; use parquet::basic::{ConvertedType, LogicalType}; -use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties}; use parquet::schema::types::ColumnDescriptor; -use tokio::task::JoinSet; mod metrics; +mod opener; mod page_filter; +mod reader; mod row_filter; mod row_groups; mod statistics; +mod writer; -use crate::datasource::physical_plan::parquet::row_groups::RowGroupSet; use crate::datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, }; pub use metrics::ParquetFileMetrics; +use opener::ParquetOpener; +pub use reader::{DefaultParquetFileReaderFactory, ParquetFileReaderFactory}; pub use statistics::{RequestedStatistics, StatisticsConverter}; +pub use writer::plan_to_parquet; /// Execution plan for reading one or more Parquet files. /// @@ -201,7 +190,7 @@ pub struct ParquetExec { schema_adapter_factory: Option>, } -/// [`ParquetExecBuilder`]`, builder for [`ParquetExec`]. +/// [`ParquetExecBuilder`], builder for [`ParquetExec`]. /// /// See example on [`ParquetExec`]. pub struct ParquetExecBuilder { @@ -279,7 +268,9 @@ impl ParquetExecBuilder { /// instance using individual I/O operations for the footer and each page. /// /// If a custom `ParquetFileReaderFactory` is provided, then data access - /// operations will be routed to this factory instead of `ObjectStore`. + /// operations will be routed to this factory instead of [`ObjectStore`]. + /// + /// [`ObjectStore`]: object_store::ObjectStore pub fn with_parquet_file_reader_factory( mut self, parquet_file_reader_factory: Arc, @@ -698,175 +689,6 @@ impl ExecutionPlan for ParquetExec { } } -/// Implements [`FileOpener`] for a parquet file -struct ParquetOpener { - partition_index: usize, - projection: Arc<[usize]>, - batch_size: usize, - limit: Option, - predicate: Option>, - pruning_predicate: Option>, - page_pruning_predicate: Option>, - table_schema: SchemaRef, - metadata_size_hint: Option, - metrics: ExecutionPlanMetricsSet, - parquet_file_reader_factory: Arc, - pushdown_filters: bool, - reorder_filters: bool, - enable_page_index: bool, - enable_bloom_filter: bool, - schema_adapter_factory: Arc, -} - -impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> Result { - let file_range = file_meta.range.clone(); - let file_metrics = ParquetFileMetrics::new( - self.partition_index, - file_meta.location().as_ref(), - &self.metrics, - ); - - let reader: Box = - self.parquet_file_reader_factory.create_reader( - self.partition_index, - file_meta, - self.metadata_size_hint, - &self.metrics, - )?; - - let batch_size = self.batch_size; - let projection = self.projection.clone(); - let projected_schema = SchemaRef::from(self.table_schema.project(&projection)?); - let schema_adapter = self.schema_adapter_factory.create(projected_schema); - let predicate = self.predicate.clone(); - let pruning_predicate = self.pruning_predicate.clone(); - let page_pruning_predicate = self.page_pruning_predicate.clone(); - let table_schema = self.table_schema.clone(); - let reorder_predicates = self.reorder_filters; - let pushdown_filters = self.pushdown_filters; - let enable_page_index = should_enable_page_index( - self.enable_page_index, - &self.page_pruning_predicate, - ); - let enable_bloom_filter = self.enable_bloom_filter; - let limit = self.limit; - - Ok(Box::pin(async move { - let options = ArrowReaderOptions::new().with_page_index(enable_page_index); - let mut builder = - ParquetRecordBatchStreamBuilder::new_with_options(reader, options) - .await?; - - let file_schema = builder.schema().clone(); - - let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(&file_schema)?; - // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; - - let mask = ProjectionMask::roots( - builder.parquet_schema(), - adapted_projections.iter().cloned(), - ); - - // Filter pushdown: evaluate predicates during scan - if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { - let row_filter = row_filter::build_row_filter( - &predicate, - &file_schema, - &table_schema, - builder.metadata(), - reorder_predicates, - &file_metrics, - ); - - match row_filter { - Ok(Some(filter)) => { - builder = builder.with_row_filter(filter); - } - Ok(None) => {} - Err(e) => { - debug!( - "Ignoring error building row filter for '{:?}': {}", - predicate, e - ); - } - }; - }; - - // Determine which row groups to actually read. The idea is to skip - // as many row groups as possible based on the metadata and query - let file_metadata = builder.metadata().clone(); - let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); - let rg_metadata = file_metadata.row_groups(); - // track which row groups to actually read - let mut row_groups = RowGroupSet::new(rg_metadata.len()); - // if there is a range restricting what parts of the file to read - if let Some(range) = file_range.as_ref() { - row_groups.prune_by_range(rg_metadata, range); - } - // If there is a predicate that can be evaluated against the metadata - if let Some(predicate) = predicate.as_ref() { - row_groups.prune_by_statistics( - &file_schema, - builder.parquet_schema(), - rg_metadata, - predicate, - &file_metrics, - ); - - if enable_bloom_filter && !row_groups.is_empty() { - row_groups - .prune_by_bloom_filters( - &file_schema, - &mut builder, - predicate, - &file_metrics, - ) - .await; - } - } - - // page index pruning: if all data on individual pages can - // be ruled using page metadata, rows from other columns - // with that range can be skipped as well - if enable_page_index && !row_groups.is_empty() { - if let Some(p) = page_pruning_predicate { - let pruned = p.prune( - &file_schema, - builder.parquet_schema(), - &row_groups, - file_metadata.as_ref(), - &file_metrics, - )?; - if let Some(row_selection) = pruned { - builder = builder.with_row_selection(row_selection); - } - } - } - - if let Some(limit) = limit { - builder = builder.with_limit(limit) - } - - let stream = builder - .with_projection(mask) - .with_batch_size(batch_size) - .with_row_groups(row_groups.indexes()) - .build()?; - - let adapted = stream - .map_err(|e| ArrowError::ExternalError(Box::new(e))) - .map(move |maybe_batch| { - maybe_batch - .and_then(|b| schema_mapping.map_batch(b).map_err(Into::into)) - }); - - Ok(adapted.boxed()) - })) - } -} - fn should_enable_page_index( enable_page_index: bool, page_pruning_predicate: &Option>, @@ -879,168 +701,6 @@ fn should_enable_page_index( .unwrap_or(false) } -/// Interface for reading parquet files. -/// -/// The combined implementations of [`ParquetFileReaderFactory`] and -/// [`AsyncFileReader`] can be used to provide custom data access operations -/// such as pre-cached data, I/O coalescing, etc. -/// -/// See [`DefaultParquetFileReaderFactory`] for a simple implementation. -pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { - /// Provides an `AsyncFileReader` for reading data from a parquet file specified - /// - /// # Arguments - /// * partition_index - Index of the partition (for reporting metrics) - /// * file_meta - The file to be read - /// * metadata_size_hint - If specified, the first IO reads this many bytes from the footer - /// * metrics - Execution metrics - fn create_reader( - &self, - partition_index: usize, - file_meta: FileMeta, - metadata_size_hint: Option, - metrics: &ExecutionPlanMetricsSet, - ) -> Result>; -} - -/// Default implementation of [`ParquetFileReaderFactory`] -/// -/// This implementation: -/// 1. Reads parquet directly from an underlying [`ObjectStore`] instance. -/// 2. Reads the footer and page metadata on demand. -/// 3. Does not cache metadata or coalesce I/O operations. -#[derive(Debug)] -pub struct DefaultParquetFileReaderFactory { - store: Arc, -} - -impl DefaultParquetFileReaderFactory { - /// Create a new `DefaultParquetFileReaderFactory`. - pub fn new(store: Arc) -> Self { - Self { store } - } -} - -/// Implements [`AsyncFileReader`] for a parquet file in object storage. -/// -/// This implementation uses the [`ParquetObjectReader`] to read data from the -/// object store on demand, as required, tracking the number of bytes read. -/// -/// This implementation does not coalesce I/O operations or cache bytes. Such -/// optimizations can be done either at the object store level or by providing a -/// custom implementation of [`ParquetFileReaderFactory`]. -pub(crate) struct ParquetFileReader { - file_metrics: ParquetFileMetrics, - inner: ParquetObjectReader, -} - -impl AsyncFileReader for ParquetFileReader { - fn get_bytes( - &mut self, - range: Range, - ) -> BoxFuture<'_, parquet::errors::Result> { - self.file_metrics.bytes_scanned.add(range.end - range.start); - self.inner.get_bytes(range) - } - - fn get_byte_ranges( - &mut self, - ranges: Vec>, - ) -> BoxFuture<'_, parquet::errors::Result>> - where - Self: Send, - { - let total = ranges.iter().map(|r| r.end - r.start).sum(); - self.file_metrics.bytes_scanned.add(total); - self.inner.get_byte_ranges(ranges) - } - - fn get_metadata( - &mut self, - ) -> BoxFuture<'_, parquet::errors::Result>> { - self.inner.get_metadata() - } -} - -impl ParquetFileReaderFactory for DefaultParquetFileReaderFactory { - fn create_reader( - &self, - partition_index: usize, - file_meta: FileMeta, - metadata_size_hint: Option, - metrics: &ExecutionPlanMetricsSet, - ) -> Result> { - let file_metrics = ParquetFileMetrics::new( - partition_index, - file_meta.location().as_ref(), - metrics, - ); - let store = Arc::clone(&self.store); - let mut inner = ParquetObjectReader::new(store, file_meta.object_meta); - - if let Some(hint) = metadata_size_hint { - inner = inner.with_footer_size_hint(hint) - }; - - Ok(Box::new(ParquetFileReader { - inner, - file_metrics, - })) - } -} - -/// Executes a query and writes the results to a partitioned Parquet file. -pub async fn plan_to_parquet( - task_ctx: Arc, - plan: Arc, - path: impl AsRef, - writer_properties: Option, -) -> Result<()> { - let path = path.as_ref(); - let parsed = ListingTableUrl::parse(path)?; - let object_store_url = parsed.object_store(); - let store = task_ctx.runtime_env().object_store(&object_store_url)?; - let mut join_set = JoinSet::new(); - for i in 0..plan.output_partitioning().partition_count() { - let plan: Arc = plan.clone(); - let filename = format!("{}/part-{i}.parquet", parsed.prefix()); - let file = Path::parse(filename)?; - let propclone = writer_properties.clone(); - - let storeref = store.clone(); - let buf_writer = BufWriter::new(storeref, file.clone()); - let mut stream = plan.execute(i, task_ctx.clone())?; - join_set.spawn(async move { - let mut writer = - AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; - while let Some(next_batch) = stream.next().await { - let batch = next_batch?; - writer.write(&batch).await?; - } - writer - .close() - .await - .map_err(DataFusionError::from) - .map(|_| ()) - }); - } - - while let Some(result) = join_set.join_next().await { - match result { - Ok(res) => res?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - Ok(()) -} - // Convert parquet column schema to arrow data type, and just consider the // decimal data type. pub(crate) fn parquet_to_arrow_decimal_type( @@ -1098,9 +758,13 @@ mod tests { use datafusion_physical_expr::create_physical_expr; use chrono::{TimeZone, Utc}; + use datafusion_physical_plan::ExecutionPlanProperties; + use futures::StreamExt; use object_store::local::LocalFileSystem; + use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; + use parquet::file::properties::WriterProperties; use tempfile::TempDir; use url::Url; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs new file mode 100644 index 000000000000..3aec1e1d2037 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ParquetOpener`] for opening Parquet files + +use crate::datasource::physical_plan::parquet::page_filter::PagePruningPredicate; +use crate::datasource::physical_plan::parquet::row_groups::RowGroupSet; +use crate::datasource::physical_plan::parquet::{row_filter, should_enable_page_index}; +use crate::datasource::physical_plan::{ + FileMeta, FileOpenFuture, FileOpener, ParquetFileMetrics, ParquetFileReaderFactory, +}; +use crate::datasource::schema_adapter::SchemaAdapterFactory; +use crate::physical_optimizer::pruning::PruningPredicate; +use arrow_schema::{ArrowError, SchemaRef}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use futures::{StreamExt, TryStreamExt}; +use log::debug; +use parquet::arrow::arrow_reader::ArrowReaderOptions; +use parquet::arrow::async_reader::AsyncFileReader; +use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use std::sync::Arc; + +/// Implements [`FileOpener`] for a parquet file +pub(super) struct ParquetOpener { + pub partition_index: usize, + pub projection: Arc<[usize]>, + pub batch_size: usize, + pub limit: Option, + pub predicate: Option>, + pub pruning_predicate: Option>, + pub page_pruning_predicate: Option>, + pub table_schema: SchemaRef, + pub metadata_size_hint: Option, + pub metrics: ExecutionPlanMetricsSet, + pub parquet_file_reader_factory: Arc, + pub pushdown_filters: bool, + pub reorder_filters: bool, + pub enable_page_index: bool, + pub enable_bloom_filter: bool, + pub schema_adapter_factory: Arc, +} + +impl FileOpener for ParquetOpener { + fn open(&self, file_meta: FileMeta) -> datafusion_common::Result { + let file_range = file_meta.range.clone(); + let file_metrics = ParquetFileMetrics::new( + self.partition_index, + file_meta.location().as_ref(), + &self.metrics, + ); + + let reader: Box = + self.parquet_file_reader_factory.create_reader( + self.partition_index, + file_meta, + self.metadata_size_hint, + &self.metrics, + )?; + + let batch_size = self.batch_size; + let projection = self.projection.clone(); + let projected_schema = SchemaRef::from(self.table_schema.project(&projection)?); + let schema_adapter = self.schema_adapter_factory.create(projected_schema); + let predicate = self.predicate.clone(); + let pruning_predicate = self.pruning_predicate.clone(); + let page_pruning_predicate = self.page_pruning_predicate.clone(); + let table_schema = self.table_schema.clone(); + let reorder_predicates = self.reorder_filters; + let pushdown_filters = self.pushdown_filters; + let enable_page_index = should_enable_page_index( + self.enable_page_index, + &self.page_pruning_predicate, + ); + let enable_bloom_filter = self.enable_bloom_filter; + let limit = self.limit; + + Ok(Box::pin(async move { + let options = ArrowReaderOptions::new().with_page_index(enable_page_index); + let mut builder = + ParquetRecordBatchStreamBuilder::new_with_options(reader, options) + .await?; + + let file_schema = builder.schema().clone(); + + let (schema_mapping, adapted_projections) = + schema_adapter.map_schema(&file_schema)?; + + let mask = ProjectionMask::roots( + builder.parquet_schema(), + adapted_projections.iter().cloned(), + ); + + // Filter pushdown: evaluate predicates during scan + if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { + let row_filter = row_filter::build_row_filter( + &predicate, + &file_schema, + &table_schema, + builder.metadata(), + reorder_predicates, + &file_metrics, + ); + + match row_filter { + Ok(Some(filter)) => { + builder = builder.with_row_filter(filter); + } + Ok(None) => {} + Err(e) => { + debug!( + "Ignoring error building row filter for '{:?}': {}", + predicate, e + ); + } + }; + }; + + // Determine which row groups to actually read. The idea is to skip + // as many row groups as possible based on the metadata and query + let file_metadata = builder.metadata().clone(); + let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); + let rg_metadata = file_metadata.row_groups(); + // track which row groups to actually read + let mut row_groups = RowGroupSet::new(rg_metadata.len()); + // if there is a range restricting what parts of the file to read + if let Some(range) = file_range.as_ref() { + row_groups.prune_by_range(rg_metadata, range); + } + // If there is a predicate that can be evaluated against the metadata + if let Some(predicate) = predicate.as_ref() { + row_groups.prune_by_statistics( + &file_schema, + builder.parquet_schema(), + rg_metadata, + predicate, + &file_metrics, + ); + + if enable_bloom_filter && !row_groups.is_empty() { + row_groups + .prune_by_bloom_filters( + &file_schema, + &mut builder, + predicate, + &file_metrics, + ) + .await; + } + } + + // page index pruning: if all data on individual pages can + // be ruled using page metadata, rows from other columns + // with that range can be skipped as well + if enable_page_index && !row_groups.is_empty() { + if let Some(p) = page_pruning_predicate { + let pruned = p.prune( + &file_schema, + builder.parquet_schema(), + &row_groups, + file_metadata.as_ref(), + &file_metrics, + )?; + if let Some(row_selection) = pruned { + builder = builder.with_row_selection(row_selection); + } + } + } + + if let Some(limit) = limit { + builder = builder.with_limit(limit) + } + + let stream = builder + .with_projection(mask) + .with_batch_size(batch_size) + .with_row_groups(row_groups.indexes()) + .build()?; + + let adapted = stream + .map_err(|e| ArrowError::ExternalError(Box::new(e))) + .map(move |maybe_batch| { + maybe_batch + .and_then(|b| schema_mapping.map_batch(b).map_err(Into::into)) + }); + + Ok(adapted.boxed()) + })) + } +} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/reader.rs b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs new file mode 100644 index 000000000000..265fb9d570cc --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ParquetFileReaderFactory`] and [`DefaultParquetFileReaderFactory`] for +//! creating parquet file readers + +use crate::datasource::physical_plan::{FileMeta, ParquetFileMetrics}; +use bytes::Bytes; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use futures::future::BoxFuture; +use object_store::ObjectStore; +use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; +use parquet::file::metadata::ParquetMetaData; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + +/// Interface for reading parquet files. +/// +/// The combined implementations of [`ParquetFileReaderFactory`] and +/// [`AsyncFileReader`] can be used to provide custom data access operations +/// such as pre-cached data, I/O coalescing, etc. +/// +/// See [`DefaultParquetFileReaderFactory`] for a simple implementation. +pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { + /// Provides an `AsyncFileReader` for reading data from a parquet file specified + /// + /// # Arguments + /// * partition_index - Index of the partition (for reporting metrics) + /// * file_meta - The file to be read + /// * metadata_size_hint - If specified, the first IO reads this many bytes from the footer + /// * metrics - Execution metrics + fn create_reader( + &self, + partition_index: usize, + file_meta: FileMeta, + metadata_size_hint: Option, + metrics: &ExecutionPlanMetricsSet, + ) -> datafusion_common::Result>; +} + +/// Default implementation of [`ParquetFileReaderFactory`] +/// +/// This implementation: +/// 1. Reads parquet directly from an underlying [`ObjectStore`] instance. +/// 2. Reads the footer and page metadata on demand. +/// 3. Does not cache metadata or coalesce I/O operations. +#[derive(Debug)] +pub struct DefaultParquetFileReaderFactory { + store: Arc, +} + +impl DefaultParquetFileReaderFactory { + /// Create a new `DefaultParquetFileReaderFactory`. + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +/// Implements [`AsyncFileReader`] for a parquet file in object storage. +/// +/// This implementation uses the [`ParquetObjectReader`] to read data from the +/// object store on demand, as required, tracking the number of bytes read. +/// +/// This implementation does not coalesce I/O operations or cache bytes. Such +/// optimizations can be done either at the object store level or by providing a +/// custom implementation of [`ParquetFileReaderFactory`]. +pub(crate) struct ParquetFileReader { + pub file_metrics: ParquetFileMetrics, + pub inner: ParquetObjectReader, +} + +impl AsyncFileReader for ParquetFileReader { + fn get_bytes( + &mut self, + range: Range, + ) -> BoxFuture<'_, parquet::errors::Result> { + self.file_metrics.bytes_scanned.add(range.end - range.start); + self.inner.get_bytes(range) + } + + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, parquet::errors::Result>> + where + Self: Send, + { + let total = ranges.iter().map(|r| r.end - r.start).sum(); + self.file_metrics.bytes_scanned.add(total); + self.inner.get_byte_ranges(ranges) + } + + fn get_metadata( + &mut self, + ) -> BoxFuture<'_, parquet::errors::Result>> { + self.inner.get_metadata() + } +} + +impl ParquetFileReaderFactory for DefaultParquetFileReaderFactory { + fn create_reader( + &self, + partition_index: usize, + file_meta: FileMeta, + metadata_size_hint: Option, + metrics: &ExecutionPlanMetricsSet, + ) -> datafusion_common::Result> { + let file_metrics = ParquetFileMetrics::new( + partition_index, + file_meta.location().as_ref(), + metrics, + ); + let store = Arc::clone(&self.store); + let mut inner = ParquetObjectReader::new(store, file_meta.object_meta); + + if let Some(hint) = metadata_size_hint { + inner = inner.with_footer_size_hint(hint) + }; + + Ok(Box::new(ParquetFileReader { + inner, + file_metrics, + })) + } +} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 0a0ca4369d27..7dd91d3d4e4b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -417,7 +417,7 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { #[cfg(test)] mod tests { use super::*; - use crate::datasource::physical_plan::parquet::ParquetFileReader; + use crate::datasource::physical_plan::parquet::reader::ParquetFileReader; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/writer.rs b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs new file mode 100644 index 000000000000..0c0c54691068 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::datasource::listing::ListingTableUrl; +use datafusion_common::DataFusionError; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use futures::StreamExt; +use object_store::buffered::BufWriter; +use object_store::path::Path; +use parquet::arrow::AsyncArrowWriter; +use parquet::file::properties::WriterProperties; +use std::sync::Arc; +use tokio::task::JoinSet; + +/// Executes a query and writes the results to a partitioned Parquet file. +pub async fn plan_to_parquet( + task_ctx: Arc, + plan: Arc, + path: impl AsRef, + writer_properties: Option, +) -> datafusion_common::Result<()> { + let path = path.as_ref(); + let parsed = ListingTableUrl::parse(path)?; + let object_store_url = parsed.object_store(); + let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let mut join_set = JoinSet::new(); + for i in 0..plan.output_partitioning().partition_count() { + let plan: Arc = plan.clone(); + let filename = format!("{}/part-{i}.parquet", parsed.prefix()); + let file = Path::parse(filename)?; + let propclone = writer_properties.clone(); + + let storeref = store.clone(); + let buf_writer = BufWriter::new(storeref, file.clone()); + let mut stream = plan.execute(i, task_ctx.clone())?; + join_set.spawn(async move { + let mut writer = + AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; + while let Some(next_batch) = stream.next().await { + let batch = next_batch?; + writer.write(&batch).await?; + } + writer + .close() + .await + .map_err(DataFusionError::from) + .map(|_| ()) + }); + } + + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + Ok(()) +} From 71a99b84627de49033037021cfeea1f2cd29db84 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sat, 1 Jun 2024 16:31:46 -0400 Subject: [PATCH 21/35] minor: consolidate unparser integration tests (#10736) * consolidate unparser integration tests * add license to new files * surpress dead code warnings * run as one integration test binary * add license --- datafusion/sql/tests/cases/mod.rs | 18 + datafusion/sql/tests/cases/plan_to_sql.rs | 290 +++++++++++++ datafusion/sql/tests/common/mod.rs | 227 ++++++++++ datafusion/sql/tests/sql_integration.rs | 477 +--------------------- 4 files changed, 543 insertions(+), 469 deletions(-) create mode 100644 datafusion/sql/tests/cases/mod.rs create mode 100644 datafusion/sql/tests/cases/plan_to_sql.rs create mode 100644 datafusion/sql/tests/common/mod.rs diff --git a/datafusion/sql/tests/cases/mod.rs b/datafusion/sql/tests/cases/mod.rs new file mode 100644 index 000000000000..fc4c59cc88d8 --- /dev/null +++ b/datafusion/sql/tests/cases/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod plan_to_sql; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs new file mode 100644 index 000000000000..1bf441351a97 --- /dev/null +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -0,0 +1,290 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::vec; + +use arrow_schema::*; +use datafusion_common::{DFSchema, Result, TableReference}; +use datafusion_expr::{col, table_scan}; +use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_sql::unparser::dialect::{ + DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, + MySqlDialect as UnparserMySqlDialect, +}; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; + +use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; +use sqlparser::parser::Parser; + +use crate::common::MockContextProvider; + +#[test] +fn roundtrip_expr() { + let tests: Vec<(TableReference, &str, &str)> = vec![ + (TableReference::bare("person"), "age > 35", r#"(age > 35)"#), + ( + TableReference::bare("person"), + "id = '10'", + r#"(id = '10')"#, + ), + ( + TableReference::bare("person"), + "CAST(id AS VARCHAR)", + r#"CAST(id AS VARCHAR)"#, + ), + ( + TableReference::bare("person"), + "SUM((age * 2))", + r#"SUM((age * 2))"#, + ), + ]; + + let roundtrip = |table, sql: &str| -> Result { + let dialect = GenericDialect {}; + let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; + + let context = MockContextProvider::default(); + let schema = context.get_table_source(table)?.schema(); + let df_schema = DFSchema::try_from(schema.as_ref().clone())?; + let sql_to_rel = SqlToRel::new(&context); + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + + let ast = expr_to_sql(&expr)?; + + Ok(format!("{}", ast)) + }; + + for (table, query, expected) in tests { + let actual = roundtrip(table, query).unwrap(); + assert_eq!(actual, expected); + } +} + +#[test] +fn roundtrip_statement() -> Result<()> { + let tests: Vec<&str> = vec![ + "select ta.j1_id from j1 ta;", + "select ta.j1_id from j1 ta order by ta.j1_id;", + "select * from j1 ta order by ta.j1_id, ta.j1_string desc;", + "select * from j1 limit 10;", + "select ta.j1_id from j1 ta where ta.j1_id > 1;", + "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", + "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", + "select * from (select id, first_name from person)", + "select * from (select id, first_name from (select * from person))", + "select id, count(*) as cnt from (select id from person) group by id", + "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", + "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))", + r#"select "First Name" from person_quoted_cols"#, + "select DISTINCT id FROM person", + "select DISTINCT on (id) id, first_name from person", + "select DISTINCT on (id) id, first_name from person order by id", + r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, + "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", + "select id, count(*), first_name from person group by first_name, id", + "select id, sum(age), first_name from person group by first_name, id", + "select id, count(*), first_name + from person + where id!=3 and first_name=='test' + group by first_name, id + having count(*)>5 and count(*)<10 + order by count(*)", + r#"select id, count("First Name") as count_first_name, "Last Name" + from person_quoted_cols + where id!=3 and "First Name"=='test' + group by "Last Name", id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, + r#"select p.id, count("First Name") as count_first_name, + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + from (select id, "First Name", "Last Name" from person_quoted_cols) qp + inner join (select * from person) p + on p.id = qp.id + where p.id!=3 and "First Name"=='test' and qp.id in + (select id from (select id, count(*) from person group by id having count(*) > 0)) + group by "Last Name", p.id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, + r#"SELECT j1_string as string FROM j1 + UNION ALL + SELECT j2_string as string FROM j2"#, + r#"SELECT j1_string as string FROM j1 + UNION ALL + SELECT j2_string as string FROM j2 + ORDER BY string DESC + LIMIT 10"# + ]; + + // For each test sql string, we transform as follows: + // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) + // We test not that s1==s2, but rather p1==p2. This ensures that unparser preserves the logical + // query information of the original sql string and disreguards other differences in syntax or + // quoting. + for query in tests { + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let roundtrip_statement = plan_to_sql(&plan)?; + + let actual = format!("{}", &roundtrip_statement); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + let plan_roundtrip = sql_to_rel + .sql_statement_to_plan(roundtrip_statement.clone()) + .unwrap(); + + assert_eq!(plan, plan_roundtrip); + } + + Ok(()) +} + +#[test] +fn roundtrip_crossjoin() -> Result<()> { + let query = "select j1.j1_id, j2.j2_string from j1, j2"; + + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let roundtrip_statement = plan_to_sql(&plan)?; + + let actual = format!("{}", &roundtrip_statement); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + let plan_roundtrip = sql_to_rel + .sql_statement_to_plan(roundtrip_statement.clone()) + .unwrap(); + + let expected = "Projection: j1.j1_id, j2.j2_string\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: j1\ + \n TableScan: j2"; + + assert_eq!(format!("{plan_roundtrip:?}"), expected); + + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect() -> Result<()> { + struct TestStatementWithDialect { + sql: &'static str, + expected: &'static str, + parser_dialect: Box, + unparser_dialect: Box, + } + let tests: Vec = vec![ + TestStatementWithDialect { + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + expected: + "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + ]; + + for query in tests { + let statement = Parser::new(&*query.parser_dialect) + .try_with_sql(query.sql)? + .parse_statement()?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let unparser = Unparser::new(&*query.unparser_dialect); + let roundtrip_statement = unparser.plan_to_sql(&plan)?; + + let actual = format!("{}", &roundtrip_statement); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + assert_eq!(query.expected, actual); + } + + Ok(()) +} + +#[test] +fn test_unnest_logical_plan() -> Result<()> { + let query = "select unnest(struct_col), unnest(array_col), struct_col, array_col from unnest_table"; + + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let expected = "Projection: unnest(unnest_table.struct_col).field1, unnest(unnest_table.struct_col).field2, unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\ + \n Unnest: lists[unnest(unnest_table.array_col)] structs[unnest(unnest_table.struct_col)]\ + \n Projection: unnest_table.struct_col AS unnest(unnest_table.struct_col), unnest_table.array_col AS unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\ + \n TableScan: unnest_table"; + + assert_eq!(format!("{plan:?}"), expected); + + Ok(()) +} + +#[test] +fn test_table_references_in_plan_to_sql() { + fn test(table_name: &str, expected_sql: &str) { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ]); + let plan = table_scan(Some(table_name), &schema, None) + .unwrap() + .project(vec![col("id"), col("value")]) + .unwrap() + .build() + .unwrap(); + let sql = plan_to_sql(&plan).unwrap(); + + assert_eq!(format!("{}", sql), expected_sql) + } + + test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\""); + test("schema.table", "SELECT \"schema\".\"table\".id, \"schema\".\"table\".\"value\" FROM \"schema\".\"table\""); + test( + "table", + "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", + ); +} diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs new file mode 100644 index 000000000000..79de4bc82691 --- /dev/null +++ b/datafusion/sql/tests/common/mod.rs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +use std::collections::HashMap; +use std::{sync::Arc, vec}; + +use arrow_schema::*; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{plan_err, Result, TableReference}; +use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; +use datafusion_sql::planner::ContextProvider; + +#[derive(Default)] +pub(crate) struct MockContextProvider { + options: ConfigOptions, + udfs: HashMap>, + udafs: HashMap>, +} + +impl MockContextProvider { + // Surpressing dead code warning, as this is used in integration test crates + #[allow(dead_code)] + pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions { + &mut self.options + } + + #[allow(dead_code)] + pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self { + self.udfs.insert(udf.name().to_string(), Arc::new(udf)); + self + } +} + +impl ContextProvider for MockContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + let schema = match name.table() { + "test" => Ok(Schema::new(vec![ + Field::new("t_date32", DataType::Date32, false), + Field::new("t_date64", DataType::Date64, false), + ])), + "j1" => Ok(Schema::new(vec![ + Field::new("j1_id", DataType::Int32, false), + Field::new("j1_string", DataType::Utf8, false), + ])), + "j2" => Ok(Schema::new(vec![ + Field::new("j2_id", DataType::Int32, false), + Field::new("j2_string", DataType::Utf8, false), + ])), + "j3" => Ok(Schema::new(vec![ + Field::new("j3_id", DataType::Int32, false), + Field::new("j3_string", DataType::Utf8, false), + ])), + "test_decimal" => Ok(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("price", DataType::Decimal128(10, 2), false), + ])), + "person" => Ok(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Float64, false), + Field::new( + "birth_date", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("😀", DataType::Int32, false), + ])), + "person_quoted_cols" => Ok(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("First Name", DataType::Utf8, false), + Field::new("Last Name", DataType::Utf8, false), + Field::new("Age", DataType::Int32, false), + Field::new("State", DataType::Utf8, false), + Field::new("Salary", DataType::Float64, false), + Field::new( + "Birth Date", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("😀", DataType::Int32, false), + ])), + "orders" => Ok(Schema::new(vec![ + Field::new("order_id", DataType::UInt32, false), + Field::new("customer_id", DataType::UInt32, false), + Field::new("o_item_id", DataType::Utf8, false), + Field::new("qty", DataType::Int32, false), + Field::new("price", DataType::Float64, false), + Field::new("delivered", DataType::Boolean, false), + ])), + "array" => Ok(Schema::new(vec![ + Field::new( + "left", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new( + "right", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ])), + "lineitem" => Ok(Schema::new(vec![ + Field::new("l_item_id", DataType::UInt32, false), + Field::new("l_description", DataType::Utf8, false), + Field::new("price", DataType::Float64, false), + ])), + "aggregate_test_100" => Ok(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ])), + "UPPERCASE_test" => Ok(Schema::new(vec![ + Field::new("Id", DataType::UInt32, false), + Field::new("lower", DataType::UInt32, false), + ])), + "unnest_table" => Ok(Schema::new(vec![ + Field::new( + "array_col", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new( + "struct_col", + DataType::Struct(Fields::from(vec![ + Field::new("field1", DataType::Int64, true), + Field::new("field2", DataType::Utf8, true), + ])), + false, + ), + ])), + _ => plan_err!("No table named: {} found", name.table()), + }; + + match schema { + Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))), + Err(e) => Err(e), + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.udfs.get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.udafs.get(name).cloned() + } + + fn get_variable_type(&self, _: &[String]) -> Option { + unimplemented!() + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn create_cte_work_table( + &self, + _name: &str, + schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(EmptyTable::new(schema))) + } + + fn udf_names(&self) -> Vec { + self.udfs.keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.udafs.keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + Vec::new() + } +} + +struct EmptyTable { + table_schema: SchemaRef, +} + +impl EmptyTable { + fn new(table_schema: SchemaRef) -> Self { + Self { table_schema } + } +} + +impl TableSource for EmptyTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.table_schema.clone() + } +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index a7224805f3dd..1f064ea0f543 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -18,35 +18,29 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; -use std::{sync::Arc, vec}; +use std::vec; use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; -use datafusion_common::config::ConfigOptions; +use common::MockContextProvider; use datafusion_common::{ - assert_contains, plan_err, DFSchema, DataFusionError, ParamValues, Result, - ScalarValue, TableReference, + assert_contains, DataFusionError, ParamValues, Result, ScalarValue, }; -use datafusion_expr::{col, table_scan}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, - AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, - Volatility, WindowUDF, + ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{string, unicode}; -use datafusion_sql::unparser::dialect::{ - DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, - MySqlDialect as UnparserMySqlDialect, -}; -use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; use datafusion_sql::{ parser::DFParser, - planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, + planner::{ParserOptions, SqlToRel}, }; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use sqlparser::parser::Parser; + +mod cases; +mod common; #[test] fn test_schema_support() { @@ -2797,184 +2791,6 @@ fn prepare_stmt_replace_params_quick_test( plan } -#[derive(Default)] -struct MockContextProvider { - options: ConfigOptions, - udfs: HashMap>, - udafs: HashMap>, -} - -impl MockContextProvider { - fn options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - - fn with_udf(mut self, udf: ScalarUDF) -> Self { - self.udfs.insert(udf.name().to_string(), Arc::new(udf)); - self - } -} - -impl ContextProvider for MockContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - let schema = match name.table() { - "test" => Ok(Schema::new(vec![ - Field::new("t_date32", DataType::Date32, false), - Field::new("t_date64", DataType::Date64, false), - ])), - "j1" => Ok(Schema::new(vec![ - Field::new("j1_id", DataType::Int32, false), - Field::new("j1_string", DataType::Utf8, false), - ])), - "j2" => Ok(Schema::new(vec![ - Field::new("j2_id", DataType::Int32, false), - Field::new("j2_string", DataType::Utf8, false), - ])), - "j3" => Ok(Schema::new(vec![ - Field::new("j3_id", DataType::Int32, false), - Field::new("j3_string", DataType::Utf8, false), - ])), - "test_decimal" => Ok(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("price", DataType::Decimal128(10, 2), false), - ])), - "person" => Ok(Schema::new(vec![ - Field::new("id", DataType::UInt32, false), - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new("age", DataType::Int32, false), - Field::new("state", DataType::Utf8, false), - Field::new("salary", DataType::Float64, false), - Field::new( - "birth_date", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new("😀", DataType::Int32, false), - ])), - "person_quoted_cols" => Ok(Schema::new(vec![ - Field::new("id", DataType::UInt32, false), - Field::new("First Name", DataType::Utf8, false), - Field::new("Last Name", DataType::Utf8, false), - Field::new("Age", DataType::Int32, false), - Field::new("State", DataType::Utf8, false), - Field::new("Salary", DataType::Float64, false), - Field::new( - "Birth Date", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new("😀", DataType::Int32, false), - ])), - "orders" => Ok(Schema::new(vec![ - Field::new("order_id", DataType::UInt32, false), - Field::new("customer_id", DataType::UInt32, false), - Field::new("o_item_id", DataType::Utf8, false), - Field::new("qty", DataType::Int32, false), - Field::new("price", DataType::Float64, false), - Field::new("delivered", DataType::Boolean, false), - ])), - "array" => Ok(Schema::new(vec![ - Field::new( - "left", - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - false, - ), - Field::new( - "right", - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - false, - ), - ])), - "lineitem" => Ok(Schema::new(vec![ - Field::new("l_item_id", DataType::UInt32, false), - Field::new("l_description", DataType::Utf8, false), - Field::new("price", DataType::Float64, false), - ])), - "aggregate_test_100" => Ok(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - Field::new("c3", DataType::Int8, false), - Field::new("c4", DataType::Int16, false), - Field::new("c5", DataType::Int32, false), - Field::new("c6", DataType::Int64, false), - Field::new("c7", DataType::UInt8, false), - Field::new("c8", DataType::UInt16, false), - Field::new("c9", DataType::UInt32, false), - Field::new("c10", DataType::UInt64, false), - Field::new("c11", DataType::Float32, false), - Field::new("c12", DataType::Float64, false), - Field::new("c13", DataType::Utf8, false), - ])), - "UPPERCASE_test" => Ok(Schema::new(vec![ - Field::new("Id", DataType::UInt32, false), - Field::new("lower", DataType::UInt32, false), - ])), - "unnest_table" => Ok(Schema::new(vec![ - Field::new( - "array_col", - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - false, - ), - Field::new( - "struct_col", - DataType::Struct(Fields::from(vec![ - Field::new("field1", DataType::Int64, true), - Field::new("field2", DataType::Utf8, true), - ])), - false, - ), - ])), - _ => plan_err!("No table named: {} found", name.table()), - }; - - match schema { - Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))), - Err(e) => Err(e), - } - } - - fn get_function_meta(&self, name: &str) -> Option> { - self.udfs.get(name).cloned() - } - - fn get_aggregate_meta(&self, name: &str) -> Option> { - self.udafs.get(name).cloned() - } - - fn get_variable_type(&self, _: &[String]) -> Option { - unimplemented!() - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn create_cte_work_table( - &self, - _name: &str, - schema: SchemaRef, - ) -> Result> { - Ok(Arc::new(EmptyTable::new(schema))) - } - - fn udf_names(&self) -> Vec { - self.udfs.keys().cloned().collect() - } - - fn udaf_names(&self) -> Vec { - self.udafs.keys().cloned().collect() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} - #[test] fn select_partially_qualified_column() { let sql = r#"SELECT person.first_name FROM public.person"#; @@ -4552,283 +4368,6 @@ fn assert_field_not_found(err: DataFusionError, name: &str) { } } -struct EmptyTable { - table_schema: SchemaRef, -} - -impl EmptyTable { - fn new(table_schema: SchemaRef) -> Self { - Self { table_schema } - } -} - -impl TableSource for EmptyTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> SchemaRef { - self.table_schema.clone() - } -} - -#[test] -fn roundtrip_expr() { - let tests: Vec<(TableReference, &str, &str)> = vec![ - (TableReference::bare("person"), "age > 35", r#"(age > 35)"#), - ( - TableReference::bare("person"), - "id = '10'", - r#"(id = '10')"#, - ), - ( - TableReference::bare("person"), - "CAST(id AS VARCHAR)", - r#"CAST(id AS VARCHAR)"#, - ), - ( - TableReference::bare("person"), - "SUM((age * 2))", - r#"SUM((age * 2))"#, - ), - ]; - - let roundtrip = |table, sql: &str| -> Result { - let dialect = GenericDialect {}; - let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; - - let context = MockContextProvider::default(); - let schema = context.get_table_source(table)?.schema(); - let df_schema = DFSchema::try_from(schema.as_ref().clone())?; - let sql_to_rel = SqlToRel::new(&context); - let expr = - sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; - - let ast = expr_to_sql(&expr)?; - - Ok(format!("{}", ast)) - }; - - for (table, query, expected) in tests { - let actual = roundtrip(table, query).unwrap(); - assert_eq!(actual, expected); - } -} - -#[test] -fn roundtrip_statement() -> Result<()> { - let tests: Vec<&str> = vec![ - "select ta.j1_id from j1 ta;", - "select ta.j1_id from j1 ta order by ta.j1_id;", - "select * from j1 ta order by ta.j1_id, ta.j1_string desc;", - "select * from j1 limit 10;", - "select ta.j1_id from j1 ta where ta.j1_id > 1;", - "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", - "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", - "select * from (select id, first_name from person)", - "select * from (select id, first_name from (select * from person))", - "select id, count(*) as cnt from (select id from person) group by id", - "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", - "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))", - r#"select "First Name" from person_quoted_cols"#, - "select DISTINCT id FROM person", - "select DISTINCT on (id) id, first_name from person", - "select DISTINCT on (id) id, first_name from person order by id", - r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, - "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", - "select id, count(*), first_name from person group by first_name, id", - "select id, sum(age), first_name from person group by first_name, id", - "select id, count(*), first_name - from person - where id!=3 and first_name=='test' - group by first_name, id - having count(*)>5 and count(*)<10 - order by count(*)", - r#"select id, count("First Name") as count_first_name, "Last Name" - from person_quoted_cols - where id!=3 and "First Name"=='test' - group by "Last Name", id - having count_first_name>5 and count_first_name<10 - order by count_first_name, "Last Name""#, - r#"select p.id, count("First Name") as count_first_name, - "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) - from (select id, "First Name", "Last Name" from person_quoted_cols) qp - inner join (select * from person) p - on p.id = qp.id - where p.id!=3 and "First Name"=='test' and qp.id in - (select id from (select id, count(*) from person group by id having count(*) > 0)) - group by "Last Name", p.id - having count_first_name>5 and count_first_name<10 - order by count_first_name, "Last Name""#, - r#"SELECT j1_string as string FROM j1 - UNION ALL - SELECT j2_string as string FROM j2"#, - r#"SELECT j1_string as string FROM j1 - UNION ALL - SELECT j2_string as string FROM j2 - ORDER BY string DESC - LIMIT 10"# - ]; - - // For each test sql string, we transform as follows: - // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) - // We test not that s1==s2, but rather p1==p2. This ensures that unparser preserves the logical - // query information of the original sql string and disreguards other differences in syntax or - // quoting. - for query in tests { - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(query)? - .parse_statement()?; - - let context = MockContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - - let roundtrip_statement = plan_to_sql(&plan)?; - - let actual = format!("{}", &roundtrip_statement); - println!("roundtrip sql: {actual}"); - println!("plan {}", plan.display_indent()); - - let plan_roundtrip = sql_to_rel - .sql_statement_to_plan(roundtrip_statement.clone()) - .unwrap(); - - assert_eq!(plan, plan_roundtrip); - } - - Ok(()) -} - -#[test] -fn roundtrip_crossjoin() -> Result<()> { - let query = "select j1.j1_id, j2.j2_string from j1, j2"; - - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(query)? - .parse_statement()?; - - let context = MockContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - - let roundtrip_statement = plan_to_sql(&plan)?; - - let actual = format!("{}", &roundtrip_statement); - println!("roundtrip sql: {actual}"); - println!("plan {}", plan.display_indent()); - - let plan_roundtrip = sql_to_rel - .sql_statement_to_plan(roundtrip_statement.clone()) - .unwrap(); - - let expected = "Projection: j1.j1_id, j2.j2_string\ - \n Inner Join: Filter: Boolean(true)\ - \n TableScan: j1\ - \n TableScan: j2"; - - assert_eq!(format!("{plan_roundtrip:?}"), expected); - - Ok(()) -} - -#[test] -fn roundtrip_statement_with_dialect() -> Result<()> { - struct TestStatementWithDialect { - sql: &'static str, - expected: &'static str, - parser_dialect: Box, - unparser_dialect: Box, - } - let tests: Vec = vec![ - TestStatementWithDialect { - sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", - expected: - "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", - parser_dialect: Box::new(MySqlDialect {}), - unparser_dialect: Box::new(UnparserMySqlDialect {}), - }, - TestStatementWithDialect { - sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", - expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - ]; - - for query in tests { - let statement = Parser::new(&*query.parser_dialect) - .try_with_sql(query.sql)? - .parse_statement()?; - - let context = MockContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - - let unparser = Unparser::new(&*query.unparser_dialect); - let roundtrip_statement = unparser.plan_to_sql(&plan)?; - - let actual = format!("{}", &roundtrip_statement); - println!("roundtrip sql: {actual}"); - println!("plan {}", plan.display_indent()); - - assert_eq!(query.expected, actual); - } - - Ok(()) -} - -#[test] -fn test_unnest_logical_plan() -> Result<()> { - let query = "select unnest(struct_col), unnest(array_col), struct_col, array_col from unnest_table"; - - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(query)? - .parse_statement()?; - - let context = MockContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - - let expected = "Projection: unnest(unnest_table.struct_col).field1, unnest(unnest_table.struct_col).field2, unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\ - \n Unnest: lists[unnest(unnest_table.array_col)] structs[unnest(unnest_table.struct_col)]\ - \n Projection: unnest_table.struct_col AS unnest(unnest_table.struct_col), unnest_table.array_col AS unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\ - \n TableScan: unnest_table"; - - assert_eq!(format!("{plan:?}"), expected); - - Ok(()) -} - -#[test] -fn test_table_references_in_plan_to_sql() { - fn test(table_name: &str, expected_sql: &str) { - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("value", DataType::Utf8, false), - ]); - let plan = table_scan(Some(table_name), &schema, None) - .unwrap() - .project(vec![col("id"), col("value")]) - .unwrap() - .build() - .unwrap(); - let sql = plan_to_sql(&plan).unwrap(); - - assert_eq!(format!("{}", sql), expected_sql) - } - - test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\""); - test("schema.table", "SELECT \"schema\".\"table\".id, \"schema\".\"table\".\"value\" FROM \"schema\".\"table\""); - test( - "table", - "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", - ); -} - #[cfg(test)] #[ctor::ctor] fn init() { From 1db3263497532fda5167386781755463f53c00a4 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:02:04 +0200 Subject: [PATCH 22/35] chore: move aggregate_var to slt (#10750) --- .../physical-expr/src/aggregate/variance.rs | 109 ---------------- .../sqllogictest/test_files/aggregate.slt | 118 ++++++++++++++++++ 2 files changed, 118 insertions(+), 109 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index 989041097730..7ae917409a21 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -337,117 +337,8 @@ mod tests { use super::*; use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::{array::*, datatypes::*}; - #[test] - fn variance_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); - generic_test_op!( - a, - DataType::Float64, - VariancePop, - ScalarValue::from(0.25_f64) - ) - } - - #[test] - fn variance_f64_2() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, VariancePop, ScalarValue::from(2_f64)) - } - - #[test] - fn variance_f64_3() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64)) - } - - #[test] - fn variance_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - Variance, - ScalarValue::from(0.9033333333333333_f64) - ) - } - - #[test] - fn variance_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, VariancePop, ScalarValue::from(2_f64)) - } - - #[test] - fn variance_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, VariancePop, ScalarValue::from(2.0f64)) - } - - #[test] - fn variance_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, VariancePop, ScalarValue::from(2_f64)) - } - - #[test] - fn test_variance_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - - Ok(()) - } - - #[test] - fn variance_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!( - a, - DataType::Int32, - VariancePop, - ScalarValue::from(2.1875_f64) - ) - } - - #[test] - fn variance_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - - Ok(()) - } - #[test] fn variance_f64_merge_1() -> Result<()> { let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 03e8fad8a7f8..df6a37644838 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2458,7 +2458,125 @@ NULL Float64 statement ok drop table t; +# aggregate variance f64_1 +statement ok +create table t (c1 double) as values (1), (2); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +0.25 Float64 + +statement ok +drop table t; + +# aggregate variance f64_2 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate variance f64_3 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select var(c1), arrow_typeof(var(c1)) from t; +---- +2.5 Float64 +statement ok +drop table t; + +# aggregate variance f64_4 +statement ok +create table t (c1 double) as values (1.1), (2), (3); + +query RT +select var(c1), arrow_typeof(var(c1)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + +# aggregate variance i32 +statement ok +create table t (c1 int) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate variance u32 +statement ok +create table t (c1 int unsigned) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate variance f32 +statement ok +create table t (c1 float) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate single input +statement ok +create table t (c1 double) as values (1); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +0 Float64 + +statement ok +drop table t; + +# aggregate i32 with nulls +statement ok +create table t (c1 int) as values (1), (null), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2.1875 Float64 + +statement ok +drop table t; + +# aggregate i32 all nulls +statement ok +create table t (c1 int) as values (null), (null); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +NULL Float64 + +statement ok +drop table t; # simple_mean query R From 59bfe773b61c8bcc83a89502221f849fc12def23 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sun, 2 Jun 2024 20:33:56 -0700 Subject: [PATCH 23/35] fix: fix string repeat for negative numbers (#10760) * fix: fix string repeat for negative numbers * style: run cargo fmt --- datafusion/functions/src/string/repeat.rs | 5 ++++- datafusion/sqllogictest/test_files/expr.slt | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index a70d0a162562..9d122f6101a7 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -88,7 +88,10 @@ fn repeat(args: &[ArrayRef]) -> Result { .iter() .zip(number_array.iter()) .map(|(string, number)| match (string, number) { - (Some(string), Some(number)) => Some(string.repeat(number as usize)), + (Some(string), Some(number)) if number >= 0 => { + Some(string.repeat(number as usize)) + } + (Some(_), Some(_)) => Some("".to_string()), _ => None, }) .collect::>(); diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 2dc00cbc5001..b6477f0b57d0 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -541,6 +541,11 @@ SELECT repeat('Pg', 4) ---- PgPgPgPg +query T +SELECT repeat('Pg', -1) +---- +(empty) + query T SELECT repeat('Pg', CAST(NULL AS INT)) ---- From 888504a8da6d20f9caf3ecb6cd1a6b7d1956e23e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 3 Jun 2024 19:43:30 +0800 Subject: [PATCH 24/35] Introduce Sum UDAF (#10651) * move accumulate Signed-off-by: jayzhan211 * move prim_op Signed-off-by: jayzhan211 * move test to slt Signed-off-by: jayzhan211 * remove sum distinct Signed-off-by: jayzhan211 * move sum aggregate Signed-off-by: jayzhan211 * fix args Signed-off-by: jayzhan211 * add sum Signed-off-by: jayzhan211 * merge fix Signed-off-by: jayzhan211 * fix sum sig Signed-off-by: jayzhan211 * todo: wait ahash merge Signed-off-by: jayzhan211 * rebase Signed-off-by: jayzhan211 * disable ordering req by default Signed-off-by: jayzhan211 * check arg count Signed-off-by: jayzhan211 * rm old workflow Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix failed test Signed-off-by: jayzhan211 * doc and fmt Signed-off-by: jayzhan211 * check udaf first Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix err msg AGAIN Signed-off-by: jayzhan211 * rm sum in builtin test which covered in sql Signed-off-by: jayzhan211 * proto for window with udaf Signed-off-by: jayzhan211 * fix slt Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix err msg Signed-off-by: jayzhan211 * fix exprfn Signed-off-by: jayzhan211 * fix ciy Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * rename first/last to lowercase Signed-off-by: jayzhan211 * skip sum Signed-off-by: jayzhan211 * fix firstvalue Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * add doc Signed-off-by: jayzhan211 * rm has_ordering_req Signed-off-by: jayzhan211 * default hard req Signed-off-by: jayzhan211 * insensitive for sum Signed-off-by: jayzhan211 * cleanup duplicate code Signed-off-by: jayzhan211 * Re-introduce check --------- Signed-off-by: jayzhan211 Co-authored-by: Mustafa Akur --- datafusion-cli/Cargo.lock | 1 + datafusion-examples/examples/advanced_udaf.rs | 5 +- .../examples/simplify_udaf_expression.rs | 6 +- datafusion/core/src/dataframe/mod.rs | 8 +- datafusion/core/src/physical_planner.rs | 5 +- datafusion/core/src/prelude.rs | 1 - datafusion/core/tests/dataframe/mod.rs | 6 +- .../core/tests/fuzz_cases/window_fuzz.rs | 12 +- .../user_defined/user_defined_aggregates.rs | 7 +- .../user_defined_scalar_functions.rs | 4 +- .../expr/src/built_in_window_function.rs | 4 +- datafusion/expr/src/expr.rs | 18 +- datafusion/expr/src/expr_fn.rs | 2 + datafusion/expr/src/expr_schema.rs | 39 +- datafusion/expr/src/function.rs | 6 +- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 82 +++- datafusion/functions-aggregate/Cargo.toml | 1 + .../functions-aggregate/src/first_last.rs | 7 +- datafusion/functions-aggregate/src/lib.rs | 8 + datafusion/functions-aggregate/src/sum.rs | 436 ++++++++++++++++++ .../optimizer/src/analyzer/type_coercion.rs | 18 +- .../simplify_expressions/expr_simplifier.rs | 5 +- .../src/single_distinct_to_groupby.rs | 5 +- .../physical-expr-common/src/aggregate/mod.rs | 120 +++-- .../src/aggregate/utils.rs | 3 +- .../physical-expr/src/aggregate/build_in.rs | 102 +--- datafusion/physical-plan/src/windows/mod.rs | 46 +- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 13 + datafusion/proto/src/generated/prost.rs | 5 +- .../proto/src/physical_plan/from_proto.rs | 62 ++- .../proto/src/physical_plan/to_proto.rs | 32 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- .../tests/cases/roundtrip_physical_plan.rs | 21 +- datafusion/sql/src/expr/function.rs | 34 +- .../sqllogictest/test_files/aggregate.slt | 12 +- datafusion/sqllogictest/test_files/order.slt | 2 +- .../test_files/sort_merge_join.slt | 1 + datafusion/sqllogictest/test_files/unnest.slt | 4 +- datafusion/sqllogictest/test_files/window.slt | 36 +- 41 files changed, 888 insertions(+), 299 deletions(-) create mode 100644 datafusion/functions-aggregate/src/sum.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 6a1ba8aba005..304058650164 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1287,6 +1287,7 @@ dependencies = [ name = "datafusion-functions-aggregate" version = "38.0.0" dependencies = [ + "ahash", "arrow", "arrow-schema", "datafusion-common", diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index cf284472212f..2c672a18a738 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -105,7 +105,10 @@ impl AggregateUDFImpl for GeoMeanUdaf { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(GeometricMeanGroupsAccumulator::new())) } } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 08b6bcab0190..d2c8c6a86c7c 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -78,9 +78,13 @@ impl AggregateUDFImpl for BetterAvgUdaf { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { unimplemented!("should not get here"); } + // we override method, to return new expression which would substitute // user defined function call fn simplify(&self) -> Option { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index aac506d48ba9..5b1aef5d2b20 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -53,8 +53,9 @@ use datafusion_expr::{ avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use datafusion_expr::{case, is_null, sum}; +use datafusion_expr::{case, is_null}; use datafusion_functions_aggregate::expr_fn::median; +use datafusion_functions_aggregate::expr_fn::sum; use async_trait::async_trait; @@ -1593,9 +1594,8 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, count_distinct, create_udf, expr, lit, sum, - BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunctionDefinition, + array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction, + ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 5e2e546a86f6..3bc898353224 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2257,9 +2257,8 @@ mod tests { use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; - use datafusion_expr::{ - col, lit, sum, LogicalPlanBuilder, UserDefinedLogicalNodeCore, - }; + use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; fn make_session_state() -> SessionState { diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 0d8d06f49bc3..d82a5a2cc1a1 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -39,7 +39,6 @@ pub use datafusion_expr::{ Expr, }; pub use datafusion_functions::expr_fn::*; -pub use datafusion_functions_aggregate::expr_fn::*; #[cfg(feature = "array_expressions")] pub use datafusion_functions_array::expr_fn::*; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 60e60bb1e3b1..befd98d04302 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,10 +52,10 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr, - ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::expr_fn::sum; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index fe0c408dc114..b85f6376c3f2 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -33,10 +33,12 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::aggregates::coerce_types; +use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; @@ -341,7 +343,7 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![arg.clone()], ), ); @@ -468,6 +470,14 @@ fn get_random_function( let coerced = coerce_types(f, &[dt], &sig).unwrap(); args[0] = cast(a, schema, coerced[0].clone()).unwrap(); } + } else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { + if !args.is_empty() { + // Do type coercion first argument + let a = args[0].clone(); + let dt = a.data_type(schema.as_ref()).unwrap(); + let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap(); + args[0] = cast(a, schema, coerced[0].clone()).unwrap(); + } } (window_fn.clone(), args, fn_name.to_string()) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index d199f04ba781..071db5adf06a 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -142,7 +142,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\", signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]), volatility: Immutable }, fun: \"\" } }(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); + assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); } /// Basic query for with a udaf returning a structure @@ -729,7 +729,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(self.clone())) } } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index df41cab7bf02..2d98b7f80fc5 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -378,8 +378,8 @@ async fn udaf_as_window_func() -> Result<()> { context.register_udaf(my_acc); let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table"; - let expected = r#"Projection: my_table.a, AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + let expected = r#"Projection: my_table.a, my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; let dataframe = context.sql(sql).await.unwrap(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 18a888ae8b2a..3885d70049f3 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -82,8 +82,8 @@ impl BuiltInWindowFunction { Ntile => "NTILE", Lag => "LAG", Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", + FirstValue => "first_value", + LastValue => "last_value", NthValue => "NTH_VALUE", } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 71cf3adddffa..14c64ef8f89d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -754,10 +754,14 @@ impl WindowFunctionDefinition { impl fmt::Display for WindowFunctionDefinition { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), - WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateFunction(fun) => { + std::fmt::Display::fmt(fun, f) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + std::fmt::Display::fmt(fun, f) + } + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Display::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => std::fmt::Display::fmt(fun, f), } } } @@ -2263,7 +2267,11 @@ mod test { let fun = find_df_window_func(name).unwrap(); let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); + if fun.to_string() == "first_value" || fun.to_string() == "last_value" { + assert_eq!(fun.to_string(), name); + } else { + assert_eq!(fun.to_string(), name.to_uppercase()); + } } Ok(()) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8c9d3c7885b0..694911592b5d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -169,6 +169,8 @@ pub fn max(expr: Expr) -> Expr { } /// Create an expression to represent the sum() aggregate function +/// +/// TODO: Remove this function and use `sum` from `datafusion_functions_aggregate::expr_fn` instead pub fn sum(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Sum, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 01c9edff306e..57470db2e0d9 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,8 +21,10 @@ use crate::expr::{ InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::binary::get_result_type; -use crate::type_coercion::functions::data_types_with_scalar_udf; -use crate::{utils, LogicalPlan, Projection, Subquery}; +use crate::type_coercion::functions::{ + data_types_with_aggregate_udf, data_types_with_scalar_udf, +}; +use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ @@ -158,7 +160,25 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) + match fun { + WindowFunctionDefinition::AggregateUDF(udf) => { + let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { + plan_datafusion_err!( + "{} and {}", + err, + utils::generate_signature_error_msg( + fun.name(), + fun.signature().clone(), + &data_types + ) + ) + })?; + Ok(fun.return_type(&new_types)?) + } + _ => { + fun.return_type(&data_types) + } + } } Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { let data_types = args @@ -170,7 +190,18 @@ impl ExprSchemable for Expr { fun.return_type(&data_types) } AggregateFunctionDefinition::UDF(fun) => { - Ok(fun.return_type(&data_types)?) + let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { + plan_datafusion_err!( + "{} and {}", + err, + utils::generate_signature_error_msg( + fun.name(), + fun.signature().clone(), + &data_types + ) + ) + })?; + Ok(fun.return_type(&new_types)?) } } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7f49b03bb2ce..c06f177510e7 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -70,6 +70,9 @@ pub struct AccumulatorArgs<'a> { /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], + /// The name of the aggregate expression + pub name: &'a str, + /// Whether the aggregate function is distinct. /// /// ```sql @@ -82,9 +85,6 @@ pub struct AccumulatorArgs<'a> { /// The number of arguments the aggregate function takes. pub args_num: usize, - - /// The name of the expression - pub name: &'a str, } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 74d6b4149dbe..bbd1d6f654f1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -64,7 +64,7 @@ pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, TryCast, WindowFunctionDefinition, + Like, Sort as SortExpr, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 0274038a36bf..d778203207c9 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -83,6 +83,12 @@ impl std::hash::Hash for AggregateUDF { } } +impl std::fmt::Display for AggregateUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + impl AggregateUDF { /// Create a new AggregateUDF /// @@ -190,8 +196,22 @@ impl AggregateUDF { } /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. - pub fn create_groups_accumulator(&self) -> Result> { - self.inner.create_groups_accumulator() + pub fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.create_groups_accumulator(args) + } + + pub fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.create_sliding_accumulator(args) + } + + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) } /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details. @@ -213,16 +233,8 @@ impl AggregateUDF { /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will /// generate same result with this `AggregateUDF` when iterated in reverse /// order, and `None` if there is no such `AggregateUDF`). - pub fn reverse_udf(&self) -> Option { - match self.inner.reverse_expr() { - ReversedUDAF::NotSupported => None, - ReversedUDAF::Identical => Some(self.clone()), - ReversedUDAF::Reversed(reverse) => Some(Self { inner: reverse }), - } - } - - pub fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) + pub fn reverse_udf(&self) -> ReversedUDAF { + self.inner.reverse_expr() } /// Do the function rewrite @@ -327,7 +339,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// # Arguments: /// 1. `name`: the name of the expression (e.g. AVG, SUM, etc) - /// 2. `value_type`: Aggregate's aggregate's output (returned by [`Self::return_type`]) + /// 2. `value_type`: Aggregate function output returned by [`Self::return_type`] if defined, otherwise + /// it is equivalent to the data type of the first arguments /// 3. `ordering_fields`: the fields used to order the input arguments, if any. /// Empty if no ordering expression is provided. /// @@ -377,7 +390,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// For maximum performance, a [`GroupsAccumulator`] should be /// implemented in addition to [`Accumulator`]. - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } @@ -389,6 +405,19 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { &[] } + /// Sliding accumulator is an alternative accumulator that can be used for + /// window functions. It has retract method to revert the previous update. + /// + /// See [retract_batch] for more details. + /// + /// [retract_batch]: crate::accumulator::Accumulator::retract_batch + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.accumulator(args) + } + /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is /// satisfied by its input. If this is not the case, UDFs with order /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce @@ -451,6 +480,29 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::NotSupported } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDAFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` + /// to ensure the argument was cast to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } } pub enum ReversedUDAF { @@ -459,7 +511,7 @@ pub enum ReversedUDAF { /// The expression does not support reverse calculation, like ArrayAgg NotSupported, /// The expression is different from the original expression - Reversed(Arc), + Reversed(Arc), } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 696bbaece9e6..26630a0352d5 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -38,6 +38,7 @@ path = "src/lib.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ahash = { workspace = true } arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f1cb92045f59..fe4501c14948 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -75,7 +75,8 @@ impl FirstValue { vec![ // TODO: we can introduce more strict signature that only numeric of array types are allowed TypeSignature::ArraySignature(ArrayFunctionSignature::Array), - TypeSignature::Uniform(1, NUMERICS.to_vec()), + TypeSignature::Numeric(1), + TypeSignature::Uniform(1, vec![DataType::Utf8]), ], Volatility::Immutable, ), @@ -159,7 +160,7 @@ impl AggregateUDFImpl for FirstValue { } fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(last_value_udaf().inner()) + datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) } } @@ -462,7 +463,7 @@ impl AggregateUDFImpl for LastValue { } fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(first_value_udaf().inner()) + datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) } } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index e82897e92693..cb8ef65420c2 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,6 +58,7 @@ pub mod macros; pub mod covariance; pub mod first_last; pub mod median; +pub mod sum; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -72,6 +73,7 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::sum::sum; } /// Returns all default aggregate functions @@ -80,6 +82,7 @@ pub fn all_default_aggregate_functions() -> Vec> { first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), + sum::sum_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), ] @@ -110,6 +113,11 @@ mod tests { fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { + // TODO: remove this + // sum is in intermidiate migration state, skip this + if func.name().to_lowercase() == "sum" { + continue; + } assert!( names.insert(func.name().to_string().to_lowercase()), "duplicate function name: {}", diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs new file mode 100644 index 000000000000..b3127726cbbf --- /dev/null +++ b/datafusion/functions-aggregate/src/sum.rs @@ -0,0 +1,436 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators + +use ahash::RandomState; +use datafusion_expr::utils::AggregateOrderSensitivity; +use std::any::Any; +use std::collections::HashSet; + +use arrow::array::Array; +use arrow::array::ArrowNativeTypeOp; +use arrow::array::{ArrowNumericType, AsArray}; +use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; +use arrow::{array::ArrayRef, datatypes::Field}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_physical_expr_common::aggregate::utils::Hashable; + +make_udaf_expr_and_func!( + Sum, + sum, + expression, + "Returns the first value in a group of values.", + sum_udaf +); + +/// Sum only supports a subset of numeric types, instead relying on type coercion +/// +/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive) +/// +/// `args` is [AccumulatorArgs] +/// `helper` is a macro accepting (ArrowPrimitiveType, DataType) +macro_rules! downcast_sum { + ($args:ident, $helper:ident) => { + match $args.data_type { + DataType::UInt64 => $helper!(UInt64Type, $args.data_type), + DataType::Int64 => $helper!(Int64Type, $args.data_type), + DataType::Float64 => $helper!(Float64Type, $args.data_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.data_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.data_type), + _ => { + not_impl_err!("Sum not supported for {}: {}", $args.name, $args.data_type) + } + } + }; +} + +#[derive(Debug)] +pub struct Sum { + signature: Signature, + aliases: Vec, +} + +impl Sum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["sum".to_string()], + } + } +} + +impl Default for Sum { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Sum { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "SUM" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("SUM expects exactly one argument"); + } + + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + + fn coerced_type(data_type: &DataType) -> Result { + match data_type { + DataType::Dictionary(_, v) => coerced_type(v), + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { + Ok(data_type.clone()) + } + dt if dt.is_signed_integer() => Ok(DataType::Int64), + dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), + dt if dt.is_floating() => Ok(DataType::Float64), + _ => exec_err!("Sum not supported for {}", data_type), + } + } + + Ok(vec![coerced_type(&arg_types[0])?]) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Int64 => Ok(DataType::Int64), + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal128(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } + other => { + exec_err!("[return_type] SUM not supported for {}", other) + } + } + } + + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + if args.is_distinct { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) + }; + } + downcast_sum!(args, helper) + } else { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "sum distinct"), + Field::new("item", args.return_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "sum"), + args.return_type.clone(), + true, + )]) + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( + &$dt, + |x, y| *x = x.add_wrapping(y), + ))) + }; + } + downcast_sum!(args, helper) + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } +} + +/// This accumulator computes SUM incrementally +struct SumAccumulator { + sum: Option, + data_type: DataType, +} + +impl std::fmt::Debug for SumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SumAccumulator({})", self.data_type) + } +} + +impl SumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: None, + data_type, + } + } +} + +impl Accumulator for SumAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + if let Some(x) = arrow::compute::sum(values) { + let v = self.sum.get_or_insert(T::Native::usize_as(0)); + *v = v.add_wrapping(x); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.sum, &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// This accumulator incrementally computes sums over a sliding window +/// +/// This is separate from [`SumAccumulator`] as requires additional state +struct SlidingSumAccumulator { + sum: T::Native, + count: u64, + data_type: DataType, +} + +impl std::fmt::Debug for SlidingSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SlidingSumAccumulator({})", self.data_type) + } +} + +impl SlidingSumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: T::Native::usize_as(0), + count: 0, + data_type, + } + } +} + +impl Accumulator for SlidingSumAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?, self.count.into()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as u64; + if let Some(x) = arrow::compute::sum(values) { + self.sum = self.sum.add_wrapping(x) + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let values = states[0].as_primitive::(); + if let Some(x) = arrow::compute::sum(values) { + self.sum = self.sum.add_wrapping(x) + } + if let Some(x) = arrow::compute::sum(states[1].as_primitive::()) { + self.count += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let v = (self.count != 0).then_some(self.sum); + ScalarValue::new_primitive::(v, &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + if let Some(x) = arrow::compute::sum(values) { + self.sum = self.sum.sub_wrapping(x) + } + self.count -= (values.len() - values.null_count()) as u64; + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + +struct DistinctSumAccumulator { + values: HashSet, RandomState>, + data_type: DataType, +} + +impl std::fmt::Debug for DistinctSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctSumAccumulator({})", self.data_type) + } +} + +impl DistinctSumAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + values: HashSet::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for DistinctSumAccumulator { + fn state(&mut self) -> Result> { + // 1. Stores aggregate state in `ScalarValue::List` + // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set + let state_out = { + let distinct_values = self + .values + .iter() + .map(|value| { + ScalarValue::new_primitive::(Some(value.0), &self.data_type) + }) + .collect::>>()?; + + vec![ScalarValue::List(ScalarValue::new_list( + &distinct_values, + &self.data_type, + ))] + }; + Ok(state_out) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + for x in states[0].as_list::().iter().flatten() { + self.update_batch(&[x])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc.add_wrapping(distinct_value.0) + } + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.capacity() * std::mem::size_of::() + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 081a54ac44f6..31dc9028b915 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -430,6 +430,13 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { &fun.signature(), )? } + expr::WindowFunctionDefinition::AggregateUDF(udf) => { + coerce_arguments_for_signature_with_aggregate_udf( + args, + self.schema, + udf, + )? + } _ => args, }; @@ -985,13 +992,10 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, "") - .err() - .unwrap(); - assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.", - err.strip_backtrace() + + let err = Projection::try_new(vec![udaf], empty).err().unwrap(); + assert!( + err.strip_backtrace().starts_with("Error during planning: Error during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed") ); Ok(()) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c87654292a01..024cb7440388 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3804,7 +3804,10 @@ mod tests { unimplemented!("not needed for testing") } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { unimplemented!("not needed for testing") } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 27449c8dd5c4..06d0dee27099 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -259,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { } Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), - args, + mut args, distinct, .. }) => { @@ -267,7 +267,6 @@ impl OptimizerRule for SingleDistinctToGroupBy { if args.len() != 1 { return internal_err!("DISTINCT aggregate should have exactly one argument"); } - let mut args = args; let arg = args.swap_remove(0); if group_fields_set.insert(arg.display_name()?) { @@ -298,7 +297,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .alias(&alias_str), ); Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - udf.clone(), + udf, vec![col(&alias_str)], false, None, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 78c7d40b87f5..2273418c6096 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -19,6 +19,14 @@ pub mod groups_accumulator; pub mod stats; pub mod utils; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::type_coercion::aggregates::check_arg_count; +use datafusion_expr::ReversedUDAF; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, +}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -27,14 +35,8 @@ use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; -use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::type_coercion::aggregates::check_arg_count; +use datafusion_common::exec_err; use datafusion_expr::utils::AggregateOrderSensitivity; -use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, -}; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. @@ -50,6 +52,7 @@ pub fn create_aggregate_expr( is_distinct: bool, ) -> Result> { debug_assert_eq!(sort_exprs.len(), ordering_req.len()); + let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(schema)) @@ -222,7 +225,7 @@ pub struct AggregatePhysicalExpressions { } /// Physical aggregate expression of a UDAF. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, @@ -234,7 +237,9 @@ pub struct AggregateFunctionExpr { sort_exprs: Vec, // The physical order by expressions ordering_req: LexOrdering, + // Whether to ignore null values ignore_nulls: bool, + // fields used for order sensitive aggregation functions ordering_fields: Vec, is_distinct: bool, input_type: DataType, @@ -294,7 +299,18 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.create_accumulator()?; + let args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + name: &self.name, + }; + + let accumulator = self.fun.create_sliding_accumulator(args)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to @@ -367,11 +383,29 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_groups_accumulator(&self) -> Result> { - self.fun.create_groups_accumulator() + let args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + name: &self.name, + }; + self.fun.create_groups_accumulator(args) } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + if self.ordering_req.is_empty() { + return None; + } + + if !self.order_sensitivity().is_insensitive() { + return Some(&self.ordering_req); + } + + None } fn order_sensitivity(&self) -> AggregateOrderSensitivity { @@ -409,37 +443,41 @@ impl AggregateExpr for AggregateFunctionExpr { } fn reverse_expr(&self) -> Option> { - if let Some(reverse_udf) = self.fun.reverse_udf() { - let reverse_ordering_req = reverse_order_bys(&self.ordering_req); - let reverse_sort_exprs = self - .sort_exprs - .iter() - .map(|e| { - if let Expr::Sort(s) = e { - Expr::Sort(s.reverse()) - } else { - // Expects to receive `Expr::Sort`. - unreachable!() - } - }) - .collect::>(); - let mut name = self.name().to_string(); - replace_order_by_clause(&mut name); - replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); - let reverse_aggr = create_aggregate_expr( - &reverse_udf, - &self.args, - &reverse_sort_exprs, - &reverse_ordering_req, - &self.schema, - name, - self.ignore_nulls, - self.is_distinct, - ) - .unwrap(); - return Some(reverse_aggr); + match self.fun.reverse_udf() { + ReversedUDAF::NotSupported => None, + ReversedUDAF::Identical => Some(Arc::new(self.clone())), + ReversedUDAF::Reversed(reverse_udf) => { + let reverse_ordering_req = reverse_order_bys(&self.ordering_req); + let reverse_sort_exprs = self + .sort_exprs + .iter() + .map(|e| { + if let Expr::Sort(s) = e { + Expr::Sort(s.reverse()) + } else { + // Expects to receive `Expr::Sort`. + unreachable!() + } + }) + .collect::>(); + let mut name = self.name().to_string(); + replace_order_by_clause(&mut name); + replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); + let reverse_aggr = create_aggregate_expr( + &reverse_udf, + &self.args, + &reverse_sort_exprs, + &reverse_ordering_req, + &self.schema, + name, + self.ignore_nulls, + self.is_distinct, + ) + .unwrap(); + + Some(reverse_aggr) + } } - None } } diff --git a/datafusion/physical-expr-common/src/aggregate/utils.rs b/datafusion/physical-expr-common/src/aggregate/utils.rs index c59c29a139d8..bcd0d05be054 100644 --- a/datafusion/physical-expr-common/src/aggregate/utils.rs +++ b/datafusion/physical-expr-common/src/aggregate/utils.rs @@ -17,9 +17,10 @@ use std::{any::Any, sync::Arc}; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::ArrowNativeType; use arrow::{ - array::{ArrayRef, ArrowNativeTypeOp, AsArray}, + array::ArrowNativeTypeOp, compute::SortOptions, datatypes::{ DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType, diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index e10008995463..813a394d6943 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -28,14 +28,15 @@ use std::sync::Arc; +use arrow::datatypes::Schema; + +use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_expr::AggregateFunction; + use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; -use arrow::datatypes::Schema; -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; - /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -103,16 +104,9 @@ pub fn create_aggregate_expr( name, data_type, )), - (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - )), - (AggregateFunction::Sum, true) => Arc::new(expressions::DistinctSum::new( - vec![input_phy_exprs[0].clone()], - name, - data_type, - )), + (AggregateFunction::Sum, _) => { + return internal_err!("Builtin Sum will be removed"); + } (AggregateFunction::ApproxDistinct, _) => Arc::new( expressions::ApproxDistinct::new(input_phy_exprs[0].clone(), name, data_type), ), @@ -378,7 +372,7 @@ mod tests { use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, - Max, Min, Stddev, Sum, Variance, + Max, Min, Stddev, Variance, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -689,7 +683,7 @@ mod tests { #[test] fn test_sum_avg_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let funcs = vec![AggregateFunction::Avg]; let data_types = vec![ DataType::UInt32, DataType::UInt64, @@ -712,37 +706,13 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Sum => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - let expect_type = match data_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::UInt64, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => DataType::Int64, - DataType::Float32 | DataType::Float64 => DataType::Float64, - _ => data_type.clone(), - }; - - assert_eq!( - Field::new("c1", expect_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Avg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} + if fun == AggregateFunction::Avg { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ); }; } } @@ -976,44 +946,6 @@ mod tests { Ok(()) } - #[test] - fn test_sum_return_type() -> Result<()> { - let observed = AggregateFunction::Sum.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Sum.return_type(&[DataType::UInt8])?; - assert_eq!(DataType::UInt64, observed); - - let observed = AggregateFunction::Sum.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Sum.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = - AggregateFunction::Sum.return_type(&[DataType::Decimal128(10, 5)])?; - assert_eq!(DataType::Decimal128(20, 5), observed); - - let observed = - AggregateFunction::Sum.return_type(&[DataType::Decimal128(35, 5)])?; - assert_eq!(DataType::Decimal128(38, 5), observed); - - Ok(()) - } - - #[test] - fn test_sum_no_utf8() { - let observed = AggregateFunction::Sum.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - - #[test] - fn test_sum_upcasts() -> Result<()> { - let observed = AggregateFunction::Sum.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::UInt64, observed); - Ok(()) - } - #[test] fn test_count_return_type() -> Result<()> { let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 42c630741cc9..9b392d941ef4 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -31,10 +31,11 @@ use crate::{ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, Column, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Expr; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, - WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, SortExpr, WindowFrame, + WindowFunctionDefinition, WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -70,12 +71,17 @@ pub fn schema_add_window_field( .iter() .map(|f| f.as_ref().clone()) .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - false, - )]); - Ok(Arc::new(Schema::new(window_fields))) + // Skip extending schema for UDAF + if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { + Ok(Arc::new(Schema::new(window_fields))) + } else { + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + false, + )]); + Ok(Arc::new(Schema::new(window_fields))) + } } /// Create a physical expression for window function @@ -118,14 +124,28 @@ pub fn create_window_expr( } WindowFunctionDefinition::AggregateUDF(fun) => { // TODO: Ordering not supported for Window UDFs yet - let sort_exprs = &[]; - let ordering_req = &[]; + // Convert `Vec` into `Vec` + let sort_exprs = order_by + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let field_name = expr.to_string(); + let field_name = field_name.split('@').next().unwrap_or(&field_name); + Expr::Sort(SortExpr { + expr: Box::new(Expr::Column(Column::new( + None::, + field_name, + ))), + asc: !options.descending, + nulls_first: options.nulls_first, + }) + }) + .collect::>(); let aggregate = udaf::create_aggregate_expr( fun.as_ref(), args, - sort_exprs, - ordering_req, + &sort_exprs, + order_by, input_schema, name, ignore_nulls, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c065948d3b17..0408ea91b9fa 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -871,7 +871,7 @@ message PhysicalWindowExprNode { oneof window_function { AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; - // udaf = 3 + string user_defined_aggr_function = 3; } repeated PhysicalExprNode args = 4; repeated PhysicalExprNode partition_by = 5; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7e7a14a5d14d..e07fbba27d3c 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -15965,6 +15965,9 @@ impl serde::Serialize for PhysicalWindowExprNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("builtInFunction", &v)?; } + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => { + struct_ser.serialize_field("userDefinedAggrFunction", v)?; + } } } struct_ser.end() @@ -15989,6 +15992,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "aggrFunction", "built_in_function", "builtInFunction", + "user_defined_aggr_function", + "userDefinedAggrFunction", ]; #[allow(clippy::enum_variant_names)] @@ -16000,6 +16005,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { Name, AggrFunction, BuiltInFunction, + UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16028,6 +16034,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "name" => Ok(GeneratedField::Name), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), + "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16097,6 +16104,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); } + GeneratedField::UserDefinedAggrFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); + } + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedAggrFunction); + } } } Ok(PhysicalWindowExprNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f9138da3ab34..c75cb3615832 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1345,7 +1345,7 @@ pub struct PhysicalWindowExprNode { pub window_frame: ::core::option::Option, #[prost(string, tag = "8")] pub name: ::prost::alloc::string::String, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2")] + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2, 3")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, @@ -1357,9 +1357,10 @@ pub mod physical_window_expr_node { pub enum WindowFunction { #[prost(enumeration = "super::AggregateFunction", tag = "1")] AggrFunction(i32), - /// udaf = 3 #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), + #[prost(string, tag = "3")] + UserDefinedAggrFunction(::prost::alloc::string::String), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index cf935e6b8304..0a91df568a1d 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -145,8 +145,37 @@ pub fn parse_physical_window_expr( ) })?; - let fun: WindowFunctionDefinition = convert_required!(proto.window_function)?; + let fun = if let Some(window_func) = proto.window_function.as_ref() { + match window_func { + protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { + let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| { + proto_error(format!( + "Received an unknown window aggregate function: {n}" + )) + })?; + + WindowFunctionDefinition::AggregateFunction(f.into()) + } + protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { + let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { + proto_error(format!( + "Received an unknown window builtin function: {n}" + )) + })?; + + WindowFunctionDefinition::BuiltInWindowFunction(f.into()) + } + protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { + let agg_udf = registry.udaf(udaf_name)?; + WindowFunctionDefinition::AggregateUDF(agg_udf) + } + } + } else { + return Err(proto_error("Missing required field in protobuf")); + }; + let name = proto.name.clone(); + // TODO: Remove extended_schema if functions are all UDAF let extended_schema = schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; create_window_expr( @@ -383,37 +412,6 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> - for WindowFunctionDefinition -{ - type Error = DataFusionError; - - fn try_from( - expr: &protobuf::physical_window_expr_node::WindowFunction, - ) -> Result { - match expr { - protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { - let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| { - proto_error(format!( - "Received an unknown window aggregate function: {n}" - )) - })?; - - Ok(WindowFunctionDefinition::AggregateFunction(f.into())) - } - protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { - let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { - proto_error(format!( - "Received an unknown window builtin function: {n}" - )) - })?; - - Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) - } - } - } -} - pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, registry: &dyn FunctionRegistry, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3135d0959331..071463614165 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -186,21 +186,29 @@ pub fn serialize_physical_window_expr( } else if let Some(sliding_aggr_window_expr) = expr.downcast_ref::() { - let AggrFn { inner, distinct } = - aggr_expr_to_aggr_fn(sliding_aggr_window_expr.get_aggregate_expr().as_ref())?; + let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr(); + if let Some(a) = aggr_expr.as_any().downcast_ref::() { + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( + a.fun().name().to_string(), + ) + } else { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( + sliding_aggr_window_expr.get_aggregate_expr().as_ref(), + )?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } + if window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } - if window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) } - - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b756d4688dc0..14d72274806d 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,7 +31,8 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; +use datafusion::functions_aggregate::expr_fn::{covar_pop, covar_samp, first_value}; +use datafusion::functions_aggregate::median::median; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -648,6 +649,7 @@ async fn roundtrip_expr_api() -> Result<()> { first_value(vec![lit(1)], false, None, None, None), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), + sum(lit(1)), median(lit(2)), ]; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index df1995f46533..9cf686dbd3d6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use std::vec; use arrow::csv::WriterBuilder; +use datafusion::functions_aggregate::sum::sum_udaf; use prost::Message; use datafusion::arrow::array::ArrayRef; @@ -47,7 +48,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, + NotExpr, NthValue, PhysicalSortExpr, StringAgg, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -296,12 +297,20 @@ fn roundtrip_window() -> Result<()> { WindowFrameBound::Preceding(ScalarValue::Int64(None)), ); + let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; + let sum_expr = udaf::create_aggregate_expr( + &sum_udaf(), + &args, + &[], + &[], + &schema, + "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", + false, + false, + )?; + let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( - Arc::new(Sum::new( - cast(col("a", &schema)?, &schema, DataType::Float64)?, - "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", - DataType::Float64, - )), + sum_expr, &[], &[], Arc::new(window_frame), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1f8492b9ba47..81a9b4b772d0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -297,22 +297,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, name: &str, ) -> Result { - expr::find_df_window_func(name) - // next check user defined aggregates - .or_else(|| { - self.context_provider - .get_aggregate_meta(name) - .map(WindowFunctionDefinition::AggregateUDF) - }) - // next check user defined window functions - .or_else(|| { - self.context_provider - .get_window_meta(name) - .map(WindowFunctionDefinition::WindowUDF) - }) - .ok_or_else(|| { - plan_datafusion_err!("There is no window function named {name}") - }) + // check udaf first + let udaf = self.context_provider.get_aggregate_meta(name); + // Skip first value and last value, since we expect window builtin first/last value not udaf version + if udaf.as_ref().is_some_and(|udaf| { + udaf.name() != "first_value" && udaf.name() != "last_value" + }) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf.unwrap())) + } else { + expr::find_df_window_func(name) + .or_else(|| { + self.context_provider + .get_window_meta(name) + .map(WindowFunctionDefinition::WindowUDF) + }) + .ok_or_else(|| { + plan_datafusion_err!("There is no window function named {name}") + }) + } } fn sql_fn_arg_to_logical_expr( diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index df6a37644838..98e64b025b22 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3559,10 +3559,10 @@ NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y # aggregate_timestamps_sum -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t; -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY tag ORDER BY tag; # aggregate_timestamps_count @@ -3670,10 +3670,10 @@ NULL NULL Row 2 Y # aggregate_timestamps_sum -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT sum(date32), sum(date64) FROM t; -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT tag, sum(date32), sum(date64) FROM t GROUP BY tag ORDER BY tag; # aggregate_timestamps_count @@ -3767,10 +3767,10 @@ select * from t; 21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 Row 3 B # aggregate_times_sum -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY tag ORDER BY tag # aggregate_times_count diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index d7f10537d02a..2678e8cbd1ba 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1131,4 +1131,4 @@ physical_plan 01)SortPreservingMergeExec: [c@0 ASC NULLS LAST] 02)--ProjectionExec: expr=[CAST(inc_col@0 > desc_col@1 AS Int32) as c] 03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true \ No newline at end of file +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index babb7dc8fd6b..ce738c7a6f3e 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -344,6 +344,7 @@ t1 as ( select 11 a, 13 b) select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) ) order by 1, 2; +---- query II select * from ( diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index bdd7e6631c16..8866cd009c32 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -65,7 +65,7 @@ select * from unnest(struct(1,2,3)); ---- 1 2 3 -## Multiple unnest expression in from clause +## Multiple unnest expression in from clause query IIII select * from unnest(struct(1,2,3)),unnest([4,5,6]); ---- @@ -446,7 +446,7 @@ query error DataFusion error: type_coercion\ncaused by\nThis feature is not impl select sum(unnest(generate_series(1,10))); ## TODO: support unnest as a child expr -query error DataFusion error: Internal error: unnest on struct can ony be applied at the root level of select expression +query error DataFusion error: Internal error: unnest on struct can ony be applied at the root level of select expression select arrow_typeof(unnest(column5)) from unnest_table; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index be1517aa75c1..2d5dd439d76d 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1344,16 +1344,16 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 +01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------WindowAggr: windowExpr=[[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] +01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2634,16 +2634,16 @@ EXPLAIN SELECT logical_plan 01)Limit: skip=0, fetch=5 02)--Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -03)----Projection: annotated_data_finite.ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 -04)------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -05)--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +03)----Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] -03)----ProjectionExec: expr=[ts@0 as ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -04)------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIIII @@ -2761,17 +2761,17 @@ logical_plan 01)Projection: first_value1, first_value2, last_value1, last_value2, nth_value1 02)--Limit: skip=0, fetch=5 03)----Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 -04)------Projection: FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS last_value1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS nth_value1, annotated_data_finite.inc_col -05)--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +04)------Projection: first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS first_value1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS first_value2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS last_value1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS nth_value1, annotated_data_finite.inc_col +05)--------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +06)----------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] 07)------------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2, nth_value1@4 as nth_value1] 02)--GlobalLimitExec: skip=0, fetch=5 03)----SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST], preserve_partitioning=[false] -04)------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] -05)--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] +05)--------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIII From 6e5344ae367001dbf70fa2882c2e89eca4a2bbd8 Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Mon, 3 Jun 2024 23:25:16 +0800 Subject: [PATCH 25/35] Extract parquet statistics from timestamps with timezones (#10766) * Fix incorrect statistics read for timestamp columns in parquet --- .../physical_plan/parquet/statistics.rs | 285 ++++++++++++- .../core/tests/parquet/arrow_statistics.rs | 383 ++++++++++++++---- datafusion/core/tests/parquet/mod.rs | 46 ++- 3 files changed, 615 insertions(+), 99 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index ae8395aef6a4..1c20fa7caa14 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -19,7 +19,7 @@ // TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 -use arrow::{array::ArrayRef, datatypes::DataType}; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::TimeUnit}; use arrow_array::{new_empty_array, new_null_array, UInt64Array}; use arrow_schema::{Field, FieldRef, Schema}; use datafusion_common::{ @@ -112,6 +112,26 @@ macro_rules! get_statistic { Some(DataType::UInt64) => { Some(ScalarValue::UInt64(Some((*s.$func()) as u64))) } + Some(DataType::Timestamp(unit, timezone)) => { + Some(match unit { + TimeUnit::Second => ScalarValue::TimestampSecond( + Some(*s.$func()), + timezone.clone(), + ), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond( + Some(*s.$func()), + timezone.clone(), + ), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond( + Some(*s.$func()), + timezone.clone(), + ), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond( + Some(*s.$func()), + timezone.clone(), + ), + }) + } _ => Some(ScalarValue::Int64(Some(*s.$func()))), } } @@ -395,7 +415,8 @@ mod test { use arrow_array::{ new_null_array, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, RecordBatch, StringArray, StructArray, TimestampNanosecondArray, + Int8Array, RecordBatch, StringArray, StructArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; use arrow_schema::{Field, SchemaRef}; use bytes::Bytes; @@ -536,28 +557,209 @@ mod test { } #[test] - #[should_panic( - expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Int64, got TimestampNanosecond(NULL, None)" - )] - // Due to https://github.com/apache/datafusion/issues/8295 fn roundtrip_timestamp() { Test { - input: timestamp_array([ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), + input: timestamp_seconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], None, - // row group 3 + ), + expected_min: timestamp_seconds_array([Some(1), Some(5), None], None), + expected_max: timestamp_seconds_array([Some(3), Some(9), None], None), + } + .run(); + + Test { + input: timestamp_milliseconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], None, + ), + expected_min: timestamp_milliseconds_array([Some(1), Some(5), None], None), + expected_max: timestamp_milliseconds_array([Some(3), Some(9), None], None), + } + .run(); + + Test { + input: timestamp_microseconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], None, + ), + expected_min: timestamp_microseconds_array([Some(1), Some(5), None], None), + expected_max: timestamp_microseconds_array([Some(3), Some(9), None], None), + } + .run(); + + Test { + input: timestamp_nanoseconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], None, - ]), - expected_min: timestamp_array([Some(1), Some(5), None]), - expected_max: timestamp_array([Some(3), Some(9), None]), + ), + expected_min: timestamp_nanoseconds_array([Some(1), Some(5), None], None), + expected_max: timestamp_nanoseconds_array([Some(3), Some(9), None], None), + } + .run() + } + + #[test] + fn roundtrip_timestamp_timezoned() { + Test { + input: timestamp_seconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], + Some("UTC"), + ), + expected_min: timestamp_seconds_array([Some(1), Some(5), None], Some("UTC")), + expected_max: timestamp_seconds_array([Some(3), Some(9), None], Some("UTC")), + } + .run(); + + Test { + input: timestamp_milliseconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], + Some("UTC"), + ), + expected_min: timestamp_milliseconds_array( + [Some(1), Some(5), None], + Some("UTC"), + ), + expected_max: timestamp_milliseconds_array( + [Some(3), Some(9), None], + Some("UTC"), + ), + } + .run(); + + Test { + input: timestamp_microseconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], + Some("UTC"), + ), + expected_min: timestamp_microseconds_array( + [Some(1), Some(5), None], + Some("UTC"), + ), + expected_max: timestamp_microseconds_array( + [Some(3), Some(9), None], + Some("UTC"), + ), + } + .run(); + + Test { + input: timestamp_nanoseconds_array( + [ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ], + Some("UTC"), + ), + expected_min: timestamp_nanoseconds_array( + [Some(1), Some(5), None], + Some("UTC"), + ), + expected_max: timestamp_nanoseconds_array( + [Some(3), Some(9), None], + Some("UTC"), + ), } .run() } @@ -914,8 +1116,8 @@ mod test { // File has no min/max for timestamp_col .with_column(ExpectedColumn { name: "timestamp_col", - expected_min: timestamp_array([None]), - expected_max: timestamp_array([None]), + expected_min: timestamp_nanoseconds_array([None], None), + expected_max: timestamp_nanoseconds_array([None], None), }) .with_column(ExpectedColumn { name: "year", @@ -1135,9 +1337,48 @@ mod test { Arc::new(array) } - fn timestamp_array(input: impl IntoIterator>) -> ArrayRef { + fn timestamp_seconds_array( + input: impl IntoIterator>, + timzezone: Option<&str>, + ) -> ArrayRef { + let array: TimestampSecondArray = input.into_iter().collect(); + match timzezone { + Some(tz) => Arc::new(array.with_timezone(tz)), + None => Arc::new(array), + } + } + + fn timestamp_milliseconds_array( + input: impl IntoIterator>, + timzezone: Option<&str>, + ) -> ArrayRef { + let array: TimestampMillisecondArray = input.into_iter().collect(); + match timzezone { + Some(tz) => Arc::new(array.with_timezone(tz)), + None => Arc::new(array), + } + } + + fn timestamp_microseconds_array( + input: impl IntoIterator>, + timzezone: Option<&str>, + ) -> ArrayRef { + let array: TimestampMicrosecondArray = input.into_iter().collect(); + match timzezone { + Some(tz) => Arc::new(array.with_timezone(tz)), + None => Arc::new(array), + } + } + + fn timestamp_nanoseconds_array( + input: impl IntoIterator>, + timzezone: Option<&str>, + ) -> ArrayRef { let array: TimestampNanosecondArray = input.into_iter().collect(); - Arc::new(array) + match timzezone { + Some(tz) => Arc::new(array.with_timezone(tz)), + None => Arc::new(array), + } } fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index eebf3447cbe9..2836cd2893f3 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -22,12 +22,16 @@ use std::fs::File; use std::sync::Arc; use arrow::compute::kernels::cast_utils::Parser; -use arrow::datatypes::{Date32Type, Date64Type}; +use arrow::datatypes::{ + Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, RecordBatch, StringArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + Int32Array, Int64Array, Int8Array, RecordBatch, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow_schema::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::parquet::{ @@ -456,36 +460,40 @@ async fn test_int_8() { // timestamp #[tokio::test] async fn test_timestamp() { - // This creates a parquet files of 5 columns named "nanos", "micros", "millis", "seconds", "names" + // This creates a parquet files of 9 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned", "names" // "nanos" --> TimestampNanosecondArray + // "nanos_timezoned" --> TimestampNanosecondArray // "micros" --> TimestampMicrosecondArray + // "micros_timezoned" --> TimestampMicrosecondArray // "millis" --> TimestampMillisecondArray + // "millis_timezoned" --> TimestampMillisecondArray // "seconds" --> TimestampSecondArray + // "seconds_timezoned" --> TimestampSecondArray // "names" --> StringArray // // The file is created by 4 record batches, each has 5 rowws. // Since the row group isze is set to 5, those 4 batches will go into 4 row groups - // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" + // This creates a parquet files of 4 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned" let reader = TestReader { scenario: Scenario::Timestamps, row_per_group: 5, }; + let tz = "Pacific/Efate"; + Test { reader: reader.build().await, - // mins are [1577840461000000000, 1577840471000000000, 1577841061000000000, 1578704461000000000,] - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461000000000, - 1577840471000000000, - 1577841061000000000, - 1578704461000000000, + expected_min: Arc::new(TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-01T01:01:01"), + TimestampNanosecondType::parse("2020-01-01T01:01:11"), + TimestampNanosecondType::parse("2020-01-01T01:11:01"), + TimestampNanosecondType::parse("2020-01-11T01:01:01"), ])), - // maxes are [1577926861000000000, 1577926871000000000, 1577927461000000000, 1578790861000000000,] - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861000000000, - 1577926871000000000, - 1577927461000000000, - 1578790861000000000, + expected_max: Arc::new(TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-02T01:01:01"), + TimestampNanosecondType::parse("2020-01-02T01:01:11"), + TimestampNanosecondType::parse("2020-01-02T01:11:01"), + TimestampNanosecondType::parse("2020-01-12T01:01:01"), ])), // nulls are [1, 1, 1, 1] expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), @@ -495,21 +503,48 @@ async fn test_timestamp() { } .run(); - // micros + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-01T01:01:01"), + TimestampNanosecondType::parse("2020-01-01T01:01:11"), + TimestampNanosecondType::parse("2020-01-01T01:11:01"), + TimestampNanosecondType::parse("2020-01-11T01:01:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-02T01:01:01"), + TimestampNanosecondType::parse("2020-01-02T01:01:11"), + TimestampNanosecondType::parse("2020-01-02T01:11:01"), + TimestampNanosecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 1, 1, 1] + expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), + // row counts are [5, 5, 5, 5] + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "nanos_timezoned", + } + .run(); + // micros Test { reader: reader.build().await, - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461000000, - 1577840471000000, - 1577841061000000, - 1578704461000000, + expected_min: Arc::new(TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-01T01:01:01"), + TimestampMicrosecondType::parse("2020-01-01T01:01:11"), + TimestampMicrosecondType::parse("2020-01-01T01:11:01"), + TimestampMicrosecondType::parse("2020-01-11T01:01:01"), ])), - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861000000, - 1577926871000000, - 1577927461000000, - 1578790861000000, + expected_max: Arc::new(TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-02T01:01:01"), + TimestampMicrosecondType::parse("2020-01-02T01:01:11"), + TimestampMicrosecondType::parse("2020-01-02T01:11:01"), + TimestampMicrosecondType::parse("2020-01-12T01:01:01"), ])), expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), @@ -517,20 +552,48 @@ async fn test_timestamp() { } .run(); + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-01T01:01:01"), + TimestampMicrosecondType::parse("2020-01-01T01:01:11"), + TimestampMicrosecondType::parse("2020-01-01T01:11:01"), + TimestampMicrosecondType::parse("2020-01-11T01:01:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-02T01:01:01"), + TimestampMicrosecondType::parse("2020-01-02T01:01:11"), + TimestampMicrosecondType::parse("2020-01-02T01:11:01"), + TimestampMicrosecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 1, 1, 1] + expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), + // row counts are [5, 5, 5, 5] + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "micros_timezoned", + } + .run(); + // millis Test { reader: reader.build().await, - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461000, - 1577840471000, - 1577841061000, - 1578704461000, + expected_min: Arc::new(TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-01T01:01:01"), + TimestampMillisecondType::parse("2020-01-01T01:01:11"), + TimestampMillisecondType::parse("2020-01-01T01:11:01"), + TimestampMillisecondType::parse("2020-01-11T01:01:01"), ])), - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861000, - 1577926871000, - 1577927461000, - 1578790861000, + expected_max: Arc::new(TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-02T01:01:01"), + TimestampMillisecondType::parse("2020-01-02T01:01:11"), + TimestampMillisecondType::parse("2020-01-02T01:11:01"), + TimestampMillisecondType::parse("2020-01-12T01:01:01"), ])), expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), @@ -538,30 +601,96 @@ async fn test_timestamp() { } .run(); + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-01T01:01:01"), + TimestampMillisecondType::parse("2020-01-01T01:01:11"), + TimestampMillisecondType::parse("2020-01-01T01:11:01"), + TimestampMillisecondType::parse("2020-01-11T01:01:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-02T01:01:01"), + TimestampMillisecondType::parse("2020-01-02T01:01:11"), + TimestampMillisecondType::parse("2020-01-02T01:11:01"), + TimestampMillisecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 1, 1, 1] + expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), + // row counts are [5, 5, 5, 5] + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "millis_timezoned", + } + .run(); + // seconds Test { reader: reader.build().await, - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461, 1577840471, 1577841061, 1578704461, + expected_min: Arc::new(TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-01T01:01:01"), + TimestampSecondType::parse("2020-01-01T01:01:11"), + TimestampSecondType::parse("2020-01-01T01:11:01"), + TimestampSecondType::parse("2020-01-11T01:01:01"), ])), - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861, 1577926871, 1577927461, 1578790861, + expected_max: Arc::new(TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-02T01:01:01"), + TimestampSecondType::parse("2020-01-02T01:01:11"), + TimestampSecondType::parse("2020-01-02T01:11:01"), + TimestampSecondType::parse("2020-01-12T01:01:01"), ])), expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds", } .run(); + + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-01T01:01:01"), + TimestampSecondType::parse("2020-01-01T01:01:11"), + TimestampSecondType::parse("2020-01-01T01:11:01"), + TimestampSecondType::parse("2020-01-11T01:01:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-02T01:01:01"), + TimestampSecondType::parse("2020-01-02T01:01:11"), + TimestampSecondType::parse("2020-01-02T01:11:01"), + TimestampSecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 1, 1, 1] + expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), + // row counts are [5, 5, 5, 5] + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "seconds_timezoned", + } + .run(); } // timestamp with different row group sizes #[tokio::test] async fn test_timestamp_diff_rg_sizes() { - // This creates a parquet files of 5 columns named "nanos", "micros", "millis", "seconds", "names" + // This creates a parquet files of 9 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned", "names" // "nanos" --> TimestampNanosecondArray + // "nanos_timezoned" --> TimestampNanosecondArray // "micros" --> TimestampMicrosecondArray + // "micros_timezoned" --> TimestampMicrosecondArray // "millis" --> TimestampMillisecondArray + // "millis_timezoned" --> TimestampMillisecondArray // "seconds" --> TimestampSecondArray + // "seconds_timezoned" --> TimestampSecondArray // "names" --> StringArray // // The file is created by 4 record batches (each has a null row), each has 5 rows but then will be split into 3 row groups with size 8, 8, 4 @@ -570,19 +699,19 @@ async fn test_timestamp_diff_rg_sizes() { row_per_group: 8, // note that the row group size is 8 }; + let tz = "Pacific/Efate"; + Test { reader: reader.build().await, - // mins are [1577840461000000000, 1577841061000000000, 1578704521000000000] - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461000000000, - 1577841061000000000, - 1578704521000000000, + expected_min: Arc::new(TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-01T01:01:01"), + TimestampNanosecondType::parse("2020-01-01T01:11:01"), + TimestampNanosecondType::parse("2020-01-11T01:02:01"), ])), - // maxes are [1577926861000000000, 1578704461000000000, 157879086100000000] - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861000000000, - 1578704461000000000, - 1578790861000000000, + expected_max: Arc::new(TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-02T01:01:01"), + TimestampNanosecondType::parse("2020-01-11T01:01:01"), + TimestampNanosecondType::parse("2020-01-12T01:01:01"), ])), // nulls are [1, 2, 1] expected_null_counts: UInt64Array::from(vec![1, 2, 1]), @@ -592,18 +721,44 @@ async fn test_timestamp_diff_rg_sizes() { } .run(); + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-01T01:01:01"), + TimestampNanosecondType::parse("2020-01-01T01:11:01"), + TimestampNanosecondType::parse("2020-01-11T01:02:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampNanosecondArray::from(vec![ + TimestampNanosecondType::parse("2020-01-02T01:01:01"), + TimestampNanosecondType::parse("2020-01-11T01:01:01"), + TimestampNanosecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 2, 1] + expected_null_counts: UInt64Array::from(vec![1, 2, 1]), + // row counts are [8, 8, 4] + expected_row_counts: UInt64Array::from(vec![8, 8, 4]), + column_name: "nanos_timezoned", + } + .run(); + // micros Test { reader: reader.build().await, - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461000000, - 1577841061000000, - 1578704521000000, + expected_min: Arc::new(TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-01T01:01:01"), + TimestampMicrosecondType::parse("2020-01-01T01:11:01"), + TimestampMicrosecondType::parse("2020-01-11T01:02:01"), ])), - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861000000, - 1578704461000000, - 1578790861000000, + expected_max: Arc::new(TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-02T01:01:01"), + TimestampMicrosecondType::parse("2020-01-11T01:01:01"), + TimestampMicrosecondType::parse("2020-01-12T01:01:01"), ])), expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), @@ -611,18 +766,44 @@ async fn test_timestamp_diff_rg_sizes() { } .run(); + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-01T01:01:01"), + TimestampMicrosecondType::parse("2020-01-01T01:11:01"), + TimestampMicrosecondType::parse("2020-01-11T01:02:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2020-01-02T01:01:01"), + TimestampMicrosecondType::parse("2020-01-11T01:01:01"), + TimestampMicrosecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 2, 1] + expected_null_counts: UInt64Array::from(vec![1, 2, 1]), + // row counts are [8, 8, 4] + expected_row_counts: UInt64Array::from(vec![8, 8, 4]), + column_name: "micros_timezoned", + } + .run(); + // millis Test { reader: reader.build().await, - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461000, - 1577841061000, - 1578704521000, + expected_min: Arc::new(TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-01T01:01:01"), + TimestampMillisecondType::parse("2020-01-01T01:11:01"), + TimestampMillisecondType::parse("2020-01-11T01:02:01"), ])), - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861000, - 1578704461000, - 1578790861000, + expected_max: Arc::new(TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-02T01:01:01"), + TimestampMillisecondType::parse("2020-01-11T01:01:01"), + TimestampMillisecondType::parse("2020-01-12T01:01:01"), ])), expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), @@ -630,20 +811,76 @@ async fn test_timestamp_diff_rg_sizes() { } .run(); + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-01T01:01:01"), + TimestampMillisecondType::parse("2020-01-01T01:11:01"), + TimestampMillisecondType::parse("2020-01-11T01:02:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampMillisecondArray::from(vec![ + TimestampMillisecondType::parse("2020-01-02T01:01:01"), + TimestampMillisecondType::parse("2020-01-11T01:01:01"), + TimestampMillisecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 2, 1] + expected_null_counts: UInt64Array::from(vec![1, 2, 1]), + // row counts are [8, 8, 4] + expected_row_counts: UInt64Array::from(vec![8, 8, 4]), + column_name: "millis_timezoned", + } + .run(); + // seconds Test { reader: reader.build().await, - expected_min: Arc::new(Int64Array::from(vec![ - 1577840461, 1577841061, 1578704521, + expected_min: Arc::new(TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-01T01:01:01"), + TimestampSecondType::parse("2020-01-01T01:11:01"), + TimestampSecondType::parse("2020-01-11T01:02:01"), ])), - expected_max: Arc::new(Int64Array::from(vec![ - 1577926861, 1578704461, 1578790861, + expected_max: Arc::new(TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-02T01:01:01"), + TimestampSecondType::parse("2020-01-11T01:01:01"), + TimestampSecondType::parse("2020-01-12T01:01:01"), ])), expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds", } .run(); + + Test { + reader: reader.build().await, + expected_min: Arc::new( + TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-01T01:01:01"), + TimestampSecondType::parse("2020-01-01T01:11:01"), + TimestampSecondType::parse("2020-01-11T01:02:01"), + ]) + .with_timezone(tz), + ), + expected_max: Arc::new( + TimestampSecondArray::from(vec![ + TimestampSecondType::parse("2020-01-02T01:01:01"), + TimestampSecondType::parse("2020-01-11T01:01:01"), + TimestampSecondType::parse("2020-01-12T01:01:01"), + ]) + .with_timezone(tz), + ), + // nulls are [1, 2, 1] + expected_null_counts: UInt64Array::from(vec![1, 2, 1]), + // row counts are [8, 8, 4] + expected_row_counts: UInt64Array::from(vec![8, 8, 4]), + column_name: "seconds_timezoned", + } + .run(); } // date with different row group sizes diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 94ae9ff601ec..41a0a86aa8d3 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -332,9 +332,13 @@ fn make_boolean_batch(v: Vec>) -> RecordBatch { /// /// Columns are named: /// "nanos" --> TimestampNanosecondArray +/// "nanos_timezoned" --> TimestampNanosecondArray with timezone /// "micros" --> TimestampMicrosecondArray +/// "micros_timezoned" --> TimestampMicrosecondArray with timezone /// "millis" --> TimestampMillisecondArray +/// "millis_timezoned" --> TimestampMillisecondArray with timezone /// "seconds" --> TimestampSecondArray +/// "seconds_timezoned" --> TimestampSecondArray with timezone /// "names" --> StringArray fn make_timestamp_batch(offset: Duration) -> RecordBatch { let ts_strings = vec![ @@ -345,6 +349,8 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { Some("2020-01-02T01:01:01.0000000000001"), ]; + let tz_string = "Pacific/Efate"; + let offset_nanos = offset.num_nanoseconds().expect("non overflow nanos"); let ts_nanos = ts_strings @@ -382,19 +388,47 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { .map(|(i, _)| format!("Row {i} + {offset}")) .collect::>(); - let arr_nanos = TimestampNanosecondArray::from(ts_nanos); - let arr_micros = TimestampMicrosecondArray::from(ts_micros); - let arr_millis = TimestampMillisecondArray::from(ts_millis); - let arr_seconds = TimestampSecondArray::from(ts_seconds); + let arr_nanos = TimestampNanosecondArray::from(ts_nanos.clone()); + let arr_nanos_timezoned = + TimestampNanosecondArray::from(ts_nanos).with_timezone(tz_string); + let arr_micros = TimestampMicrosecondArray::from(ts_micros.clone()); + let arr_micros_timezoned = + TimestampMicrosecondArray::from(ts_micros).with_timezone(tz_string); + let arr_millis = TimestampMillisecondArray::from(ts_millis.clone()); + let arr_millis_timezoned = + TimestampMillisecondArray::from(ts_millis).with_timezone(tz_string); + let arr_seconds = TimestampSecondArray::from(ts_seconds.clone()); + let arr_seconds_timezoned = + TimestampSecondArray::from(ts_seconds).with_timezone(tz_string); let names = names.iter().map(|s| s.as_str()).collect::>(); let arr_names = StringArray::from(names); let schema = Schema::new(vec![ Field::new("nanos", arr_nanos.data_type().clone(), true), + Field::new( + "nanos_timezoned", + arr_nanos_timezoned.data_type().clone(), + true, + ), Field::new("micros", arr_micros.data_type().clone(), true), + Field::new( + "micros_timezoned", + arr_micros_timezoned.data_type().clone(), + true, + ), Field::new("millis", arr_millis.data_type().clone(), true), + Field::new( + "millis_timezoned", + arr_millis_timezoned.data_type().clone(), + true, + ), Field::new("seconds", arr_seconds.data_type().clone(), true), + Field::new( + "seconds_timezoned", + arr_seconds_timezoned.data_type().clone(), + true, + ), Field::new("name", arr_names.data_type().clone(), true), ]); let schema = Arc::new(schema); @@ -403,9 +437,13 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { schema, vec![ Arc::new(arr_nanos), + Arc::new(arr_nanos_timezoned), Arc::new(arr_micros), + Arc::new(arr_micros_timezoned), Arc::new(arr_millis), + Arc::new(arr_millis_timezoned), Arc::new(arr_seconds), + Arc::new(arr_seconds_timezoned), Arc::new(arr_names), ], ) From fbbab6c7adb0c2c285ff3f9ed25f5bd9796ecb89 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 3 Jun 2024 11:31:06 -0400 Subject: [PATCH 26/35] Minor: Add tests for extracting dictionary parquet statistics (#10729) --- .../core/tests/parquet/arrow_statistics.rs | 40 ++++++++++++++++- datafusion/core/tests/parquet/mod.rs | 44 ++++++++++++++++++- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 2836cd2893f3..5e0f8b4f5f18 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -1231,8 +1231,44 @@ async fn test_decimal() { .run(); } -// BUG: not convert BinaryArray to StringArray -// https://github.com/apache/datafusion/issues/10605 +#[tokio::test] +async fn test_dictionary() { + let reader = TestReader { + scenario: Scenario::Dictionary, + row_per_group: 5, + }; + + Test { + reader: reader.build().await, + expected_min: Arc::new(StringArray::from(vec!["abc", "aaa"])), + expected_max: Arc::new(StringArray::from(vec!["def", "fffff"])), + expected_null_counts: UInt64Array::from(vec![1, 0]), + expected_row_counts: UInt64Array::from(vec![5, 2]), + column_name: "string_dict_i8", + } + .run(); + + Test { + reader: reader.build().await, + expected_min: Arc::new(StringArray::from(vec!["abc", "aaa"])), + expected_max: Arc::new(StringArray::from(vec!["def", "fffff"])), + expected_null_counts: UInt64Array::from(vec![1, 0]), + expected_row_counts: UInt64Array::from(vec![5, 2]), + column_name: "string_dict_i32", + } + .run(); + + Test { + reader: reader.build().await, + expected_min: Arc::new(Int64Array::from(vec![-100, 0])), + expected_max: Arc::new(Int64Array::from(vec![0, 100])), + expected_null_counts: UInt64Array::from(vec![1, 0]), + expected_row_counts: UInt64Array::from(vec![5, 2]), + column_name: "int_dict_i8", + } + .run(); +} + #[tokio::test] async fn test_byte() { // This creates a parquet file of 4 columns diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 41a0a86aa8d3..f45ff53d3fb8 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -28,7 +28,8 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::{make_array, BooleanArray, Float32Array, StructArray}; +use arrow_array::types::{Int32Type, Int8Type}; +use arrow_array::{make_array, BooleanArray, DictionaryArray, Float32Array, StructArray}; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, @@ -81,7 +82,10 @@ enum Scenario { DecimalBloomFilterInt64, DecimalLargePrecision, DecimalLargePrecisionBloomFilter, + /// StringArray, BinaryArray, FixedSizeBinaryArray ByteArray, + /// DictionaryArray + Dictionary, PeriodsInColumnNames, WithNullValues, WithNullValuesPageLevel, @@ -783,6 +787,41 @@ fn make_numeric_limit_batch() -> RecordBatch { .unwrap() } +fn make_dict_batch() -> RecordBatch { + let values = [ + Some("abc"), + Some("def"), + None, + Some("def"), + Some("abc"), + Some("fffff"), + Some("aaa"), + ]; + let dict_i8_array = DictionaryArray::::from_iter(values.iter().cloned()); + let dict_i32_array = DictionaryArray::::from_iter(values.iter().cloned()); + + // Dictionary array of integers + let int64_values = Int64Array::from(vec![0, -100, 100]); + let keys = Int8Array::from_iter([ + Some(0), + Some(1), + None, + Some(0), + Some(0), + Some(2), + Some(0), + ]); + let dict_i8_int_array = + DictionaryArray::::try_new(keys, Arc::new(int64_values)).unwrap(); + + RecordBatch::try_from_iter(vec![ + ("string_dict_i8", Arc::new(dict_i8_array) as _), + ("string_dict_i32", Arc::new(dict_i32_array) as _), + ("int_dict_i8", Arc::new(dict_i8_int_array) as _), + ]) + .unwrap() +} + fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Boolean => { @@ -954,6 +993,9 @@ fn create_data_batch(scenario: Scenario) -> Vec { ), ] } + Scenario::Dictionary => { + vec![make_dict_batch()] + } Scenario::PeriodsInColumnNames => { vec![ // all frontend From eabbd28603ebf440506ec118a1bc6ec29e696eac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:31:11 -0400 Subject: [PATCH 27/35] Update rstest requirement from 0.20.0 to 0.21.0 (#10774) Updates the requirements on [rstest](https://github.com/la10736/rstest) to permit the latest version. - [Release notes](https://github.com/la10736/rstest/releases) - [Changelog](https://github.com/la10736/rstest/blob/master/CHANGELOG.md) - [Commits](https://github.com/la10736/rstest/compare/v0.20.0...v0.20.0) --- updated-dependencies: - dependency-name: rstest dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 45504be3f1ba..54f2f203fcdc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,7 +109,7 @@ parking_lot = "0.12" parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" regex = "1.8" -rstest = "0.20.0" +rstest = "0.21.0" serde_json = "1" sqlparser = { version = "0.45.0", features = ["visitor"] } tempfile = "3" From a92f803298da35776ffc40adce161bac88601938 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Mon, 3 Jun 2024 20:40:32 +0200 Subject: [PATCH 28/35] Minor: Refactor memory size estimation for HashTable (#10748) * refactor: extract estimate_memory_size * refactor: cap at usize::MAX * refactor: use estimate_memory_size * chore: add examples * refactor: return Result; add testcase * fix: docs * fix: remove unneccessary checked_div * fix: remove additional and_then --- datafusion/common/src/utils/memory.rs | 134 ++++++++++++++++++ datafusion/common/src/utils/mod.rs | 1 + .../src/aggregate/count_distinct/native.rs | 35 ++--- .../physical-plan/src/joins/hash_join.rs | 25 +--- 4 files changed, 153 insertions(+), 42 deletions(-) create mode 100644 datafusion/common/src/utils/memory.rs diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs new file mode 100644 index 000000000000..17668cf93d99 --- /dev/null +++ b/datafusion/common/src/utils/memory.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides a function to estimate the memory size of a HashTable prior to alloaction + +use crate::{DataFusionError, Result}; + +/// Estimates the memory size required for a hash table prior to allocation. +/// +/// # Parameters +/// - `num_elements`: The number of elements expected in the hash table. +/// - `fixed_size`: A fixed overhead size associated with the collection +/// (e.g., HashSet or HashTable). +/// - `T`: The type of elements stored in the hash table. +/// +/// # Details +/// This function calculates the estimated memory size by considering: +/// - An overestimation of buckets to keep approximately 1/8 of them empty. +/// - The total memory size is computed as: +/// - The size of each entry (`T`) multiplied by the estimated number of +/// buckets. +/// - One byte overhead for each bucket. +/// - The fixed size overhead of the collection. +/// - If the estimation overflows, we return a [`DataFusionError`] +/// +/// # Examples +/// --- +/// +/// ## From within a struct +/// +/// ```rust +/// # use datafusion_common::utils::memory::estimate_memory_size; +/// # use datafusion_common::Result; +/// +/// struct MyStruct { +/// values: Vec, +/// other_data: usize, +/// } +/// +/// impl MyStruct { +/// fn size(&self) -> Result { +/// let num_elements = self.values.len(); +/// let fixed_size = std::mem::size_of_val(self) + +/// std::mem::size_of_val(&self.values); +/// +/// estimate_memory_size::(num_elements, fixed_size) +/// } +/// } +/// ``` +/// --- +/// ## With a simple collection +/// +/// ```rust +/// # use datafusion_common::utils::memory::estimate_memory_size; +/// # use std::collections::HashMap; +/// +/// let num_rows = 100; +/// let fixed_size = std::mem::size_of::>(); +/// let estimated_hashtable_size = +/// estimate_memory_size::<(u64, u64)>(num_rows,fixed_size) +/// .expect("Size estimation failed"); +/// ``` +pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result { + // For the majority of cases hashbrown overestimates the bucket quantity + // to keep ~1/8 of them empty. We take this factor into account by + // multiplying the number of elements with a fixed ratio of 8/7 (~1.14). + // This formula leads to overallocation for small tables (< 8 elements) + // but should be fine overall. + num_elements + .checked_mul(8) + .and_then(|overestimate| { + let estimated_buckets = (overestimate / 7).next_power_of_two(); + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of collection (HashSet/HashTable) + std::mem::size_of::() + .checked_mul(estimated_buckets)? + .checked_add(estimated_buckets)? + .checked_add(fixed_size) + }) + .ok_or_else(|| { + DataFusionError::Execution( + "usize overflow while estimating the number of buckets".to_string(), + ) + }) +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::estimate_memory_size; + + #[test] + fn test_estimate_memory() { + // size (bytes): 48 + let fixed_size = std::mem::size_of::>(); + + // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two() + let num_elements = 8; + // size (bytes): 128 = 16 * 4 + 16 + 48 + let estimated = estimate_memory_size::(num_elements, fixed_size).unwrap(); + assert_eq!(estimated, 128); + + // estimated buckets: 64 = (40 * 8 / 7).next_power_of_two() + let num_elements = 40; + // size (bytes): 368 = 64 * 4 + 64 + 48 + let estimated = estimate_memory_size::(num_elements, fixed_size).unwrap(); + assert_eq!(estimated, 368); + } + + #[test] + fn test_estimate_memory_overflow() { + let num_elements = usize::MAX; + let fixed_size = std::mem::size_of::>(); + let estimated = estimate_memory_size::(num_elements, fixed_size); + + assert!(estimated.is_err()); + } +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 402ec95b33b3..ae444c2cb285 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,6 +17,7 @@ //! This module provides the bisect function, which implements binary search. +pub mod memory; pub mod proxy; use crate::error::{_internal_datafusion_err, _internal_err}; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs index 95d8662e0f6e..0e7483d4a1cd 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs @@ -33,6 +33,7 @@ use arrow_schema::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; @@ -115,18 +116,11 @@ where } fn size(&self) -> usize { - let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) - / 7) - .next_power_of_two(); - - // Size of accumulator - // + size of entry * number of buckets - // + 1 byte for each bucket - // + fixed size of HashSet - std::mem::size_of_val(self) - + std::mem::size_of::() * estimated_buckets - + estimated_buckets - + std::mem::size_of_val(&self.values) + let num_elements = self.values.len(); + let fixed_size = + std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + + estimate_memory_size::(num_elements, fixed_size).unwrap() } } @@ -202,17 +196,10 @@ where } fn size(&self) -> usize { - let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) - / 7) - .next_power_of_two(); - - // Size of accumulator - // + size of entry * number of buckets - // + 1 byte for each bucket - // + fixed size of HashSet - std::mem::size_of_val(self) - + std::mem::size_of::() * estimated_buckets - + estimated_buckets - + std::mem::size_of_val(&self.values) + let num_elements = self.values.len(); + let fixed_size = + std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + + estimate_memory_size::(num_elements, fixed_size).unwrap() } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index e669517be400..784584f03f0f 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -18,7 +18,6 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator use std::fmt; -use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; @@ -59,6 +58,7 @@ use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_array::cast::downcast_array; use arrow_schema::ArrowError; +use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, JoinSide, JoinType, Result, @@ -875,23 +875,12 @@ async fn collect_left_input( // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` - // - // For majority of cases hashbrown overestimates buckets qty to keep ~1/8 of them empty. - // This formula leads to overallocation for small tables (< 8 elements) but fine overall. - let estimated_buckets = (num_rows.checked_mul(8).ok_or_else(|| { - DataFusionError::Execution( - "usize overflow while estimating number of hasmap buckets".to_string(), - ) - })? / 7) - .next_power_of_two(); - // 16 bytes per `(u64, u64)` - // + 1 byte for each bucket - // + fixed size of JoinHashMap (RawTable + Vec) - let estimated_hastable_size = - 16 * estimated_buckets + estimated_buckets + size_of::(); - - reservation.try_grow(estimated_hastable_size)?; - metrics.build_mem_used.add(estimated_hastable_size); + let fixed_size = std::mem::size_of::(); + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?; + + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); From 3aae451d38476510670fff04404418955a4fc83c Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Mon, 3 Jun 2024 21:43:07 +0300 Subject: [PATCH 29/35] Reduce code repetition in `datafusion/functions` mod files (#10700) * initial reduce repetition using macros * formatting and docs * fix docs * refix doc * replace math mod too * fix vec arguments * fix math variadic args * apply to functions * pattern-match hack to avoid second macro * missed a function * fix merge conflict * fix octet_length argument --- datafusion/functions/src/core/mod.rs | 82 ++++---- datafusion/functions/src/crypto/mod.rs | 60 +++--- datafusion/functions/src/datetime/mod.rs | 123 +++++------- datafusion/functions/src/encoding/mod.rs | 22 ++- datafusion/functions/src/macros.rs | 38 ++-- datafusion/functions/src/math/mod.rs | 238 ++++------------------- datafusion/functions/src/string/mod.rs | 181 ++++++++--------- datafusion/functions/src/unicode/mod.rs | 111 +++++------ 8 files changed, 339 insertions(+), 516 deletions(-) diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 349d483a4100..a2742220f3e9 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -42,59 +42,49 @@ make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -// Export the functions out of this package, both as expr_fn as well as a list of functions pub mod expr_fn { use datafusion_expr::{Expr, Literal}; - /// returns NULL if value1 equals value2; otherwise it returns value1. This - /// can be used to perform the inverse operation of the COALESCE expression - pub fn nullif(arg1: Expr, arg2: Expr) -> Expr { - super::nullif().call(vec![arg1, arg2]) - } - - /// returns value1 cast to the `arrow_type` given the second argument. This - /// can be used to cast to a specific `arrow_type`. - pub fn arrow_cast(arg1: Expr, arg2: Expr) -> Expr { - super::arrow_cast().call(vec![arg1, arg2]) - } - - /// Returns value2 if value1 is NULL; otherwise it returns value1 - pub fn nvl(arg1: Expr, arg2: Expr) -> Expr { - super::nvl().call(vec![arg1, arg2]) - } - - /// Returns value2 if value1 is not NULL; otherwise, it returns value3. - pub fn nvl2(arg1: Expr, arg2: Expr, arg3: Expr) -> Expr { - super::nvl2().call(vec![arg1, arg2, arg3]) - } - - /// Returns the Arrow type of the input expression. - pub fn arrow_typeof(arg1: Expr) -> Expr { - super::arrow_typeof().call(vec![arg1]) - } - - /// Returns a struct with the given arguments - pub fn r#struct(args: Vec) -> Expr { - super::r#struct().call(args) - } - - /// Returns a struct with the given names and arguments pairs - pub fn named_struct(args: Vec) -> Expr { - super::named_struct().call(args) - } - - /// Returns the value of the field with the given name from the struct - pub fn get_field(arg1: Expr, field_name: impl Literal) -> Expr { - super::get_field().call(vec![arg1, field_name.lit()]) - } + export_functions!(( + nullif, + "Returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression", + arg1 arg2 + ),( + arrow_cast, + "Returns value2 if value1 is NULL; otherwise it returns value1", + arg1 arg2 + ),( + nvl, + "Returns value2 if value1 is NULL; otherwise it returns value1", + arg1 arg2 + ),( + nvl2, + "Returns value2 if value1 is not NULL; otherwise, it returns value3.", + arg1 arg2 arg3 + ),( + arrow_typeof, + "Returns the Arrow type of the input expression.", + arg1 + ),( + r#struct, + "Returns a struct with the given arguments", + args, + ),( + named_struct, + "Returns a struct with the given names and arguments pairs", + args, + ),( + coalesce, + "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", + args, + )); - /// Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL - pub fn coalesce(args: Vec) -> Expr { - super::coalesce().call(args) + #[doc = "Returns the value of the field with the given name from the struct"] + pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr { + super::get_field().call(vec![arg1, arg2.lit()]) } } -/// Return a list of all functions in this package pub fn functions() -> Vec> { vec![ nullif(), diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index a879fdb45b35..497c1af62a72 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -17,6 +17,9 @@ //! "crypto" DataFusion functions +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod basic; pub mod digest; pub mod md5; @@ -30,28 +33,35 @@ make_udf_function!(sha224::SHA224Func, SHA224, sha224); make_udf_function!(sha256::SHA256Func, SHA256, sha256); make_udf_function!(sha384::SHA384Func, SHA384, sha384); make_udf_function!(sha512::SHA512Func, SHA512, sha512); -export_functions!(( - digest, - input_arg1 input_arg2, - "Computes the binary hash of an expression using the specified algorithm." -),( - md5, - input_arg, - "Computes an MD5 128-bit checksum for a string expression." -),( - sha224, - input_arg1, - "Computes the SHA-224 hash of a binary string." -),( - sha256, - input_arg1, - "Computes the SHA-256 hash of a binary string." -),( - sha384, - input_arg1, - "Computes the SHA-384 hash of a binary string." -),( - sha512, - input_arg1, - "Computes the SHA-512 hash of a binary string." -)); + +pub mod expr_fn { + export_functions!(( + digest, + "Computes the binary hash of an expression using the specified algorithm.", + input_arg1 input_arg2 + ),( + md5, + "Computes an MD5 128-bit checksum for a string expression.", + input_arg + ),( + sha224, + "Computes the SHA-224 hash of a binary string.", + input_arg1 + ),( + sha256, + "Computes the SHA-256 hash of a binary string.", + input_arg1 + ),( + sha384, + "Computes the SHA-384 hash of a binary string.", + input_arg1 + ),( + sha512, + "Computes the SHA-512 hash of a binary string.", + input_arg1 + )); +} + +pub fn functions() -> Vec> { + vec![digest(), md5(), sha224(), sha256(), sha384(), sha512()] +} diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index c6939976eb02..9c2f80856bf8 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -79,45 +79,60 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; - #[doc = "returns current UTC date as a Date32 value"] - pub fn current_date() -> Expr { - super::current_date().call(vec![]) - } - - #[doc = "returns current UTC time as a Time64 value"] - pub fn current_time() -> Expr { - super::current_time().call(vec![]) - } - - #[doc = "coerces an arbitrary timestamp to the start of the nearest specified interval"] - pub fn date_bin(stride: Expr, source: Expr, origin: Expr) -> Expr { - super::date_bin().call(vec![stride, source, origin]) - } - - #[doc = "extracts a subfield from the date"] - pub fn date_part(part: Expr, date: Expr) -> Expr { - super::date_part().call(vec![part, date]) - } - - #[doc = "truncates the date to a specified level of precision"] - pub fn date_trunc(part: Expr, date: Expr) -> Expr { - super::date_trunc().call(vec![part, date]) - } - - #[doc = "converts an integer to RFC3339 timestamp format string"] - pub fn from_unixtime(unixtime: Expr) -> Expr { - super::from_unixtime().call(vec![unixtime]) - } - - #[doc = "make a date from year, month and day component parts"] - pub fn make_date(year: Expr, month: Expr, day: Expr) -> Expr { - super::make_date().call(vec![year, month, day]) - } - - #[doc = "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement"] - pub fn now() -> Expr { - super::now().call(vec![]) - } + export_functions!(( + current_date, + "returns current UTC date as a Date32 value", + ),( + current_time, + "returns current UTC time as a Time64 value", + ),( + from_unixtime, + "converts an integer to RFC3339 timestamp format string", + unixtime + ),( + date_bin, + "coerces an arbitrary timestamp to the start of the nearest specified interval", + stride source origin + ),( + date_part, + "extracts a subfield from the date", + part date + ),( + date_trunc, + "truncates the date to a specified level of precision", + part date + ),( + make_date, + "make a date from year, month and day component parts", + year month day + ),( + now, + "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", + ),( + to_unixtime, + "converts a string and optional formats to a Unixtime", + args, + ),( + to_timestamp, + "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`", + args, + ),( + to_timestamp_seconds, + "converts a string and optional formats to a `Timestamp(Seconds, None)`", + args, + ),( + to_timestamp_millis, + "converts a string and optional formats to a `Timestamp(Milliseconds, None)`", + args, + ),( + to_timestamp_micros, + "converts a string and optional formats to a `Timestamp(Microseconds, None)`", + args, + ),( + to_timestamp_nanos, + "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`", + args, + )); /// Returns a string representation of a date, time, timestamp or duration based /// on a Chrono pattern. @@ -247,36 +262,6 @@ pub mod expr_fn { pub fn to_date(args: Vec) -> Expr { super::to_date().call(args) } - - #[doc = "converts a string and optional formats to a Unixtime"] - pub fn to_unixtime(args: Vec) -> Expr { - super::to_unixtime().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`"] - pub fn to_timestamp(args: Vec) -> Expr { - super::to_timestamp().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Seconds, None)`"] - pub fn to_timestamp_seconds(args: Vec) -> Expr { - super::to_timestamp_seconds().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Milliseconds, None)`"] - pub fn to_timestamp_millis(args: Vec) -> Expr { - super::to_timestamp_millis().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Microseconds, None)`"] - pub fn to_timestamp_micros(args: Vec) -> Expr { - super::to_timestamp_micros().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`"] - pub fn to_timestamp_nanos(args: Vec) -> Expr { - super::to_timestamp_nanos().call(args) - } } /// Return a list of all functions in this package diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs index 49f914a68774..24e11e5d635f 100644 --- a/datafusion/functions/src/encoding/mod.rs +++ b/datafusion/functions/src/encoding/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod inner; // create `encode` and `decode` UDFs @@ -22,7 +25,18 @@ make_udf_function!(inner::EncodeFunc, ENCODE, encode); make_udf_function!(inner::DecodeFunc, DECODE, decode); // Export the functions out of this package, both as expr_fn as well as a list of functions -export_functions!( - (encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"), - (decode, input encoding, "decode the `input`, using the `encoding`. encoding can be base64 or hex") -); +pub mod expr_fn { + export_functions!( ( + encode, + "encode the `input`, using the `encoding`. encoding can be base64 or hex", + input encoding + ),( + decode, + "decode the `input`, using the `encoding`. encoding can be base64 or hex", + input encoding + )); +} + +pub fn functions() -> Vec> { + vec![encode(), decode()] +} diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index dcc37f100c9a..cae689b3e0cb 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -36,25 +36,31 @@ /// ] /// } /// ``` +/// +/// Exported functions accept: +/// - `Vec` argument (single argument followed by a comma) +/// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) macro_rules! export_functions { - ($(($FUNC:ident, $($arg:ident)*, $DOC:expr)),*) => { - pub mod expr_fn { - $( - #[doc = $DOC] - /// Return $name(arg) - pub fn $FUNC($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - super::$FUNC().call(vec![$($arg),*],) - } - )* + ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { + $( + // switch to single-function cases below + export_functions!(single $FUNC, $DOC, $($arg)*); + )* + }; + + // single vector argument (a single argument followed by a comma) + (single $FUNC:ident, $DOC:expr, $arg:ident,) => { + #[doc = $DOC] + pub fn $FUNC($arg: Vec) -> datafusion_expr::Expr { + super::$FUNC().call($arg) } + }; - /// Return a list of all functions in this package - pub fn functions() -> Vec> { - vec![ - $( - $FUNC(), - )* - ] + // variadic arguments (zero or more arguments, without commas) + (single $FUNC:ident, $DOC:expr, $($arg:ident)*) => { + #[doc = $DOC] + pub fn $FUNC($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + super::$FUNC().call(vec![$($arg),*]) } }; } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 387237acb769..9ee173bb6176 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -17,11 +17,9 @@ //! "math" DataFusion functions -use std::sync::Arc; - use crate::math::monotonicity::*; - use datafusion_expr::ScalarUDF; +use std::sync::Arc; pub mod abs; pub mod cot; @@ -92,200 +90,48 @@ make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, super::tanh_order); make_udf_function!(trunc::TruncFunc, TRUNC, trunc); pub mod expr_fn { - use datafusion_expr::Expr; - - #[doc = "returns the absolute value of a given number"] - pub fn abs(num: Expr) -> Expr { - super::abs().call(vec![num]) - } - - #[doc = "returns the arc cosine or inverse cosine of a number"] - pub fn acos(num: Expr) -> Expr { - super::acos().call(vec![num]) - } - - #[doc = "returns inverse hyperbolic cosine"] - pub fn acosh(num: Expr) -> Expr { - super::acosh().call(vec![num]) - } - - #[doc = "returns the arc sine or inverse sine of a number"] - pub fn asin(num: Expr) -> Expr { - super::asin().call(vec![num]) - } - - #[doc = "returns inverse hyperbolic sine"] - pub fn asinh(num: Expr) -> Expr { - super::asinh().call(vec![num]) - } - - #[doc = "returns inverse tangent"] - pub fn atan(num: Expr) -> Expr { - super::atan().call(vec![num]) - } - - #[doc = "returns inverse tangent of a division given in the argument"] - pub fn atan2(y: Expr, x: Expr) -> Expr { - super::atan2().call(vec![y, x]) - } - - #[doc = "returns inverse hyperbolic tangent"] - pub fn atanh(num: Expr) -> Expr { - super::atanh().call(vec![num]) - } - - #[doc = "cube root of a number"] - pub fn cbrt(num: Expr) -> Expr { - super::cbrt().call(vec![num]) - } - - #[doc = "nearest integer greater than or equal to argument"] - pub fn ceil(num: Expr) -> Expr { - super::ceil().call(vec![num]) - } - - #[doc = "cosine"] - pub fn cos(num: Expr) -> Expr { - super::cos().call(vec![num]) - } - - #[doc = "hyperbolic cosine"] - pub fn cosh(num: Expr) -> Expr { - super::cosh().call(vec![num]) - } - - #[doc = "cotangent of a number"] - pub fn cot(num: Expr) -> Expr { - super::cot().call(vec![num]) - } - - #[doc = "converts radians to degrees"] - pub fn degrees(num: Expr) -> Expr { - super::degrees().call(vec![num]) - } - - #[doc = "exponential"] - pub fn exp(num: Expr) -> Expr { - super::exp().call(vec![num]) - } - - #[doc = "factorial"] - pub fn factorial(num: Expr) -> Expr { - super::factorial().call(vec![num]) - } - - #[doc = "nearest integer less than or equal to argument"] - pub fn floor(num: Expr) -> Expr { - super::floor().call(vec![num]) - } - - #[doc = "greatest common divisor"] - pub fn gcd(x: Expr, y: Expr) -> Expr { - super::gcd().call(vec![x, y]) - } - - #[doc = "returns true if a given number is +NaN or -NaN otherwise returns false"] - pub fn isnan(num: Expr) -> Expr { - super::isnan().call(vec![num]) - } - - #[doc = "returns true if a given number is +0.0 or -0.0 otherwise returns false"] - pub fn iszero(num: Expr) -> Expr { - super::iszero().call(vec![num]) - } - - #[doc = "least common multiple"] - pub fn lcm(x: Expr, y: Expr) -> Expr { - super::lcm().call(vec![x, y]) - } - - #[doc = "natural logarithm (base e) of a number"] - pub fn ln(num: Expr) -> Expr { - super::ln().call(vec![num]) - } - - #[doc = "logarithm of a number for a particular `base`"] - pub fn log(base: Expr, num: Expr) -> Expr { - super::log().call(vec![base, num]) - } - - #[doc = "base 2 logarithm of a number"] - pub fn log2(num: Expr) -> Expr { - super::log2().call(vec![num]) - } - - #[doc = "base 10 logarithm of a number"] - pub fn log10(num: Expr) -> Expr { - super::log10().call(vec![num]) - } - - #[doc = "returns x if x is not NaN otherwise returns y"] - pub fn nanvl(x: Expr, y: Expr) -> Expr { - super::nanvl().call(vec![x, y]) - } - - #[doc = "Returns an approximate value of π"] - pub fn pi() -> Expr { - super::pi().call(vec![]) - } - - #[doc = "`base` raised to the power of `exponent`"] - pub fn power(base: Expr, exponent: Expr) -> Expr { - super::power().call(vec![base, exponent]) - } - - #[doc = "converts degrees to radians"] - pub fn radians(num: Expr) -> Expr { - super::radians().call(vec![num]) - } - - #[doc = "Returns a random value in the range 0.0 <= x < 1.0"] - pub fn random() -> Expr { - super::random().call(vec![]) - } - - #[doc = "round to nearest integer"] - pub fn round(args: Vec) -> Expr { - super::round().call(args) - } - - #[doc = "sign of the argument (-1, 0, +1)"] - pub fn signum(num: Expr) -> Expr { - super::signum().call(vec![num]) - } - - #[doc = "sine"] - pub fn sin(num: Expr) -> Expr { - super::sin().call(vec![num]) - } - - #[doc = "hyperbolic sine"] - pub fn sinh(num: Expr) -> Expr { - super::sinh().call(vec![num]) - } - - #[doc = "square root of a number"] - pub fn sqrt(num: Expr) -> Expr { - super::sqrt().call(vec![num]) - } - - #[doc = "returns the tangent of a number"] - pub fn tan(num: Expr) -> Expr { - super::tan().call(vec![num]) - } - - #[doc = "returns the hyperbolic tangent of a number"] - pub fn tanh(num: Expr) -> Expr { - super::tanh().call(vec![num]) - } - - #[doc = "truncate toward zero, with optional precision"] - pub fn trunc(args: Vec) -> Expr { - super::trunc().call(args) - } + export_functions!( + (abs, "returns the absolute value of a given number", num), + (acos, "returns the arc cosine or inverse cosine of a number", num), + (acosh, "returns inverse hyperbolic cosine", num), + (asin, "returns the arc sine or inverse sine of a number", num), + (asinh, "returns inverse hyperbolic sine", num), + (atan, "returns inverse tangent", num), + (atan2, "returns inverse tangent of a division given in the argument", y x), + (atanh, "returns inverse hyperbolic tangent", num), + (cbrt, "cube root of a number", num), + (ceil, "nearest integer greater than or equal to argument", num), + (cos, "cosine", num), + (cosh, "hyperbolic cosine", num), + (cot, "cotangent of a number", num), + (degrees, "converts radians to degrees", num), + (exp, "exponential", num), + (factorial, "factorial", num), + (floor, "nearest integer less than or equal to argument", num), + (gcd, "greatest common divisor", x y), + (isnan, "returns true if a given number is +NaN or -NaN otherwise returns false", num), + (iszero, "returns true if a given number is +0.0 or -0.0 otherwise returns false", num), + (lcm, "least common multiple", x y), + (ln, "natural logarithm (base e) of a number", num), + (log, "logarithm of a number for a particular `base`", base num), + (log2, "base 2 logarithm of a number", num), + (log10, "base 10 logarithm of a number", num), + (nanvl, "returns x if x is not NaN otherwise returns y", x y), + (pi, "Returns an approximate value of π",), + (power, "`base` raised to the power of `exponent`", base exponent), + (radians, "converts degrees to radians", num), + (random, "Returns a random value in the range 0.0 <= x < 1.0",), + (signum, "sign of the argument (-1, 0, +1)", num), + (sin, "sine", num), + (sinh, "hyperbolic sine", num), + (sqrt, "square root of a number", num), + (tan, "returns the tangent of a number", num), + (tanh, "returns the hyperbolic tangent of a number", num), + (round, "round to nearest integer", args,), + (trunc, "truncate toward zero, with optional precision", args,) + ); } -/// Return a list of all functions in this package pub fn functions() -> Vec> { vec![ abs(), @@ -318,13 +164,13 @@ pub fn functions() -> Vec> { power(), radians(), random(), - round(), signum(), sin(), sinh(), sqrt(), tan(), tanh(), + round(), trunc(), ] } diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index e931c4998115..219ef8b5a50f 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -70,117 +70,98 @@ make_udf_function!(uuid::UuidFunc, UUID, uuid); pub mod expr_fn { use datafusion_expr::Expr; - #[doc = "Returns the numeric code of the first character of the argument."] - pub fn ascii(arg1: Expr) -> Expr { - super::ascii().call(vec![arg1]) - } - - #[doc = "Returns the number of bits in the `string`"] - pub fn bit_length(arg: Expr) -> Expr { - super::bit_length().call(vec![arg]) - } + export_functions!(( + ascii, + "Returns the numeric code of the first character of the argument.", + arg1 + ),( + bit_length, + "Returns the number of bits in the `string`", + arg1 + ),( + btrim, + "Removes all characters, spaces by default, from both sides of a string", + args, + ),( + chr, + "Converts the Unicode code point to a UTF8 character", + arg1 + ),( + concat, + "Concatenates the text representations of all the arguments. NULL arguments are ignored", + args, + ),( + ends_with, + "Returns true if the `string` ends with the `suffix`, false otherwise.", + string suffix + ),( + initcap, + "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase", + string + ),( + levenshtein, + "Returns the Levenshtein distance between the two given strings", + arg1 arg2 + ),( + lower, + "Converts a string to lowercase.", + arg1 + ),( + ltrim, + "Removes all characters, spaces by default, from the beginning of a string", + args, + ),( + octet_length, + "returns the number of bytes of a string", + args + ),( + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring", + args, + ),( + repeat, + "Repeats the `string` to `n` times", + string n + ),( + replace, + "Replaces all occurrences of `from` with `to` in the `string`", + string from to + ),( + rtrim, + "Removes all characters, spaces by default, from the end of a string", + args, + ),( + split_part, + "Splits a string based on a delimiter and picks out the desired field based on the index.", + string delimiter index + ),( + starts_with, + "Returns true if string starts with prefix.", + arg1 arg2 + ),( + to_hex, + "Converts an integer to a hexadecimal string.", + arg1 + ),( + upper, + "Converts a string to uppercase.", + arg1 + ),( + uuid, + "returns uuid v4 as a string value", + )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] - pub fn btrim(args: Vec) -> Expr { + pub fn trim(args: Vec) -> Expr { super::btrim().call(args) } - #[doc = "Converts the Unicode code point to a UTF8 character"] - pub fn chr(arg: Expr) -> Expr { - super::chr().call(vec![arg]) - } - - #[doc = "Concatenates the text representations of all the arguments. NULL arguments are ignored"] - pub fn concat(args: Vec) -> Expr { - super::concat().call(args) - } - #[doc = "Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored."] pub fn concat_ws(delimiter: Expr, args: Vec) -> Expr { let mut args = args; args.insert(0, delimiter); super::concat_ws().call(args) } - - #[doc = "Returns true if the `string` ends with the `suffix`, false otherwise."] - pub fn ends_with(string: Expr, suffix: Expr) -> Expr { - super::ends_with().call(vec![string, suffix]) - } - - #[doc = "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"] - pub fn initcap(string: Expr) -> Expr { - super::initcap().call(vec![string]) - } - - #[doc = "Returns the Levenshtein distance between the two given strings"] - pub fn levenshtein(arg1: Expr, arg2: Expr) -> Expr { - super::levenshtein().call(vec![arg1, arg2]) - } - - #[doc = "Converts a string to lowercase."] - pub fn lower(arg1: Expr) -> Expr { - super::lower().call(vec![arg1]) - } - - #[doc = "Removes all characters, spaces by default, from the beginning of a string"] - pub fn ltrim(args: Vec) -> Expr { - super::ltrim().call(args) - } - - #[doc = "returns the number of bytes of a string"] - pub fn octet_length(args: Expr) -> Expr { - super::octet_length().call(vec![args]) - } - - #[doc = "replace the substring of string that starts at the start'th character and extends for count characters with new substring"] - pub fn overlay(args: Vec) -> Expr { - super::overlay().call(args) - } - - #[doc = "Repeats the `string` to `n` times"] - pub fn repeat(string: Expr, n: Expr) -> Expr { - super::repeat().call(vec![string, n]) - } - - #[doc = "Replaces all occurrences of `from` with `to` in the `string`"] - pub fn replace(string: Expr, from: Expr, to: Expr) -> Expr { - super::replace().call(vec![string, from, to]) - } - - #[doc = "Removes all characters, spaces by default, from the end of a string"] - pub fn rtrim(args: Vec) -> Expr { - super::rtrim().call(args) - } - - #[doc = "Splits a string based on a delimiter and picks out the desired field based on the index."] - pub fn split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr { - super::split_part().call(vec![string, delimiter, index]) - } - - #[doc = "Returns true if string starts with prefix."] - pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr { - super::starts_with().call(vec![arg1, arg2]) - } - - #[doc = "Converts an integer to a hexadecimal string."] - pub fn to_hex(arg1: Expr) -> Expr { - super::to_hex().call(vec![arg1]) - } - - #[doc = "Removes all characters, spaces by default, from both sides of a string"] - pub fn trim(args: Vec) -> Expr { - super::btrim().call(args) - } - - #[doc = "Converts a string to uppercase."] - pub fn upper(arg1: Expr) -> Expr { - super::upper().call(vec![arg1]) - } - - #[doc = "returns uuid v4 as a string value"] - pub fn uuid() -> Expr { - super::uuid().call(vec![]) - } } /// Return a list of all functions in this package diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 5a8e953bc161..9e8c07cd36ed 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -47,27 +47,68 @@ make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); make_udf_function!(strpos::StrposFunc, STRPOS, strpos); make_udf_function!(substr::SubstrFunc, SUBSTR, substr); +make_udf_function!(substr::SubstrFunc, SUBSTRING, substring); make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); pub mod expr_fn { use datafusion_expr::Expr; + export_functions!(( + character_length, + "the number of characters in the `string`", + string + ),( + lpad, + "fill up a string to the length by prepending the characters", + args, + ),( + rpad, + "fill up a string to the length by appending the characters", + args, + ),( + reverse, + "reverses the `string`", + string + ),( + substr, + "substring from the `position` to the end", + string position + ),( + substr_index, + "Returns the substring from str before count occurrences of the delimiter", + string delimiter count + ),( + strpos, + "finds the position from where the `substring` matches the `string`", + string substring + ),( + substring, + "substring from the `position` with `length` characters", + string position length + ),( + translate, + "replaces the characters in `from` with the counterpart in `to`", + string from to + ),( + right, + "returns the last `n` characters in the `string`", + string n + ),( + left, + "returns the first `n` characters in the `string`", + string n + ),( + find_in_set, + "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings", + string strlist + )); + #[doc = "the number of characters in the `string`"] pub fn char_length(string: Expr) -> Expr { character_length(string) } - #[doc = "the number of characters in the `string`"] - pub fn character_length(string: Expr) -> Expr { - super::character_length().call(vec![string]) - } - - #[doc = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"] - pub fn find_in_set(string: Expr, strlist: Expr) -> Expr { - super::find_in_set().call(vec![string, strlist]) - } - #[doc = "finds the position from where the `substring` matches the `string`"] pub fn instr(string: Expr, substring: Expr) -> Expr { strpos(string, substring) @@ -78,60 +119,10 @@ pub mod expr_fn { character_length(string) } - #[doc = "returns the first `n` characters in the `string`"] - pub fn left(string: Expr, n: Expr) -> Expr { - super::left().call(vec![string, n]) - } - - #[doc = "fill up a string to the length by prepending the characters"] - pub fn lpad(args: Vec) -> Expr { - super::lpad().call(args) - } - #[doc = "finds the position from where the `substring` matches the `string`"] pub fn position(string: Expr, substring: Expr) -> Expr { strpos(string, substring) } - - #[doc = "reverses the `string`"] - pub fn reverse(string: Expr) -> Expr { - super::reverse().call(vec![string]) - } - - #[doc = "returns the last `n` characters in the `string`"] - pub fn right(string: Expr, n: Expr) -> Expr { - super::right().call(vec![string, n]) - } - - #[doc = "fill up a string to the length by appending the characters"] - pub fn rpad(args: Vec) -> Expr { - super::rpad().call(args) - } - - #[doc = "finds the position from where the `substring` matches the `string`"] - pub fn strpos(string: Expr, substring: Expr) -> Expr { - super::strpos().call(vec![string, substring]) - } - - #[doc = "substring from the `position` to the end"] - pub fn substr(string: Expr, position: Expr) -> Expr { - super::substr().call(vec![string, position]) - } - - #[doc = "substring from the `position` with `length` characters"] - pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { - super::substr().call(vec![string, position, length]) - } - - #[doc = "Returns the substring from str before count occurrences of the delimiter"] - pub fn substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr { - super::substr_index().call(vec![string, delimiter, count]) - } - - #[doc = "replaces the characters in `from` with the counterpart in `to`"] - pub fn translate(string: Expr, from: Expr, to: Expr) -> Expr { - super::translate().call(vec![string, from, to]) - } } /// Return a list of all functions in this package From ccf395f40c9cf4e59c3c8fbd828a164bfebc3024 Mon Sep 17 00:00:00 2001 From: hsiang-c <137842490+hsiang-c@users.noreply.github.com> Date: Tue, 4 Jun 2024 02:46:45 +0800 Subject: [PATCH 30/35] (Doc) Enable rt-multi-thread feature for sample code (#10770) --- docs/source/user-guide/example-usage.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index ae45c98d7483..71a614313e8a 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -30,7 +30,7 @@ crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml datafusion = "latest_version" -tokio = "1.0" +tokio = { version = "1.0", features = ["rt-multi-thread"] } ``` ## Add latest non published DataFusion dependency From 180f3e8af12d6da8e814a0ee0e5718b48d7d8aee Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Mon, 3 Jun 2024 11:50:36 -0700 Subject: [PATCH 31/35] Support negatives in split part (#10780) * impv: support negative indexes for split_part * tests: update unittests in func * tests: add out of bounds negative test * style: fix clippy --- datafusion/functions/src/string/split_part.rs | 37 +++++++++++++++---- datafusion/sqllogictest/test_files/expr.slt | 13 +++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 517fa93e5284..d6f7bb4a4d4a 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -97,14 +97,21 @@ fn split_part(args: &[ArrayRef]) -> Result { .zip(n_array.iter()) .map(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { - if n <= 0 { - exec_err!("field position must be greater than zero") - } else { - let split_string: Vec<&str> = string.split(delimiter).collect(); - match split_string.get(n as usize - 1) { - Some(s) => Ok(Some(*s)), - None => Ok(Some("")), + let split_string: Vec<&str> = string.split(delimiter).collect(); + let len = split_string.len(); + + let index = match n.cmp(&0) { + std::cmp::Ordering::Less => len as i64 + n, + std::cmp::Ordering::Equal => { + return exec_err!("field position must not be zero"); } + std::cmp::Ordering::Greater => n - 1, + } as usize; + + if index < len { + Ok(Some(split_string[index])) + } else { + Ok(Some("")) } } _ => Ok(None), @@ -165,7 +172,21 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), ], - exec_err!("field position must be greater than zero"), + Ok(Some("ghi")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + exec_err!("field position must not be zero"), &str, Utf8, StringArray diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index b6477f0b57d0..cb2bb9fad1b7 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -626,6 +626,19 @@ SELECT split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT)) ---- NULL +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', -1) +---- +ghi + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', -100) +---- +(empty) + +statement error DataFusion error: Execution error: field position must not be zero +SELECT split_part('abc~@~def~@~ghi', '~@~', 0) + query B SELECT starts_with('alphabet', 'alph') ---- From e4f7b9811f245f0ccf8d0289f7d5edfe1499947a Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Mon, 3 Jun 2024 15:00:32 -0400 Subject: [PATCH 32/35] feat: support unparsing LogicalPlan::Window nodes (#10767) * unparse window plans * new tests + fixes * fmt --- datafusion/sql/src/unparser/expr.rs | 32 ++++++---- datafusion/sql/src/unparser/plan.rs | 71 +++++++++++++++-------- datafusion/sql/src/unparser/utils.rs | 68 ++++++++++++++++++---- datafusion/sql/tests/cases/plan_to_sql.rs | 8 ++- 4 files changed, 132 insertions(+), 47 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index df390ce6eaf8..1ba6638e73d7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -236,8 +236,8 @@ impl Unparser<'_> { .map(|expr| expr_to_unparsed(expr)?.into_order_by_expr()) .collect::>>()?; - let start_bound = self.convert_bound(&window_frame.start_bound); - let end_bound = self.convert_bound(&window_frame.end_bound); + let start_bound = self.convert_bound(&window_frame.start_bound)?; + let end_bound = self.convert_bound(&window_frame.end_bound)?; let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec { window_name: None, partition_by: partition_by @@ -513,20 +513,30 @@ impl Unparser<'_> { fn convert_bound( &self, bound: &datafusion_expr::window_frame::WindowFrameBound, - ) -> ast::WindowFrameBound { + ) -> Result { match bound { datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => { - ast::WindowFrameBound::Preceding( - self.scalar_to_sql(val).map(Box::new).ok(), - ) + Ok(ast::WindowFrameBound::Preceding({ + let val = self.scalar_to_sql(val)?; + if let ast::Expr::Value(ast::Value::Null) = &val { + None + } else { + Some(Box::new(val)) + } + })) } datafusion_expr::window_frame::WindowFrameBound::Following(val) => { - ast::WindowFrameBound::Following( - self.scalar_to_sql(val).map(Box::new).ok(), - ) + Ok(ast::WindowFrameBound::Following({ + let val = self.scalar_to_sql(val)?; + if let ast::Expr::Value(ast::Value::Null) = &val { + None + } else { + Some(Box::new(val)) + } + })) } datafusion_expr::window_frame::WindowFrameBound::CurrentRow => { - ast::WindowFrameBound::CurrentRow + Ok(ast::WindowFrameBound::CurrentRow) } } } @@ -1148,7 +1158,7 @@ mod tests { window_frame: WindowFrame::new(None), null_treatment: None, }), - r#"ROW_NUMBER(col) OVER (ROWS BETWEEN NULL PRECEDING AND NULL FOLLOWING)"#, + r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( Expr::WindowFunction(WindowFunction { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e7e4d7700ac0..183bb1f7fb49 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -28,7 +28,7 @@ use super::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, - utils::find_agg_node_within_select, + utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, Unparser, }; @@ -162,23 +162,42 @@ impl Unparser<'_> { // A second projection implies a derived tablefactor if !select.already_projected() { // Special handling when projecting an agregation plan - if let Some(agg) = find_agg_node_within_select(plan, true) { - let items = p - .expr - .iter() - .map(|proj_expr| { - let unproj = unproject_agg_exprs(proj_expr, agg)?; - self.select_item_to_sql(&unproj) - }) - .collect::>>()?; - - select.projection(items); - select.group_by(ast::GroupByExpr::Expressions( - agg.group_expr - .iter() - .map(|expr| self.expr_to_sql(expr)) - .collect::>>()?, - )); + if let Some(aggvariant) = + find_agg_node_within_select(plan, None, true) + { + match aggvariant { + AggVariant::Aggregate(agg) => { + let items = p + .expr + .iter() + .map(|proj_expr| { + let unproj = unproject_agg_exprs(proj_expr, agg)?; + self.select_item_to_sql(&unproj) + }) + .collect::>>()?; + + select.projection(items); + select.group_by(ast::GroupByExpr::Expressions( + agg.group_expr + .iter() + .map(|expr| self.expr_to_sql(expr)) + .collect::>>()?, + )); + } + AggVariant::Window(window) => { + let items = p + .expr + .iter() + .map(|proj_expr| { + let unproj = + unproject_window_exprs(proj_expr, &window)?; + self.select_item_to_sql(&unproj) + }) + .collect::>>()?; + + select.projection(items); + } + } } else { let items = p .expr @@ -210,8 +229,8 @@ impl Unparser<'_> { } } LogicalPlan::Filter(filter) => { - if let Some(agg) = - find_agg_node_within_select(plan, select.already_projected()) + if let Some(AggVariant::Aggregate(agg)) = + find_agg_node_within_select(plan, None, select.already_projected()) { let unprojected = unproject_agg_exprs(&filter.predicate, agg)?; let filter_expr = self.expr_to_sql(&unprojected)?; @@ -265,7 +284,7 @@ impl Unparser<'_> { ) } LogicalPlan::Aggregate(agg) => { - // Aggregate nodes are handled simulatenously with Projection nodes + // Aggregate nodes are handled simultaneously with Projection nodes self.select_to_sql_recursively( agg.input.as_ref(), query, @@ -441,8 +460,14 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Window(_window) => { - not_impl_err!("Unsupported operator: {plan:?}") + LogicalPlan::Window(window) => { + // Window nodes are handled simultaneously with Projection nodes + self.select_to_sql_recursively( + window.input.as_ref(), + query, + select, + relation, + ) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), _ => not_impl_err!("Unsupported operator: {plan:?}"), diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index c1b02c330fae..326cd15ba140 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -20,16 +20,24 @@ use datafusion_common::{ tree_node::{Transformed, TreeNode}, Result, }; -use datafusion_expr::{Aggregate, Expr, LogicalPlan}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; -/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists +/// One of the possible aggregation plans which can be found within a single select query. +pub(crate) enum AggVariant<'a> { + Aggregate(&'a Aggregate), + Window(Vec<&'a Window>), +} + +/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). -/// If an Aggregate node is not found prior to this or at all before reaching the end -/// of the tree, None is returned. -pub(crate) fn find_agg_node_within_select( - plan: &LogicalPlan, +/// If an Aggregate or window node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. It is assumed that a Window and Aggegate node cannot both +/// be found in a single select query. +pub(crate) fn find_agg_node_within_select<'a>( + plan: &'a LogicalPlan, + mut prev_windows: Option>, already_projected: bool, -) -> Option<&Aggregate> { +) -> Option> { // Note that none of the nodes that have a corresponding agg node can have more // than 1 input node. E.g. Projection / Filter always have 1 input node. let input = plan.inputs(); @@ -38,18 +46,29 @@ pub(crate) fn find_agg_node_within_select( } else { input.first()? }; + // Agg nodes explicitly return immediately with a single node + // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection if let LogicalPlan::Aggregate(agg) = input { - Some(agg) + Some(AggVariant::Aggregate(agg)) + } else if let LogicalPlan::Window(window) = input { + prev_windows = match &mut prev_windows { + Some(AggVariant::Window(windows)) => { + windows.push(window); + prev_windows + } + _ => Some(AggVariant::Window(vec![window])), + }; + find_agg_node_within_select(input, prev_windows, already_projected) } else if let LogicalPlan::TableScan(_) = input { - None + prev_windows } else if let LogicalPlan::Projection(_) = input { if already_projected { - None + prev_windows } else { - find_agg_node_within_select(input, true) + find_agg_node_within_select(input, prev_windows, true) } } else { - find_agg_node_within_select(input, already_projected) + find_agg_node_within_select(input, prev_windows, already_projected) } } @@ -82,3 +101,28 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result }) .map(|e| e.data) } + +/// Recursively identify all Column expressions and transform them into the appropriate +/// window expression contained in window. +/// +/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed +/// into an actual window expression as identified in the window node. +pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result { + expr.clone() + .transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unproj) = windows + .iter() + .flat_map(|w| w.window_expr.iter()) + .find(|window_expr| window_expr.display_name().unwrap() == c.name) + { + Ok(Transformed::yes(unproj.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } + } else { + Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data) +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 1bf441351a97..4a430bdc8003 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -127,7 +127,13 @@ fn roundtrip_statement() -> Result<()> { UNION ALL SELECT j2_string as string FROM j2 ORDER BY string DESC - LIMIT 10"# + LIMIT 10"#, + "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + first_name from person", + r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, + "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", ]; // For each test sql string, we transform as follows: From d3fa083acfc558b9fff5c0bb539d6dc972bfbf7f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 4 Jun 2024 03:55:06 +0800 Subject: [PATCH 33/35] refactor: handle LargeUtf8 statistics and add tests for UTF8 and LargeUTF8 (#10762) Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/statistics.rs | 16 +++++---- .../core/tests/parquet/arrow_statistics.rs | 34 ++++++++++++++++++- datafusion/core/tests/parquet/mod.rs | 28 ++++++++++++++- 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 1c20fa7caa14..e7e6360c2500 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -152,16 +152,18 @@ macro_rules! get_statistic { Some(DataType::Binary) => { Some(ScalarValue::Binary(Some(s.$bytes_func().to_vec()))) } - _ => { - let s = std::str::from_utf8(s.$bytes_func()) + Some(DataType::LargeUtf8) | _ => { + let utf8_value = std::str::from_utf8(s.$bytes_func()) .map(|s| s.to_string()) .ok(); - if s.is_none() { - log::debug!( - "Utf8 statistics is a non-UTF8 value, ignoring it." - ); + if utf8_value.is_none() { + log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); + } + + match $target_arrow_type { + Some(DataType::LargeUtf8) => Some(ScalarValue::LargeUtf8(utf8_value)), + _ => Some(ScalarValue::Utf8(utf8_value)), } - Some(ScalarValue::Utf8(s)) } } } diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 5e0f8b4f5f18..aa5fc7c34c48 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -29,7 +29,7 @@ use arrow::datatypes::{ use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, RecordBatch, StringArray, + Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; @@ -1447,6 +1447,38 @@ async fn test_struct() { } .run(); } + +// UTF8 +#[tokio::test] +async fn test_utf8() { + let reader = TestReader { + scenario: Scenario::UTF8, + row_per_group: 5, + }; + + // test for utf8 + Test { + reader: reader.build().await, + expected_min: Arc::new(StringArray::from(vec!["a", "e"])), + expected_max: Arc::new(StringArray::from(vec!["d", "i"])), + expected_null_counts: UInt64Array::from(vec![1, 0]), + expected_row_counts: UInt64Array::from(vec![5, 5]), + column_name: "utf8", + } + .run(); + + // test for large_utf8 + Test { + reader: reader.build().await, + expected_min: Arc::new(LargeStringArray::from(vec!["a", "e"])), + expected_max: Arc::new(LargeStringArray::from(vec!["d", "i"])), + expected_null_counts: UInt64Array::from(vec![1, 0]), + expected_row_counts: UInt64Array::from(vec![5, 5]), + column_name: "large_utf8", + } + .run(); +} + ////// Files with missing statistics /////// #[tokio::test] diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index f45ff53d3fb8..bfb6e8e555c9 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -29,7 +29,10 @@ use arrow::{ util::pretty::pretty_format_batches, }; use arrow_array::types::{Int32Type, Int8Type}; -use arrow_array::{make_array, BooleanArray, DictionaryArray, Float32Array, StructArray}; +use arrow_array::{ + make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray, + StructArray, +}; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, @@ -90,6 +93,7 @@ enum Scenario { WithNullValues, WithNullValuesPageLevel, StructArray, + UTF8, } enum Unit { @@ -787,6 +791,16 @@ fn make_numeric_limit_batch() -> RecordBatch { .unwrap() } +fn make_utf8_batch(value: Vec>) -> RecordBatch { + let utf8 = StringArray::from(value.clone()); + let large_utf8 = LargeStringArray::from(value); + RecordBatch::try_from_iter(vec![ + ("utf8", Arc::new(utf8) as _), + ("large_utf8", Arc::new(large_utf8) as _), + ]) + .unwrap() +} + fn make_dict_batch() -> RecordBatch { let values = [ Some("abc"), @@ -1044,6 +1058,18 @@ fn create_data_batch(scenario: Scenario) -> Vec { )])); vec![RecordBatch::try_new(schema, vec![struct_array_data]).unwrap()] } + Scenario::UTF8 => { + vec![ + make_utf8_batch(vec![Some("a"), Some("b"), Some("c"), Some("d"), None]), + make_utf8_batch(vec![ + Some("e"), + Some("f"), + Some("g"), + Some("h"), + Some("i"), + ]), + ] + } } } From 826331e01c95ba843ec1bcd7f077c07a3fa05056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 4 Jun 2024 04:11:27 +0800 Subject: [PATCH 34/35] Cleanup GetIndexedField (#10769) * Cleanup GetIndexedField * Generate pb --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/expr.rs | 22 ---- datafusion/expr/src/lib.rs | 4 +- datafusion/proto/proto/datafusion.proto | 9 -- datafusion/proto/src/generated/pbjson.rs | 142 ----------------------- datafusion/proto/src/generated/prost.rs | 21 ---- 5 files changed, 2 insertions(+), 196 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 14c64ef8f89d..1abd8c97ee10 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -541,28 +541,6 @@ pub enum GetFieldAccess { }, } -/// Returns the field of a [`ListArray`] or -/// [`StructArray`] by `key`. -/// -/// See [`GetFieldAccess`] for details. -/// -/// [`ListArray`]: arrow::array::ListArray -/// [`StructArray`]: arrow::array::StructArray -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct GetIndexedField { - /// The expression to take the field from - pub expr: Box, - /// The name of the field to take - pub field: GetFieldAccess, -} - -impl GetIndexedField { - /// Create a new GetIndexedField expression - pub fn new(expr: Box, field: GetFieldAccess) -> Self { - Self { expr, field } - } -} - /// Cast expression #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Cast { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bbd1d6f654f1..8c9893b8a748 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -63,8 +63,8 @@ pub use aggregate_function::AggregateFunction; pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ - Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, Sort as SortExpr, TryCast, WindowFunctionDefinition, + Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GroupingSet, Like, + Sort as SortExpr, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0408ea91b9fa..fa95194696dd 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -407,15 +407,6 @@ message ListRange { LogicalExprNode stride = 3; } -message GetIndexedField { - LogicalExprNode expr = 1; - oneof field { - NamedStructField named_struct_field = 2; - ListIndex list_index = 3; - ListRange list_range = 4; - } -} - message IsNull { LogicalExprNode expr = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e07fbba27d3c..b0e77eb69eff 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6159,148 +6159,6 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GetIndexedField { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if self.field.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.field.as_ref() { - match v { - get_indexed_field::Field::NamedStructField(v) => { - struct_ser.serialize_field("namedStructField", v)?; - } - get_indexed_field::Field::ListIndex(v) => { - struct_ser.serialize_field("listIndex", v)?; - } - get_indexed_field::Field::ListRange(v) => { - struct_ser.serialize_field("listRange", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for GetIndexedField { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - "named_struct_field", - "namedStructField", - "list_index", - "listIndex", - "list_range", - "listRange", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - NamedStructField, - ListIndex, - ListRange, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), - "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), - "listRange" | "list_range" => Ok(GeneratedField::ListRange), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GetIndexedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GetIndexedField") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - let mut field__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::NamedStructField => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("namedStructField")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::NamedStructField) -; - } - GeneratedField::ListIndex => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listIndex")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) -; - } - GeneratedField::ListRange => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listRange")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) -; - } - } - } - Ok(GetIndexedField { - expr: expr__, - field: field__, - }) - } - } - deserializer.deserialize_struct("datafusion.GetIndexedField", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for GlobalLimitExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c75cb3615832..6d8a0c305761 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -651,27 +651,6 @@ pub struct ListRange { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct GetIndexedField { - #[prost(message, optional, tag = "1")] - pub expr: ::core::option::Option, - #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4")] - pub field: ::core::option::Option, -} -/// Nested message and enum types in `GetIndexedField`. -pub mod get_indexed_field { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Field { - #[prost(message, tag = "2")] - NamedStructField(super::NamedStructField), - #[prost(message, tag = "3")] - ListIndex(super::ListIndex), - #[prost(message, tag = "4")] - ListRange(super::ListRange), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] pub struct IsNull { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, From fe536495de8116d41f62c550fd000df9c3d98aab Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:27:49 -0700 Subject: [PATCH 35/35] Extract parquet statistics from f16 columns, add `ScalarValue::Float16` (#10763) * Extract parquet statistics from f16 columns * Update datafusion/common/src/scalar/mod.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/scalar/mod.rs | 39 +++++++++++-- .../physical_plan/parquet/statistics.rs | 14 ++++- .../core/tests/parquet/arrow_statistics.rs | 43 +++++++++++++-- datafusion/core/tests/parquet/mod.rs | 55 +++++++++++++++---- datafusion/proto-common/src/to_proto/mod.rs | 5 ++ datafusion/sql/src/unparser/expr.rs | 4 ++ 6 files changed, 136 insertions(+), 24 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d2c6513eef95..ba006247cd70 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -18,13 +18,13 @@ //! [`ScalarValue`]: stores single values mod struct_builder; - use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::{HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; use std::hash::Hash; +use std::hash::Hasher; use std::iter::repeat; use std::str::FromStr; use std::sync::Arc; @@ -55,6 +55,7 @@ use arrow::{ use arrow_buffer::Buffer; use arrow_schema::{UnionFields, UnionMode}; +use half::f16; pub use struct_builder::ScalarStructBuilder; /// A dynamically typed, nullable single value. @@ -192,6 +193,8 @@ pub enum ScalarValue { Null, /// true or false value Boolean(Option), + /// 16bit float + Float16(Option), /// 32bit float Float32(Option), /// 64bit float @@ -285,6 +288,12 @@ pub enum ScalarValue { Dictionary(Box, Box), } +impl Hash for Fl { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + // manual implementation of `PartialEq` impl PartialEq for ScalarValue { fn eq(&self, other: &Self) -> bool { @@ -307,7 +316,12 @@ impl PartialEq for ScalarValue { (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), _ => v1.eq(v2), }, + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), + _ => v1.eq(v2), + }, (Float32(_), _) => false, + (Float16(_), _) => false, (Float64(v1), Float64(v2)) => match (v1, v2) { (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), _ => v1.eq(v2), @@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue { (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), _ => v1.partial_cmp(v2), }, + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), + _ => v1.partial_cmp(v2), + }, (Float32(_), _) => None, + (Float16(_), _) => None, (Float64(v1), Float64(v2)) => match (v1, v2) { (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), _ => v1.partial_cmp(v2), @@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue { s.hash(state) } Boolean(v) => v.hash(state), + Float16(v) => v.map(Fl).hash(state), Float32(v) => v.map(Fl).hash(state), Float64(v) => v.map(Fl).hash(state), Int8(v) => v.hash(state), @@ -1082,6 +1102,7 @@ impl ScalarValue { ScalarValue::TimestampNanosecond(_, tz_opt) => { DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) } + ScalarValue::Float16(_) => DataType::Float16, ScalarValue::Float32(_) => DataType::Float32, ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, @@ -1276,6 +1297,7 @@ impl ScalarValue { match self { ScalarValue::Boolean(v) => v.is_none(), ScalarValue::Null => true, + ScalarValue::Float16(v) => v.is_none(), ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), @@ -1522,6 +1544,7 @@ impl ScalarValue { } DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Float16 => build_array_primitive!(Float16Array, Float16), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), DataType::Int8 => build_array_primitive!(Int8Array, Int8), @@ -1682,8 +1705,7 @@ impl ScalarValue { // not supported if the TimeUnit is not valid (Time32 can // only be used with Second and Millisecond, Time64 only // with Microsecond and Nanosecond) - DataType::Float16 - | DataType::Time32(TimeUnit::Microsecond) + DataType::Time32(TimeUnit::Microsecond) | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) @@ -1700,7 +1722,6 @@ impl ScalarValue { ); } }; - Ok(array) } @@ -1921,6 +1942,9 @@ impl ScalarValue { ScalarValue::Float32(e) => { build_array_from_option!(Float32, Float32Array, e, size) } + ScalarValue::Float16(e) => { + build_array_from_option!(Float16, Float16Array, e, size) + } ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), @@ -2595,6 +2619,9 @@ impl ScalarValue { ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val)? } + ScalarValue::Float16(val) => { + eq_array_primitive!(array, index, Float16Array, val)? + } ScalarValue::Float32(val) => { eq_array_primitive!(array, index, Float32Array, val)? } @@ -2738,6 +2765,7 @@ impl ScalarValue { + match self { ScalarValue::Null | ScalarValue::Boolean(_) + | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) | ScalarValue::Decimal128(_, _, _) @@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue { fn try_from(data_type: &DataType) -> Result { Ok(match data_type { DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float16 => ScalarValue::Float16(None), DataType::Float64 => ScalarValue::Float64(None), DataType::Float32 => ScalarValue::Float32(None), DataType::Int8 => ScalarValue::Int8(None), @@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue { write!(f, "{v:?},{p:?},{s:?}")?; } ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float16(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, ScalarValue::Int8(e) => format_option!(f, e)?, @@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), + ScalarValue::Float16(_) => write!(f, "Float16({self})"), ScalarValue::Float32(_) => write!(f, "Float32({self})"), ScalarValue::Float64(_) => write!(f, "Float64({self})"), ScalarValue::Int8(_) => write!(f, "Int8({self})"), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index e7e6360c2500..6c738cfe03a9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -25,11 +25,11 @@ use arrow_schema::{Field, FieldRef, Schema}; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, }; +use half::f16; use parquet::file::metadata::ParquetMetaData; use parquet::file::statistics::Statistics as ParquetStatistics; use parquet::schema::types::SchemaDescriptor; use std::sync::Arc; - // Convert the bytes array to i128. // The endian of the input bytes array must be big-endian. pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { @@ -39,6 +39,14 @@ pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { i128::from_be_bytes(sign_extend_be(b)) } +// Convert the bytes array to f16 +pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option { + match b { + [low, high] => Some(f16::from_be_bytes([*high, *low])), + _ => None, + } +} + // Copy from arrow-rs // https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 // Convert the byte slice to fixed length byte array with the length of 16 @@ -196,6 +204,9 @@ macro_rules! get_statistic { value, )) } + Some(DataType::Float16) => { + Some(ScalarValue::Float16(from_bytes_to_f16(s.$bytes_func()))) + } _ => None, } } @@ -344,7 +355,6 @@ impl<'a> StatisticsConverter<'a> { column_name ); }; - Ok(Self { column_name, statistics_type, diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index aa5fc7c34c48..c2bf75c8f089 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -21,6 +21,7 @@ use std::fs::File; use std::sync::Arc; +use crate::parquet::{struct_array, Scenario}; use arrow::compute::kernels::cast_utils::Parser; use arrow::datatypes::{ Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType, @@ -28,21 +29,21 @@ use arrow::datatypes::{ }; use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, + StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; use arrow_schema::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::parquet::{ RequestedStatistics, StatisticsConverter, }; +use half::f16; use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder}; use parquet::arrow::ArrowWriter; use parquet::file::properties::{EnabledStatistics, WriterProperties}; -use crate::parquet::{struct_array, Scenario}; - use super::make_test_file_rg; // TEST HELPERS @@ -1203,6 +1204,36 @@ async fn test_float64() { .run(); } +#[tokio::test] +async fn test_float16() { + // This creates a parquet file of 1 column "f" + // file has 4 record batches, each has 5 rows. They will be saved into 4 row groups + let reader = TestReader { + scenario: Scenario::Float16, + row_per_group: 5, + }; + + Test { + reader: reader.build().await, + expected_min: Arc::new(Float16Array::from( + vec![-5.0, -4.0, -0.0, 5.0] + .into_iter() + .map(f16::from_f32) + .collect::>(), + )), + expected_max: Arc::new(Float16Array::from( + vec![-1.0, 0.0, 4.0, 9.0] + .into_iter() + .map(f16::from_f32) + .collect::>(), + )), + expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "f", + } + .run(); +} + #[tokio::test] async fn test_decimal() { // This creates a parquet file of 1 column "decimal_col" with decimal data type and precicion 9, scale 2 diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index bfb6e8e555c9..e951644f2cbf 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,20 +19,17 @@ use arrow::array::Decimal128Array; use arrow::{ array::{ - Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, + DictionaryArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, + StructArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Int32Type, Int8Type, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::types::{Int32Type, Int8Type}; -use arrow_array::{ - make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray, - StructArray, -}; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, @@ -40,11 +37,11 @@ use datafusion::{ prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use half::f16; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; use std::sync::Arc; use tempfile::NamedTempFile; - mod arrow_statistics; mod custom_reader; mod file_statistics; @@ -79,6 +76,7 @@ enum Scenario { /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64 /// -MIN, -100, -1, 0, 1, 100, MAX NumericLimits, + Float16, Float64, Decimal, DecimalBloomFilterInt32, @@ -542,6 +540,12 @@ fn make_f64_batch(v: Vec) -> RecordBatch { RecordBatch::try_new(schema, vec![array.clone()]).unwrap() } +fn make_f16_batch(v: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16, true)])); + let array = Arc::new(Float16Array::from(v)) as ArrayRef; + RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +} + /// Return record batch with decimal vector /// /// Columns are named @@ -897,6 +901,34 @@ fn create_data_batch(scenario: Scenario) -> Vec { Scenario::NumericLimits => { vec![make_numeric_limit_batch()] } + Scenario::Float16 => { + vec![ + make_f16_batch( + vec![-5.0, -4.0, -3.0, -2.0, -1.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + make_f16_batch( + vec![-4.0, -3.0, -2.0, -1.0, 0.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + make_f16_batch( + vec![0.0, 1.0, 2.0, 3.0, 4.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + make_f16_batch( + vec![5.0, 6.0, 7.0, 8.0, 9.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + ] + } Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), @@ -1087,7 +1119,6 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem .build(); let batches = create_data_batch(scenario); - let schema = batches[0].schema(); let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index f160bc40af39..a92deaa88b1c 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -294,6 +294,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ScalarValue::Boolean(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) } + ScalarValue::Float16(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Float32Value((*s).into()) + }) + } ScalarValue::Float32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float32Value(*s)) } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 1ba6638e73d7..3efbe2ace680 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -643,6 +643,10 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned()))) } ScalarValue::Boolean(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Float16(Some(f)) => { + Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) + } + ScalarValue::Float16(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Float32(Some(f)) => { Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) }