diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index f2982522a7cd..3b5a68d24628 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -45,6 +45,7 @@ async fn main() -> Result<()> { Some(vec![12, 0]), true, b',', + b'"', object_store, ); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index e3b290cc54d7..7aec9698d92f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -145,7 +145,7 @@ async fn main() -> Result<()> { // the name; used to represent it in plan descriptions and in the registry, to use in SQL. "geo_mean", // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. - DataType::Float64, + vec![DataType::Float64], // the return type; DataFusion expects this to match the type returned by `evaluate`. Arc::new(DataType::Float64), Volatility::Immutable, diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 04ae32ec35aa..4356f36b18d8 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -34,6 +34,7 @@ use arrow::{ }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; +use arrow_array::Decimal256Array; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -65,6 +66,11 @@ pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) } +// Downcast ArrayRef to Decimal256Array +pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> { + Ok(downcast_value!(array, Decimal256Array)) +} + // Downcast ArrayRef to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> { Ok(downcast_value!(array, Float32Array)) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 99ff5f3384d4..4a7767023fed 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -26,14 +26,14 @@ use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; use crate::cast::{ - as_decimal128_array, as_dictionary_array, as_fixed_size_binary_array, - as_fixed_size_list_array, as_list_array, as_struct_array, + as_decimal128_array, as_decimal256_array, as_dictionary_array, + as_fixed_size_binary_array, as_fixed_size_list_array, as_list_array, as_struct_array, }; use crate::delta::shift_months; use crate::error::{DataFusionError, Result}; use arrow::buffer::NullBuffer; use arrow::compute::nullif; -use arrow::datatypes::{FieldRef, Fields, SchemaBuilder}; +use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, @@ -47,6 +47,7 @@ use arrow::{ }, }; use arrow_array::timezone::Tz; +use arrow_array::ArrowNativeTypeOp; use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; // Constants we use throughout this file: @@ -75,6 +76,8 @@ pub enum ScalarValue { Float64(Option), /// 128bit decimal, using the i128 to represent the decimal, precision scale Decimal128(Option, u8, i8), + /// 256bit decimal, using the i256 to represent the decimal, precision scale + Decimal256(Option, u8, i8), /// signed 8bit int Int8(Option), /// signed 16bit int @@ -160,6 +163,10 @@ impl PartialEq for ScalarValue { v1.eq(v2) && p1.eq(p2) && s1.eq(s2) } (Decimal128(_, _, _), _) => false, + (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal256(_, _, _), _) => false, (Boolean(v1), Boolean(v2)) => v1.eq(v2), (Boolean(_), _) => false, (Float32(v1), Float32(v2)) => match (v1, v2) { @@ -283,6 +290,15 @@ impl PartialOrd for ScalarValue { } } (Decimal128(_, _, _), _) => None, + (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal256(_, _, _), _) => None, (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), (Boolean(_), _) => None, (Float32(v1), Float32(v2)) => match (v1, v2) { @@ -1038,6 +1054,7 @@ macro_rules! impl_op_arithmetic { get_sign!($OPERATION), true, )))), + // todo: Add Decimal256 support _ => Err(DataFusionError::Internal(format!( "Operator {} is not implemented for types {:?} and {:?}", stringify!($OPERATION), @@ -1516,6 +1533,11 @@ impl std::hash::Hash for ScalarValue { p.hash(state); s.hash(state) } + Decimal256(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Boolean(v) => v.hash(state), Float32(v) => v.map(Fl).hash(state), Float64(v) => v.map(Fl).hash(state), @@ -1994,6 +2016,9 @@ impl ScalarValue { ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal128(*precision, *scale) } + ScalarValue::Decimal256(_, precision, scale) => { + DataType::Decimal256(*precision, *scale) + } ScalarValue::TimestampSecond(_, tz_opt) => { DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) } @@ -2083,6 +2108,9 @@ impl ScalarValue { ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) } + ScalarValue::Decimal256(Some(v), precision, scale) => Ok( + ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), + ), value => Err(DataFusionError::Internal(format!( "Can not run arithmetic negative on scalar value {value:?}" ))), @@ -2154,6 +2182,7 @@ impl ScalarValue { ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), + ScalarValue::Decimal256(v, _, _) => v.is_none(), ScalarValue::Int8(v) => v.is_none(), ScalarValue::Int16(v) => v.is_none(), ScalarValue::Int32(v) => v.is_none(), @@ -2415,10 +2444,10 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Decimal256(_, _) => { - return Err(DataFusionError::Internal( - "Decimal256 is not supported for ScalarValue".to_string(), - )); + DataType::Decimal256(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) } DataType::Null => ScalarValue::iter_to_null_array(scalars), DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), @@ -2680,6 +2709,22 @@ impl ScalarValue { Ok(array) } + fn iter_to_decimal256_array( + scalars: impl IntoIterator, + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal256(v1, _, _) => v1, + _ => unreachable!(), + }) + .collect::() + .with_precision_and_scale(precision, scale)?; + Ok(array) + } + fn iter_to_array_list( scalars: impl IntoIterator, data_type: &DataType, @@ -2764,12 +2809,28 @@ impl ScalarValue { } } + fn build_decimal256_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Decimal256Array { + std::iter::repeat(value) + .take(size) + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap() + } + /// Converts a scalar value into an array of `size` rows. pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { ScalarValue::Decimal128(e, precision, scale) => Arc::new( ScalarValue::build_decimal_array(*e, *precision, *scale, size), ), + ScalarValue::Decimal256(e, precision, scale) => Arc::new( + ScalarValue::build_decimal256_array(*e, *precision, *scale, size), + ), ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef } @@ -3044,12 +3105,28 @@ impl ScalarValue { precision: u8, scale: i8, ) -> Result { - let array = as_decimal128_array(array)?; - if array.is_null(index) { - Ok(ScalarValue::Decimal128(None, precision, scale)) - } else { - let value = array.value(index); - Ok(ScalarValue::Decimal128(Some(value), precision, scale)) + match array.data_type() { + DataType::Decimal128(_, _) => { + let array = as_decimal128_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal128(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) + } + } + DataType::Decimal256(_, _) => { + let array = as_decimal256_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal256(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal256(Some(value), precision, scale)) + } + } + _ => Err(DataFusionError::Internal( + "Unsupported decimal type".to_string(), + )), } } @@ -3067,6 +3144,11 @@ impl ScalarValue { array, index, *precision, *scale, )? } + DataType::Decimal256(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), @@ -3265,6 +3347,25 @@ impl ScalarValue { } } + fn eq_array_decimal256( + array: &ArrayRef, + index: usize, + value: Option<&i256>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal256_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + /// Compares a single row of array @ index for equality with self, /// in an optimized fashion. /// @@ -3294,6 +3395,16 @@ impl ScalarValue { ) .unwrap() } + ScalarValue::Decimal256(v, precision, scale) => { + ScalarValue::eq_array_decimal256( + array, + index, + v.as_ref(), + *precision, + *scale, + ) + .unwrap() + } ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val) } @@ -3416,6 +3527,7 @@ impl ScalarValue { | ScalarValue::Float32(_) | ScalarValue::Float64(_) | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) | ScalarValue::Int16(_) | ScalarValue::Int32(_) @@ -3647,6 +3759,22 @@ impl TryFrom for i128 { } } +// special implementation for i256 because of Decimal128 +impl TryFrom for i256 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal256(Some(inner_value), _, _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + impl_try_from!(UInt8, u8); impl_try_from!(UInt16, u16); impl_try_from!(UInt32, u32); @@ -3684,6 +3812,9 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(None, *precision, *scale) } + DataType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } DataType::Utf8 => ScalarValue::Utf8(None), DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), DataType::Binary => ScalarValue::Binary(None), @@ -3753,6 +3884,9 @@ impl fmt::Display for ScalarValue { ScalarValue::Decimal128(v, p, s) => { write!(f, "{v:?},{p:?},{s:?}")?; } + ScalarValue::Decimal256(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } ScalarValue::Boolean(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, @@ -3830,6 +3964,7 @@ impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), + ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), ScalarValue::Float32(_) => write!(f, "Float32({self})"), ScalarValue::Float64(_) => write!(f, "Float64({self})"), diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index b284079ec6e4..ee67003d2fa0 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -61,6 +61,8 @@ pub const DEFAULT_CSV_EXTENSION: &str = ".csv"; pub struct CsvFormat { has_header: bool, delimiter: u8, + quote: u8, + escape: Option, schema_infer_max_rec: Option, file_compression_type: FileCompressionType, } @@ -71,6 +73,8 @@ impl Default for CsvFormat { schema_infer_max_rec: Some(DEFAULT_SCHEMA_INFER_MAX_RECORD), has_header: true, delimiter: b',', + quote: b'"', + escape: None, file_compression_type: FileCompressionType::UNCOMPRESSED, } } @@ -159,6 +163,20 @@ impl CsvFormat { self } + /// The quote character in a row. + /// - default to '"' + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// The escape character in a row. + /// - default is None + pub fn with_escape(mut self, escape: Option) -> Self { + self.escape = escape; + self + } + /// Set a `FileCompressionType` of CSV /// - defaults to `FileCompressionType::UNCOMPRESSED` pub fn with_file_compression_type( @@ -173,6 +191,16 @@ impl CsvFormat { pub fn delimiter(&self) -> u8 { self.delimiter } + + /// The quote character. + pub fn quote(&self) -> u8 { + self.quote + } + + /// The escape character. + pub fn escape(&self) -> Option { + self.escape + } } #[async_trait] @@ -227,6 +255,8 @@ impl FileFormat for CsvFormat { conf, self.has_header, self.delimiter, + self.quote, + self.escape, self.file_compression_type.to_owned(), ); Ok(Arc::new(exec)) diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 6155dc6640fa..69449e9f8db6 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -55,6 +55,10 @@ pub struct CsvReadOptions<'a> { pub has_header: bool, /// An optional column delimiter. Defaults to `b','`. pub delimiter: u8, + /// An optional quote character. Defaults to `b'"'`. + pub quote: u8, + /// An optional escape character. Defaults to None. + pub escape: Option, /// An optional schema representing the CSV files. If None, CSV reader will try to infer it /// based on data in file. pub schema: Option<&'a Schema>, @@ -85,6 +89,8 @@ impl<'a> CsvReadOptions<'a> { schema: None, schema_infer_max_records: DEFAULT_SCHEMA_INFER_MAX_RECORD, delimiter: b',', + quote: b'"', + escape: None, file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, @@ -110,6 +116,18 @@ impl<'a> CsvReadOptions<'a> { self } + /// Specify quote to use for CSV read + pub fn quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// Specify delimiter to use for CSV read + pub fn escape(mut self, escape: u8) -> Self { + self.escape = Some(escape); + self + } + /// Specify the file extension for CSV file selection pub fn file_extension(mut self, file_extension: &'a str) -> Self { self.file_extension = file_extension; @@ -435,6 +453,8 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { let file_format = CsvFormat::default() .with_has_header(self.has_header) .with_delimiter(self.delimiter) + .with_quote(self.quote) + .with_escape(self.escape) .with_schema_infer_max_rec(Some(self.schema_infer_max_records)) .with_file_compression_type(self.file_compression_type.to_owned()); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 9a7602b792fe..6ef92bed4ae8 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -59,6 +59,8 @@ pub struct CsvExec { projected_output_ordering: Vec, has_header: bool, delimiter: u8, + quote: u8, + escape: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Compression type of the file associated with CsvExec @@ -71,6 +73,8 @@ impl CsvExec { base_config: FileScanConfig, has_header: bool, delimiter: u8, + quote: u8, + escape: Option, file_compression_type: FileCompressionType, ) -> Self { let (projected_schema, projected_statistics, projected_output_ordering) = @@ -83,6 +87,8 @@ impl CsvExec { projected_output_ordering, has_header, delimiter, + quote, + escape, metrics: ExecutionPlanMetricsSet::new(), file_compression_type, } @@ -101,6 +107,16 @@ impl CsvExec { self.delimiter } + /// The quote character + pub fn quote(&self) -> u8 { + self.quote + } + + /// The escape character + pub fn escape(&self) -> Option { + self.escape + } + /// Redistribute files across partitions according to their size /// See comments on `repartition_file_groups()` for more detail. /// @@ -203,6 +219,8 @@ impl ExecutionPlan for CsvExec { file_projection: self.base_config.file_column_projection_indices(), has_header: self.has_header, delimiter: self.delimiter, + quote: self.quote, + escape: self.escape, object_store, }); @@ -232,6 +250,8 @@ pub struct CsvConfig { file_projection: Option>, has_header: bool, delimiter: u8, + quote: u8, + escape: Option, object_store: Arc, } @@ -243,6 +263,7 @@ impl CsvConfig { file_projection: Option>, has_header: bool, delimiter: u8, + quote: u8, object_store: Arc, ) -> Self { Self { @@ -251,6 +272,8 @@ impl CsvConfig { file_projection, has_header, delimiter, + quote, + escape: None, object_store, } } @@ -261,8 +284,11 @@ impl CsvConfig { let mut builder = csv::ReaderBuilder::new(self.file_schema.clone()) .has_header(self.has_header) .with_delimiter(self.delimiter) + .with_quote(self.quote) .with_batch_size(self.batch_size); - + if let Some(escape) = self.escape { + builder = builder.with_escape(escape); + } if let Some(p) = &self.file_projection { builder = builder.with_projection(p.clone()); } @@ -662,7 +688,14 @@ mod tests { let mut config = partitioned_csv_config(file_schema, file_groups)?; config.projection = Some(vec![0, 2, 4]); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); @@ -718,7 +751,14 @@ mod tests { let mut config = partitioned_csv_config(file_schema, file_groups)?; config.projection = Some(vec![4, 0, 2]); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); @@ -774,7 +814,14 @@ mod tests { let mut config = partitioned_csv_config(file_schema, file_groups)?; config.limit = Some(5); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(13, csv.projected_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); @@ -830,7 +877,14 @@ mod tests { let mut config = partitioned_csv_config(file_schema, file_groups)?; config.limit = Some(5); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(14, csv.base_config.file_schema.fields().len()); assert_eq!(14, csv.projected_schema.fields().len()); assert_eq!(14, csv.schema().fields().len()); @@ -884,7 +938,14 @@ mod tests { // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(2, csv.projected_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); @@ -970,7 +1031,14 @@ mod tests { .unwrap(); let config = partitioned_csv_config(file_schema, file_groups).unwrap(); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); let it = csv.execute(0, task_ctx).unwrap(); let batches: Vec<_> = it.try_collect().await.unwrap(); diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c6f082f1eb9c..811c2ec7656a 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2248,7 +2248,7 @@ mod tests { // Note capitalization let my_avg = create_udaf( "MY_AVG", - DataType::Float64, + vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, Arc::new(|_| { diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index a9dec73c36f8..9ecfb8993f12 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -15,36 +15,37 @@ // specific language governing permissions and limitations // under the License. -//! Select the proper PartitionMode and build side based on the avaliable statistics for hash join. -use std::sync::Arc; +//! The [`JoinSelection`] rule tries to modify a given plan so that it can +//! accommodate infinite sources and utilize statistical information (if there +//! is any) to obtain more performant plans. To achieve the first goal, it +//! tries to transform a non-runnable query (with the given infinite sources) +//! into a runnable query by replacing pipeline-breaking join operations with +//! pipeline-friendly ones. To achieve the second goal, it selects the proper +//! `PartitionMode` and the build side using the available statistics for hash joins. -use arrow::datatypes::Schema; +use std::sync::Arc; use crate::config::ConfigOptions; -use crate::logical_expr::JoinType; -use crate::physical_plan::expressions::Column; +use crate::error::Result; +use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use crate::physical_plan::joins::{ - utils::{ColumnIndex, JoinFilter, JoinSide}, - CrossJoinExec, HashJoinExec, PartitionMode, + CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode, + SymmetricHashJoinExec, }; use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; +use crate::physical_plan::ExecutionPlan; -use super::optimizer::PhysicalOptimizerRule; -use crate::error::Result; +use arrow_schema::Schema; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{DataFusionError, JoinType}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalExpr; -/// For hash join with the partition mode [PartitionMode::Auto], JoinSelection rule will make -/// a cost based decision to select which PartitionMode mode(Partitioned/CollectLeft) is optimal -/// based on the available statistics that the inputs have. -/// If the statistics information is not available, the partition mode will fall back to [PartitionMode::Partitioned]. -/// -/// JoinSelection rule will also reorder the build and probe phase of the hash joins -/// based on the avaliable statistics that the inputs have. -/// The rule optimizes the order such that the left (build) side of the join is the smallest. -/// If the statistics information is not available, the order stays the same as the original query. -/// JoinSelection rule will also swap the left and right sides for cross join to keep the left side -/// is the smallest. +/// The [`JoinSelection`] rule tries to modify a given plan so that it can +/// accommodate infinite sources and optimize joins in the plan according to +/// available statistical information, if there is any. #[derive(Default)] pub struct JoinSelection {} @@ -55,8 +56,9 @@ impl JoinSelection { } } -// TODO we need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. -// TODO In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is is 8 times. +// TODO: We need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. +// TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is is 8 times. +/// Checks statistics for join swap. fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -> bool { // Get the left and right table's total bytes // If both the left and right tables contain total_byte_size statistics, @@ -89,8 +91,9 @@ fn supports_collect_by_size( false } } + /// Predicate that checks whether the given join type supports input swapping. -pub fn supports_swap(join_type: JoinType) -> bool { +fn supports_swap(join_type: JoinType) -> bool { matches!( join_type, JoinType::Inner @@ -103,9 +106,10 @@ pub fn supports_swap(join_type: JoinType) -> bool { | JoinType::RightAnti ) } + /// This function returns the new join type we get after swapping the given /// join's inputs. -pub fn swap_join_type(join_type: JoinType) -> JoinType { +fn swap_join_type(join_type: JoinType) -> JoinType { match join_type { JoinType::Inner => JoinType::Inner, JoinType::Full => JoinType::Full, @@ -119,7 +123,7 @@ pub fn swap_join_type(join_type: JoinType) -> JoinType { } /// This function swaps the inputs of the given join operator. -pub fn swap_hash_join( +fn swap_hash_join( hash_join: &HashJoinExec, partition_mode: PartitionMode, ) -> Result> { @@ -160,7 +164,7 @@ pub fn swap_hash_join( /// the output should not be impacted. This function creates the expressions /// that will allow to swap back the values from the original left as the first /// columns and those on the right next. -pub fn swap_reverting_projection( +fn swap_reverting_projection( left_schema: &Schema, right_schema: &Schema, ) -> Vec<(Arc, String)> { @@ -182,30 +186,26 @@ pub fn swap_reverting_projection( } /// Swaps join sides for filter column indices and produces new JoinFilter -fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { - filter.map(|filter| { - let column_indices = filter - .column_indices() - .iter() - .map(|idx| { - let side = if matches!(idx.side, JoinSide::Left) { - JoinSide::Right - } else { - JoinSide::Left - }; - ColumnIndex { - index: idx.index, - side, - } - }) - .collect(); +fn swap_filter(filter: &JoinFilter) -> JoinFilter { + let column_indices = filter + .column_indices() + .iter() + .map(|idx| ColumnIndex { + index: idx.index, + side: idx.side.negate(), + }) + .collect(); - JoinFilter::new( - filter.expression().clone(), - column_indices, - filter.schema().clone(), - ) - }) + JoinFilter::new( + filter.expression().clone(), + column_indices, + filter.schema().clone(), + ) +} + +/// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). +fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { + filter.map(swap_filter) } impl PhysicalOptimizerRule for JoinSelection { @@ -214,63 +214,32 @@ impl PhysicalOptimizerRule for JoinSelection { plan: Arc, config: &ConfigOptions, ) -> Result> { + let pipeline = PipelineStatePropagator::new(plan); + // First, we make pipeline-fixing modifications to joins so as to accommodate + // unbounded inputs. Each pipeline-fixing subrule, which is a function + // of type `PipelineFixerSubrule`, takes a single [`PipelineStatePropagator`] + // argument storing state variables that indicate the unboundedness status + // of the current [`ExecutionPlan`] as we traverse the plan tree. + let subrules: Vec> = vec![ + Box::new(hash_join_convert_symmetric_subrule), + Box::new(hash_join_swap_subrule), + ]; + let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + // Next, we apply another subrule that tries to optimize joins using any + // statistics their inputs might have. + // - For a hash join with partition mode [`PartitionMode::Auto`], we will + // make a cost-based decision to select which `PartitionMode` mode + // (`Partitioned`/`CollectLeft`) is optimal. If the statistics information + // is not available, we will fall back to [`PartitionMode::Partitioned`]. + // - We optimize/swap join sides so that the left (build) side of the join + // is the small side. If the statistics information is not available, we + // do not modify join sides. + // - We will also swap left and right sides for cross joins so that the left + // side is the small side. let config = &config.optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; - plan.transform_up(&|plan| { - let transformed = if let Some(hash_join) = - plan.as_any().downcast_ref::() - { - match hash_join.partition_mode() { - PartitionMode::Auto => { - try_collect_left(hash_join, Some(collect_left_threshold))? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )? - } - PartitionMode::CollectLeft => try_collect_left(hash_join, None)? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if should_swap_join_order(&**left, &**right) - && supports_swap(*hash_join.join_type()) - { - swap_hash_join(hash_join, PartitionMode::Partitioned) - .map(Some)? - } else { - None - } - } - } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() - { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right) { - let new_join = - CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj: Arc = Arc::new(ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?); - Some(proj) - } else { - None - } - } else { - None - }; - - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) + state.plan.transform_up(&|plan| { + statistical_join_selection_subrule(plan, collect_left_threshold) }) } @@ -283,13 +252,17 @@ impl PhysicalOptimizerRule for JoinSelection { } } -/// Try to create the PartitionMode::CollectLeft HashJoinExec when possible. -/// The method will first consider the current join type and check whether it is applicable to run CollectLeft mode -/// and will try to swap the join if the orignal type is unapplicable to run CollectLeft. -/// When the collect_threshold is provided, the method will also check both the left side and right side sizes +/// Tries to create a [`HashJoinExec`] in [`PartitionMode::CollectLeft`] when possible. /// -/// For [JoinType::Full], it is alway unable to run CollectLeft mode and will return None. -/// For [JoinType::Left] and [JoinType::LeftAnti], can not run CollectLeft mode, should swap join type to [JoinType::Right] and [JoinType::RightAnti] +/// This function will first consider the given join type and check whether the +/// `CollectLeft` mode is applicable. Otherwise, it will try to swap the join sides. +/// When the `collect_threshold` is provided, this function will also check left +/// and right sizes. +/// +/// For [`JoinType::Full`], it can not use `CollectLeft` mode and will return `None`. +/// For [`JoinType::Left`] and [`JoinType::LeftAnti`], it can not run `CollectLeft` +/// mode as is, but it can do so by changing the join type to [`JoinType::Right`] +/// and [`JoinType::RightAnti`], respectively. fn try_collect_left( hash_join: &HashJoinExec, collect_threshold: Option, @@ -375,8 +348,238 @@ fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result, + collect_left_threshold: usize, +) -> Result>> { + let transformed = if let Some(hash_join) = + plan.as_any().downcast_ref::() + { + match hash_join.partition_mode() { + PartitionMode::Auto => { + try_collect_left(hash_join, Some(collect_left_threshold))?.map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )? + } + PartitionMode::CollectLeft => try_collect_left(hash_join, None)? + .map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )?, + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if should_swap_join_order(&**left, &**right) + && supports_swap(*hash_join.join_type()) + { + swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? + } else { + None + } + } + } + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right) { + let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); + // TODO avoid adding ProjectionExec again and again, only adding Final Projection + let proj: Arc = Arc::new(ProjectionExec::try_new( + swap_reverting_projection(&left.schema(), &right.schema()), + Arc::new(new_join), + )?); + Some(proj) + } else { + None + } + } else { + None + }; + + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(plan) + }) +} + +/// Pipeline-fixing join selection subrule. +pub type PipelineFixerSubrule = dyn Fn( + PipelineStatePropagator, + &ConfigOptions, +) -> Option>; + +/// This subrule checks if we can replace a hash join with a symmetric hash +/// join when we are dealing with infinite inputs on both sides. This change +/// avoids pipeline breaking and preserves query runnability. If possible, +/// this subrule makes this replacement; otherwise, it has no effect. +fn hash_join_convert_symmetric_subrule( + mut input: PipelineStatePropagator, + config_options: &ConfigOptions, +) -> Option> { + if let Some(hash_join) = input.plan.as_any().downcast_ref::() { + let ub_flags = &input.children_unbounded; + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); + input.unbounded = left_unbounded || right_unbounded; + let result = if left_unbounded && right_unbounded { + let mode = if config_options.optimizer.repartition_joins { + StreamJoinPartitionMode::Partitioned + } else { + StreamJoinPartitionMode::SinglePartition + }; + SymmetricHashJoinExec::try_new( + hash_join.left().clone(), + hash_join.right().clone(), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.null_equals_null(), + mode, + ) + .map(|exec| { + input.plan = Arc::new(exec) as _; + input + }) + } else { + Ok(input) + }; + Some(result) + } else { + None + } +} + +/// This subrule will swap build/probe sides of a hash join depending on whether +/// one of its inputs may produce an infinite stream of records. The rule ensures +/// that the left (build) side of the hash join always operates on an input stream +/// that will produce a finite set of records. If the left side can not be chosen +/// to be "finite", the join sides stay the same as the original query. +/// ```text +/// For example, this rule makes the following transformation: +/// +/// +/// +/// +--------------+ +--------------+ +/// | | unbounded | | +/// Left | Infinite | true | Hash |\true +/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ +/// | | | | \ | | | | +/// +--------------+ +--------------+ - | Hash Join |-------| Projection | +/// - | | | | +/// +--------------+ +--------------+ / +--------------+ +--------------+ +/// | | unbounded | | / +/// Right | Finite | false | Hash |/false +/// | Data Source |--------------| Repartition | +/// | | | | +/// +--------------+ +--------------+ +/// +/// +/// +/// +--------------+ +--------------+ +/// | | unbounded | | +/// Left | Finite | false | Hash |\false +/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ +/// | | | | \ | | true | | true +/// +--------------+ +--------------+ - | Hash Join |-------| Projection |----- +/// - | | | | +/// +--------------+ +--------------+ / +--------------+ +--------------+ +/// | | unbounded | | / +/// Right | Infinite | true | Hash |/true +/// | Data Source |--------------| Repartition | +/// | | | | +/// +--------------+ +--------------+ +/// +/// ``` +fn hash_join_swap_subrule( + mut input: PipelineStatePropagator, + _config_options: &ConfigOptions, +) -> Option> { + if let Some(hash_join) = input.plan.as_any().downcast_ref::() { + let ub_flags = &input.children_unbounded; + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); + input.unbounded = left_unbounded || right_unbounded; + let result = if left_unbounded + && !right_unbounded + && matches!( + *hash_join.join_type(), + JoinType::Inner + | JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + ) { + swap_join_according_to_unboundedness(hash_join).map(|plan| { + input.plan = plan; + input + }) + } else { + Ok(input) + }; + Some(result) + } else { + None + } +} + +/// This function swaps sides of a hash join to make it runnable even if one of +/// its inputs are infinite. Note that this is not always possible; i.e. +/// [`JoinType::Full`], [`JoinType::Right`], [`JoinType::RightAnti`] and +/// [`JoinType::RightSemi`] can not run with an unbounded left side, even if +/// we swap join sides. Therefore, we do not consider them here. +fn swap_join_according_to_unboundedness( + hash_join: &HashJoinExec, +) -> Result> { + let partition_mode = hash_join.partition_mode(); + let join_type = hash_join.join_type(); + match (*partition_mode, *join_type) { + ( + _, + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, + ) => Err(DataFusionError::Internal(format!( + "{join_type} join cannot be swapped for unbounded input." + ))), + (PartitionMode::Partitioned, _) => { + swap_hash_join(hash_join, PartitionMode::Partitioned) + } + (PartitionMode::CollectLeft, _) => { + swap_hash_join(hash_join, PartitionMode::CollectLeft) + } + (PartitionMode::Auto, _) => Err(DataFusionError::Internal( + "Auto is not acceptable for unbounded input here.".to_string(), + )), + } +} + +/// Apply given `PipelineFixerSubrule`s to a given plan. This plan, along with +/// auxiliary boundedness information, is in the `PipelineStatePropagator` object. +fn apply_subrules( + mut input: PipelineStatePropagator, + subrules: &Vec>, + config_options: &ConfigOptions, +) -> Result> { + for subrule in subrules { + if let Some(value) = subrule(input.clone(), config_options).transpose()? { + input = value; + } + } + let is_unbounded = input + .plan + .unbounded_output(&input.children_unbounded) + // Treat the case where an operator can not run on unbounded data as + // if it can and it outputs unbounded data. Do not raise an error yet. + // Such operators may be fixed, adjusted or replaced later on during + // optimization passes -- sorts may be removed, windows may be adjusted + // etc. If this doesn't happen, the final `PipelineChecker` rule will + // catch this and raise an error anyway. + .unwrap_or(true); + input.unbounded = is_unbounded; + Ok(Transformed::Yes(input)) +} + #[cfg(test)] -mod tests { +mod tests_statistical { use crate::{ physical_plan::{ displayable, joins::PartitionMode, ColumnStatistics, Statistics, @@ -388,7 +591,9 @@ mod tests { use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{JoinType, ScalarValue}; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::PhysicalExpr; fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( @@ -556,7 +761,6 @@ mod tests { .expect("A proj is required to swap columns back to their original order"); assert_eq!(swapping_projection.expr().len(), 2); - println!("swapping_projection {swapping_projection:?}"); let (col, name) = &swapping_projection.expr()[0]; assert_eq!(name, "small_col"); assert_col_expr(col, "small_col", 1); @@ -693,7 +897,7 @@ mod tests { " StatisticsExec: col_count=1, row_count=Some(1000)", " StatisticsExec: col_count=1, row_count=Some(100000)", " StatisticsExec: col_count=1, row_count=Some(10000)", - "" + "", ]; assert_optimized!(expected, join); } @@ -967,3 +1171,484 @@ mod tests { } } } + +#[cfg(test)] +mod util_tests { + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; + use datafusion_physical_expr::intervals::check_support; + use datafusion_physical_expr::PhysicalExpr; + use std::sync::Arc; + + #[test] + fn check_expr_supported() { + let supported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(check_support(&supported_expr)); + let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; + assert!(check_support(&supported_expr_2)); + let unsupported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(!check_support(&unsupported_expr)); + let unsupported_expr_2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), + )) as Arc; + assert!(!check_support(&unsupported_expr_2)); + } +} + +#[cfg(test)] +mod hash_join_tests { + use super::*; + use crate::physical_optimizer::join_selection::swap_join_type; + use crate::physical_optimizer::test_utils::SourceType; + use crate::physical_plan::expressions::Column; + use crate::physical_plan::joins::PartitionMode; + use crate::physical_plan::projection::ProjectionExec; + use crate::test_util::UnboundedExec; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::utils::DataPtr; + use datafusion_common::JoinType; + use std::sync::Arc; + + struct TestCase { + case: String, + initial_sources_unbounded: (SourceType, SourceType), + initial_join_type: JoinType, + initial_mode: PartitionMode, + expected_sources_unbounded: (SourceType, SourceType), + expected_join_type: JoinType, + expected_mode: PartitionMode, + expecting_swap: bool, + } + + #[tokio::test] + async fn test_join_with_swap_full() -> Result<()> { + // NOTE: Currently, some initial conditions are not viable after join order selection. + // For example, full join always comes in partitioned mode. See the warning in + // function "swap". If this changes in the future, we should update these tests. + let cases = vec![ + TestCase { + case: "Bounded - Unbounded 1".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + TestCase { + case: "Unbounded - Bounded 2".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + TestCase { + case: "Bounded - Bounded 3".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + TestCase { + case: "Unbounded - Unbounded 4".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + ]; + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + #[tokio::test] + async fn test_cases_without_collect_left_check() -> Result<()> { + let mut cases = vec![]; + let join_types = vec![JoinType::LeftSemi, JoinType::Inner]; + for join_type in join_types { + cases.push(TestCase { + case: "Unbounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: swap_join_type(join_type), + expected_mode: PartitionMode::CollectLeft, + expecting_swap: true, + }); + cases.push(TestCase { + case: "Bounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: swap_join_type(join_type), + expected_mode: PartitionMode::Partitioned, + expecting_swap: true, + }); + cases.push(TestCase { + case: "Bounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + } + + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + #[tokio::test] + async fn test_not_support_collect_left() -> Result<()> { + let mut cases = vec![]; + // After [JoinSelection] optimization, these join types cannot run in CollectLeft mode except + // [JoinType::LeftSemi] + let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti]; + for join_type in the_ones_not_support_collect_left { + cases.push(TestCase { + case: "Unbounded - Bounded".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: swap_join_type(join_type), + expected_mode: PartitionMode::Partitioned, + expecting_swap: true, + }); + cases.push(TestCase { + case: "Bounded - Unbounded".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + } + + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + #[tokio::test] + async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> { + let mut cases = vec![]; + let the_ones_not_support_collect_left = + vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi]; + for join_type in the_ones_not_support_collect_left { + // We expect that (SourceType::Unbounded, SourceType::Bounded) will change, regardless of the + // statistics. + cases.push(TestCase { + case: "Unbounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + // We expect that (SourceType::Bounded, SourceType::Unbounded) will stay same, regardless of the + // statistics. + cases.push(TestCase { + case: "Bounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + // + cases.push(TestCase { + case: "Bounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + // If cases are partitioned, only unbounded & bounded check will affect the order. + cases.push(TestCase { + case: "Unbounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + } + + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { + let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded; + let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded; + let left_exec = Arc::new(UnboundedExec::new( + (!left_unbounded).then_some(1), + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Int32, + false, + )]))), + 2, + )) as Arc; + let right_exec = Arc::new(UnboundedExec::new( + (!right_unbounded).then_some(1), + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "b", + DataType::Int32, + false, + )]))), + 2, + )) as Arc; + + let join = HashJoinExec::try_new( + Arc::clone(&left_exec), + Arc::clone(&right_exec), + vec![( + Column::new_with_schema("a", &left_exec.schema())?, + Column::new_with_schema("b", &right_exec.schema())?, + )], + None, + &t.initial_join_type, + t.initial_mode, + false, + )?; + + let initial_hash_join_state = PipelineStatePropagator { + plan: Arc::new(join), + unbounded: false, + children_unbounded: vec![left_unbounded, right_unbounded], + }; + + let optimized_hash_join = + hash_join_swap_subrule(initial_hash_join_state, &ConfigOptions::new()) + .unwrap()?; + let optimized_join_plan = optimized_hash_join.plan; + + // If swap did happen + let projection_added = optimized_join_plan.as_any().is::(); + let plan = if projection_added { + let proj = optimized_join_plan + .as_any() + .downcast_ref::() + .expect( + "A proj is required to swap columns back to their original order", + ); + proj.input().clone() + } else { + optimized_join_plan + }; + + if let Some(HashJoinExec { + left, + right, + join_type, + mode, + .. + }) = plan.as_any().downcast_ref::() + { + let left_changed = Arc::data_ptr_eq(left, &right_exec); + let right_changed = Arc::data_ptr_eq(right, &left_exec); + // If this is not equal, we have a bigger problem. + assert_eq!(left_changed, right_changed); + assert_eq!( + ( + t.case.as_str(), + if left.unbounded_output(&[])? { + SourceType::Unbounded + } else { + SourceType::Bounded + }, + if right.unbounded_output(&[])? { + SourceType::Unbounded + } else { + SourceType::Bounded + }, + join_type, + mode, + left_changed && right_changed + ), + ( + t.case.as_str(), + t.expected_sources_unbounded.0, + t.expected_sources_unbounded.1, + &t.expected_join_type, + &t.expected_mode, + t.expecting_swap + ) + ); + }; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 8ee95ea663f3..f74d4ea0c9a6 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -30,12 +30,11 @@ pub mod optimizer; pub mod pipeline_checker; pub mod pruning; pub mod repartition; -pub mod replace_repartition_execs; +pub mod replace_with_order_preserving_variants; pub mod sort_enforcement; mod sort_pushdown; mod utils; -pub mod pipeline_fixer; #[cfg(test)] pub mod test_utils; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index d35c82abd28e..3f6698c6cf46 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -26,7 +26,6 @@ use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAgg use crate::physical_optimizer::dist_enforcement::EnforceDistribution; use crate::physical_optimizer::join_selection::JoinSelection; use crate::physical_optimizer::pipeline_checker::PipelineChecker; -use crate::physical_optimizer::pipeline_fixer::PipelineFixer; use crate::physical_optimizer::repartition::Repartition; use crate::physical_optimizer::sort_enforcement::EnforceSorting; use crate::{error::Result, physical_plan::ExecutionPlan}; @@ -76,12 +75,6 @@ impl PhysicalOptimizer { // repartitioning and local sorting steps to meet distribution and ordering requirements. // Therefore, it should run before EnforceDistribution and EnforceSorting. Arc::new(JoinSelection::new()), - // If the query is processing infinite inputs, the PipelineFixer rule applies the - // necessary transformations to make the query runnable (if it is not already runnable). - // If the query can not be made runnable, the rule emits an error with a diagnostic message. - // Since the transformations it applies may alter output partitioning properties of operators - // (e.g. by swapping hash join sides), this rule runs before EnforceDistribution. - Arc::new(PipelineFixer::new()), // In order to increase the parallelism, the Repartition rule will change the // output partitioning of some operators in the plan tree, which will influence // other rules. Therefore, it should run as soon as possible. It is optional because: diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs deleted file mode 100644 index 7db3e99c3920..000000000000 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ /dev/null @@ -1,716 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The [PipelineFixer] rule tries to modify a given plan so that it can -//! accommodate its infinite sources, if there are any. In other words, -//! it tries to obtain a runnable query (with the given infinite sources) -//! from an non-runnable query by transforming pipeline-breaking operations -//! to pipeline-friendly ones. If this can not be done, the rule emits a -//! diagnostic error message. -//! -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_optimizer::join_selection::swap_hash_join; -use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{ - HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, -}; -use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::DataFusionError; -use datafusion_expr::logical_plan::JoinType; - -use std::sync::Arc; - -/// The [`PipelineFixer`] rule tries to modify a given plan so that it can -/// accommodate its infinite sources, if there are any. If this is not -/// possible, the rule emits a diagnostic error message. -#[derive(Default)] -pub struct PipelineFixer {} - -impl PipelineFixer { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} -/// [`PipelineFixer`] subrules are functions of this type. Such functions take a -/// single [`PipelineStatePropagator`] argument, which stores state variables -/// indicating the unboundedness status of the current [`ExecutionPlan`] as -/// the `PipelineFixer` rule traverses the entire plan tree. -type PipelineFixerSubrule = - dyn Fn(PipelineStatePropagator) -> Option>; - -impl PhysicalOptimizerRule for PipelineFixer { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - let pipeline = PipelineStatePropagator::new(plan); - let subrules: Vec> = vec![ - Box::new(hash_join_convert_symmetric_subrule), - Box::new(hash_join_swap_subrule), - ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules))?; - Ok(state.plan) - } - - fn name(&self) -> &str { - "PipelineFixer" - } - - fn schema_check(&self) -> bool { - true - } -} - -/// This subrule checks if one can replace a hash join with a symmetric hash -/// join so that the pipeline does not break due to the join operation in -/// question. If possible, it makes this replacement; otherwise, it has no -/// effect. -fn hash_join_convert_symmetric_subrule( - mut input: PipelineStatePropagator, -) -> Option> { - if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; - let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); - input.unbounded = left_unbounded || right_unbounded; - let result = if left_unbounded && right_unbounded { - SymmetricHashJoinExec::try_new( - hash_join.left().clone(), - hash_join.right().clone(), - hash_join - .on() - .iter() - .map(|(l, r)| (l.clone(), r.clone())) - .collect(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.null_equals_null(), - StreamJoinPartitionMode::Partitioned, - ) - .map(|exec| { - input.plan = Arc::new(exec) as _; - input - }) - } else { - Ok(input) - }; - Some(result) - } else { - None - } -} - -/// This subrule will swap build/probe sides of a hash join depending on whether its inputs -/// may produce an infinite stream of records. The rule ensures that the left (build) side -/// of the hash join always operates on an input stream that will produce a finite set of. -/// records If the left side can not be chosen to be "finite", the order stays the -/// same as the original query. -/// ```text -/// For example, this rule makes the following transformation: -/// -/// -/// -/// +--------------+ +--------------+ -/// | | unbounded | | -/// Left | Infinite | true | Hash |\true -/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ -/// | | | | \ | | | | -/// +--------------+ +--------------+ - | Hash Join |-------| Projection | -/// - | | | | -/// +--------------+ +--------------+ / +--------------+ +--------------+ -/// | | unbounded | | / -/// Right | Finite | false | Hash |/false -/// | Data Source |--------------| Repartition | -/// | | | | -/// +--------------+ +--------------+ -/// -/// -/// -/// +--------------+ +--------------+ -/// | | unbounded | | -/// Left | Finite | false | Hash |\false -/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ -/// | | | | \ | | true | | true -/// +--------------+ +--------------+ - | Hash Join |-------| Projection |----- -/// - | | | | -/// +--------------+ +--------------+ / +--------------+ +--------------+ -/// | | unbounded | | / -/// Right | Infinite | true | Hash |/true -/// | Data Source |--------------| Repartition | -/// | | | | -/// +--------------+ +--------------+ -/// -/// ``` -fn hash_join_swap_subrule( - mut input: PipelineStatePropagator, -) -> Option> { - if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; - let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); - input.unbounded = left_unbounded || right_unbounded; - let result = if left_unbounded - && !right_unbounded - && matches!( - *hash_join.join_type(), - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - ) { - swap(hash_join).map(|plan| { - input.plan = plan; - input - }) - } else { - Ok(input) - }; - Some(result) - } else { - None - } -} - -/// This function swaps sides of a hash join to make it runnable even if one of its -/// inputs are infinite. Note that this is not always possible; i.e. [JoinType::Full], -/// [JoinType::Right], [JoinType::RightAnti] and [JoinType::RightSemi] can not run with -/// an unbounded left side, even if we swap. Therefore, we do not consider them here. -fn swap(hash_join: &HashJoinExec) -> Result> { - let partition_mode = hash_join.partition_mode(); - let join_type = hash_join.join_type(); - match (*partition_mode, *join_type) { - ( - _, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, - ) => Err(DataFusionError::Internal(format!( - "{join_type} join cannot be swapped for unbounded input." - ))), - (PartitionMode::Partitioned, _) => { - swap_hash_join(hash_join, PartitionMode::Partitioned) - } - (PartitionMode::CollectLeft, _) => { - swap_hash_join(hash_join, PartitionMode::CollectLeft) - } - (PartitionMode::Auto, _) => Err(DataFusionError::Internal( - "Auto is not acceptable for unbounded input here.".to_string(), - )), - } -} - -fn apply_subrules( - mut input: PipelineStatePropagator, - subrules: &Vec>, -) -> Result> { - for subrule in subrules { - if let Some(value) = subrule(input.clone()).transpose()? { - input = value; - } - } - let is_unbounded = input - .plan - .unbounded_output(&input.children_unbounded) - // Treat the case where an operator can not run on unbounded data as - // if it can and it outputs unbounded data. Do not raise an error yet. - // Such operators may be fixed, adjusted or replaced later on during - // optimization passes -- sorts may be removed, windows may be adjusted - // etc. If this doesn't happen, the final `PipelineChecker` rule will - // catch this and raise an error anyway. - .unwrap_or(true); - input.unbounded = is_unbounded; - Ok(Transformed::Yes(input)) -} - -#[cfg(test)] -mod util_tests { - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; - use datafusion_physical_expr::intervals::check_support; - use datafusion_physical_expr::PhysicalExpr; - use std::sync::Arc; - - #[test] - fn check_expr_supported() { - let supported_expr = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )) as Arc; - assert!(check_support(&supported_expr)); - let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; - assert!(check_support(&supported_expr_2)); - let unsupported_expr = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Or, - Arc::new(Column::new("a", 0)), - )) as Arc; - assert!(!check_support(&unsupported_expr)); - let unsupported_expr_2 = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Or, - Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), - )) as Arc; - assert!(!check_support(&unsupported_expr_2)); - } -} - -#[cfg(test)] -mod hash_join_tests { - use super::*; - use crate::physical_optimizer::join_selection::swap_join_type; - use crate::physical_optimizer::test_utils::SourceType; - use crate::physical_plan::expressions::Column; - use crate::physical_plan::joins::PartitionMode; - use crate::physical_plan::projection::ProjectionExec; - use crate::test_util::UnboundedExec; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::utils::DataPtr; - use std::sync::Arc; - - struct TestCase { - case: String, - initial_sources_unbounded: (SourceType, SourceType), - initial_join_type: JoinType, - initial_mode: PartitionMode, - expected_sources_unbounded: (SourceType, SourceType), - expected_join_type: JoinType, - expected_mode: PartitionMode, - expecting_swap: bool, - } - - #[tokio::test] - async fn test_join_with_swap_full() -> Result<()> { - // NOTE: Currently, some initial conditions are not viable after join order selection. - // For example, full join always comes in partitioned mode. See the warning in - // function "swap". If this changes in the future, we should update these tests. - let cases = vec![ - TestCase { - case: "Bounded - Unbounded 1".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - TestCase { - case: "Unbounded - Bounded 2".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - TestCase { - case: "Bounded - Bounded 3".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - TestCase { - case: "Unbounded - Unbounded 4".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - ]; - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - #[tokio::test] - async fn test_cases_without_collect_left_check() -> Result<()> { - let mut cases = vec![]; - let join_types = vec![JoinType::LeftSemi, JoinType::Inner]; - for join_type in join_types { - cases.push(TestCase { - case: "Unbounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), - expected_mode: PartitionMode::CollectLeft, - expecting_swap: true, - }); - cases.push(TestCase { - case: "Bounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), - expected_mode: PartitionMode::Partitioned, - expecting_swap: true, - }); - cases.push(TestCase { - case: "Bounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - } - - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - #[tokio::test] - async fn test_not_support_collect_left() -> Result<()> { - let mut cases = vec![]; - // After [JoinSelection] optimization, these join types cannot run in CollectLeft mode except - // [JoinType::LeftSemi] - let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti]; - for join_type in the_ones_not_support_collect_left { - cases.push(TestCase { - case: "Unbounded - Bounded".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), - expected_mode: PartitionMode::Partitioned, - expecting_swap: true, - }); - cases.push(TestCase { - case: "Bounded - Unbounded".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - } - - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - #[tokio::test] - async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> { - let mut cases = vec![]; - let the_ones_not_support_collect_left = - vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi]; - for join_type in the_ones_not_support_collect_left { - // We expect that (SourceType::Unbounded, SourceType::Bounded) will change, regardless of the - // statistics. - cases.push(TestCase { - case: "Unbounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - // We expect that (SourceType::Bounded, SourceType::Unbounded) will stay same, regardless of the - // statistics. - cases.push(TestCase { - case: "Bounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - // - cases.push(TestCase { - case: "Bounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - // If cases are partitioned, only unbounded & bounded check will affect the order. - cases.push(TestCase { - case: "Unbounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - } - - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { - let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded; - let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded; - let left_exec = Arc::new(UnboundedExec::new( - (!left_unbounded).then_some(1), - RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Int32, - false, - )]))), - 2, - )) as Arc; - let right_exec = Arc::new(UnboundedExec::new( - (!right_unbounded).then_some(1), - RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( - "b", - DataType::Int32, - false, - )]))), - 2, - )) as Arc; - - let join = HashJoinExec::try_new( - Arc::clone(&left_exec), - Arc::clone(&right_exec), - vec![( - Column::new_with_schema("a", &left_exec.schema())?, - Column::new_with_schema("b", &right_exec.schema())?, - )], - None, - &t.initial_join_type, - t.initial_mode, - false, - )?; - - let initial_hash_join_state = PipelineStatePropagator { - plan: Arc::new(join), - unbounded: false, - children_unbounded: vec![left_unbounded, right_unbounded], - }; - let optimized_hash_join = - hash_join_swap_subrule(initial_hash_join_state).unwrap()?; - let optimized_join_plan = optimized_hash_join.plan; - - // If swap did happen - let projection_added = optimized_join_plan.as_any().is::(); - let plan = if projection_added { - let proj = optimized_join_plan - .as_any() - .downcast_ref::() - .expect( - "A proj is required to swap columns back to their original order", - ); - proj.input().clone() - } else { - optimized_join_plan - }; - - if let Some(HashJoinExec { - left, - right, - join_type, - mode, - .. - }) = plan.as_any().downcast_ref::() - { - let left_changed = Arc::data_ptr_eq(left, &right_exec); - let right_changed = Arc::data_ptr_eq(right, &left_exec); - // If this is not equal, we have a bigger problem. - assert_eq!(left_changed, right_changed); - assert_eq!( - ( - t.case.as_str(), - if left.unbounded_output(&[])? { - SourceType::Unbounded - } else { - SourceType::Bounded - }, - if right.unbounded_output(&[])? { - SourceType::Unbounded - } else { - SourceType::Bounded - }, - join_type, - mode, - left_changed && right_changed - ), - ( - t.case.as_str(), - t.expected_sources_unbounded.0, - t.expected_sources_unbounded.1, - &t.expected_join_type, - &t.expected_mode, - t.expecting_swap - ) - ); - }; - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs index 866686fe1e23..aa48fd77a8b1 100644 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ b/datafusion/core/src/physical_optimizer/repartition.rs @@ -396,6 +396,8 @@ mod tests { scan_config(false, true), false, b',', + b'"', + None, FileCompressionType::UNCOMPRESSED, )) } @@ -411,6 +413,8 @@ mod tests { scan_config(false, false), false, b',', + b'"', + None, FileCompressionType::UNCOMPRESSED, )) } @@ -426,6 +430,8 @@ mod tests { scan_config(true, true), false, b',', + b'"', + None, FileCompressionType::UNCOMPRESSED, )) } @@ -992,6 +998,8 @@ mod tests { scan_config(false, true), false, b',', + b'"', + None, compression_type, ))); diff --git a/datafusion/core/src/physical_optimizer/replace_repartition_execs.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs similarity index 73% rename from datafusion/core/src/physical_optimizer/replace_repartition_execs.rs rename to datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index faab48215684..8c86906a68f2 100644 --- a/datafusion/core/src/physical_optimizer/replace_repartition_execs.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -15,136 +15,255 @@ // specific language governing permissions and limitations // under the License. -//! Repartition optimizer that replaces `SortExec`s and their suitable `RepartitionExec` children with `SortPreservingRepartitionExec`s. +//! Optimizer rule that replaces executors that lose ordering with their +//! order-preserving variants when it is helpful; either in terms of +//! performance or to accommodate unbounded streams by fixing the pipeline. + use crate::error::Result; -use crate::physical_optimizer::sort_enforcement::unbounded_output; +use crate::physical_optimizer::sort_enforcement::{unbounded_output, ExecTree}; +use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort}; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::ExecutionPlan; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use super::utils::is_repartition; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_physical_expr::utils::ordering_satisfy; -use itertools::enumerate; use std::sync::Arc; -/// Creates a `SortPreservingRepartitionExec` from given `RepartitionExec` -fn sort_preserving_repartition( - repartition: &RepartitionExec, -) -> Result> { - Ok(Arc::new( - RepartitionExec::try_new( - repartition.input().clone(), - repartition.partitioning().clone(), - )? - .with_preserve_order(), - )) +/// For a given `plan`, this object carries the information one needs from its +/// descendants to decide whether it is beneficial to replace order-losing (but +/// somewhat faster) variants of certain operators with their order-preserving +/// (but somewhat slower) cousins. +#[derive(Debug, Clone)] +pub(crate) struct OrderPreservationContext { + pub(crate) plan: Arc, + ordering_onwards: Vec>, } -fn does_plan_maintain_input_order(plan: &Arc) -> bool { - plan.maintains_input_order().iter().any(|flag| *flag) -} +impl OrderPreservationContext { + /// Creates a "default" order-preservation context. + pub fn new(plan: Arc) -> Self { + let length = plan.children().len(); + OrderPreservationContext { + plan, + ordering_onwards: vec![None; length], + } + } -/// Check the children nodes of a `SortExec` until ordering is lost (e.g. until -/// another `SortExec` or a `CoalescePartitionsExec` which doesn't maintain ordering) -/// and replace `RepartitionExec`s that do not maintain ordering (e.g. those whose -/// input partition counts are larger than unity) with `SortPreservingRepartitionExec`s. -/// Note that doing this may render the `SortExec` in question unneccessary, which will -/// be removed later on. -/// -/// For example, we transform the plan below -/// "FilterExec: c@2 > 3", -/// " RepartitionExec: partitioning=Hash(\[b@0], 16), input_partitions=16", -/// " RepartitionExec: partitioning=Hash(\[a@0], 16), input_partitions=1", -/// " MemoryExec: partitions=1, partition_sizes=\[()], output_ordering: \[PhysicalSortExpr { expr: Column { name: \"a\", index: 0 }, options: SortOptions { descending: false, nulls_first: false } }]", -/// into -/// "FilterExec: c@2 > 3", -/// " SortPreservingRepartitionExec: partitioning=Hash(\[b@0], 16), input_partitions=16", -/// " RepartitionExec: partitioning=Hash(\[a@0], 16), input_partitions=1", -/// " MemoryExec: partitions=1, partition_sizes=\[], output_ordering: \[PhysicalSortExpr { expr: Column { name: \"a\", index: 0 }, options: SortOptions { descending: false, nulls_first: false } }]", -/// where the `FilterExec` in the latter has output ordering `a ASC`. This ordering will -/// potentially remove a `SortExec` at the top of `FilterExec`. If this doesn't help remove -/// a `SortExec`, the old version is used. -fn replace_sort_children( - plan: &Arc, -) -> Result> { - if plan.children().is_empty() { - return Ok(plan.clone()); + /// Creates a new order-preservation context from those of children nodes. + pub fn new_from_children_nodes( + children_nodes: Vec, + parent_plan: Arc, + ) -> Result { + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect(); + let ordering_onwards = children_nodes + .into_iter() + .enumerate() + .map(|(idx, item)| { + // `ordering_onwards` tree keeps track of executors that maintain + // ordering, (or that can maintain ordering with the replacement of + // its variant) + let plan = item.plan; + let ordering_onwards = item.ordering_onwards; + if plan.children().is_empty() { + // Plan has no children, there is nothing to propagate. + None + } else if ordering_onwards[0].is_none() + && ((is_repartition(&plan) && !plan.maintains_input_order()[0]) + || (is_coalesce_partitions(&plan) + && plan.children()[0].output_ordering().is_some())) + { + Some(ExecTree::new(plan, idx, vec![])) + } else { + let children = ordering_onwards + .into_iter() + .flatten() + .filter(|item| { + // Only consider operators that maintains ordering + plan.maintains_input_order()[item.idx] + || is_coalesce_partitions(&plan) + || is_repartition(&plan) + }) + .collect::>(); + if children.is_empty() { + None + } else { + Some(ExecTree::new(plan, idx, children)) + } + } + }) + .collect(); + let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); + Ok(OrderPreservationContext { + plan, + ordering_onwards, + }) } - let mut children = plan.children(); - for (idx, child) in enumerate(plan.children()) { - if !is_repartition(&child) && !does_plan_maintain_input_order(&child) { - break; - } + /// Computes order-preservation contexts for every child of the plan. + pub fn children(&self) -> Vec { + self.plan + .children() + .into_iter() + .map(|child| OrderPreservationContext::new(child)) + .collect() + } +} - if let Some(repartition) = child.as_any().downcast_ref::() { - // Replace this `RepartitionExec` with a `SortPreservingRepartitionExec` - // if it doesn't preserve ordering and its input is unbounded. Doing - // so avoids breaking the pipeline. - if !repartition.maintains_input_order()[0] && unbounded_output(&child) { - let spr = sort_preserving_repartition(repartition)? - .with_new_children(repartition.children())?; - // Perform the replacement and recurse into this plan's children: - children[idx] = replace_sort_children(&spr)?; - continue; +impl TreeNode for OrderPreservationContext { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), } } + Ok(VisitRecursion::Continue) + } - children[idx] = replace_sort_children(&child)?; + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if children.is_empty() { + Ok(self) + } else { + let children_nodes = children + .into_iter() + .map(transform) + .collect::>>()?; + OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) + } } +} - plan.clone().with_new_children(children) +/// Calculates the updated plan by replacing executors that lose ordering +/// inside the `ExecTree` with their order-preserving variants. This will +/// generate an alternative plan, which will be accepted or rejected later on +/// depending on whether it helps us remove a `SortExec`. +fn get_updated_plan( + exec_tree: &ExecTree, + // Flag indicating that it is desirable to replace `RepartitionExec`s with + // `SortPreservingRepartitionExec`s: + is_spr_better: bool, + // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s + // with `SortPreservingMergeExec`s: + is_spm_better: bool, +) -> Result> { + let plan = exec_tree.plan.clone(); + + let mut children = plan.children(); + // Update children and their descendants in the given tree: + for item in &exec_tree.children { + children[item.idx] = get_updated_plan(item, is_spr_better, is_spm_better)?; + } + // Construct the plan with updated children: + let mut plan = plan.with_new_children(children)?; + + // When a `RepartitionExec` doesn't preserve ordering, replace it with + // a `SortPreservingRepartitionExec` if appropriate: + if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { + let child = plan.children()[0].clone(); + plan = Arc::new( + RepartitionExec::try_new(child, plan.output_partitioning())? + .with_preserve_order(), + ) as _ + } + // When the input of a `CoalescePartitionsExec` has an ordering, replace it + // with a `SortPreservingMergeExec` if appropriate: + if is_coalesce_partitions(&plan) + && plan.children()[0].output_ordering().is_some() + && is_spm_better + { + let child = plan.children()[0].clone(); + plan = Arc::new(SortPreservingMergeExec::new( + child.output_ordering().unwrap_or(&[]).to_vec(), + child, + )) as _ + } + Ok(plan) } -/// The `replace_repartition_execs` optimizer sub-rule searches for `SortExec`s -/// and their `RepartitionExec` children with multiple input partitioning having -/// local (per-partition) ordering, so that it can replace the `RepartitionExec` -/// with a `SortPreservingRepartitionExec` and remove the pipeline-breaking `SortExec` -/// from the physical plan. +/// The `replace_with_order_preserving_variants` optimizer sub-rule tries to +/// remove `SortExec`s from the physical plan by replacing operators that do +/// not preserve ordering with their order-preserving variants; i.e. by replacing +/// `RepartitionExec`s with `SortPreservingRepartitionExec`s or by replacing +/// `CoalescePartitionsExec`s with `SortPreservingMergeExec`s. +/// +/// If this replacement is helpful for removing a `SortExec`, it updates the plan. +/// Otherwise, it leaves the plan unchanged. /// /// The algorithm flow is simply like this: -/// 1. Visit nodes of the physical plan top-down and look for `SortExec` nodes. -/// 2. If a `SortExec` is found, iterate over its children recursively until an -/// executor that doesn't maintain ordering is encountered (or until a leaf node). -/// `RepartitionExec`s with multiple input partitions are considered as if they -/// maintain input ordering because they are potentially replaceable with -/// `SortPreservingRepartitionExec`s which maintain ordering. -/// 3_1. Replace the `RepartitionExec`s with multiple input partitions (which doesn't -/// maintain ordering) with a `SortPreservingRepartitionExec`. -/// 3_2. Otherwise, keep the plan as is. -/// 4. Check if the `SortExec` is still necessary in the updated plan by comparing +/// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. +/// 1_1. During the traversal, build an `ExecTree` to keep track of operators +/// that maintain ordering (or can maintain ordering when replaced by an +/// order-preserving variant) until a `SortExec` is found. +/// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing +/// operators that do not preserve ordering in the `ExecTree` with their order +/// preserving variants. +/// 3. Check if the `SortExec` is still necessary in the updated plan by comparing /// its input ordering with the output ordering it imposes. We do this because -/// replacing `RepartitionExec`s with `SortPreservingRepartitionExec`s enables us -/// to preserve the previously lost ordering during `RepartitionExec`s. -/// 5_1. If the `SortExec` in question turns out to be unnecessary, remove it and use -/// updated plan. Otherwise, use the original plan. -/// 6. Continue the top-down iteration until another `SortExec` is seen, or the iterations finish. -pub fn replace_repartition_execs( - plan: Arc, -) -> Result>> { - if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let changed_plan = replace_sort_children(&plan)?; - // Since we have a `SortExec` here, it's guaranteed that it has a single child. - let input = &changed_plan.children()[0]; - // Check if any child is changed, if so remove the `SortExec`. If the ordering - // is being satisfied with the child, then it means `SortExec` is unnecessary. +/// replacing operators that lose ordering with their order-preserving variants +/// enables us to preserve the previously lost ordering at the input of `SortExec`. +/// 4. If the `SortExec` in question turns out to be unnecessary, remove it and use +/// updated plan. Otherwise, use the original plan. +/// 5. Continue the bottom-up traversal until another `SortExec` is seen, or the traversal +/// is complete. +pub(crate) fn replace_with_order_preserving_variants( + requirements: OrderPreservationContext, + // A flag indicating that replacing `RepartitionExec`s with + // `SortPreservingRepartitionExec`s is desirable when it helps + // to remove a `SortExec` from the plan. If this flag is `false`, + // this replacement should only be made to fix the pipeline (streaming). + is_spr_better: bool, + // A flag indicating that replacing `CoalescePartitionsExec`s with + // `SortPreservingMergeExec`s is desirable when it helps to remove + // a `SortExec` from the plan. If this flag is `false`, this replacement + // should only be made to fix the pipeline (streaming). + is_spm_better: bool, +) -> Result> { + let plan = &requirements.plan; + let ordering_onwards = &requirements.ordering_onwards; + if is_sort(plan) { + let exec_tree = if let Some(exec_tree) = &ordering_onwards[0] { + exec_tree + } else { + return Ok(Transformed::No(requirements)); + }; + // For unbounded cases, replace with the order-preserving variant in + // any case, as doing so helps fix the pipeline. + let is_unbounded = unbounded_output(plan); + let updated_sort_input = get_updated_plan( + exec_tree, + is_spr_better || is_unbounded, + is_spm_better || is_unbounded, + )?; + // If this sort is unnecessary, we should remove it and update the plan: if ordering_satisfy( - input.output_ordering(), - sort_exec.output_ordering(), - || input.equivalence_properties(), - || input.ordering_equivalence_properties(), + updated_sort_input.output_ordering(), + plan.output_ordering(), + || updated_sort_input.equivalence_properties(), + || updated_sort_input.ordering_equivalence_properties(), ) { - Ok(Transformed::Yes(input.clone())) - } else { - Ok(Transformed::No(plan)) + return Ok(Transformed::Yes(OrderPreservationContext { + plan: updated_sort_input, + ordering_onwards: vec![None], + })); } - } else { - // We don't have anything to do until we get to the `SortExec` parent. - Ok(Transformed::No(plan)) } + + Ok(Transformed::No(requirements)) } #[cfg(test)] @@ -174,7 +293,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - /// Runs the `replace_repartition_execs` sub-rule and asserts the plan + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan /// against the original and expected plans. /// /// `$EXPECTED_PLAN_LINES`: input plan @@ -197,7 +316,10 @@ mod tests { let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); // Run the rule top-down - let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; + // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; + let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false))?; + let optimized_physical_plan = parallel.plan; // Get string representation of the plan let actual = get_plan_string(&optimized_physical_plan); @@ -276,11 +398,10 @@ mod tests { " FilterExec: c@2 > 3", " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -516,11 +637,10 @@ mod tests { " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = vec![ - "SortExec: expr=[a@0 ASC NULLS LAST]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -773,6 +893,8 @@ mod tests { }, true, 0, + b'"', + None, FileCompressionType::UNCOMPRESSED, )) } diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index c9da83d86b34..d8ee638e3be4 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -36,7 +36,9 @@ use crate::config::ConfigOptions; use crate::error::Result; -use crate::physical_optimizer::replace_repartition_execs::replace_repartition_execs; +use crate::physical_optimizer::replace_with_order_preserving_variants::{ + replace_with_order_preserving_variants, OrderPreservationContext, +}; use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ add_sort_above, find_indices, is_coalesce_partitions, is_limit, is_repartition, @@ -78,7 +80,7 @@ impl EnforceSorting { /// This object implements a tree that we use while keeping track of paths /// leading to [`SortExec`]s. #[derive(Debug, Clone)] -struct ExecTree { +pub(crate) struct ExecTree { /// The `ExecutionPlan` associated with this node pub plan: Arc, /// Child index of the plan in its parent @@ -367,11 +369,21 @@ impl PhysicalOptimizerRule for EnforceSorting { } else { adjusted.plan }; + let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); + let updated_plan = + plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + replace_with_order_preserving_variants( + plan_with_pipeline_fixer, + false, + true, + ) + })?; + // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(new_plan); + let sort_pushdown = SortPushDown::init(updated_plan.plan); let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; - adjusted.plan.transform_down(&replace_repartition_execs) + Ok(adjusted.plan) } fn name(&self) -> &str { @@ -985,7 +997,7 @@ mod tests { Linear, PartiallySorted, Sorted, }; use crate::physical_plan::{displayable, Partitioning}; - use crate::prelude::SessionContext; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::csv_exec_sorted; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -1386,10 +1398,12 @@ mod tests { /// `$EXPECTED_PLAN_LINES`: input plan /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan /// `$PLAN`: the plan to optimized + /// `REPARTITION_SORTS`: Flag to set `config.options.optimizer.repartition_sorts` option. /// macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { - let session_ctx = SessionContext::new(); + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr) => { + let config = SessionConfig::new().with_repartition_sorts($REPARTITION_SORTS); + let session_ctx = SessionContext::with_config(config); let state = session_ctx.state(); let physical_plan = $PLAN; @@ -1437,7 +1451,7 @@ mod tests { "SortExec: expr=[nullable_col@0 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1500,7 +1514,7 @@ mod tests { " SortExec: expr=[non_nullable_col@1 DESC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1521,7 +1535,7 @@ mod tests { "SortExec: expr=[nullable_col@0 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1547,7 +1561,7 @@ mod tests { "SortExec: expr=[nullable_col@0 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1586,7 +1600,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1630,7 +1644,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1687,7 +1701,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1718,7 +1732,7 @@ mod tests { " MemoryExec: partitions=0, partition_sizes=[]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1747,7 +1761,7 @@ mod tests { "SortExec: expr=[nullable_col@0 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1793,7 +1807,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1816,7 +1830,7 @@ mod tests { "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1843,7 +1857,7 @@ mod tests { "SortExec: expr=[non_nullable_col@1 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1870,7 +1884,7 @@ mod tests { ]; // should not add a sort at the output of the union, input plan should not be changed let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1901,7 +1915,7 @@ mod tests { ]; // should not add a sort at the output of the union, input plan should not be changed let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1941,7 +1955,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1986,7 +2000,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2031,7 +2045,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2081,7 +2095,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2133,7 +2147,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2171,7 +2185,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_output, physical_plan); + assert_optimized!(expected_input, expected_output, physical_plan, true); Ok(()) } @@ -2223,7 +2237,7 @@ mod tests { " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2274,7 +2288,7 @@ mod tests { " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2314,7 +2328,7 @@ mod tests { " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2369,7 +2383,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2442,7 +2456,7 @@ mod tests { ] } }; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); } Ok(()) } @@ -2518,7 +2532,7 @@ mod tests { ] } }; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); } Ok(()) } @@ -2562,7 +2576,7 @@ mod tests { " SortExec: expr=[col_a@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); // order by (nullable_col, col_b, col_a) let sort_exprs2 = vec![ @@ -2588,7 +2602,7 @@ mod tests { " SortExec: expr=[col_a@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2628,7 +2642,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2666,7 +2680,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2749,7 +2763,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", " MemoryExec: partitions=0, partition_sizes=[]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2780,7 +2794,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=false", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2810,7 +2824,37 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) + } + + #[tokio::test] + async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec(source); + let repartition_hash = Arc::new(RepartitionExec::try_new( + repartition_rr, + Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + )?) as _; + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + + let expected_input = vec![ + "SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + ]; + let expected_optimized = vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) } } diff --git a/datafusion/core/src/physical_plan/aggregates/order/partial.rs b/datafusion/core/src/physical_plan/aggregates/order/partial.rs index ac32c69fd568..019e61ef2688 100644 --- a/datafusion/core/src/physical_plan/aggregates/order/partial.rs +++ b/datafusion/core/src/physical_plan/aggregates/order/partial.rs @@ -108,9 +108,10 @@ impl GroupOrderingPartial { ordering: &[PhysicalSortExpr], ) -> Result { assert!(!order_indices.is_empty()); - assert_eq!(order_indices.len(), ordering.len()); + assert!(order_indices.len() <= ordering.len()); - let fields = ordering + // get only the section of ordering, that consist of group by expressions. + let fields = ordering[0..order_indices.len()] .iter() .map(|sort_expr| { Ok(SortField::new_with_options( diff --git a/datafusion/core/src/physical_plan/projection.rs b/datafusion/core/src/physical_plan/projection.rs index dac5227503d9..5c4b66114328 100644 --- a/datafusion/core/src/physical_plan/projection.rs +++ b/datafusion/core/src/physical_plan/projection.rs @@ -97,7 +97,7 @@ impl ProjectionExec { // construct a map from the input columns to the output columns of the Projection let mut columns_map: HashMap> = HashMap::new(); - for (expression, name) in expr.iter() { + for (expr_idx, (expression, name)) in expr.iter().enumerate() { if let Some(column) = expression.as_any().downcast_ref::() { // For some executors, logical and physical plan schema fields // are not the same. The information in a `Column` comes from @@ -107,11 +107,10 @@ impl ProjectionExec { let idx = column.index(); let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - let new_col_idx = schema.index_of(name)?; let entry = columns_map .entry(matching_input_column) .or_insert_with(Vec::new); - entry.push(Column::new(name, new_col_idx)); + entry.push(Column::new(name, expr_idx)); }; } diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 0dad1d30dd18..e8d571631bab 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -146,6 +146,13 @@ impl ExecutionPlan for SortPreservingMergeExec { Partitioning::UnknownPartitioning(1) } + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution] } diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 9b44ac615c33..7d9d70f4771e 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -539,7 +539,7 @@ mod tests { let my_count = create_udaf( "my_count", - DataType::Int64, + vec![DataType::Int64], Arc::new(DataType::Int64), Volatility::Immutable, Arc::new(|_| Ok(Box::new(MyCount(0)))), diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 838c13f96856..6e2bdfeeca89 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -89,6 +89,8 @@ pub fn scan_partitioned_csv(partitions: usize) -> Result> { config, true, b',', + b'"', + None, FileCompressionType::UNCOMPRESSED, ))) } @@ -348,6 +350,8 @@ pub fn csv_exec_sorted( }, false, 0, + 0, + None, FileCompressionType::UNCOMPRESSED, )) } diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 36872b7361bf..5b18d616b3f9 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -28,7 +28,7 @@ use datafusion::datasource::MemTable; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::physical_optimizer::pipeline_fixer::PipelineFixer; +use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::assert_contains; @@ -304,7 +304,7 @@ async fn run_streaming_test_with_config( // Disable all physical optimizer rules except the PipelineFixer rule to avoid sorts or // repartition, as they also have memory budgets that may be hit first let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![Arc::new(PipelineFixer::new())]); + .with_physical_optimizer_rules(vec![Arc::new(JoinSelection::new())]); // Create a new session context with the session state let ctx = SessionContext::with_state(state); diff --git a/datafusion/core/tests/sql/csv_files.rs b/datafusion/core/tests/sql/csv_files.rs new file mode 100644 index 000000000000..5ed0068d6135 --- /dev/null +++ b/datafusion/core/tests/sql/csv_files.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_custom_quote() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Utf8, false), + ])); + let filename = format!("partition.{}", "csv"); + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value{index:}"); + let data = format!("~{text1}~,~{text2}~\r\n"); + file.write_all(data.as_bytes())?; + } + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .quote(b'~'), + ) + .await?; + + let results = plan_and_collect(&ctx, "SELECT * from test").await?; + + let expected = vec![ + "+-----+--------+", + "| c1 | c2 |", + "+-----+--------+", + "| id0 | value0 |", + "| id1 | value1 |", + "| id2 | value2 |", + "| id3 | value3 |", + "| id4 | value4 |", + "| id5 | value5 |", + "| id6 | value6 |", + "| id7 | value7 |", + "| id8 | value8 |", + "| id9 | value9 |", + "+-----+--------+", + ]; + + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + +#[tokio::test] +async fn csv_custom_escape() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Utf8, false), + ])); + let filename = format!("partition.{}", "csv"); + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value\\\"{index:}"); + let data = format!("\"{text1}\",\"{text2}\"\r\n"); + file.write_all(data.as_bytes())?; + } + + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .escape(b'\\'), + ) + .await?; + + let results = plan_and_collect(&ctx, "SELECT * from test").await?; + + let expected = vec![ + "+-----+---------+", + "| c1 | c2 |", + "+-----+---------+", + "| id0 | value\"0 |", + "| id1 | value\"1 |", + "| id2 | value\"2 |", + "| id3 | value\"3 |", + "| id4 | value\"4 |", + "| id5 | value\"5 |", + "| id6 | value\"6 |", + "| id7 | value\"7 |", + "| id8 | value\"8 |", + "| id9 | value\"9 |", + "+-----+---------+", + ]; + + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} diff --git a/datafusion/core/tests/sql/information_schema.rs b/datafusion/core/tests/sql/information_schema.rs deleted file mode 100644 index 1cb518099174..000000000000 --- a/datafusion/core/tests/sql/information_schema.rs +++ /dev/null @@ -1,220 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use async_trait::async_trait; -use datafusion::execution::context::SessionState; -use datafusion::{ - catalog::{ - schema::{MemorySchemaProvider, SchemaProvider}, - CatalogProvider, MemoryCatalogProvider, - }, - datasource::{TableProvider, TableType}, -}; -use datafusion_expr::Expr; - -use super::*; - -#[tokio::test] -async fn information_schema_tables_tables_with_multiple_catalogs() { - let ctx = - SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - schema - .register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog - .register_schema("my_schema", Arc::new(schema)) - .unwrap(); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog - .register_schema("my_other_schema", Arc::new(schema)) - .unwrap(); - ctx.register_catalog("my_other_catalog", Arc::new(catalog)); - - let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+------------------+--------------------+-------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+------------------+--------------------+-------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| my_catalog | information_schema | columns | VIEW |", - "| my_catalog | information_schema | df_settings | VIEW |", - "| my_catalog | information_schema | tables | VIEW |", - "| my_catalog | information_schema | views | VIEW |", - "| my_catalog | my_schema | t1 | BASE TABLE |", - "| my_catalog | my_schema | t2 | BASE TABLE |", - "| my_other_catalog | information_schema | columns | VIEW |", - "| my_other_catalog | information_schema | df_settings | VIEW |", - "| my_other_catalog | information_schema | tables | VIEW |", - "| my_other_catalog | information_schema | views | VIEW |", - "| my_other_catalog | my_other_schema | t3 | BASE TABLE |", - "+------------------+--------------------+-------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); -} - -#[tokio::test] -async fn information_schema_tables_table_types() { - struct TestTable(TableType); - - #[async_trait] - impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn table_type(&self) -> TableType { - self.0 - } - - fn schema(&self) -> SchemaRef { - unimplemented!() - } - - async fn scan( - &self, - _state: &SessionState, - _: Option<&Vec>, - _: &[Expr], - _: Option, - ) -> Result> { - unimplemented!() - } - } - - let ctx = - SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - - ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) - .unwrap(); - ctx.register_table("query", Arc::new(TestTable(TableType::View))) - .unwrap(); - ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) - .unwrap(); - - let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+-------------+-----------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+-------------+-----------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| datafusion | public | physical | BASE TABLE |", - "| datafusion | public | query | VIEW |", - "| datafusion | public | temp | LOCAL TEMPORARY |", - "+---------------+--------------------+-------------+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &result); -} - -fn table_with_many_types() -> Arc { - let schema = Schema::new(vec![ - Field::new("int32_col", DataType::Int32, false), - Field::new("float64_col", DataType::Float64, true), - Field::new("utf8_col", DataType::Utf8, true), - Field::new("large_utf8_col", DataType::LargeUtf8, false), - Field::new("binary_col", DataType::Binary, false), - Field::new("large_binary_col", DataType::LargeBinary, false), - Field::new( - "timestamp_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Float64Array::from(vec![1.0])), - Arc::new(StringArray::from(vec![Some("foo")])), - Arc::new(LargeStringArray::from(vec![Some("bar")])), - Arc::new(BinaryArray::from(vec![b"foo" as &[u8]])), - Arc::new(LargeBinaryArray::from(vec![b"foo" as &[u8]])), - Arc::new(TimestampNanosecondArray::from(vec![Some(123)])), - ], - ) - .unwrap(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); - Arc::new(provider) -} - -#[tokio::test] -async fn information_schema_columns() { - let ctx = - SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - - schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - - schema - .register_table("t2".to_owned(), table_with_many_types()) - .unwrap(); - catalog - .register_schema("my_schema", Arc::new(schema)) - .unwrap(); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let result = plan_and_collect(&ctx, "SELECT * from information_schema.columns") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| my_catalog | my_schema | t1 | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | binary_col | 4 | | NO | Binary | | 2147483647 | | | | | |", - "| my_catalog | my_schema | t2 | float64_col | 1 | | YES | Float64 | | | 24 | 2 | | | |", - "| my_catalog | my_schema | t2 | int32_col | 0 | | NO | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | large_binary_col | 5 | | NO | LargeBinary | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | large_utf8_col | 3 | | NO | LargeUtf8 | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | timestamp_nanos | 6 | | NO | Timestamp(Nanosecond, None) | | | | | | | |", - "| my_catalog | my_schema | t2 | utf8_col | 2 | | YES | Utf8 | | 2147483647 | | | | | |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &result); -} - -/// Execute SQL and return results -async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { - ctx.sql(sql).await?.collect().await -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index ca0cfc3dbeff..85a806428548 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -80,10 +80,10 @@ pub mod aggregates; pub mod arrow_files; #[cfg(feature = "avro")] pub mod create_drop; +pub mod csv_files; pub mod explain_analyze; pub mod expr; pub mod group_by; -pub mod information_schema; pub mod joins; pub mod limit; pub mod order; diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index 0ecd5d0fde86..2907d468066e 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -234,7 +234,7 @@ async fn simple_udaf() -> Result<()> { // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( "my_avg", - DataType::Float64, + vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, Arc::new(|_| { @@ -291,7 +291,7 @@ async fn udaf_as_window_func() -> Result<()> { let my_acc = create_udaf( "my_acc", - DataType::Int32, + vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, Arc::new(|_| Ok(Box::new(MyAccumulator))), diff --git a/datafusion/core/tests/sqllogictests/src/main.rs b/datafusion/core/tests/sqllogictests/src/main.rs index 58089be24509..c74d1cb11a47 100644 --- a/datafusion/core/tests/sqllogictests/src/main.rs +++ b/datafusion/core/tests/sqllogictests/src/main.rs @@ -271,6 +271,14 @@ async fn context_for_test_file(relative_path: &Path) -> Option { info!("Registering scalar tables"); setup::register_scalar_tables(test_ctx.session_ctx()).await; } + "information_schema_table_types.slt" => { + info!("Registering local temporary table"); + setup::register_temp_table(test_ctx.session_ctx()).await; + } + "information_schema_columns.slt" => { + info!("Registering table with many types"); + setup::register_table_with_many_types(test_ctx.session_ctx()).await; + } "avro.slt" => { #[cfg(feature = "avro")] { diff --git a/datafusion/core/tests/sqllogictests/src/setup.rs b/datafusion/core/tests/sqllogictests/src/setup.rs index 34365f509a53..32569c7575ce 100644 --- a/datafusion/core/tests/sqllogictests/src/setup.rs +++ b/datafusion/core/tests/sqllogictests/src/setup.rs @@ -15,14 +15,25 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::TimeUnit; +use async_trait::async_trait; +use datafusion::execution::context::SessionState; +use datafusion::physical_plan::ExecutionPlan; use datafusion::{ arrow::{ - array::Float64Array, - datatypes::{DataType, Field, Schema}, + array::{ + BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, + }, + datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }, + catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, + datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_common::DataFusionError; +use datafusion_expr::Expr; use std::fs::File; use std::io::Write; use std::sync::Arc; @@ -116,3 +127,84 @@ pub async fn register_partition_table(test_ctx: &mut TestContext) { .await .unwrap(); } + +// registers a LOCAL TEMPORARY table. +pub async fn register_temp_table(ctx: &SessionContext) { + struct TestTable(TableType); + + #[async_trait] + impl TableProvider for TestTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn table_type(&self) -> TableType { + self.0 + } + + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + async fn scan( + &self, + _state: &SessionState, + _: Option<&Vec>, + _: &[Expr], + _: Option, + ) -> Result, DataFusionError> { + unimplemented!() + } + } + + ctx.register_table( + "datafusion.public.temp", + Arc::new(TestTable(TableType::Temporary)), + ) + .unwrap(); +} + +pub async fn register_table_with_many_types(ctx: &SessionContext) { + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + catalog + .register_schema("my_schema", Arc::new(schema)) + .unwrap(); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + ctx.register_table("my_catalog.my_schema.t2", table_with_many_types()) + .unwrap(); +} + +fn table_with_many_types() -> Arc { + let schema = Schema::new(vec![ + Field::new("int32_col", DataType::Int32, false), + Field::new("float64_col", DataType::Float64, true), + Field::new("utf8_col", DataType::Utf8, true), + Field::new("large_utf8_col", DataType::LargeUtf8, false), + Field::new("binary_col", DataType::Binary, false), + Field::new("large_binary_col", DataType::LargeBinary, false), + Field::new( + "timestamp_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Float64Array::from(vec![1.0])), + Arc::new(StringArray::from(vec![Some("foo")])), + Arc::new(LargeStringArray::from(vec![Some("bar")])), + Arc::new(BinaryArray::from(vec![b"foo" as &[u8]])), + Arc::new(LargeBinaryArray::from(vec![b"foo" as &[u8]])), + Arc::new(TimestampNanosecondArray::from(vec![Some(123)])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); + Arc::new(provider) +} diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 4a3d39bdebcf..5c82c7e0091d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -180,23 +180,29 @@ drop table foo statement ok create table foo as select - arrow_cast(100, 'Decimal128(5,2)') as col_d128 - -- Can't make a decimal 156: - -- This feature is not implemented: Can't create a scalar from array of type "Decimal256(3, 2)" - --arrow_cast(100, 'Decimal256(5,2)') as col_d256 + arrow_cast(100, 'Decimal128(5,2)') as col_d128, + arrow_cast(100, 'Decimal256(5,2)') as col_d256 ; ## Ensure each column in the table has the expected type -query T +query TT SELECT - arrow_typeof(col_d128) - -- arrow_typeof(col_d256), + arrow_typeof(col_d128), + arrow_typeof(col_d256) FROM foo; ---- -Decimal128(5, 2) +Decimal128(5, 2) Decimal256(5, 2) + +query RR +SELECT + col_d128, + col_d256 + FROM foo; +---- +100 100.00 statement ok drop table foo diff --git a/datafusion/core/tests/sqllogictests/test_files/binary.slt b/datafusion/core/tests/sqllogictests/test_files/binary.slt index ca55ff56cb1c..54499e29787b 100644 --- a/datafusion/core/tests/sqllogictests/test_files/binary.slt +++ b/datafusion/core/tests/sqllogictests/test_files/binary.slt @@ -45,6 +45,23 @@ FF01 ff01 Utf8 Binary ABC 0abc Utf8 Binary 000 0000 Utf8 Binary +# comparisons +query ?BBBB +SELECT + column2, + -- binary compare with string + column2 = 'ABC', + column2 <> 'ABC', + -- binary compared with binary + column2 = X'ABC', + column2 <> X'ABC' +FROM t; +---- +ff01 false true false true +0abc false true true false +0000 false true false true + + # predicates query T? SELECT column1, column2 @@ -127,3 +144,9 @@ SELECT column1, column1 = arrow_cast(X'0102', 'FixedSizeBinary(2)') FROM t # Comparison to different sized Binary query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation FixedSizeBinary\(3\) = Binary SELECT column1, column1 = X'0102' FROM t + +statement ok +drop table t_source + +statement ok +drop table t diff --git a/datafusion/core/tests/sqllogictests/test_files/decimal.slt b/datafusion/core/tests/sqllogictests/test_files/decimal.slt index f41351774172..8fd08f87c849 100644 --- a/datafusion/core/tests/sqllogictests/test_files/decimal.slt +++ b/datafusion/core/tests/sqllogictests/test_files/decimal.slt @@ -612,3 +612,12 @@ insert into foo VALUES (1, 5); query error DataFusion error: Arrow error: Compute error: Overflow happened on: 100000000000000000000 \* 100000000000000000000000000000000000000 select a / b from foo; + +statement ok +create table t as values (arrow_cast(123, 'Decimal256(5,2)')); + +query error DataFusion error: Internal error: Operator \+ is not implemented for types Decimal256\(None,15,2\) and Decimal256\(Some\(12300\),15,2\)\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +select AVG(column1) from t; + +statement ok +drop table t; diff --git a/datafusion/core/tests/sqllogictests/test_files/explain.slt b/datafusion/core/tests/sqllogictests/test_files/explain.slt index 56f2fdf10a8d..bd3513550a4e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/explain.slt +++ b/datafusion/core/tests/sqllogictests/test_files/explain.slt @@ -240,7 +240,6 @@ logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE -physical_plan after PipelineFixer SAME TEXT AS ABOVE physical_plan after repartition SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE diff --git a/datafusion/core/tests/sqllogictests/test_files/groupby.slt b/datafusion/core/tests/sqllogictests/test_files/groupby.slt index dae48a9464ae..de57956f0ea8 100644 --- a/datafusion/core/tests/sqllogictests/test_files/groupby.slt +++ b/datafusion/core/tests/sqllogictests/test_files/groupby.slt @@ -1922,7 +1922,59 @@ SELECT DISTINCT + col1 FROM tab2 AS cor0 GROUP BY cor0.col1 59 61 - +# query below should work in multi partition, successfully. +query II +SELECT l.col0, LAST_VALUE(r.col1 ORDER BY r.col0) as last_col1 +FROM tab0 as l +JOIN tab0 as r +ON l.col0 = r.col0 +GROUP BY l.col0, l.col1, l.col2 +ORDER BY l.col0; +---- +26 0 +43 81 +83 0 + +# assert that above query works in indeed multi partitions +# physical plan for this query should contain RepartitionExecs. +# Aggregation should be in two stages, Partial + FinalPartitioned stages. +query TT +EXPLAIN SELECT l.col0, LAST_VALUE(r.col1 ORDER BY r.col0) as last_col1 +FROM tab0 as l +JOIN tab0 as r +ON l.col0 = r.col0 +GROUP BY l.col0, l.col1, l.col2 +ORDER BY l.col0; +---- +logical_plan +Sort: l.col0 ASC NULLS LAST +--Projection: l.col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST] AS last_col1 +----Aggregate: groupBy=[[l.col0, l.col1, l.col2]], aggr=[[LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]]] +------Inner Join: l.col0 = r.col0 +--------SubqueryAlias: l +----------TableScan: tab0 projection=[col0, col1, col2] +--------SubqueryAlias: r +----------TableScan: tab0 projection=[col0, col1] +physical_plan +SortPreservingMergeExec: [col0@0 ASC NULLS LAST] +--SortExec: expr=[col0@0 ASC NULLS LAST] +----ProjectionExec: expr=[col0@0 as col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] +------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] +--------SortExec: expr=[col0@3 ASC NULLS LAST] +----------CoalesceBatchesExec: target_batch_size=8192 +------------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallyOrdered +----------------SortExec: expr=[col0@3 ASC NULLS LAST] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 +--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 +--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2568,6 +2620,52 @@ TUR 100 75 175 GRC 80 30 110 FRA 200 50 250 +query TT +EXPLAIN SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate +FROM sales_global AS s +JOIN sales_global AS e + ON s.currency = e.currency AND + s.ts >= e.ts +GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency +ORDER BY s.sn +---- +logical_plan +Sort: s.sn ASC NULLS LAST +--Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST] AS last_rate +----Aggregate: groupBy=[[s.sn, s.zip_code, s.country, s.ts, s.currency]], aggr=[[LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]]] +------Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, e.sn, e.amount +--------Inner Join: s.currency = e.currency Filter: s.ts >= e.ts +----------SubqueryAlias: s +------------TableScan: sales_global projection=[zip_code, country, sn, ts, currency] +----------SubqueryAlias: e +------------TableScan: sales_global projection=[sn, ts, currency, amount] +physical_plan +SortExec: expr=[sn@2 ASC NULLS LAST] +--ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] +----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] +------SortExec: expr=[sn@5 ASC NULLS LAST] +--------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, sn@5 as sn, amount@8 as amount] +----------CoalesceBatchesExec: target_batch_size=8192 +------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@4, currency@2)], filter=ts@0 >= ts@1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query ITIPTR +SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate +FROM sales_global AS s +JOIN sales_global AS e + ON s.currency = e.currency AND + s.ts >= e.ts +GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency +ORDER BY s.sn +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 4 2022-01-03T10:00:00 TRY 100 + # Run order-sensitive aggregators in multiple partitions statement ok set datafusion.execution.target_partitions = 8; @@ -2847,3 +2945,19 @@ SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FRA [200.0, 50.0] 50 50 GRC [80.0, 30.0] 30 30 TUR [100.0, 75.0] 75 75 + +query ITIPTR +SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate +FROM sales_global AS s +JOIN sales_global AS e + ON s.currency = e.currency AND + s.ts >= e.ts +GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency +ORDER BY s.sn +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 4 2022-01-03T10:00:00 TRY 100 diff --git a/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt b/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt new file mode 100644 index 000000000000..fcb653cedd16 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +set datafusion.catalog.information_schema = true; + +statement ok +set datafusion.catalog.default_catalog = my_catalog; + +statement ok +set datafusion.catalog.default_schema = my_schema; + +########### +# Information schema columns +########### + +statement ok +CREATE TABLE t1 (i int) as values(1); + +# table t2 is created using rust code because it is not possible to set nullable columns with `arrow_cast` syntax + +query TTTTITTTIIIIIIT rowsort +SELECT * from information_schema.columns; +---- +my_catalog my_schema t1 i 0 NULL YES Int32 NULL NULL 32 2 NULL NULL NULL +my_catalog my_schema t2 binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL +my_catalog my_schema t2 float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL +my_catalog my_schema t2 int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL +my_catalog my_schema t2 large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema t2 large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema t2 timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL +my_catalog my_schema t2 utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL + +# Cleanup +statement ok +drop table t1 + +statement ok +drop table t2 \ No newline at end of file diff --git a/datafusion/core/tests/sqllogictests/test_files/information_schema_multiple_catalogs.slt b/datafusion/core/tests/sqllogictests/test_files/information_schema_multiple_catalogs.slt new file mode 100644 index 000000000000..c7f4dcfd54d8 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/information_schema_multiple_catalogs.slt @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Verify the information schema does not exit by default +statement error DataFusion error: Error during planning: table 'datafusion.information_schema.tables' not found +SELECT * from information_schema.tables + +statement error DataFusion error: Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled +show all + +# Turn it on + +# expect that the queries now work +statement ok +set datafusion.catalog.information_schema = true; + +# Verify the information schema now does exist and is empty +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +############ +# Create multiple catalogs +########### +statement ok +create database my_catalog; + +statement ok +create schema my_catalog.my_schema; + +statement ok +set datafusion.catalog.default_catalog = my_catalog; + +statement ok +set datafusion.catalog.default_schema = my_schema; + +statement ok +create table t1 as values(1); + +statement ok +create table t2 as values(1); + +statement ok +create database my_other_catalog; + +statement ok +create schema my_other_catalog.my_other_schema; + +statement ok +set datafusion.catalog.default_catalog = my_other_catalog; + +statement ok +set datafusion.catalog.default_schema = my_other_schema; + +statement ok +create table t3 as values(1); + +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW +my_catalog information_schema columns VIEW +my_catalog information_schema df_settings VIEW +my_catalog information_schema tables VIEW +my_catalog information_schema views VIEW +my_catalog my_schema t1 BASE TABLE +my_catalog my_schema t2 BASE TABLE +my_other_catalog information_schema columns VIEW +my_other_catalog information_schema df_settings VIEW +my_other_catalog information_schema tables VIEW +my_other_catalog information_schema views VIEW +my_other_catalog my_other_schema t3 BASE TABLE + +# Cleanup + +statement ok +drop table t3 + +statement ok +set datafusion.catalog.default_catalog = my_catalog; + +statement ok +set datafusion.catalog.default_schema = my_schema; + +statement ok +drop table t1 + +statement ok +drop table t2 diff --git a/datafusion/core/tests/sqllogictests/test_files/information_schema_table_types.slt b/datafusion/core/tests/sqllogictests/test_files/information_schema_table_types.slt new file mode 100644 index 000000000000..eb72f3399fe7 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/information_schema_table_types.slt @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Turn it on +statement ok +set datafusion.catalog.information_schema = true; + +############ +# Table with many types +############ + +statement ok +create table physical as values(1); + +statement ok +create view query as values(1); + +# Temporary tables cannot be created using SQL syntax so it is done using Rust code. + +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW +datafusion public physical BASE TABLE +datafusion public query VIEW +datafusion public temp LOCAL TEMPORARY + +# Cleanup + +statement ok +drop table physical + +statement ok +drop view query diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/insert.slt index faa519834c6f..9f4122ac5ba9 100644 --- a/datafusion/core/tests/sqllogictests/test_files/insert.slt +++ b/datafusion/core/tests/sqllogictests/test_files/insert.slt @@ -299,4 +299,4 @@ select * from table_without_values; 2 NULL statement ok -drop table table_without_values; \ No newline at end of file +drop table table_without_values; diff --git a/datafusion/core/tests/sqllogictests/test_files/interval.slt b/datafusion/core/tests/sqllogictests/test_files/interval.slt index 1016cb155e4e..043f63958d1b 100644 --- a/datafusion/core/tests/sqllogictests/test_files/interval.slt +++ b/datafusion/core/tests/sqllogictests/test_files/interval.slt @@ -430,15 +430,11 @@ select '1 month'::interval + '1980-01-01T12:00:00'::timestamp; ---- 1980-02-01T12:00:00 -query D +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select '1 month'::interval - '1980-01-01'::date; ----- -1979-12-01 -query P +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select '1 month'::interval - '1980-01-01T12:00:00'::timestamp; ----- -1979-12-01T12:00:00 # interval (array) + date / timestamp (array) query D @@ -456,19 +452,11 @@ select i + ts from t; 2000-02-01T00:01:00 # expected error interval (array) - date / timestamp (array) -query D +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select i - d from t; ----- -1979-12-01 -1990-09-30 -1980-01-02 -query P +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select i - ts from t; ----- -1999-12-01T00:00:00 -1999-12-31T12:11:10 -2000-01-31T23:59:00 # interval (scalar) + date / timestamp (array) @@ -487,19 +475,11 @@ select '1 month'::interval + ts from t; 2000-03-01T00:00:00 # expected error interval (scalar) - date / timestamp (array) -query D +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select '1 month'::interval - d from t; ----- -1979-12-01 -1990-09-01 -1979-12-02 -query P +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select '1 month'::interval - ts from t; ----- -1999-12-01T00:00:00 -1999-12-01T12:11:10 -2000-01-01T00:00:00 # interval + date query D diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/core/tests/sqllogictests/test_files/scalar.slt index 6e563a671d86..d5ce7737fba0 100644 --- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt +++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt @@ -918,6 +918,25 @@ select trunc(4.267, 3), trunc(1.1234, 2), trunc(-1.1231, 6), trunc(1.2837284, 2) ---- 4.267 1.12 -1.1231 1.28 1 +# trunc with negative precision should truncate digits left of decimal +query R +select trunc(12345.678, -3); +---- +12000 + +# trunc with columns and precision +query RRR rowsort +select + trunc(sqrt(abs(a)), 3) as a3, + trunc(sqrt(abs(a)), 1) as a1, + trunc(arrow_cast(sqrt(abs(a)), 'Float64'), 3) as a3_f64 +from small_floats; +---- +0.447 0.4 0.447 +0.707 0.7 0.707 +0.837 0.8 0.837 +1 1 1 + ## bitwise and # bitwise and with column and scalar @@ -1497,7 +1516,6 @@ true true false true true true # csv query boolean gt gt eq query BBBBBB rowsort SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1 ----- ---- NULL NULL NULL NULL NULL NULL NULL false NULL false NULL NULL @@ -1512,10 +1530,10 @@ true true false true true true # csv query boolean distinct from query BBBBBB rowsort SELECT a, b, - a is distinct from b as df, - b is distinct from true as df_scalar, - a is not distinct from b as ndf, - a is not distinct from true as ndf_scalar + a is distinct from b as df, + b is distinct from true as df_scalar, + a is not distinct from b as ndf, + a is not distinct from true as ndf_scalar FROM t1 ---- NULL NULL false true true false diff --git a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt b/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt index 8b329df0c138..aa1e6826eca5 100644 --- a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt +++ b/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt @@ -43,13 +43,9 @@ SELECT '2023-05-01 12:30:00'::timestamp - interval '1 month'; 2023-04-01T12:30:00 # interval - date -query D +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select interval '1 month' - '2023-05-01'::date; ----- -2023-04-01 # interval - timestamp -query P +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types SELECT interval '1 month' - '2023-05-01 12:30:00'::timestamp; ----- -2023-04-01T12:30:00 diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/core/tests/sqllogictests/test_files/window.slt index 55c4d0189301..444ba73386a8 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/core/tests/sqllogictests/test_files/window.slt @@ -1943,6 +1943,9 @@ e 20 e 21 # test_window_agg_global_sort_parallelize_sort_disabled +# even if, parallelize sort is disabled, we should use SortPreservingMergeExec +# instead of CoalescePartitionsExec + SortExec stack. Because at the end +# we already have the desired ordering. statement ok set datafusion.optimizer.repartition_sorts = false; @@ -1955,15 +1958,14 @@ Sort: aggregate_test_100.c1 ASC NULLS LAST ----WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1] physical_plan -SortExec: expr=[c1@0 ASC NULLS LAST] ---CoalescePartitionsExec -----ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -------BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] ---------SortExec: expr=[c1@0 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +--ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +------SortExec: expr=[c1@0 ASC NULLS LAST] +--------CoalesceBatchesExec: target_batch_size=4096 +----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true statement ok set datafusion.optimizer.repartition_sorts = true; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 6914de686aa3..2cb1cf544176 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -436,7 +436,15 @@ impl BuiltinScalarFunction { ) } - /// Returns the dimension [`DataType`] of [`DataType::List`]. + /// Returns the dimension [`DataType`] of [`DataType::List`] if + /// treated as a N-dimensional array. + /// + /// ## Examples: + /// + /// * `Int64` has dimension 1 + /// * `List(Int64)` has dimension 2 + /// * `List(List(Int64))` has dimension 3 + /// * etc. fn return_dimension(self, input_expr_type: DataType) -> u64 { let mut res: u64 = 1; let mut current_data_type = input_expr_type; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 418aa8d8f8a9..a48b5e0beeab 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -801,7 +801,7 @@ pub fn create_udf( /// The signature and state type must match the `Accumulator's implementation`. pub fn create_udaf( name: &str, - input_type: DataType, + input_type: Vec, return_type: Arc, volatility: Volatility, accumulator: AccumulatorFactoryFunction, @@ -811,7 +811,7 @@ pub fn create_udaf( let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); AggregateUDF::new( name, - &Signature::exact(vec![input_type], volatility), + &Signature::exact(input_type, volatility), &return_type, &accumulator, &state_type, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 1fccdcbd2ca1..dec2eb7f1238 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,6 +17,7 @@ use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; use datafusion_common::{DataFusionError, Result}; use std::ops::Deref; @@ -360,6 +361,12 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } DataType::Dictionary(_, dict_value_type) => { sum_return_type(dict_value_type.as_ref()) } @@ -423,6 +430,13 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_return_type(dict_value_type.as_ref()) @@ -441,6 +455,11 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } + DataType::Decimal256(precision, scale) => { + // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index b6392e2a6be2..9ebea19a1631 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -120,7 +120,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result Operator::Divide| Operator::Modulo => { // TODO: this logic would be easier to follow if the functions were inlined - if let Some(ret) = mathematics_temporal_result_type(lhs, rhs) { + if let Some(ret) = mathematics_temporal_result_type(lhs, rhs, op) { // Temporal arithmetic, e.g. Date32 + Interval Ok(Signature{ lhs: lhs.clone(), @@ -130,7 +130,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } else if let Some(coerced) = temporal_coercion(lhs, rhs) { // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp - let ret = mathematics_temporal_result_type(&coerced, &coerced).ok_or_else(|| { + let ret = mathematics_temporal_result_type(&coerced, &coerced, op).ok_or_else(|| { DataFusionError::Plan(format!( "Cannot get result type for temporal operation {coerced} {op} {coerced}" )) @@ -169,6 +169,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result fn mathematics_temporal_result_type( lhs_type: &DataType, rhs_type: &DataType, + op: &Operator, ) -> Option { use arrow::datatypes::DataType::*; use arrow::datatypes::IntervalUnit::*; @@ -176,12 +177,14 @@ fn mathematics_temporal_result_type( match (lhs_type, rhs_type) { // datetime +/- interval - (Interval(_), Timestamp(_, _)) => Some(rhs_type.clone()), - (Timestamp(_, _), Interval(_)) => Some(lhs_type.clone()), - (Interval(_), Date32) => Some(rhs_type.clone()), - (Date32, Interval(_)) => Some(lhs_type.clone()), - (Interval(_), Date64) => Some(rhs_type.clone()), - (Date64, Interval(_)) => Some(lhs_type.clone()), + (Timestamp(_, _) | Date32 | Date64, Interval(_)) => Some(lhs_type.clone()), + (Interval(_), Timestamp(_, _) | Date32 | Date64) => { + if matches!(op, Operator::Plus) { + Some(rhs_type.clone()) + } else { + None + } + } // interval +/- (Interval(l), Interval(h)) if l == h => Some(lhs_type.clone()), (Interval(_), Interval(_)) => Some(Interval(MonthDayNano)), @@ -315,6 +318,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option } } +/// Coercion rules for Binaries: the type that both lhs and rhs can be +/// casted to for the purpose of a computation +fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Binary | Utf8, Binary) | (Binary, Utf8) => Some(Binary), + (LargeBinary | Binary | Utf8 | LargeUtf8, LargeBinary) + | (LargeBinary, Binary | Utf8 | LargeUtf8) => Some(LargeBinary), + _ => None, + } +} + /// coercion rules for like operations. /// This is a union of string coercion rules and dictionary coercion rules pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -1036,10 +1052,13 @@ mod tests { let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Utf8)); - // Can not coerce values of Binary to int, cannot support this + // Since we can coerce values of Utf8 to Binary can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None); + assert_eq!( + dictionary_coercion(&lhs_type, &rhs_type, true), + Some(Binary) + ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; @@ -1440,6 +1459,70 @@ mod tests { DataType::Decimal128(15, 3) ); + // Binary + test_coercion_binary_rule!( + DataType::Binary, + DataType::Binary, + Operator::Eq, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Binary, + Operator::Eq, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::Utf8, + Operator::Eq, + DataType::Binary + ); + + // LargeBinary + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::Binary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::Utf8, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeUtf8, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeUtf8, + Operator::Eq, + DataType::LargeBinary + ); + // TODO add other data type Ok(()) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8279f68e3ce8..2bd7fd67d8c9 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -890,7 +890,7 @@ mod test { let empty = empty(); let my_avg = create_udaf( "MY_AVG", - DataType::Float64, + vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, Arc::new(|_| { diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index a1d77a2d8849..9c01093edf5f 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -77,12 +77,12 @@ impl Avg { // the internal sum data type of avg just support FLOAT64 and Decimal data type. assert!(matches!( sum_data_type, - DataType::Float64 | DataType::Decimal128(_, _) + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) )); // the result of avg just support FLOAT64 and Decimal data type. assert!(matches!( rt_data_type, - DataType::Float64 | DataType::Decimal128(_, _) + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) )); Self { name: name.into(), diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 45e2be7fb4c6..9ac90cef4bab 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -28,6 +28,7 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::array::Array; use arrow::array::Decimal128Array; +use arrow::array::Decimal256Array; use arrow::compute; use arrow::compute::kernels::cast; use arrow::datatypes::DataType; @@ -39,8 +40,8 @@ use arrow::{ datatypes::Field, }; use arrow_array::types::{ - Decimal128Type, Float32Type, Float64Type, Int32Type, Int64Type, UInt32Type, - UInt64Type, + Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, Int64Type, + UInt32Type, UInt64Type, }; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -169,6 +170,10 @@ impl AggregateExpr for Sum { instantiate_primitive_accumulator!(self, Decimal128Type, |x, y| x .add_assign(y)) } + DataType::Decimal256(_, _) => { + instantiate_primitive_accumulator!(self, Decimal256Type, |x, y| *x = + *x + y) + } _ => Err(DataFusionError::NotImplemented(format!( "GroupsAccumulator not supported for {}: {}", self.name, self.data_type @@ -250,6 +255,16 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: i8) -> Result Result { + let array = downcast_value!(values, Decimal256Array); + let result = compute::sum(array); + Ok(ScalarValue::Decimal256(result, precision, scale)) +} + // sums the array and returns a ScalarValue of its corresponding type. pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result { // TODO refine the cast kernel in arrow-rs @@ -263,6 +278,9 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result { sum_decimal_batch(values, *precision, *scale)? } + DataType::Decimal256(precision, scale) => { + sum_decimal256_batch(values, *precision, *scale)? + } DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32), DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8192a403d33e..f00c18405634 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -88,6 +88,10 @@ message ProjectionColumns { message CsvFormat { bool has_header = 1; string delimiter = 2; + string quote = 3; + oneof optional_escape { + string escape = 4; + } } message ParquetFormat { @@ -908,6 +912,8 @@ message ScalarValue{ //WAS: ScalarType null_list_value = 18; Decimal128 decimal128_value = 20; + Decimal256 decimal256_value = 39; + int64 date_64_value = 21; int32 interval_yearmonth_value = 24; int64 interval_daytime_value = 25; @@ -934,6 +940,12 @@ message Decimal128{ int64 s = 3; } +message Decimal256{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + // Serialized data type message ArrowType{ oneof arrow_type_enum { @@ -1261,6 +1273,10 @@ message CsvScanExecNode { FileScanExecConf base_conf = 1; bool has_header = 2; string delimiter = 3; + string quote = 4; + oneof optional_escape { + string escape = 5; + } } message AvroScanExecNode { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 01a192324b28..566dffb5350a 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -124,7 +124,7 @@ impl Serializeable for Expr { fn udaf(&self, name: &str) -> Result> { Ok(Arc::new(create_udaf( name, - arrow::datatypes::DataType::Null, + vec![arrow::datatypes::DataType::Null], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, Arc::new(|_| unimplemented!()), diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index ed826f587413..cbbb469f0863 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -17,26 +17,22 @@ use datafusion_common::{DataFusionError, Result}; -pub fn csv_delimiter_to_string(b: u8) -> Result { - let b = &[b]; - let b = std::str::from_utf8(b) - .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; - Ok(b.to_owned()) -} - -pub fn str_to_byte(s: &String) -> Result { +pub(crate) fn str_to_byte(s: &String, description: &str) -> Result { if s.len() != 1 { - return Err(DataFusionError::Internal( - "Invalid CSV delimiter".to_owned(), - )); + return Err(DataFusionError::Internal(format!( + "Invalid CSV {description}: expected single character, got {s}" + ))); } Ok(s.as_bytes()[0]) } -pub fn byte_to_string(b: u8) -> Result { +pub(crate) fn byte_to_string(b: u8, description: &str) -> Result { let b = &[b]; - let b = std::str::from_utf8(b) - .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; + let b = std::str::from_utf8(b).map_err(|_| { + DataFusionError::Internal(format!( + "Invalid CSV {description}: can not represent {b:0x?} as utf8" + )) + })?; Ok(b.to_owned()) } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 05bfbd089dfe..2bac658f0496 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4191,6 +4191,12 @@ impl serde::Serialize for CsvFormat { if !self.delimiter.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if self.optional_escape.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CsvFormat", len)?; if self.has_header { struct_ser.serialize_field("hasHeader", &self.has_header)?; @@ -4198,6 +4204,16 @@ impl serde::Serialize for CsvFormat { if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if let Some(v) = self.optional_escape.as_ref() { + match v { + csv_format::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; + } + } + } struct_ser.end() } } @@ -4211,12 +4227,16 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { "has_header", "hasHeader", "delimiter", + "quote", + "escape", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { HasHeader, Delimiter, + Quote, + Escape, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4240,6 +4260,8 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { match value { "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4261,6 +4283,8 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { { let mut has_header__ = None; let mut delimiter__ = None; + let mut quote__ = None; + let mut optional_escape__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::HasHeader => { @@ -4275,11 +4299,25 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { } delimiter__ = Some(map.next_value()?); } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map.next_value()?); + } + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + optional_escape__ = map.next_value::<::std::option::Option<_>>()?.map(csv_format::OptionalEscape::Escape); + } } } Ok(CsvFormat { has_header: has_header__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + optional_escape: optional_escape__, }) } } @@ -4303,6 +4341,12 @@ impl serde::Serialize for CsvScanExecNode { if !self.delimiter.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if self.optional_escape.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CsvScanExecNode", len)?; if let Some(v) = self.base_conf.as_ref() { struct_ser.serialize_field("baseConf", v)?; @@ -4313,6 +4357,16 @@ impl serde::Serialize for CsvScanExecNode { if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if let Some(v) = self.optional_escape.as_ref() { + match v { + csv_scan_exec_node::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; + } + } + } struct_ser.end() } } @@ -4328,6 +4382,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "has_header", "hasHeader", "delimiter", + "quote", + "escape", ]; #[allow(clippy::enum_variant_names)] @@ -4335,6 +4391,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { BaseConf, HasHeader, Delimiter, + Quote, + Escape, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4359,6 +4417,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4381,6 +4441,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { let mut base_conf__ = None; let mut has_header__ = None; let mut delimiter__ = None; + let mut quote__ = None; + let mut optional_escape__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::BaseConf => { @@ -4401,12 +4463,26 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { } delimiter__ = Some(map.next_value()?); } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map.next_value()?); + } + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + optional_escape__ = map.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalEscape::Escape); + } } } Ok(CsvScanExecNode { base_conf: base_conf__, has_header: has_header__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + optional_escape: optional_escape__, }) } } @@ -4983,6 +5059,137 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Decimal256 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Decimal256", len)?; + if !self.value.is_empty() { + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Decimal256") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DfField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -19125,6 +19332,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } + scalar_value::Value::Decimal256Value(v) => { + struct_ser.serialize_field("decimal256Value", v)?; + } scalar_value::Value::Date64Value(v) => { struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; } @@ -19218,6 +19428,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "listValue", "decimal128_value", "decimal128Value", + "decimal256_value", + "decimal256Value", "date_64_value", "date64Value", "interval_yearmonth_value", @@ -19270,6 +19482,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Time32Value, ListValue, Decimal128Value, + Decimal256Value, Date64Value, IntervalYearmonthValue, IntervalDaytimeValue, @@ -19324,6 +19537,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), "listValue" | "list_value" => Ok(GeneratedField::ListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), + "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), "intervalYearmonthValue" | "interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue), "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), @@ -19471,6 +19685,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("decimal128Value")); } value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value) +; + } + GeneratedField::Decimal256Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal256Value")); + } + value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value) ; } GeneratedField::Date64Value => { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f50754494d1d..801162ab1910 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -122,6 +122,19 @@ pub struct CsvFormat { pub has_header: bool, #[prost(string, tag = "2")] pub delimiter: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub quote: ::prost::alloc::string::String, + #[prost(oneof = "csv_format::OptionalEscape", tags = "4")] + pub optional_escape: ::core::option::Option, +} +/// Nested message and enum types in `CsvFormat`. +pub mod csv_format { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum OptionalEscape { + #[prost(string, tag = "4")] + Escape(::prost::alloc::string::String), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1097,7 +1110,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1146,6 +1159,8 @@ pub mod scalar_value { ListValue(super::ScalarListValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), + #[prost(message, tag = "39")] + Decimal256Value(super::Decimal256), #[prost(int64, tag = "21")] Date64Value(i64), #[prost(int32, tag = "24")] @@ -1188,6 +1203,16 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} /// Serialized data type #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1772,6 +1797,19 @@ pub struct CsvScanExecNode { pub has_header: bool, #[prost(string, tag = "3")] pub delimiter: ::prost::alloc::string::String, + #[prost(string, tag = "4")] + pub quote: ::prost::alloc::string::String, + #[prost(oneof = "csv_scan_exec_node::OptionalEscape", tags = "5")] + pub optional_escape: ::core::option::Option, +} +/// Nested message and enum types in `CsvScanExecNode`. +pub mod csv_scan_exec_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum OptionalEscape { + #[prost(string, tag = "5")] + Escape(::prost::alloc::string::String), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 674588692d98..71a1bf87db6e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -26,7 +26,7 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, }; use datafusion::execution::registry::FunctionRegistry; @@ -648,6 +648,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { val.s as i8, ) } + Value::Decimal256Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal256( + Some(i256::from_be_bytes(array)), + val.p as u8, + val.s as i8, + ) + } Value::Date64Value(v) => Self::Date64(Some(*v)), Value::Time32Value(v) => { let time_value = diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index b25f470f8dec..405c58c20baf 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -348,11 +348,17 @@ impl AsLogicalPlan for LogicalPlanNode { FileFormatType::Csv(protobuf::CsvFormat { has_header, delimiter, - }) => Arc::new( - CsvFormat::default() - .with_has_header(*has_header) - .with_delimiter(str_to_byte(delimiter)?), - ), + quote, + optional_escape + }) => { + let mut csv = CsvFormat::default() + .with_has_header(*has_header) + .with_delimiter(str_to_byte(delimiter, "delimiter")?) + .with_quote(str_to_byte(quote, "quote")?); + if let Some(protobuf::csv_format::OptionalEscape::Escape(escape)) = optional_escape { + csv = csv.with_quote(str_to_byte(escape, "escape")?); + } + Arc::new(csv)}, FileFormatType::Avro(..) => Arc::new(AvroFormat), }; @@ -844,8 +850,16 @@ impl AsLogicalPlan for LogicalPlanNode { FileFormatType::Parquet(protobuf::ParquetFormat {}) } else if let Some(csv) = any.downcast_ref::() { FileFormatType::Csv(protobuf::CsvFormat { - delimiter: byte_to_string(csv.delimiter())?, + delimiter: byte_to_string(csv.delimiter(), "delimiter")?, has_header: csv.has_header(), + quote: byte_to_string(csv.quote(), "quote")?, + optional_escape: if let Some(escape) = csv.escape() { + Some(protobuf::csv_format::OptionalEscape::Escape( + byte_to_string(escape, "escape")?, + )) + } else { + None + }, }) } else if any.is::() { FileFormatType::Avro(protobuf::AvroFormat {}) @@ -2674,7 +2688,7 @@ mod roundtrip_tests { // the name; used to represent it in plan descriptions and in the registry, to use in SQL. "dummy_agg", // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. - DataType::Float64, + vec![DataType::Float64], // the return type; DataFusion expects this to match the type returned by `evaluate`. Arc::new(DataType::Float64), Volatility::Immutable, @@ -2860,7 +2874,7 @@ mod roundtrip_tests { // the name; used to represent it in plan descriptions and in the registry, to use in SQL. "dummy_agg", // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. - DataType::Float64, + vec![DataType::Float64], // the return type; DataFusion expects this to match the type returned by `evaluate`. Arc::new(DataType::Float64), Volatility::Immutable, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 072bc84d5452..f1a961576128 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1148,6 +1148,24 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { )), }), }, + ScalarValue::Decimal256(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal256Value(protobuf::Decimal256 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, ScalarValue::Date64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date64Value(*s)) } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 9e9b391a6cae..e97a773d3472 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -51,8 +51,8 @@ use datafusion_common::{DataFusionError, Result}; use prost::bytes::BufMut; use prost::Message; -use crate::common::proto_error; -use crate::common::{csv_delimiter_to_string, str_to_byte}; +use crate::common::str_to_byte; +use crate::common::{byte_to_string, proto_error}; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_physical_sort_expr, parse_protobuf_file_scan_config, }; @@ -155,7 +155,16 @@ impl AsExecutionPlan for PhysicalPlanNode { registry, )?, scan.has_header, - str_to_byte(&scan.delimiter)?, + str_to_byte(&scan.delimiter, "delimiter")?, + str_to_byte(&scan.quote, "quote")?, + if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape( + escape, + )) = &scan.optional_escape + { + Some(str_to_byte(escape, "escape")?) + } else { + None + }, FileCompressionType::UNCOMPRESSED, ))), PhysicalPlanType::ParquetScan(scan) => { @@ -1070,7 +1079,15 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::CsvScanExecNode { base_conf: Some(exec.base_config().try_into()?), has_header: exec.has_header(), - delimiter: csv_delimiter_to_string(exec.delimiter())?, + delimiter: byte_to_string(exec.delimiter(), "delimiter")?, + quote: byte_to_string(exec.quote(), "quote")?, + optional_escape: if let Some(escape) = exec.escape() { + Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape( + byte_to_string(escape, "escape")?, + )) + } else { + None + }, }, )), }) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 80195c6a3f2e..79081c8d3ab4 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -90,18 +90,17 @@ parellelized streaming execution plans, file format support, etc. ## Known Users -Here are some of the projects known to use DataFusion: +Here are some active projects using DataFusion: + + - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine -- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core - [CeresDB](https://github.com/CeresDB/ceresdb) Distributed Time-Series Database -- [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python -- [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion +- [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake -- [Flock](https://github.com/flock-lab/flock) - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline @@ -111,10 +110,17 @@ Here are some of the projects known to use DataFusion: - [ROAPI](https://github.com/roapi/roapi) - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database - [Synnada](https://synnada.ai/) Streaming-first framework for data products -- [Tensorbase](https://github.com/tensorbase/tensorbase) - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar - [ZincObserve](https://github.com/zinclabs/zincobserve) Distributed cloud native observability platform +Here are some less active projects that used DataFusion: + +- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core +- [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) +- [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion +- [Flock](https://github.com/flock-lab/flock) +- [Tensorbase](https://github.com/tensorbase/tensorbase) + [ballista]: https://github.com/apache/arrow-ballista [blaze]: https://github.com/blaze-init/blaze [ceresdb]: https://github.com/CeresDB/ceresdb diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index bcdd3832523e..301f57d0310a 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -504,10 +504,10 @@ tanh(numeric_expression) ### `trunc` -Truncates a number toward zero (at the decimal point). +Truncates a number to a whole number or truncated to the specified decimal places. ``` -trunc(numeric_expression) +trunc(numeric_expression[, decimal_places]) ``` #### Arguments @@ -515,6 +515,12 @@ trunc(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **decimal_places**: Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`. + ## Conditional Functions - [coalesce](#coalesce)