diff --git a/Cargo.toml b/Cargo.toml index 629992177913..b8bf83a5ab53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -137,7 +137,7 @@ prost = "0.13.1" prost-derive = "0.13.1" rand = "0.8" regex = "1.8" -rstest = "0.22.0" +rstest = "0.23.0" serde_json = "1" sqlparser = { version = "0.51.0", features = ["visitor"] } tempfile = "3" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fbe7d5c04b9b..a1157cbffbd6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -424,9 +424,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.82" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -450,15 +450,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.6" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848d7b9b605720989929279fa644ce8f244d0ce3146fcca5b70e4eb7b3c020fc" +checksum = "8191fb3091fa0561d1379ef80333c3c7191c6f0435d986e85821bcf7acbd1126" dependencies = [ "aws-credential-types", "aws-runtime", @@ -523,9 +523,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.43.0" +version = "1.44.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a9d27ed1c12b1140c47daf1bc541606c43fdafd918c4797d520db0043ceef2" +checksum = "0b90cfe6504115e13c41d3ea90286ede5aa14da294f3fe077027a6e83850843c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -545,9 +545,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.44.0" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44514a6ca967686cde1e2a1b81df6ef1883d0e3e570da8d8bc5c491dcb6fc29b" +checksum = "167c0fad1f212952084137308359e8e4c4724d1c643038ce163f06de9662c1d0" dependencies = [ "aws-credential-types", "aws-runtime", @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.43.0" +version = "1.44.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd7a4d279762a35b9df97209f6808b95d4fe78547fe2316b4d200a0283960c5a" +checksum = "2cb5f98188ec1435b68097daa2a37d74b9d17c9caa799466338a8d1544e71b9d" dependencies = [ "aws-credential-types", "aws-runtime", @@ -707,9 +707,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.6" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03701449087215b5369c7ea17fef0dd5d24cb93439ec5af0c7615f58c3f22605" +checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" dependencies = [ "base64-simd", "bytes", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.21" +version = "1.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" +checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0" dependencies = [ "jobserver", "libc", @@ -975,9 +975,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.17" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3" dependencies = [ "clap_builder", "clap_derive", @@ -985,9 +985,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.17" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b" dependencies = [ "anstream", "anstyle", @@ -997,9 +997,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -1345,6 +1345,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", + "indexmap", "paste", "serde_json", "sqlparser", @@ -1402,7 +1403,6 @@ dependencies = [ "half", "log", "paste", - "sqlparser", ] [[package]] @@ -1447,6 +1447,7 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "log", + "paste", ] [[package]] @@ -1722,9 +1723,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.33" +version = "1.0.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", "miniz_oxide", @@ -2137,9 +2138,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" +checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" dependencies = [ "bytes", "futures-channel", @@ -2150,7 +2151,6 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", "tower-service", "tracing", ] @@ -2333,9 +2333,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.158" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "libflate" @@ -2799,26 +2799,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pin-project" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.14" @@ -2833,9 +2813,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "powerfmt" @@ -2908,9 +2888,9 @@ checksum = "b76f1009795ca44bb5aaae8fd3f18953e209259c33d9b059b1f53d58ab7511db" [[package]] name = "quick-xml" -version = "0.36.1" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96a05e2e8efddfa51a84ca47cec303fac86c8541b686d37cac5efc0e094417bc" +checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" dependencies = [ "memchr", "serde", @@ -3015,9 +2995,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.4" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853" +checksum = "355ae415ccd3a04315d3f8246e86d67689ea74d88d915576e1589a351062a13b" dependencies = [ "bitflags 2.6.0", ] @@ -3289,9 +3269,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" +checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" [[package]] name = "rustls-webpki" @@ -3397,9 +3377,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.1" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -3510,18 +3490,18 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snafu" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b835cb902660db3415a672d862905e791e54d306c6e8189168c7f3d9ae1c79d" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" dependencies = [ "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d1e02fca405f6280643174a50c942219f0bbf4dbf7d480f1dd864d6f211ae5" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -3633,9 +3613,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.77" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -3653,9 +3633,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", "fastrand", @@ -3672,18 +3652,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", @@ -3826,36 +3806,15 @@ checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" [[package]] name = "toml_edit" -version = "0.22.21" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b072cee73c449a636ffd6f32bd8de3a9f7119139aff882f44943ce2986dc5cf" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap", "toml_datetime", "winnow", ] -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tower-layer", - "tower-service", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - [[package]] name = "tower-service" version = "0.3.3" @@ -3964,9 +3923,9 @@ checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "untrusted" @@ -4122,9 +4081,9 @@ checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "wasm-streams" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" dependencies = [ "futures-util", "js-sys", @@ -4341,9 +4300,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.18" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 6c36d907acc3..ca3a2bef882e 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -25,6 +25,7 @@ use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_common::Result; use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{ CreateExternalTable, Expr, LogicalPlan, TableProviderFilterPushDown, TableType, }; @@ -274,7 +275,7 @@ pub trait TableProvider: Debug + Sync + Send { &self, _state: &dyn Session, _input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { not_impl_err!("Insert into not implemented for this table") } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 0dec14e9178a..69cdf866cf98 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -226,7 +226,12 @@ impl DFSchema { for (field, qualifier) in self.inner.fields().iter().zip(&self.field_qualifiers) { if let Some(qualifier) = qualifier { - qualified_names.insert((qualifier, field.name())); + if !qualified_names.insert((qualifier, field.name())) { + return _schema_err!(SchemaError::DuplicateQualifiedField { + qualifier: Box::new(qualifier.clone()), + name: field.name().to_string(), + }); + } } else if !unqualified_names.insert(field.name()) { return _schema_err!(SchemaError::DuplicateUnqualifiedField { name: field.name().to_string() @@ -1165,7 +1170,10 @@ mod tests { let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let join = left.join(&right); - assert!(join.err().is_none()); + assert_eq!( + join.unwrap_err().strip_backtrace(), + "Schema error: Schema contains duplicate qualified field name t1.c0", + ); Ok(()) } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 116dab316bf5..5bf0f08b092a 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -291,6 +291,9 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { } /// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. +/// +/// TODO: use implementation in arrow-rs when available: +/// pub fn take_arrays(arrays: &[ArrayRef], indices: &dyn Array) -> Result> { arrays .iter() diff --git a/datafusion/core/src/catalog_common/mod.rs b/datafusion/core/src/catalog_common/mod.rs index b8414378862e..85207845a005 100644 --- a/datafusion/core/src/catalog_common/mod.rs +++ b/datafusion/core/src/catalog_common/mod.rs @@ -185,9 +185,7 @@ pub fn resolve_table_references( let _ = s.as_ref().visit(visitor); } DFStatement::CreateExternalTable(table) => { - visitor - .relations - .insert(ObjectName(vec![Ident::from(table.name.as_str())])); + visitor.relations.insert(table.name.clone()); } DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 72b763ce0f2b..f5867881da13 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,6 +52,7 @@ use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, @@ -66,8 +67,9 @@ use datafusion_catalog::Session; /// Contains options that control how data is /// written out from a DataFrame pub struct DataFrameWriteOptions { - /// Controls if existing data should be overwritten - overwrite: bool, + /// Controls how new data should be written to the table, determining whether + /// to append, overwrite, or replace existing data. + insert_op: InsertOp, /// Controls if all partitions should be coalesced into a single output file /// Generally will have slower performance when set to true. single_file_output: bool, @@ -80,14 +82,15 @@ impl DataFrameWriteOptions { /// Create a new DataFrameWriteOptions with default values pub fn new() -> Self { DataFrameWriteOptions { - overwrite: false, + insert_op: InsertOp::Append, single_file_output: false, partition_by: vec![], } } - /// Set the overwrite option to true or false - pub fn with_overwrite(mut self, overwrite: bool) -> Self { - self.overwrite = overwrite; + + /// Set the insert operation + pub fn with_insert_operation(mut self, insert_op: InsertOp) -> Self { + self.insert_op = insert_op; self } @@ -1525,7 +1528,7 @@ impl DataFrame { self.plan, table_name.to_owned(), &arrow_schema, - write_options.overwrite, + write_options.insert_op, )? .build()?; @@ -1566,10 +1569,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_csv.", + options.insert_op + ))); } let format = if let Some(csv_opts) = writer_options { @@ -1626,10 +1630,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_json.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_json.", + options.insert_op + ))); } let format = if let Some(json_opts) = writer_options { @@ -3375,52 +3380,6 @@ mod tests { Ok(()) } - // Table 't1' self join - // Supplementary test of issue: https://github.com/apache/datafusion/issues/7790 - #[tokio::test] - async fn with_column_self_join() -> Result<()> { - let df = test_table().await?.select_columns(&["c1"])?; - let ctx = SessionContext::new(); - - ctx.register_table("t1", df.into_view())?; - - let df = ctx - .table("t1") - .await? - .join( - ctx.table("t1").await?, - JoinType::Inner, - &["c1"], - &["c1"], - None, - )? - .sort(vec![ - // make the test deterministic - col("t1.c1").sort(true, true), - ])? - .limit(0, Some(1))?; - - let df_results = df.clone().collect().await?; - assert_batches_sorted_eq!( - [ - "+----+----+", - "| c1 | c1 |", - "+----+----+", - "| a | a |", - "+----+----+", - ], - &df_results - ); - - let actual_err = df.clone().with_column("new_column", lit(true)).unwrap_err(); - let expected_err = "Error during planning: Projections require unique expression names \ - but the expression \"t1.c1\" at position 0 and \"t1.c1\" at position 1 have the same name. \ - Consider aliasing (\"AS\") one of them."; - assert_eq!(actual_err.strip_backtrace(), expected_err); - - Ok(()) - } - #[tokio::test] async fn with_column_renamed() -> Result<()> { let df = test_table() diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 66974e37f453..f90b35fde6ba 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -26,6 +26,7 @@ use super::{ }; use datafusion_common::config::TableParquetOptions; +use datafusion_expr::dml::InsertOp; impl DataFrame { /// Execute the `DataFrame` and write the results to Parquet file(s). @@ -57,10 +58,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_parquet.", + options.insert_op + ))); } let format = if let Some(parquet_opts) = writer_options { diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 3a5d50bba07f..98b6702bc383 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -573,7 +573,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { // extract list values, with non-lists converted to Value::Null let array_item_count = rows .iter() - .map(|row| match row { + .map(|row| match maybe_resolve_union(row) { Value::Array(values) => values.len(), _ => 1, }) @@ -1643,6 +1643,93 @@ mod test { assert_batches_eq!(expected, &[batch]); } + #[test] + fn test_avro_nullable_struct_array() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "array", + "items": { + "type": [ + "null", + { + "type": "record", + "name": "Item", + "fields": [ + { + "name": "id", + "type": "long" + } + ] + } + ] + } + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + let jv1 = serde_json::json!({ + "col1": [ + { + "id": 234 + }, + { + "id": 345 + } + ] + }); + let r1 = apache_avro::to_value(jv1) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + for _i in 0..5 { + w.append(r1.clone()).unwrap(); + } + w.append(r2).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(20) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+------------------------+", + "| col1 |", + "+------------------------+", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| |", + "+------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + #[test] fn test_avro_iterator() { let reader = build_reader("alltypes_plain.avro", 5); diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs index 3c409af29703..6654d0871c3f 100644 --- a/datafusion/core/src/datasource/dynamic_file.rs +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -69,11 +69,18 @@ impl UrlTableFactory for DynamicListTableFactory { .ok_or_else(|| plan_datafusion_err!("get current SessionStore error"))?; match ListingTableConfig::new(table_url.clone()) - .infer(state) + .infer_options(state) .await { - Ok(cfg) => ListingTable::try_new(cfg) - .map(|table| Some(Arc::new(table) as Arc)), + Ok(cfg) => { + let cfg = cfg + .infer_partitions_from_path(state) + .await? + .infer_schema(state) + .await?; + ListingTable::try_new(cfg) + .map(|table| Some(Arc::new(table) as Arc)) + } Err(_) => Ok(None), } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 6ee4280956e8..c10ebbd6c9ea 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -47,6 +47,7 @@ use datafusion_common::{ not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; @@ -181,7 +182,7 @@ impl FileFormat for ArrowFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Arrow format"); } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 99e8f13776fc..e821fa806fce 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -46,6 +46,7 @@ use datafusion_common::{ exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, }; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; @@ -382,7 +383,7 @@ impl FileFormat for CsvFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for CSV"); } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 4471d7d6cb31..c9ed0c0d2805 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -46,6 +46,7 @@ use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; @@ -252,7 +253,7 @@ impl FileFormat for JsonFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Json"); } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 35296b0d7907..98ae0ce14bd7 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -53,6 +53,7 @@ use datafusion_common::{ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_expr::Expr; use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::PhysicalExpr; @@ -403,7 +404,7 @@ impl FileFormat for ParquetFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } @@ -2269,7 +2270,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( @@ -2364,7 +2365,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("a".to_string(), DataType::Utf8)], // add partitioning - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( @@ -2447,7 +2448,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 2a35fddeb033..a9c6aec17537 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -33,7 +33,8 @@ use crate::datasource::{ }; use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{config_err, DataFusionError, Result}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; @@ -191,6 +192,38 @@ impl ListingTableConfig { pub async fn infer(self, state: &SessionState) -> Result { self.infer_options(state).await?.infer_schema(state).await } + + /// Infer the partition columns from the path. Requires `self.options` to be set prior to using. + pub async fn infer_partitions_from_path(self, state: &SessionState) -> Result { + match self.options { + Some(options) => { + let Some(url) = self.table_paths.first() else { + return config_err!("No table path found"); + }; + let partitions = options + .infer_partitions(state, url) + .await? + .into_iter() + .map(|col_name| { + ( + col_name, + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) + .collect::>(); + let options = options.with_table_partition_cols(partitions); + Ok(Self { + table_paths: self.table_paths, + file_schema: self.file_schema, + options: Some(options), + }) + } + None => config_err!("No `ListingOptions` set for inferring schema"), + } + } } /// Options for creating a [`ListingTable`] @@ -504,7 +537,7 @@ impl ListingOptions { /// Infer the partitioning at the given path on the provided object store. /// For performance reasons, it doesn't read all the files on disk /// and therefore may fail to detect invalid partitioning. - async fn infer_partitions( + pub(crate) async fn infer_partitions( &self, state: &SessionState, table_path: &ListingTableUrl, @@ -916,7 +949,7 @@ impl TableProvider for ListingTable { &self, state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // Check that the schema of the plan matches the schema of this table. if !self @@ -975,7 +1008,7 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - overwrite, + insert_op, keep_partition_by_columns, }; @@ -1990,7 +2023,8 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 70f3c36b81e1..24a4938e7b2b 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -39,6 +39,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; @@ -262,7 +263,7 @@ impl TableProvider for MemTable { &self, _state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // If we are inserting into the table, any sort order may be messed up so reset it here *self.sort_order.lock() = vec![]; @@ -289,8 +290,8 @@ impl TableProvider for MemTable { .collect::>() ); } - if overwrite { - return not_impl_err!("Overwrite not implemented for MemoryTable yet"); + if insert_op != InsertOp::Append { + return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } let sink = Arc::new(MemSink::new(self.batches.clone())); Ok(Arc::new(DataSinkExec::new( @@ -638,7 +639,8 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4018b3bb2920..6e8752ccfbf4 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -36,6 +36,7 @@ pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactor pub use arrow_file::ArrowExec; pub use avro::AvroExec; pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; +use datafusion_expr::dml::InsertOp; pub use file_groups::FileGroupPartitioner; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, @@ -83,8 +84,9 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// Controls whether existing data should be overwritten by this sink - pub overwrite: bool, + /// Controls how new data should be written to the file, determining whether + /// to append to, overwrite, or replace records in existing files. + pub insert_op: InsertOp, /// Controls whether partition columns are kept for the file pub keep_partition_by_columns: bool, } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index d30247e2c67a..34023fbbb620 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -33,6 +33,7 @@ use arrow_schema::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; @@ -350,7 +351,7 @@ impl TableProvider for StreamTable { &self, _state: &dyn Session, input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { let ordering = match self.0.order.first() { Some(x) => { diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index cffb63f52047..4953eecd66e3 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -174,27 +174,30 @@ pub struct SessionState { } impl Debug for SessionState { + /// Prefer having short fields at the top and long vector fields near the end + /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SessionState") .field("session_id", &self.session_id) .field("config", &self.config) .field("runtime_env", &self.runtime_env) - .field("catalog_list", &"...") - .field("serializer_registry", &"...") + .field("catalog_list", &self.catalog_list) + .field("serializer_registry", &self.serializer_registry) + .field("file_formats", &self.file_formats) .field("execution_props", &self.execution_props) .field("table_options", &self.table_options) - .field("table_factories", &"...") - .field("function_factory", &"...") - .field("expr_planners", &"...") - .field("query_planner", &"...") - .field("analyzer", &"...") - .field("optimizer", &"...") - .field("physical_optimizers", &"...") - .field("table_functions", &"...") + .field("table_factories", &self.table_factories) + .field("function_factory", &self.function_factory) + .field("expr_planners", &self.expr_planners) + .field("query_planners", &self.query_planner) + .field("analyzer", &self.analyzer) + .field("optimizer", &self.optimizer) + .field("physical_optimizers", &self.physical_optimizers) + .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) - .finish_non_exhaustive() + .finish() } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b2b912d8add2..78c70606bf68 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -71,7 +71,7 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; @@ -529,7 +529,7 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols, - overwrite: false, + insert_op: InsertOp::Append, keep_partition_by_columns, }; @@ -542,7 +542,7 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Dml(DmlStatement { table_name, - op: WriteOp::InsertInto, + op: WriteOp::Insert(insert_op), .. }) => { let name = table_name.table(); @@ -550,23 +550,7 @@ impl DefaultPhysicalPlanner { if let Some(provider) = schema.table(name).await? { let input_exec = children.one()?; provider - .insert_into(session_state, input_exec, false) - .await? - } else { - return exec_err!("Table '{table_name}' does not exist"); - } - } - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::InsertOverwrite, - .. - }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { - let input_exec = children.one()?; - provider - .insert_into(session_state, input_exec, true) + .insert_into(session_state, input_exec, *insert_op) .await? } else { return exec_err!("Table '{table_name}' does not exist"); @@ -2573,6 +2557,10 @@ mod tests { ) -> Result { unimplemented!("NoOp"); } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug)] diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs new file mode 100644 index 000000000000..ff14fa0be3fb --- /dev/null +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::{ + error::Result, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::{DisplayAs, ExecutionMode, ExecutionPlan, PlanProperties}; + +#[tokio::test] +async fn insert_operation_is_passed_correctly_to_table_provider() { + // Use the SQLite syntax so we can test the "INSERT OR REPLACE INTO" syntax + let ctx = session_ctx_with_dialect("SQLite"); + let table_provider = Arc::new(TestInsertTableProvider::new()); + ctx.register_table("testing", table_provider.clone()) + .unwrap(); + + let sql = "INSERT INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Append).await; + + let sql = "INSERT OVERWRITE testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Overwrite).await; + + let sql = "REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; + + let sql = "INSERT OR REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; +} + +async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) { + let df = ctx.sql(sql).await.unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + let exec = plan.as_any().downcast_ref::().unwrap(); + assert_eq!(exec.op, insert_op); +} + +fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { + let mut config = SessionConfig::new(); + let options = config.options_mut(); + options.sql_parser.dialect = dialect.into(); + SessionContext::new_with_config(config) +} + +#[derive(Debug)] +struct TestInsertTableProvider { + schema: SchemaRef, +} + +impl TestInsertTableProvider { + fn new() -> Self { + Self { + schema: SchemaRef::new(Schema::new(vec![Field::new( + "column", + DataType::Int64, + false, + )])), + } + } +} + +#[async_trait] +impl TableProvider for TestInsertTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!("TestInsertTableProvider is a stub for testing.") + } + + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + insert_op: InsertOp, + ) -> Result> { + Ok(Arc::new(TestInsertExec::new(insert_op))) + } +} + +#[derive(Debug)] +struct TestInsertExec { + op: InsertOp, + plan_properties: PlanProperties, +} + +impl TestInsertExec { + fn new(op: InsertOp) -> Self { + let eq_properties = EquivalenceProperties::new(make_count_schema()); + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { + op, + plan_properties, + } + } +} + +impl DisplayAs for TestInsertExec { + fn fmt_as( + &self, + _t: datafusion_physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestInsertExec") + } +} + +impl ExecutionPlan for TestInsertExec { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.is_empty()); + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } +} + +fn make_count_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])) +} diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 56cec8df468b..5d84cdb69283 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -32,3 +32,6 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; + +/// Tests for insert operations +mod insert_operation; diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index e51adbc4ddc1..2b45d0ed600b 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -443,6 +443,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } /// Physical planner for TopK nodes diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs index 7fc5e458b86b..f3eb7b77e03c 100644 --- a/datafusion/execution/src/stream.rs +++ b/datafusion/execution/src/stream.rs @@ -20,7 +20,9 @@ use datafusion_common::Result; use futures::Stream; use std::pin::Pin; -/// Trait for types that stream [arrow::record_batch::RecordBatch] +/// Trait for types that stream [RecordBatch] +/// +/// See [`SendableRecordBatchStream`] for more details. pub trait RecordBatchStream: Stream> { /// Returns the schema of this `RecordBatchStream`. /// @@ -29,5 +31,23 @@ pub trait RecordBatchStream: Stream> { fn schema(&self) -> SchemaRef; } -/// Trait for a [`Stream`] of [`RecordBatch`]es +/// Trait for a [`Stream`] of [`RecordBatch`]es that can be passed between threads +/// +/// This trait is used to retrieve the results of DataFusion execution plan nodes. +/// +/// The trait is a specialized Rust Async [`Stream`] that also knows the schema +/// of the data it will return (even if the stream has no data). Every +/// `RecordBatch` returned by the stream should have the same schema as returned +/// by [`schema`](`RecordBatchStream::schema`). +/// +/// # Error Handling +/// +/// Once a stream returns an error, it should not be polled again (the caller +/// should stop calling `next`) and handle the error. +/// +/// However, returning `Ready(None)` (end of stream) is likely the safest +/// behavior after an error. Like [`Stream`]s, `RecordBatchStream`s should not +/// be polled after end of stream or returning an error. However, also like +/// [`Stream`]s there is no mechanism to prevent callers polling so returning +/// `Ready(None)` is recommended. pub type SendableRecordBatchStream = Pin>; diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 55387fea22ee..d7dc1afe4d50 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -48,6 +48,7 @@ datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } +indexmap = { workspace = true } paste = "^1.0" serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 260065f69af9..7d94a3b93eab 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -90,9 +90,9 @@ pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs}; pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::{WindowUDF, WindowUDFImpl}; +pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index ad96f6a85d0e..cc8ddf8ec8e8 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -54,6 +54,7 @@ use datafusion_common::{ TableReference, ToDFSchema, UnnestOptions, }; +use super::dml::InsertOp; use super::plan::{ColumnUnnestList, ColumnUnnestType}; /// Default table name for unnamed table @@ -307,20 +308,14 @@ impl LogicalPlanBuilder { input: LogicalPlan, table_name: impl Into, table_schema: &Schema, - overwrite: bool, + insert_op: InsertOp, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto - }; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), table_schema, - op, + WriteOp::Insert(insert_op), Arc::new(input), )))) } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index c2ed9dc0781c..68b3ac41fa08 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -146,8 +146,7 @@ impl PartialOrd for DmlStatement { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { - InsertOverwrite, - InsertInto, + Insert(InsertOp), Delete, Update, Ctas, @@ -157,8 +156,7 @@ impl WriteOp { /// Return a descriptive name of this [`WriteOp`] pub fn name(&self) -> &str { match self { - WriteOp::InsertOverwrite => "Insert Overwrite", - WriteOp::InsertInto => "Insert Into", + WriteOp::Insert(insert) => insert.name(), WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", @@ -172,6 +170,37 @@ impl Display for WriteOp { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum InsertOp { + /// Appends new rows to the existing table without modifying any + /// existing rows. This corresponds to the SQL `INSERT INTO` query. + Append, + /// Overwrites all existing rows in the table with the new rows. + /// This corresponds to the SQL `INSERT OVERWRITE` query. + Overwrite, + /// If any existing rows collides with the inserted rows (typically based + /// on a unique key or primary key), those existing rows are replaced. + /// This corresponds to the SQL `REPLACE INTO` query and its equivalents. + Replace, +} + +impl InsertOp { + /// Return a descriptive name of this [`InsertOp`] + pub fn name(&self) -> &str { + match self { + InsertOp::Append => "Insert Into", + InsertOp::Overwrite => "Insert Overwrite", + InsertOp::Replace => "Replace Into", + } + } +} + +impl Display for InsertOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + fn make_count_schema() -> DFSchemaRef { Arc::new( Schema::new(vec![Field::new("count", DataType::UInt64, false)]) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index d49c85fb6fd6..19d4cb3db9ce 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -195,6 +195,16 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// directly because it must remain object safe. fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool; fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option; + + /// Returns `true` if a limit can be safely pushed down through this + /// `UserDefinedLogicalNode` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false + } } impl Hash for dyn UserDefinedLogicalNode { @@ -295,6 +305,16 @@ pub trait UserDefinedLogicalNodeCore: ) -> Option>> { None } + + /// Returns `true` if a limit can be safely pushed down through this + /// `UserDefinedLogicalNode` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` @@ -361,6 +381,10 @@ impl UserDefinedLogicalNode for T { .downcast_ref::() .and_then(|other| self.partial_cmp(other)) } + + fn supports_limit_pushdown(&self) -> bool { + self.supports_limit_pushdown() + } } fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 443d23804adb..19e73140b75c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -51,6 +51,7 @@ use datafusion_common::{ DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, }; +use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; @@ -3071,6 +3072,8 @@ fn calc_func_dependencies_for_aggregate( let group_by_expr_names = group_expr .iter() .map(|item| item.schema_name().to_string()) + .collect::>() + .into_iter() .collect::>(); let aggregate_func_dependencies = aggregate_functional_dependencies( input.schema(), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d30d202df050..8ac6ad372482 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -602,89 +602,48 @@ fn coerced_from<'a>( Some(type_into.clone()) } // coerced into type_into - (Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()), - (Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => { - Some(type_into.clone()) - } - (Int32, _) - if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => - { - Some(type_into.clone()) - } - (Int64, _) - if matches!( - type_from, - Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 - ) => - { - Some(type_into.clone()) - } - (UInt8, _) if matches!(type_from, Null | UInt8) => Some(type_into.clone()), - (UInt16, _) if matches!(type_from, Null | UInt8 | UInt16) => { - Some(type_into.clone()) - } - (UInt32, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => { - Some(type_into.clone()) - } - (UInt64, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => { - Some(type_into.clone()) - } - (Float32, _) - if matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - ) => - { - Some(type_into.clone()) - } - (Float64, _) - if matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - ) => - { - Some(type_into.clone()) - } - (Timestamp(TimeUnit::Nanosecond, None), _) - if matches!( - type_from, - Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8 - ) => - { - Some(type_into.clone()) - } - (Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => { + (Int8, Null | Int8) => Some(type_into.clone()), + (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()), + (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()), + (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => { Some(type_into.clone()) } + (UInt8, Null | UInt8) => Some(type_into.clone()), + (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), + (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), + (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + ( + Float32, + Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + | Float32, + ) => Some(type_into.clone()), + ( + Float64, + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Decimal128(_, _), + ) => Some(type_into.clone()), + ( + Timestamp(TimeUnit::Nanosecond, None), + Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8, + ) => Some(type_into.clone()), + (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()), // We can go into a Utf8View from a Utf8 or LargeUtf8 - (Utf8View, _) if matches!(type_from, Utf8 | LargeUtf8 | Null) => { - Some(type_into.clone()) - } + (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()), // Any type can be coerced into strings (Utf8 | LargeUtf8, _) => Some(type_into.clone()), (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()), - (List(_), _) if matches!(type_from, FixedSizeList(_, _)) => { - Some(type_into.clone()) - } + (List(_), FixedSizeList(_, _)) => Some(type_into.clone()), // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this @@ -695,18 +654,16 @@ fn coerced_from<'a>( Some(type_into.clone()) } // should be able to coerce wildcard fixed size list to non wildcard fixed size list - (FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), _) => match type_from { - FixedSizeList(f_from, size_from) => { - match coerced_from(f_into.data_type(), f_from.data_type()) { - Some(data_type) if &data_type != f_into.data_type() => { - let new_field = - Arc::new(f_into.as_ref().clone().with_data_type(data_type)); - Some(FixedSizeList(new_field, *size_from)) - } - Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), - _ => None, - } + ( + FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), + FixedSizeList(f_from, size_from), + ) => match coerced_from(f_into.data_type(), f_from.data_type()) { + Some(data_type) if &data_type != f_into.data_type() => { + let new_field = + Arc::new(f_into.as_ref().clone().with_data_type(data_type)); + Some(FixedSizeList(new_field, *size_from)) } + Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), _ => None, }, (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { @@ -721,12 +678,7 @@ fn coerced_from<'a>( _ => None, } } - (Timestamp(_, Some(_)), _) - if matches!( - type_from, - Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8 - ) => - { + (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => { Some(type_into.clone()) } _ => None, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e3ef672daf5f..780ea36910a4 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -26,7 +26,8 @@ use std::vec; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expr::AggregateFunction; use crate::function::{ @@ -94,6 +95,22 @@ impl fmt::Display for AggregateUDF { } } +/// Arguments passed to [`AggregateUDFImpl::value_from_stats`] +pub struct StatisticsArgs<'a> { + /// The statistics of the aggregate input + pub statistics: &'a Statistics, + /// The resolved return type of the aggregate function + pub return_type: &'a DataType, + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + /// The physical expression of arguments the aggregate function takes. + pub exprs: &'a [Arc], +} + impl AggregateUDF { /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object /// @@ -237,13 +254,23 @@ impl AggregateUDF { } /// Returns true if the function is max, false if the function is min - /// None in all other cases, used in certain optimizations or + /// None in all other cases, used in certain optimizations for /// or aggregate - /// pub fn is_descending(&self) -> Option { self.inner.is_descending() } + /// Return the value of this aggregate function if it can be determined + /// entirely from statistics and arguments. + /// + /// See [`AggregateUDFImpl::value_from_stats`] for more details. + pub fn value_from_stats( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + self.inner.value_from_stats(statistics_args) + } + /// See [`AggregateUDFImpl::default_value`] for more details. pub fn default_value(&self, data_type: &DataType) -> Result { self.inner.default_value(data_type) @@ -557,6 +584,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { None } + /// Return the value of this aggregate function if it can be determined + /// entirely from statistics and arguments. + /// + /// Using a [`ScalarValue`] rather than a runtime computation can significantly + /// improving query performance. + /// + /// For example, if the minimum value of column `x` is known to be `42` from + /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))` + fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + None + } + /// Returns default value of the function given the input is all `null`. /// /// Most of the aggregate function return Null if input is Null, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 7cc57523a14d..678a0b62cd9a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -172,6 +172,14 @@ impl WindowUDF { pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) } + + /// Returns the reversed user-defined window function when the + /// order of evaluation is reversed. + /// + /// See [`WindowUDFImpl::reverse_expr`] for more details. + pub fn reverse_expr(&self) -> ReversedUDWF { + self.inner.reverse_expr() + } } impl From for WindowUDF @@ -351,6 +359,24 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Allows customizing the behavior of the user-defined window + /// function when it is evaluated in reverse order. + fn reverse_expr(&self) -> ReversedUDWF { + ReversedUDWF::NotSupported + } +} + +pub enum ReversedUDWF { + /// The result of evaluating the user-defined window function + /// remains identical when reversed. + Identical, + /// A window function which does not support evaluating the result + /// in reverse order. + NotSupported, + /// Customize the user-defined window function for evaluating the + /// result in reverse order. + Reversed(Arc), } impl PartialEq for dyn WindowUDFImpl { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 1d8eb9445eda..9bb53a1d04a0 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -38,6 +38,7 @@ use datafusion_common::{ DataFusionError, Result, TableReference, }; +use indexmap::IndexSet; use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem}; pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; @@ -59,16 +60,7 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { - if group_expr.len() > 1 { - return plan_err!( - "Invalid group by expressions, GroupingSet must be the only expression" - ); - } - Ok(grouping_set.distinct_expr().len()) - } else { - Ok(group_expr.len()) - } + grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) } /// The [power set] (or powerset) of a set S is the set of all subsets of S, \ @@ -260,7 +252,11 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { } Ok(grouping_set.distinct_expr()) } else { - Ok(group_expr.iter().collect()) + Ok(group_expr + .iter() + .collect::>() + .into_iter() + .collect()) } } diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index d78f68a2604e..33a52afbe21a 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -50,7 +50,6 @@ datafusion-physical-expr-common = { workspace = true } half = { workspace = true } log = { workspace = true } paste = "1.0.14" -sqlparser = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 417e28e72a71..cc245b3572ec 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,7 +16,9 @@ // under the License. use ahash::RandomState; +use datafusion_common::stats::Precision; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_physical_expr::expressions; use std::collections::HashSet; use std::ops::BitAnd; use std::{fmt::Debug, sync::Arc}; @@ -46,7 +48,7 @@ use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; +use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, @@ -54,6 +56,7 @@ use datafusion_functions_aggregate_common::aggregate::count_distinct::{ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::binary_map::OutputType; +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; make_udaf_expr_and_func!( Count, count, @@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count { fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Int64(Some(0))) } + + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + if statistics_args.is_distinct { + return None; + } + if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + let current_val = &statistics_args.statistics.column_statistics + [col_expr.index()] + .null_count; + if let &Precision::Exact(val) = current_val { + return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); + } + } else if let Some(lit_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(ScalarValue::Int64(Some(num_rows as i64))); + } + } + } + } + None + } } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 573b9fd5bdb2..ffb5183278e6 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,23 +15,6 @@ // specific language governing permissions and limitations // under the License. -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - macro_rules! make_udaf_expr { ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 961e8639604c..1ce1abe09ea8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -15,7 +15,7 @@ // under the License. //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function -//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function +//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -49,10 +49,12 @@ use arrow::datatypes::{ UInt8Type, }; use arrow_schema::IntervalUnit; +use datafusion_common::stats::Precision; use datafusion_common::{ - downcast_value, exec_err, internal_err, DataFusionError, Result, + downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_physical_expr::expressions; use std::fmt::Debug; use arrow::datatypes::i256; @@ -63,10 +65,10 @@ use arrow::datatypes::{ }; use datafusion_common::ScalarValue; -use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use half::f16; use std::ops::Deref; @@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator { }}; } +trait FromColumnStatistics { + fn value_from_column_statistics( + &self, + stats: &ColumnStatistics, + ) -> Option; + + fn value_from_statistics( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { + match *num_rows { + 0 => return ScalarValue::try_from(statistics_args.return_type).ok(), + value if value > 0 => { + let col_stats = &statistics_args.statistics.column_statistics; + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + return self.value_from_column_statistics( + &col_stats[col_expr.index()], + ); + } + } + } + _ => {} + } + } + None + } +} + +impl FromColumnStatistics for Max { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.max_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + impl AggregateUDFImpl for Max { fn as_any(&self) -> &dyn std::any::Any { self @@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max { fn is_descending(&self) -> Option { Some(true) } + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } @@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Identical } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } } // Statically-typed version of min/max(array) -> ScalarValue for string types @@ -926,6 +980,20 @@ impl Default for Min { } } +impl FromColumnStatistics for Min { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.min_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + impl AggregateUDFImpl for Min { fn as_any(&self) -> &dyn std::any::Any { self @@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min { Some(false) } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 8dcec6bc964b..952e5720c77c 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -43,6 +43,7 @@ datafusion-expr = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } +paste = "1.0.15" [dev-dependencies] arrow = { workspace = true } diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 790a500f1f3f..6e98bb091446 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -29,6 +29,8 @@ use log::debug; use datafusion_expr::registry::FunctionRegistry; use datafusion_expr::WindowUDF; +#[macro_use] +pub mod macros; pub mod row_number; /// Fluent-style API for creating `Expr`s diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs new file mode 100644 index 000000000000..843d8ecb38cc --- /dev/null +++ b/datafusion/functions-window/src/macros.rs @@ -0,0 +1,674 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convenience macros for defining a user-defined window function +//! and associated expression API (fluent style). +//! +//! See [`define_udwf_and_expr!`] for usage examples. +//! +//! [`define_udwf_and_expr!`]: crate::define_udwf_and_expr! + +/// Lazily initializes a user-defined window function exactly once +/// when called concurrently. Repeated calls return a reference to the +/// same instance. +/// +/// # Parameters +/// +/// * `$UDWF`: The struct which defines the [`Signature`](datafusion_expr::Signature) +/// of the user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDWF::default()`. +/// +/// # Example +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window::get_or_init_udwf; +/// # +/// /// Defines the `simple_udwf()` user-defined window function. +/// get_or_init_udwf!( +/// SimpleUDWF, +/// simple, +/// "Simple user-defined window function doc comment." +/// ); +/// # +/// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); +/// # +/// # #[derive(Debug)] +/// # struct SimpleUDWF { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for SimpleUDWF { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for SimpleUDWF { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "simple_user_defined_window_function" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # } +/// # } +/// # +/// ``` +#[macro_export] +macro_rules! get_or_init_udwf { + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $UDWF::default); + }; + + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { + paste::paste! { + #[doc = concat!(" Singleton instance of [`", stringify!($OUT_FN_NAME), "`], ensures the user-defined")] + #[doc = concat!(" window function is only created once.")] + #[allow(non_upper_case_globals)] + static []: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + #[doc = concat!(" Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for [`", stringify!($OUT_FN_NAME), "`].")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn [<$OUT_FN_NAME _udwf>]() -> std::sync::Arc { + [] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::WindowUDF::from($CTOR())) + }) + .clone() + } + } + }; +} + +/// Create a [`WindowFunction`] expression that exposes a fluent API +/// which you can use to build more complex expressions. +/// +/// [`WindowFunction`]: datafusion_expr::Expr::WindowFunction +/// +/// # Parameters +/// +/// * `$UDWF`: The struct which defines the [`Signature`] of the +/// user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. +/// +/// [`Signature`]: datafusion_expr::Signature +/// [`Expr`]: datafusion_expr::Expr +/// +/// # Example +/// +/// 1. With Zero Parameters +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # get_or_init_udwf!( +/// # RowNumber, +/// # row_number, +/// # "Returns a unique row number for each row in window partition beginning at 1." +/// # ); +/// /// Creates `row_number()` API which has zero parameters: +/// /// +/// /// ``` +/// /// /// Returns a unique row number for each row in window partition +/// /// /// beginning at 1. +/// /// pub fn row_number() -> datafusion_expr::Expr { +/// /// row_number_udwf().call(vec![]) +/// /// } +/// /// ``` +/// create_udwf_expr!( +/// RowNumber, +/// row_number, +/// "Returns a unique row number for each row in window partition beginning at 1." +/// ); +/// # +/// # assert_eq!( +/// # row_number().name_for_alias().unwrap(), +/// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct RowNumber { +/// # signature: Signature, +/// # } +/// # impl Default for RowNumber { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # impl WindowUDFImpl for RowNumber { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "row_number" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # } +/// # } +/// ``` +/// +/// 2. With Multiple Parameters +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # +/// # get_or_init_udwf!(Lead, lead, "user-defined window function"); +/// # +/// /// Creates `lead(expr, offset, default)` with 3 parameters: +/// /// +/// /// ``` +/// /// /// Returns a value evaluated at the row that is offset rows +/// /// /// after the current row within the partition. +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// create_udwf_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], +/// "Returns a value evaluated at the row that is offset rows after the current row within the partition." +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for Lead { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +#[macro_export] +macro_rules! create_udwf_expr { + // zero arguments + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + paste::paste! { + #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] + #[doc = concat!(" [`", stringify!($UDWF), "`] user-defined window function.")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn $OUT_FN_NAME() -> datafusion_expr::Expr { + [<$OUT_FN_NAME _udwf>]().call(vec![]) + } + } + }; + + // 1 or more arguments + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { + paste::paste! { + #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] + #[doc = concat!(" [`", stringify!($UDWF), "`] user-defined window function.")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn $OUT_FN_NAME( + $($PARAM: datafusion_expr::Expr),+ + ) -> datafusion_expr::Expr { + [<$OUT_FN_NAME _udwf>]() + .call(vec![$($PARAM),+]) + } + } + }; +} + +/// Defines a user-defined window function. +/// +/// Combines [`get_or_init_udwf!`] and [`create_udwf_expr!`] into a +/// single macro for convenience. +/// +/// # Arguments +/// +/// * `$UDWF`: The struct which defines the [`Signature`] of the +/// user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDWF::default()`. +/// +/// [`Signature`]: datafusion_expr::Signature +/// [`Expr`]: datafusion_expr::Expr +/// +/// # Usage +/// +/// ## Expression API With Zero parameters +/// 1. Uses default constructor for UDWF. +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window::{define_udwf_and_expr, get_or_init_udwf, create_udwf_expr}; +/// # +/// /// 1. Defines the `simple_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn simple() -> datafusion_expr::Expr { +/// /// simple_udwf().call(vec![]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// SimpleUDWF, +/// simple, +/// "a simple user-defined window function" +/// ); +/// # +/// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); +/// # +/// # #[derive(Debug)] +/// # struct SimpleUDWF { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for SimpleUDWF { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for SimpleUDWF { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "simple_user_defined_window_function" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # } +/// # } +/// # +/// ``` +/// +/// 2. Uses a custom constructor for UDWF. +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// /// 1. Defines the `row_number_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn row_number() -> datafusion_expr::Expr { +/// /// row_number_udwf().call(vec![]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// RowNumber, +/// row_number, +/// "Returns a unique row number for each row in window partition beginning at 1.", +/// RowNumber::new // <-- custom constructor +/// ); +/// # +/// # assert_eq!( +/// # row_number().name_for_alias().unwrap(), +/// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct RowNumber { +/// # signature: Signature, +/// # } +/// # impl RowNumber { +/// # fn new() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # impl WindowUDFImpl for RowNumber { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "row_number" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # } +/// # } +/// ``` +/// +/// ## Expression API With Multiple Parameters +/// 3. Uses default constructor for UDWF +/// +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # +/// /// 1. Defines the `lead_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], // <- 3 parameters +/// "user-defined window function" +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for Lead { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +/// 4. Uses custom constructor for UDWF +/// +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # +/// /// 1. Defines the `lead_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], // <- 3 parameters +/// "user-defined window function", +/// Lead::new // <- Custom constructor +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Lead { +/// # fn new() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +#[macro_export] +macro_rules! define_udwf_and_expr { + // Defines UDWF with default constructor + // Defines expression API with zero parameters + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC); + create_udwf_expr!($UDWF, $OUT_FN_NAME, $DOC); + }; + + // Defines UDWF by passing a custom constructor + // Defines expression API with zero parameters + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $CTOR); + create_udwf_expr!($UDWF, $OUT_FN_NAME, $DOC); + }; + + // Defines UDWF with default constructor + // Defines expression API with multiple parameters + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC); + create_udwf_expr!($UDWF, $OUT_FN_NAME, [$($PARAM),+], $DOC); + }; + + // Defines UDWF by passing a custom constructor + // Defines expression API with multiple parameters + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr, $CTOR:path) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $CTOR); + create_udwf_expr!($UDWF, $OUT_FN_NAME, [$($PARAM),+], $DOC); + }; +} diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 7f348bf9d2a0..a2e1b2222bb7 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -27,31 +27,15 @@ use datafusion_common::arrow::compute::SortOptions; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::expr::WindowFunction; -use datafusion_expr::{Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; use datafusion_functions_window_common::field; use field::WindowUDFFieldArgs; -/// Create a [`WindowFunction`](Expr::WindowFunction) expression for -/// `row_number` user-defined window function. -pub fn row_number() -> Expr { - Expr::WindowFunction(WindowFunction::new(row_number_udwf(), vec![])) -} - -/// Singleton instance of `row_number`, ensures the UDWF is only created once. -#[allow(non_upper_case_globals)] -static STATIC_RowNumber: std::sync::OnceLock> = - std::sync::OnceLock::new(); - -/// Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for `row_number` -/// user-defined window function. -pub fn row_number_udwf() -> std::sync::Arc { - STATIC_RowNumber - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::WindowUDF::from(RowNumber::default())) - }) - .clone() -} +define_udwf_and_expr!( + RowNumber, + row_number, + "Returns a unique row number for each row in window partition beginning at 1." +); /// row_number expression #[derive(Debug)] diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index ff1b926a9b82..a3d114221d3f 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -102,6 +102,11 @@ harness = false name = "to_timestamp" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "encoding" +required-features = ["encoding_expressions"] + [[bench]] harness = false name = "regx" diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs new file mode 100644 index 000000000000..d49235aac938 --- /dev/null +++ b/datafusion/functions/benches/encoding.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::encoding; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let decode = encoding::decode(); + for size in [1024, 4096, 8192] { + let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + c.bench_function(&format!("base64_decode/{size}"), |b| { + let method = ColumnarValue::Scalar("base64".into()); + let encoded = encoding::encode() + .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .unwrap(); + + let args = vec![encoded, method]; + b.iter(|| black_box(decode.invoke(&args).unwrap())) + }); + + c.bench_function(&format!("hex_decode/{size}"), |b| { + let method = ColumnarValue::Scalar("hex".into()); + let encoded = encoding::encode() + .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .unwrap(); + + let args = vec![encoded, method]; + b.iter(|| black_box(decode.invoke(&args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 5b80c908cfc3..2a22e572614b 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -18,9 +18,12 @@ //! Encoding expressions use arrow::{ - array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait, StringArray}, - datatypes::DataType, + array::{ + Array, ArrayRef, BinaryArray, GenericByteArray, OffsetSizeTrait, StringArray, + }, + datatypes::{ByteArrayType, DataType}, }; +use arrow_buffer::{Buffer, OffsetBufferBuilder}; use base64::{engine::general_purpose, Engine as _}; use datafusion_common::{ cast::{as_generic_binary_array, as_generic_string_array}, @@ -245,16 +248,22 @@ fn base64_encode(input: &[u8]) -> String { general_purpose::STANDARD_NO_PAD.encode(input) } -fn hex_decode(input: &[u8]) -> Result> { - hex::decode(input).map_err(|e| { +fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { + // only write input / 2 bytes to buf + let out_len = input.len() / 2; + let buf = &mut buf[..out_len]; + hex::decode_to_slice(input, buf).map_err(|e| { DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) - }) + })?; + Ok(out_len) } -fn base64_decode(input: &[u8]) -> Result> { - general_purpose::STANDARD_NO_PAD.decode(input).map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) - }) +fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { + general_purpose::STANDARD_NO_PAD + .decode_slice(input, buf) + .map_err(|e| { + DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) + }) } macro_rules! encode_to_array { @@ -267,14 +276,35 @@ macro_rules! encode_to_array { }}; } -macro_rules! decode_to_array { - ($METHOD: ident, $INPUT:expr) => {{ - let binary_array: BinaryArray = $INPUT - .iter() - .map(|x| x.map(|x| $METHOD(x.as_ref())).transpose()) - .collect::>()?; - Arc::new(binary_array) - }}; +fn decode_to_array( + method: F, + input: &GenericByteArray, + conservative_upper_bound_size: usize, +) -> Result +where + F: Fn(&[u8], &mut [u8]) -> Result, +{ + let mut values = vec![0; conservative_upper_bound_size]; + let mut offsets = OffsetBufferBuilder::new(input.len()); + let mut total_bytes_decoded = 0; + for v in input { + if let Some(v) = v { + let cursor = &mut values[total_bytes_decoded..]; + let decoded = method(v.as_ref(), cursor)?; + total_bytes_decoded += decoded; + offsets.push_length(decoded); + } else { + offsets.push_length(0); + } + } + // We reserved an upper bound size for the values buffer, but we only use the actual size + values.truncate(total_bytes_decoded); + let binary_array = BinaryArray::try_new( + offsets.finish(), + Buffer::from_vec(values), + input.nulls().cloned(), + )?; + Ok(Arc::new(binary_array)) } impl Encoding { @@ -381,10 +411,7 @@ impl Encoding { T: OffsetSizeTrait, { let input_value = as_generic_binary_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => decode_to_array!(base64_decode, input_value), - Self::Hex => decode_to_array!(hex_decode, input_value), - }; + let array = self.decode_byte_array(input_value)?; Ok(ColumnarValue::Array(array)) } @@ -393,12 +420,29 @@ impl Encoding { T: OffsetSizeTrait, { let input_value = as_generic_string_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => decode_to_array!(base64_decode, input_value), - Self::Hex => decode_to_array!(hex_decode, input_value), - }; + let array = self.decode_byte_array(input_value)?; Ok(ColumnarValue::Array(array)) } + + fn decode_byte_array( + &self, + input_value: &GenericByteArray, + ) -> Result { + match self { + Self::Base64 => { + let upper_bound = + base64::decoded_len_estimate(input_value.values().len()); + decode_to_array(base64_decode, input_value, upper_bound) + } + Self::Hex => { + // Calculate the upper bound for decoded byte size + // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded + // So the upper bound is half the length of the input values. + let upper_bound = input_value.values().len() / 2; + decode_to_array(hex_decode, input_value, upper_bound) + } + } + } } impl fmt::Display for Encoding { diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 20029ba005c4..8cd26a824acc 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -48,9 +48,9 @@ impl RegexpLikeFunc { signature: Signature::one_of( vec![ Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index bf40eff11d30..498b591620ee 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -54,9 +54,9 @@ impl RegexpMatchFunc { // If that fails, it proceeds to `(LargeUtf8, Utf8)`. // TODO: Native support Utf8View for regexp_match. Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], Volatility::Immutable, ), @@ -131,7 +131,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { let flags = as_generic_string_array::(&args[2])?; if flags.iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option") + return plan_err!("regexp_match() does not support the \"global\" option"); } regexp::regexp_match(values, regex, Some(flags)) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 86520b3587cd..b3b24724552a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -48,13 +48,7 @@ impl AnalyzerRule for CountWildcardRule { } fn is_wildcard(expr: &Expr) -> bool { - matches!( - expr, - Expr::Wildcard { - qualifier: None, - .. - } - ) + matches!(expr, Expr::Wildcard { .. }) } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index c771f31a58b2..aabc549de583 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -385,6 +385,10 @@ mod test { empty_schema: Arc::clone(&self.empty_schema), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5ab427a31699..b5d581f3919f 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -895,6 +895,10 @@ mod tests { // Since schema is same. Output columns requires their corresponding version in the input columns. Some(vec![output_columns.to_vec()]) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug, Hash, PartialEq, Eq)] @@ -991,6 +995,10 @@ mod tests { } Some(vec![left_reqs, right_reqs]) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4e36cc62588e..6e2cc0cbdbcb 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1499,6 +1499,10 @@ mod tests { schema: Arc::clone(&self.schema), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 158c7592df51..8b5e483001b3 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -153,6 +153,29 @@ impl OptimizerRule for PushDownLimit { subquery_alias.input = Arc::new(new_limit); Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) } + LogicalPlan::Extension(extension_plan) + if extension_plan.node.supports_limit_pushdown() => + { + let new_children = extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + LogicalPlan::Limit(Limit { + skip: 0, + fetch: Some(fetch + skip), + input: Arc::new(child.clone()), + }) + }) + .collect::>(); + + // Create a new extension node with updated inputs + let child_plan = LogicalPlan::Extension(extension_plan); + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), new_children)?; + + transformed_limit(skip, fetch, new_extension) + } input => original_limit(skip, fetch, input), } } @@ -258,17 +281,241 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { #[cfg(test)] mod test { + use std::cmp::Ordering; + use std::fmt::{Debug, Formatter}; use std::vec; use super::*; use crate::test::*; - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder}; + + use datafusion_common::DFSchemaRef; + use datafusion_expr::{ + col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, + UserDefinedLogicalNodeCore, + }; use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } + #[derive(Debug, PartialEq, Eq, Hash)] + pub struct NoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + + impl UserDefinedLogicalNodeCore for NoopPlan { + fn name(&self) -> &str { + "NoopPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoopPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + true // Allow limit push-down + } + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct NoLimitNoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoLimitNoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + + impl UserDefinedLogicalNodeCore for NoLimitNoopPlan { + fn name(&self) -> &str { + "NoLimitNoopPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoLimitNoopPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } + } + #[test] + fn limit_pushdown_basic() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_with_skip() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(10, Some(1000))? + .build()?; + + let expected = "Limit: skip=10, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1010\ + \n TableScan: test, fetch=1010"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_multiple_limits() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(10, Some(1000))? + .limit(20, Some(500))? + .build()?; + + let expected = "Limit: skip=30, fetch=500\ + \n NoopPlan\ + \n Limit: skip=0, fetch=530\ + \n TableScan: test, fetch=530"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_multiple_inputs() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone(), table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_disallowed_noop_plan() -> Result<()> { + let table_scan = test_table_scan()?; + let no_limit_noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoLimitNoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(no_limit_noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoLimitNoopPlan\ + \n TableScan: test"; + + assert_optimized_plan_equal(plan, expected) + } + #[test] fn limit_pushdown_projection_table_provider() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/test/user_defined.rs b/datafusion/optimizer/src/test/user_defined.rs index 814cd0c0cd0a..a39f90b5da5d 100644 --- a/datafusion/optimizer/src/test/user_defined.rs +++ b/datafusion/optimizer/src/test/user_defined.rs @@ -76,4 +76,8 @@ impl UserDefinedLogicalNodeCore for TestUserDefinedPlanNode { input: inputs.swap_remove(0), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 470bd947c7fb..236167985790 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -345,7 +345,7 @@ fn select_wildcard_with_repeated_column() { let sql = "SELECT *, col_int32 FROM test"; let err = test_sql(sql).expect_err("query should have failed"); assert_eq!( - "expand_wildcard_rule\ncaused by\nError during planning: Projections require unique expression names but the expression \"test.col_int32\" at position 0 and \"test.col_int32\" at position 7 have the same name. Consider aliasing (\"AS\") one of them.", + "Schema error: Schema contains duplicate qualified field name test.col_int32", err.strip_backtrace() ); } @@ -396,7 +396,7 @@ fn test_sql(sql: &str) -> Result { .with_udaf(count_udaf()) .with_udaf(avg_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; let config = OptimizerContext::new().with_skip_failing_rules(false); let analyzer = Analyzer::new(); diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 65423033d5e0..bb3e9218bc41 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -18,6 +18,7 @@ use std::fmt::Display; use std::hash::Hash; use std::sync::Arc; +use std::vec::IntoIter; use crate::equivalence::add_offset_to_expr; use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; @@ -36,7 +37,7 @@ use arrow_schema::SortOptions; /// /// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table /// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct OrderingEquivalenceClass { pub orderings: Vec, } @@ -44,7 +45,7 @@ pub struct OrderingEquivalenceClass { impl OrderingEquivalenceClass { /// Creates new empty ordering equivalence class. pub fn empty() -> Self { - Self { orderings: vec![] } + Default::default() } /// Clears (empties) this ordering equivalence class. @@ -197,6 +198,15 @@ impl OrderingEquivalenceClass { } } +impl IntoIterator for OrderingEquivalenceClass { + type Item = LexOrdering; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.orderings.into_iter() + } +} + /// This function constructs a duplicate-free `LexOrdering` by filtering out /// duplicate entries that have same physical expression inside. For example, /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. @@ -229,10 +239,10 @@ impl Display for OrderingEquivalenceClass { write!(f, "[")?; let mut iter = self.orderings.iter(); if let Some(ordering) = iter.next() { - write!(f, "{}", PhysicalSortExpr::format_list(ordering))?; + write!(f, "[{}]", PhysicalSortExpr::format_list(ordering))?; } for ordering in iter { - write!(f, "{}", PhysicalSortExpr::format_list(ordering))?; + write!(f, ", [{}]", PhysicalSortExpr::format_list(ordering))?; } write!(f, "]")?; Ok(()) diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index dc59a1eb835b..8137b4f9da13 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -118,7 +118,7 @@ use itertools::Itertools; /// PhysicalSortExpr::new_default(col_c).desc(), /// ]); /// -/// assert_eq!(eq_properties.to_string(), "order: [a@0 ASC,c@2 DESC], const: [b@1]") +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC,c@2 DESC]], const: [b@1]") /// ``` #[derive(Debug, Clone)] pub struct EquivalenceProperties { @@ -2708,379 +2708,428 @@ mod tests { )) } - #[tokio::test] - async fn test_union_equivalence_properties_multi_children() -> Result<()> { - let schema = create_test_schema()?; + #[test] + fn test_union_equivalence_properties_multi_children_1() { + let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); - let test_cases = vec![ - // --------- TEST CASE 1 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b", "c"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1", "c1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["a2", "b2"]], - Arc::clone(&schema3), - ), - ], - // Expected - vec![vec!["a", "b"]], - ), - // --------- TEST CASE 2 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b", "c"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1", "c1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["a2", "b2", "c2"]], - Arc::clone(&schema3), - ), - ], - // Expected - vec![vec!["a", "b", "c"]], - ), - // --------- TEST CASE 3 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1", "c1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["a2", "b2", "c2"]], - Arc::clone(&schema3), - ), - ], - // Expected - vec![vec!["a", "b"]], - ), - // --------- TEST CASE 4 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["b2", "c2"]], - Arc::clone(&schema3), - ), - ], - // Expected - vec![], - ), - // --------- TEST CASE 5 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b"], vec!["c"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1"], vec!["c1"]], - Arc::clone(&schema2), - ), - ], - // Expected - vec![vec!["a", "b"], vec!["c"]], - ), - ]; - for (children, expected) in test_cases { - let children_eqs = children - .iter() - .map(|(orderings, schema)| { - let orderings = orderings - .iter() - .map(|ordering| { - ordering - .iter() - .map(|name| PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect::>() - }) - .collect::>(); - EquivalenceProperties::new_with_orderings( - Arc::clone(schema), - &orderings, - ) - }) - .collect::>(); - let actual = calculate_union(children_eqs, Arc::clone(&schema))?; + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b"]]) + .run() + } - let expected_ordering = expected - .into_iter() - .map(|ordering| { - ordering - .into_iter() - .map(|name| PhysicalSortExpr { - expr: col(name, &schema).unwrap(), - options: SortOptions::default(), - }) - .collect::>() - }) - .collect::>(); - let expected = EquivalenceProperties::new_with_orderings( - Arc::clone(&schema), - &expected_ordering, - ); - assert_eq_properties_same( - &actual, - &expected, - format!("expected: {:?}, actual: {:?}", expected, actual), - ); - } - Ok(()) + #[test] + fn test_union_equivalence_properties_multi_children_2() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b", "c"]]) + .run() } - #[tokio::test] - async fn test_union_equivalence_properties_binary() -> Result<()> { - let schema = create_test_schema()?; + #[test] + fn test_union_equivalence_properties_multi_children_3() { + let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_a1 = &col("a1", &schema2)?; - let col_b1 = &col("b1", &schema2)?; - let options = SortOptions::default(); - let options_desc = !SortOptions::default(); - let test_cases = [ - //-----------TEST CASE 1----------// - ( - ( - // First child orderings - vec![ - // [a ASC] - (vec![(col_a, options)]), - ], - // First child constants - vec![col_b, col_c], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [b ASC] - (vec![(col_b, options)]), - ], - // Second child constants - vec![col_a, col_c], - Arc::clone(&schema), - ), - ( - // Union expected orderings - vec![ - // [a ASC] - vec![(col_a, options)], - // [b ASC] - vec![(col_b, options)], - ], - // Union - vec![col_c], - ), - ), - //-----------TEST CASE 2----------// + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_4() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["b2", "c2"]], &schema3) + .with_expected_sort(vec![]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_5() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2) + .with_expected_sort(vec![vec!["a", "b"], vec!["c"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_1() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [b, c] + vec![vec!["a"]], + vec!["b", "c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [b ASC], const [a, c] + vec![vec!["b"]], + vec!["a", "c"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union expected orderings: [[a ASC], [b ASC]], const [c] + vec![vec!["a"], vec!["b"]], + vec!["c"], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_2() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) // Meet ordering between [a ASC], [a ASC, b ASC] should be [a ASC] - ( - ( - // First child orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, options), (col_b, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Union orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - ), - ), - //-----------TEST CASE 3----------// + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, b ASC], const [] + vec![vec!["a", "b"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [a ASC], const [] + vec![vec!["a"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_3() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) // Meet ordering between [a ASC], [a DESC] should be [] - ( - ( - // First child orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [a DESC] - vec![(col_a, options_desc)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Union doesn't have any ordering - vec![], - // No constant - vec![], - ), - ), - //-----------TEST CASE 4----------// - // Meet ordering between [a ASC], [a1 ASC, b1 ASC] should be [a ASC] - // Where a, and a1 ath the same index for their corresponding schemas. - ( - ( - // First child orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [a1 ASC, b1 ASC] - vec![(col_a1, options), (col_b1, options)], - ], - // No constant - vec![], - Arc::clone(&schema2), - ), - ( - // Union orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - ), - ), - ]; + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a DESC], const [] + vec![vec!["a DESC"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union doesn't have any ordering or constant + vec![], + vec![], + ) + .run() + } - for ( - test_idx, - ( - (first_child_orderings, first_child_constants, first_schema), - (second_child_orderings, second_child_constants, second_schema), - (union_orderings, union_constants), - ), - ) in test_cases.iter().enumerate() - { - let first_orderings = first_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let first_constants = first_child_constants - .iter() - .map(|expr| ConstExpr::new(Arc::clone(expr))) - .collect::>(); - let mut lhs = EquivalenceProperties::new(Arc::clone(first_schema)); - lhs = lhs.with_constants(first_constants); - lhs.add_new_orderings(first_orderings); + #[test] + fn test_union_equivalence_properties_constants_4() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a1 ASC, b1 ASC], const [] + vec![vec!["a1", "b1"]], + vec![], + &schema2, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // should be [a ASC] + // + // Where a, and a1 ath the same index for their corresponding + // schemas. + vec![vec!["a"]], + vec![], + ) + .run() + } - let second_orderings = second_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let second_constants = second_child_constants - .iter() - .map(|expr| ConstExpr::new(Arc::clone(expr))) - .collect::>(); - let mut rhs = EquivalenceProperties::new(Arc::clone(second_schema)); - rhs = rhs.with_constants(second_constants); - rhs.add_new_orderings(second_orderings); + #[test] + #[ignore] + // ignored due to https://github.com/apache/datafusion/issues/12446 + fn test_union_equivalence_properties_constants() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [b] + vec![vec!["a", "c"]], + vec!["b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [ + // [a ASC, b ASC, c ASC], + // [b ASC, a ASC, c ASC] + // ], const [] + vec![vec!["a", "b", "c"], vec!["b", "a", "c"]], + vec![], + ) + .run() + } + + #[test] + #[ignore] + // ignored due to https://github.com/apache/datafusion/issues/12446 + fn test_union_equivalence_properties_constants_desc() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // NB `b DESC` in the second child + // First child orderings: [a ASC, c ASC], const [b] + vec![vec!["a", "c"]], + vec!["b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b DESC", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [ + // [a ASC, b ASC, c ASC], + // [b ASC, a ASC, c ASC] + // ], const [] + vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]], + vec![], + ) + .run() + } + + #[test] + #[ignore] + // ignored due to https://github.com/apache/datafusion/issues/12446 + fn test_union_equivalence_properties_constants_middle() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC, b ASC, d ASC], const [c] + vec![vec!["a", "b", "d"]], + vec!["c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, c ASC, d ASC], const [b] + vec![vec!["a", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a, b, d] (c constant) + // [a, c, d] (b constant) + vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], + vec![], + ) + .run() + } + + #[test] + #[ignore] + // ignored due to https://github.com/apache/datafusion/issues/12446 + fn test_union_equivalence_properties_constants_middle_desc() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // NB `b DESC` in the first child + // + // First child: [a ASC, b DESC, d ASC], const [c] + vec![vec!["a", "b DESC", "d"]], + vec!["c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, c ASC, d ASC], const [b] + vec![vec!["a", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a, b, d] (c constant) + // [a, c, d] (b constant) + vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]], + vec![], + ) + .run() + } + + // TODO tests with multiple constants - let union_expected_orderings = union_orderings + #[derive(Debug)] + struct UnionEquivalenceTest { + /// The schema of the output of the Union + output_schema: SchemaRef, + /// The equivalence properties of each child to the union + child_properties: Vec, + /// The expected output properties of the union. Must be set before + /// running `build` + expected_properties: Option, + } + + impl UnionEquivalenceTest { + fn new(output_schema: &SchemaRef) -> Self { + Self { + output_schema: Arc::clone(output_schema), + child_properties: vec![], + expected_properties: None, + } + } + + /// Add a union input with the specified orderings + /// + /// See [`Self::make_props`] for the format of the strings in `orderings` + fn with_child_sort( + mut self, + orderings: Vec>, + schema: &SchemaRef, + ) -> Self { + let properties = self.make_props(orderings, vec![], schema); + self.child_properties.push(properties); + self + } + + /// Add a union input with the specified orderings and constant + /// equivalences + /// + /// See [`Self::make_props`] for the format of the strings in + /// `orderings` and `constants` + fn with_child_sort_and_const_exprs( + mut self, + orderings: Vec>, + constants: Vec<&str>, + schema: &SchemaRef, + ) -> Self { + let properties = self.make_props(orderings, constants, schema); + self.child_properties.push(properties); + self + } + + /// Set the expected output sort order for the union of the children + /// + /// See [`Self::make_props`] for the format of the strings in `orderings` + fn with_expected_sort(mut self, orderings: Vec>) -> Self { + let properties = self.make_props(orderings, vec![], &self.output_schema); + self.expected_properties = Some(properties); + self + } + + /// Set the expected output sort order and constant expressions for the + /// union of the children + /// + /// See [`Self::make_props`] for the format of the strings in + /// `orderings` and `constants`. + fn with_expected_sort_and_const_exprs( + mut self, + orderings: Vec>, + constants: Vec<&str>, + ) -> Self { + let properties = self.make_props(orderings, constants, &self.output_schema); + self.expected_properties = Some(properties); + self + } + + /// compute the union's output equivalence properties from the child + /// properties, and compare them to the expected properties + fn run(self) { + let Self { + output_schema, + child_properties, + expected_properties, + } = self; + let expected_properties = + expected_properties.expect("expected_properties not set"); + let actual_properties = + calculate_union(child_properties, Arc::clone(&output_schema)) + .expect("failed to calculate union equivalence properties"); + assert_eq_properties_same( + &actual_properties, + &expected_properties, + format!( + "expected: {expected_properties:?}\nactual: {actual_properties:?}" + ), + ); + } + + /// Make equivalence properties for the specified columns named in orderings and constants + /// + /// orderings: strings formatted like `"a"` or `"a DESC"`. See [`parse_sort_expr`] + /// constants: strings formatted like `"a"`. + fn make_props( + &self, + orderings: Vec>, + constants: Vec<&str>, + schema: &SchemaRef, + ) -> EquivalenceProperties { + let orderings = orderings .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) + .map(|ordering| { + ordering + .iter() + .map(|name| parse_sort_expr(name, schema)) + .collect::>() + }) .collect::>(); - let union_constants = union_constants + + let constants = constants .iter() - .map(|expr| ConstExpr::new(Arc::clone(expr))) + .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap())) .collect::>(); - let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); - union_expected_eq = union_expected_eq.with_constants(union_constants); - union_expected_eq.add_new_orderings(union_expected_orderings); - let actual_union_eq = calculate_union_binary(lhs, rhs)?; - let err_msg = format!( - "Error in test id: {:?}, test case: {:?}", - test_idx, test_cases[test_idx] - ); - assert_eq_properties_same(&actual_union_eq, &union_expected_eq, err_msg); + EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings) + .with_constants(constants) } - Ok(()) } fn assert_eq_properties_same( @@ -3091,21 +3140,63 @@ mod tests { // Check whether constants are same let lhs_constants = lhs.constants(); let rhs_constants = rhs.constants(); - assert_eq!(lhs_constants.len(), rhs_constants.len(), "{}", err_msg); for rhs_constant in rhs_constants { assert!( const_exprs_contains(lhs_constants, rhs_constant.expr()), - "{}", - err_msg + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" ); } + assert_eq!( + lhs_constants.len(), + rhs_constants.len(), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); let rhs_orderings = &rhs.oeq_class.orderings; - assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); for rhs_ordering in rhs_orderings { - assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + assert!( + lhs_orderings.contains(rhs_ordering), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); } + assert_eq!( + lhs_orderings.len(), + rhs_orderings.len(), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); + } + + /// Converts a string to a physical sort expression + /// + /// # Example + /// * `"a"` -> (`"a"`, `SortOptions::default()`) + /// * `"a ASC"` -> (`"a"`, `SortOptions { descending: false, nulls_first: false }`) + fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr { + let mut parts = name.split_whitespace(); + let name = parts.next().expect("empty sort expression"); + let mut sort_expr = PhysicalSortExpr::new( + col(name, schema).expect("invalid column name"), + SortOptions::default(), + ); + + if let Some(options) = parts.next() { + sort_expr = match options { + "ASC" => sort_expr.asc(), + "DESC" => sort_expr.desc(), + _ => panic!( + "unknown sort options. Expected 'ASC' or 'DESC', got {}", + options + ), + } + } + + assert!( + parts.next().is_none(), + "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got '{name}'" + ); + + sort_expr } } diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index cd1597217c83..fbb59cc92fa0 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -93,18 +93,18 @@ impl LiteralGuarantee { /// Create a new instance of the guarantee if the provided operator is /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to /// create these structures from an predicate (boolean expression). - fn try_new<'a>( + fn new<'a>( column_name: impl Into, guarantee: Guarantee, literals: impl IntoIterator, - ) -> Option { + ) -> Self { let literals: HashSet<_> = literals.into_iter().cloned().collect(); - Some(Self { + Self { column: Column::from_name(column_name), guarantee, literals, - }) + } } /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` @@ -338,13 +338,10 @@ impl<'a> GuaranteeBuilder<'a> { // This is a new guarantee let new_values: HashSet<_> = new_values.into_iter().collect(); - if let Some(guarantee) = - LiteralGuarantee::try_new(col.name(), guarantee, new_values) - { - // add it to the list of guarantees - self.guarantees.push(Some(guarantee)); - self.map.insert(key, self.guarantees.len() - 1); - } + let guarantee = LiteralGuarantee::new(col.name(), guarantee, new_values); + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); } self @@ -851,7 +848,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() + LiteralGuarantee::new(column, Guarantee::In, literals.iter()) } /// Guarantee that the expression is true if the column is NOT any of the specified values @@ -861,7 +858,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() + LiteralGuarantee::new(column, Guarantee::NotIn, literals.iter()) } // Schema for testing diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 87c74579c639..d94983c5adf7 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -30,7 +30,7 @@ use crate::PhysicalExpr; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::ScalarValue; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::PartitionEvaluator; @@ -86,16 +86,13 @@ impl NthValue { n: i64, ignore_nulls: bool, ) -> Result { - match n { - 0 => exec_err!("NTH_VALUE expects n to be non-zero"), - _ => Ok(Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Nth(n), - ignore_nulls, - }), - } + Ok(Self { + name: name.into(), + expr, + data_type, + kind: NthValueKind::Nth(n), + ignore_nulls, + }) } /// Get the NTH_VALUE kind @@ -188,10 +185,7 @@ impl PartitionEvaluator for NthValueEvaluator { // Negative index represents reverse direction. (n_range >= reverse_index, true) } - Ordering::Equal => { - // The case n = 0 is not valid for the NTH_VALUE function. - unreachable!(); - } + Ordering::Equal => (true, false), } } }; @@ -298,10 +292,7 @@ impl PartitionEvaluator for NthValueEvaluator { ) } } - Ordering::Equal => { - // The case n = 0 is not valid for the NTH_VALUE function. - unreachable!(); - } + Ordering::Equal => ScalarValue::try_from(arr.data_type()), } } } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 71f129be984d..a11b498b955c 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -23,14 +23,12 @@ use datafusion_common::scalar::ScalarValue; use datafusion_common::Result; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::{expressions, ExecutionPlan, Statistics}; +use datafusion_physical_plan::{expressions, ExecutionPlan}; use crate::PhysicalOptimizerRule; -use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::udaf::AggregateFunctionExpr; +use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; /// Optimizer that uses available statistics for aggregate functions #[derive(Default, Debug)] @@ -57,14 +55,19 @@ impl PhysicalOptimizerRule for AggregateStatistics { let stats = partial_agg_exec.input().statistics()?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { - if let Some((non_null_rows, name)) = - take_optimizable_column_and_table_count(expr, &stats) + let field = expr.field(); + let args = expr.expressions(); + let statistics_args = StatisticsArgs { + statistics: &stats, + return_type: field.data_type(), + is_distinct: expr.is_distinct(), + exprs: args.as_slice(), + }; + if let Some((optimizable_statistic, name)) = + take_optimizable_value_from_statistics(&statistics_args, expr) { - projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((min, name)) = take_optimizable_min(expr, &stats) { - projections.push((expressions::lit(min), name.to_owned())); - } else if let Some((max, name)) = take_optimizable_max(expr, &stats) { - projections.push((expressions::lit(max), name.to_owned())); + projections + .push((expressions::lit(optimizable_statistic), name.to_owned())); } else { // TODO: we need all aggr_expr to be resolved (cf TODO fullres) break; @@ -135,160 +138,11 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_and_table_count( - agg_expr: &AggregateFunctionExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if is_non_distinct_count(agg_expr) { - if let Precision::Exact(num_rows) = stats.num_rows { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - agg_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = - exprs[0].as_any().downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - agg_expr.name().to_string(), - )); - } - } - } - } - } - None -} - -/// If this agg_expr is a min that is exactly defined in the statistics, return it. -fn take_optimizable_min( - agg_expr: &AggregateFunctionExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_min(agg_expr) { - if let Ok(min_data_type) = - ScalarValue::try_from(agg_expr.field().data_type()) - { - return Some((min_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_min(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].min_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } - } - _ => {} - } - } - None -} - /// If this agg_expr is a max that is exactly defined in the statistics, return it. -fn take_optimizable_max( +fn take_optimizable_value_from_statistics( + statistics_args: &StatisticsArgs, agg_expr: &AggregateFunctionExpr, - stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_max(agg_expr) { - if let Ok(max_data_type) = - ScalarValue::try_from(agg_expr.field().data_type()) - { - return Some((max_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_max(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].max_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } - } - _ => {} - } - } - None -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_non_distinct_count(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { - return true; - } - false + let value = agg_expr.fun().value_from_stats(statistics_args); + value.map(|val| (val, agg_expr.name().to_string())) } - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_min(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name().to_lowercase() == "min" { - return true; - } - false -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_max(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name().to_lowercase() == "max" { - return true; - } - false -} - -// See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 977b40922f7c..28f35b2bded2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -36,7 +36,9 @@ use datafusion_physical_expr::binary_map::OutputType; use hashbrown::raw::RawTable; -/// Compare GroupValue Rows column by column +/// A [`GroupValues`] that stores multiple columns of group values. +/// +/// pub struct GroupValuesColumn { /// The output schema schema: SchemaRef, @@ -55,8 +57,13 @@ pub struct GroupValuesColumn { map_size: usize, /// The actual group by values, stored column-wise. Compare from - /// the left to right, each column is stored as `ArrayRowEq`. - /// This is shown faster than the row format + /// the left to right, each column is stored as [`GroupColumn`]. + /// + /// Performance tests showed that this design is faster than using the + /// more general purpose [`GroupValuesRows`]. See the ticket for details: + /// + /// + /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows group_values: Vec>, /// reused buffer to store hashes @@ -116,6 +123,25 @@ impl GroupValuesColumn { } } +/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v +/// +/// Arguments: +/// `$v`: the vector to push the new builder into +/// `$nullable`: whether the input can contains nulls +/// `$t`: the primitive type of the builder +/// +macro_rules! instantiate_primitive { + ($v:expr, $nullable:expr, $t:ty) => { + if $nullable { + let b = PrimitiveGroupValueBuilder::<$t, true>::new(); + $v.push(Box::new(b) as _) + } else { + let b = PrimitiveGroupValueBuilder::<$t, false>::new(); + $v.push(Box::new(b) as _) + } + }; +} + impl GroupValues for GroupValuesColumn { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { let n_rows = cols[0].len(); @@ -126,54 +152,22 @@ impl GroupValues for GroupValuesColumn { for f in self.schema.fields().iter() { let nullable = f.is_nullable(); match f.data_type() { - &DataType::Int8 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Int16 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Int32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Int64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt8 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt16 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } + &DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type), + &DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type), + &DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type), + &DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type), + &DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type), + &DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type), + &DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type), + &DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type), &DataType::Float32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) + instantiate_primitive!(v, nullable, Float32Type) } &DataType::Float64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Date32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Date64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) + instantiate_primitive!(v, nullable, Float64Type) } + &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), + &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), &DataType::Utf8 => { let b = ByteGroupValueBuilder::::new(OutputType::Utf8); v.push(Box::new(b) as _) diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index a82e6d856c70..15c93262968e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -15,37 +15,37 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::BooleanBufferBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericBinaryArray; use arrow::array::GenericStringArray; use arrow::array::OffsetSizeTrait; use arrow::array::PrimitiveArray; use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; -use arrow::buffer::NullBuffer; use arrow::buffer::OffsetBuffer; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::ByteArrayType; use arrow::datatypes::DataType; use arrow::datatypes::GenericBinaryType; -use arrow::datatypes::GenericStringType; use datafusion_common::utils::proxy::VecAllocExt; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow_array::types::GenericStringType; +use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; use std::sync::Arc; use std::vec; -use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; - -/// Trait for group values column-wise row comparison +/// Trait for storing a single column of group values in [`GroupValuesColumn`] /// -/// Implementations of this trait store a in-progress collection of group values +/// Implementations of this trait store an in-progress collection of group values /// (similar to various builders in Arrow-rs) that allow for quick comparison to /// incoming rows. /// +/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn pub trait GroupColumn: Send + Sync { /// Returns equal if the row stored in this builder at `lhs_row` is equal to /// the row in `array` at `rhs_row` + /// + /// Note that this comparison returns true if both elements are NULL fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; /// Appends the row at `row` in `array` to this builder fn append_val(&mut self, array: &ArrayRef, row: usize); @@ -60,59 +60,60 @@ pub trait GroupColumn: Send + Sync { fn take_n(&mut self, n: usize) -> ArrayRef; } -pub struct PrimitiveGroupValueBuilder { +/// An implementation of [`GroupColumn`] for primitive values +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `T`: the native Rust type that stores the data +/// `NULLABLE`: if the data can contain any nulls +#[derive(Debug)] +pub struct PrimitiveGroupValueBuilder { group_values: Vec, - nulls: Vec, - // whether the array contains at least one null, for fast non-null path - has_null: bool, - nullable: bool, + nulls: MaybeNullBufferBuilder, } -impl PrimitiveGroupValueBuilder +impl PrimitiveGroupValueBuilder where T: ArrowPrimitiveType, { - pub fn new(nullable: bool) -> Self { + /// Create a new `PrimitiveGroupValueBuilder` + pub fn new() -> Self { Self { group_values: vec![], - nulls: vec![], - has_null: false, - nullable, + nulls: MaybeNullBufferBuilder::new(), } } } -impl GroupColumn for PrimitiveGroupValueBuilder { +impl GroupColumn + for PrimitiveGroupValueBuilder +{ fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - // non-null fast path - // both non-null - if !self.nullable { - return self.group_values[lhs_row] - == array.as_primitive::().value(rhs_row); - } - - // lhs is non-null - if self.nulls[lhs_row] { - if array.is_null(rhs_row) { - return false; - } - - return self.group_values[lhs_row] - == array.as_primitive::().value(rhs_row); - } + // Perf: skip null check (by short circuit) if input is not ullable + let null_match = if NULLABLE { + self.nulls.is_null(lhs_row) == array.is_null(rhs_row) + } else { + true + }; - array.is_null(rhs_row) + null_match + && self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } fn append_val(&mut self, array: &ArrayRef, row: usize) { - if self.nullable && array.is_null(row) { - self.group_values.push(T::default_value()); - self.nulls.push(false); - self.has_null = true; + // Perf: skip null check if input can't have nulls + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(array.as_primitive::().value(row)); + } } else { - let elem = array.as_primitive::().value(row); - self.group_values.push(elem); - self.nulls.push(true); + self.group_values.push(array.as_primitive::().value(row)); } } @@ -125,48 +126,54 @@ impl GroupColumn for PrimitiveGroupValueBuilder { } fn build(self: Box) -> ArrayRef { - if self.has_null { - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(self.group_values), - Some(NullBuffer::from(self.nulls)), - )) - } else { - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(self.group_values), - None, - )) + let Self { + group_values, + nulls, + } = *self; + + let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); } + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(group_values), + nulls, + )) } fn take_n(&mut self, n: usize) -> ArrayRef { - if self.has_null { - let first_n = self.group_values.drain(0..n).collect::>(); - let first_n_nulls = self.nulls.drain(0..n).collect::>(); - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(first_n), - Some(NullBuffer::from(first_n_nulls)), - )) - } else { - let first_n = self.group_values.drain(0..n).collect::>(); - self.nulls.truncate(self.nulls.len() - n); - Arc::new(PrimitiveArray::::new(ScalarBuffer::from(first_n), None)) - } + let first_n = self.group_values.drain(0..n).collect::>(); + + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(first_n), + first_n_nulls, + )) } } +/// An implementation of [`GroupColumn`] for binary and utf8 types. +/// +/// Stores a collection of binary or utf8 group values in a single buffer +/// in a way that allows: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array pub struct ByteGroupValueBuilder where O: OffsetSizeTrait, { output_type: OutputType, buffer: BufferBuilder, - /// Offsets into `buffer` for each distinct value. These offsets as used + /// Offsets into `buffer` for each distinct value. These offsets as used /// directly to create the final `GenericBinaryArray`. The `i`th string is /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values /// are stored as a zero length string. offsets: Vec, - /// Null indexes in offsets, if `i` is in nulls, `offsets[i]` should be equals to `offsets[i+1]` - nulls: Vec, + /// Nulls + nulls: MaybeNullBufferBuilder, } impl ByteGroupValueBuilder @@ -178,7 +185,7 @@ where output_type, buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], - nulls: vec![], + nulls: MaybeNullBufferBuilder::new(), } } @@ -188,40 +195,33 @@ where { let arr = array.as_bytes::(); if arr.is_null(row) { - self.nulls.push(self.len()); + self.nulls.append(true); // nulls need a zero length in the offset buffer let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - return; + } else { + self.nulls.append(false); + let value: &[u8] = arr.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); } - - let value: &[u8] = arr.value(row).as_ref(); - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); } fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool where B: ByteArrayType, { - // Handle nulls - let is_lhs_null = self.nulls.iter().any(|null_idx| *null_idx == lhs_row); let arr = array.as_bytes::(); - if is_lhs_null { - return arr.is_null(rhs_row); - } else if arr.is_null(rhs_row) { - return false; - } + self.nulls.is_null(lhs_row) == arr.is_null(rhs_row) + && self.value(lhs_row) == (arr.value(rhs_row).as_ref() as &[u8]) + } - let arr = array.as_bytes::(); - let rhs_elem: &[u8] = arr.value(rhs_row).as_ref(); - let rhs_elem_len = arr.value_length(rhs_row).as_usize(); - debug_assert_eq!(rhs_elem_len, rhs_elem.len()); - let l = self.offsets[lhs_row].as_usize(); - let r = self.offsets[lhs_row + 1].as_usize(); - let existing_elem = unsafe { self.buffer.as_slice().get_unchecked(l..r) }; - rhs_elem == existing_elem + /// return the current value of the specified row irrespective of null + pub fn value(&self, row: usize) -> &[u8] { + let l = self.offsets[row].as_usize(); + let r = self.offsets[row + 1].as_usize(); + // Safety: the offsets are constructed correctly and never decrease + unsafe { self.buffer.as_slice().get_unchecked(l..r) } } } @@ -289,18 +289,7 @@ where nulls, } = *self; - let null_buffer = if nulls.is_empty() { - None - } else { - // Only make a `NullBuffer` if there was a null value - let num_values = offsets.len() - 1; - let mut bool_builder = BooleanBufferBuilder::new(num_values); - bool_builder.append_n(num_values, true); - nulls.into_iter().for_each(|null_index| { - bool_builder.set_bit(null_index, false); - }); - Some(NullBuffer::from(bool_builder.finish())) - }; + let null_buffer = nulls.build(); // SAFETY: the offsets were constructed correctly in `insert_if_new` -- // monotonically increasing, overflows were checked. @@ -317,9 +306,9 @@ where // SAFETY: // 1. the offsets were constructed safely // - // 2. we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out + // 2. the input arrays were all the correct type and thus since + // all the values that went in were valid (e.g. utf8) so are all + // the values that come out Arc::new(unsafe { GenericStringArray::new_unchecked(offsets, values, null_buffer) }) @@ -330,27 +319,7 @@ where fn take_n(&mut self, n: usize) -> ArrayRef { debug_assert!(self.len() >= n); - - let null_buffer = if self.nulls.is_empty() { - None - } else { - // Only make a `NullBuffer` if there was a null value - let mut bool_builder = BooleanBufferBuilder::new(n); - bool_builder.append_n(n, true); - - let mut new_nulls = vec![]; - self.nulls.iter().for_each(|null_index| { - if *null_index < n { - bool_builder.set_bit(*null_index, false); - } else { - new_nulls.push(null_index - n); - } - }); - - self.nulls = new_nulls; - Some(NullBuffer::from(bool_builder.finish())) - }; - + let null_buffer = self.nulls.take_n(n); let first_remaining_offset = O::as_usize(self.offsets[n]); // Given offests like [0, 2, 4, 5] and n = 1, we expect to get diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 9256631fa578..fb7b66775092 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`GroupValues`] trait for storing and interning group keys + use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::{DataType, SchemaRef}; @@ -36,19 +38,63 @@ use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; mod group_column; - -/// An interning store for group keys +mod null_builder; + +/// Stores the group values during hash aggregation. +/// +/// # Background +/// +/// In a query such as `SELECT a, b, count(*) FROM t GROUP BY a, b`, the group values +/// identify each group, and correspond to all the distinct values of `(a,b)`. +/// +/// ```sql +/// -- Input has 4 rows with 3 distinct combinations of (a,b) ("groups") +/// create table t(a int, b varchar) +/// as values (1, 'a'), (2, 'b'), (1, 'a'), (3, 'c'); +/// +/// select a, b, count(*) from t group by a, b; +/// ---- +/// 1 a 2 +/// 2 b 1 +/// 3 c 1 +/// ``` +/// +/// # Design +/// +/// Managing group values is a performance critical operation in hash +/// aggregation. The major operations are: +/// +/// 1. Intern: Quickly finding existing and adding new group values +/// 2. Emit: Returning the group values as an array +/// +/// There are multiple specialized implementations of this trait optimized for +/// different data types and number of columns, optimized for these operations. +/// See [`new_group_values`] for details. +/// +/// # Group Ids +/// +/// Each distinct group in a hash aggregation is identified by a unique group id +/// (usize) which is assigned by instances of this trait. Group ids are +/// continuous without gaps, starting from 0. pub trait GroupValues: Send { - /// Calculates the `groups` for each input row of `cols` + /// Calculates the group id for each input row of `cols`, assigning new + /// group ids as necessary. + /// + /// When the function returns, `groups` must contain the group id for each + /// row in `cols`. + /// + /// If a row has the same value as a previous row, the same group id is + /// assigned. If a row has a new value, the next available group id is + /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; - /// Returns the number of bytes used by this [`GroupValues`] + /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; /// Returns true if this [`GroupValues`] is empty fn is_empty(&self) -> bool; - /// The number of values stored in this [`GroupValues`] + /// The number of values (distinct group values) stored in this [`GroupValues`] fn len(&self) -> usize; /// Emits the group values @@ -58,6 +104,7 @@ pub trait GroupValues: Send { fn clear_shrink(&mut self, batch: &RecordBatch); } +/// Return a specialized implementation of [`GroupValues`] for the given schema. pub fn new_group_values(schema: SchemaRef) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs new file mode 100644 index 000000000000..0249390f38cd --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + +/// Builder for an (optional) null mask +/// +/// Optimized for avoid creating the bitmask when all values are non-null +#[derive(Debug)] +pub(crate) enum MaybeNullBufferBuilder { + /// seen `row_count` rows but no nulls yet + NoNulls { row_count: usize }, + /// have at least one null value + /// + /// Note this is an Arrow *VALIDITY* buffer (so it is false for nulls, true + /// for non-nulls) + Nulls(BooleanBufferBuilder), +} + +impl MaybeNullBufferBuilder { + /// Create a new builder + pub fn new() -> Self { + Self::NoNulls { row_count: 0 } + } + + /// Return true if the row at index `row` is null + pub fn is_null(&self, row: usize) -> bool { + match self { + Self::NoNulls { .. } => false, + // validity mask means a unset bit is NULL + Self::Nulls(builder) => !builder.get_bit(row), + } + } + + /// Set the nullness of the next row to `is_null` + /// + /// num_values is the current length of the rows being tracked + /// + /// If `value` is true, the row is null. + /// If `value` is false, the row is non null + pub fn append(&mut self, is_null: bool) { + match self { + Self::NoNulls { row_count } if is_null => { + // have seen no nulls so far, this is the first null, + // need to create the nulls buffer for all currently valid values + // alloc 2x the need given we push a new but immediately + let mut nulls = BooleanBufferBuilder::new(*row_count * 2); + nulls.append_n(*row_count, true); + nulls.append(false); + *self = Self::Nulls(nulls); + } + Self::NoNulls { row_count } => { + *row_count += 1; + } + Self::Nulls(builder) => builder.append(!is_null), + } + } + + /// return the number of heap allocated bytes used by this structure to store boolean values + pub fn allocated_size(&self) -> usize { + match self { + Self::NoNulls { .. } => 0, + // BooleanBufferBuilder builder::capacity returns capacity in bits (not bytes) + Self::Nulls(builder) => builder.capacity() / 8, + } + } + + /// Return a NullBuffer representing the accumulated nulls so far + pub fn build(self) -> Option { + match self { + Self::NoNulls { .. } => None, + Self::Nulls(mut builder) => Some(NullBuffer::from(builder.finish())), + } + } + + /// Returns a NullBuffer representing the first `n` rows accumulated so far + /// shifting any remaining down by `n` + pub fn take_n(&mut self, n: usize) -> Option { + match self { + Self::NoNulls { row_count } => { + *row_count -= n; + None + } + Self::Nulls(builder) => { + // Copy over the values at n..len-1 values to the start of a + // new builder and leave it in self + // + // TODO: it would be great to use something like `set_bits` from arrow here. + let mut new_builder = BooleanBufferBuilder::new(builder.len()); + for i in n..builder.len() { + new_builder.append(builder.get_bit(i)); + } + std::mem::swap(&mut new_builder, builder); + + // take only first n values from the original builder + new_builder.truncate(n); + Some(NullBuffer::from(new_builder.finish())) + } + } + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index b252d0008784..8ca88257bf1a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -30,6 +30,13 @@ use hashbrown::raw::RawTable; use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] +/// +/// This is a general implementation of [`GroupValues`] that works for any +/// combination of data types and number of columns, including nested types such as +/// structs and lists. +/// +/// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise +/// representation. pub struct GroupValuesRows { /// The output schema schema: SchemaRef, @@ -220,7 +227,8 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) + // TODO: Materialize dictionaries in group keys + // https://github.com/apache/datafusion/issues/7647 for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); *array = dictionary_encode_if_necessary( diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2bdaed479655..9466ff6dd459 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -26,6 +26,7 @@ use crate::aggregates::{ topk_stream::GroupedTopKAggregateStream, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -795,14 +796,17 @@ fn create_schema( ) -> Result { let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); for (index, (expr, name)) in group_expr.iter().enumerate() { - fields.push(Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - group_expr_nullable[index] || expr.nullable(input_schema)?, - )) + fields.push( + Field::new( + name, + expr.data_type(input_schema)?, + // In cases where we have multiple grouping sets, we will use NULL expressions in + // order to align the grouping sets. So the field must be nullable even if the underlying + // schema field is not. + group_expr_nullable[index] || expr.nullable(input_schema)?, + ) + .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()), + ) } match mode { @@ -823,7 +827,10 @@ fn create_schema( } } - Ok(Schema::new(fields)) + Ok(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )) } fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d4dbdf0f029d..998f6184f321 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -38,7 +38,7 @@ use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; -use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { + // New batch to aggregate in partial aggregation operator + Some(Ok(batch)) if self.mode == AggregateMode::Partial => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); - // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + timer.done(); + } + + // Found error from input stream Some(Err(e)) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } + + // Found end from input stream None => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); @@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream { ( if self.input_done { ExecutionState::Done - } else if self.should_skip_aggregation() { + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput @@ -879,10 +920,10 @@ impl GroupedHashAggregateStream { if self.group_values.len() > 0 && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) - && !matches!(self.mode, AggregateMode::Partial) && !self.spill_state.is_stream_merging && self.update_memory_reservation().is_err() { + assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. self.spill_state.spill_schema = batch.schema(); @@ -927,9 +968,9 @@ impl GroupedHashAggregateStream { fn emit_early_if_necessary(&mut self) -> Result<()> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) - && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { + assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; self.exec_state = ExecutionState::ProducingOutput(batch); @@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream { } /// Updates skip aggregation probe state. + /// + /// Notice: It should only be called in Partial aggregation fn update_skip_aggregation_probe(&mut self, input_rows: usize) { if let Some(probe) = self.skip_aggregation_probe.as_mut() { // Skip aggregation probe is not supported if stream has any spills, @@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream { /// In case the probe indicates that aggregation may be /// skipped, forces stream to produce currently accumulated output. + /// + /// Notice: It should only be called in Partial aggregation fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { @@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream { /// Returns true if the aggregation probe indicates that aggregation /// should be skipped. + /// + /// Notice: It should only be called in Partial aggregation fn should_skip_aggregation(&self) -> bool { self.skip_aggregation_probe .as_ref() @@ -1034,13 +1081,14 @@ impl GroupedHashAggregateStream { /// Transforms input batch to intermediate aggregate state, without grouping it fn transform_to_states(&self, batch: RecordBatch) -> Result { - let group_values = evaluate_group_by(&self.group_by, &batch)?; + let mut group_values = evaluate_group_by(&self.group_by, &batch)?; let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; - let mut output = group_values.first().cloned().ok_or_else(|| { - internal_datafusion_err!("group_values expected to have at least one element") - })?; + if group_values.len() != 1 { + return internal_err!("group_values expected to have single element"); + } + let mut output = group_values.swap_remove(0); let iter = self .accumulators diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 542861688dfe..b14021f4a99b 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -228,6 +228,16 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// [`TryStreamExt`]: futures::stream::TryStreamExt /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter /// + /// # Error handling + /// + /// Any error that occurs during execution is sent as an `Err` in the output + /// stream. + /// + /// `ExecutionPlan` implementations in DataFusion cancel additional work + /// immediately once an error occurs. The rationale is that if the overall + /// query will return an error, any additional work such as continued + /// polling of inputs will be wasted as it will be thrown away. + /// /// # Cancellation / Aborting Execution /// /// The [`Stream`] that is returned must ensure that any allocated resources diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 7cbfd49afb86..845a74eaea48 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -82,6 +82,7 @@ pub mod windows; pub mod work_table; pub mod udaf { + pub use datafusion_expr::StatisticsArgs; pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr; } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index f1b9cdaf728f..4c889d1fc88c 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -237,7 +237,7 @@ impl ExecutionPlan for ProjectionExec { /// If e is a direct column reference, returns the field level /// metadata for that field, if any. Otherwise returns None -fn get_field_metadata( +pub(crate) fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 5b25d582d20c..4fd364cca4d0 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -377,6 +377,11 @@ impl BatchPartitioner { /// `───────' `───────' ///``` /// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// all output partitions and inputs are not polled again. +/// /// # Output Ordering /// /// If more than one stream is being repartitioned, the output will be some diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 875922ac34b5..e0644e3d99e5 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -39,6 +39,7 @@ use futures::Stream; /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; +/// Merges a stream of sorted cursors and record batches into a single sorted stream #[derive(Debug)] pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index f83bb58d08dd..b00a11a5355f 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -65,6 +65,11 @@ use log::{debug, trace}; /// Input Streams Output stream /// (sorted) (sorted) /// ``` +/// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// the output and inputs are not polled again. #[derive(Debug)] pub struct SortPreservingMergeExec { /// Input plan diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 9510baab51fb..4a4c940b22e2 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -257,17 +257,11 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.input_order_mode != InputOrderMode::Sorted - || self.ordered_partition_by_indices.len() >= partition_bys.len() - { - let partition_bys = self - .ordered_partition_by_indices - .iter() - .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys)] - } else { - vec![calc_requirements(partition_bys, order_keys)] - } + let partition_bys = self + .ordered_partition_by_indices + .iter() + .map(|idx| &partition_bys[*idx]); + vec![calc_requirements(partition_bys, order_keys)] } fn required_input_distribution(&self) -> Vec { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6e1cb8db5f09..6aafaad0ad77 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,8 +34,8 @@ use datafusion_common::{ exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, - WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, WindowFrame, + WindowFunctionDefinition, WindowUDF, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::equivalence::collapse_lex_req; @@ -130,7 +130,7 @@ pub fn create_window_expr( } // TODO: Ordering not supported for Window UDFs yet WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( - create_udwf_window_expr(fun, args, input_schema, name)?, + create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, partition_by, order_by, window_frame, @@ -185,20 +185,26 @@ fn get_scalar_value_from_args( } fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + if !value.data_type().is_integer() { - return Err(DataFusionError::Execution( - "Expected an integer value".to_string(), - )); + return exec_err!("Expected an integer value"); } + value.cast_to(&DataType::Int64)?.try_into() } fn get_unsigned_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + if !value.data_type().is_integer() { - return Err(DataFusionError::Execution( - "Expected an integer value".to_string(), - )); + return exec_err!("Expected an integer value"); } + value.cast_to(&DataType::UInt64)?.try_into() } @@ -329,6 +335,7 @@ fn create_udwf_window_expr( args: &[Arc], input_schema: &Schema, name: String, + ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason let input_types: Vec<_> = args @@ -341,6 +348,8 @@ fn create_udwf_window_expr( args: args.to_vec(), input_types, name, + is_reversed: false, + ignore_nulls, })) } @@ -353,6 +362,12 @@ struct WindowUDFExpr { name: String, /// Types of input expressions input_types: Vec, + /// This is set to `true` only if the user-defined window function + /// expression supports evaluation in reverse order, and the + /// evaluation order is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is defined, `false` otherwise. + ignore_nulls: bool, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -378,7 +393,18 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn reverse_expr(&self) -> Option> { - None + match self.fun.reverse_expr() { + ReversedUDWF::Identical => Some(Arc::new(self.clone())), + ReversedUDWF::NotSupported => None, + ReversedUDWF::Reversed(fun) => Some(Arc::new(WindowUDFExpr { + fun, + args: self.args.clone(), + name: self.name.clone(), + input_types: self.input_types.clone(), + is_reversed: !self.is_reversed, + ignore_nulls: self.ignore_nulls, + })), + } } fn get_result_ordering(&self, schema: &SchemaRef) -> Option { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1204c843fdb1..e36c91e7d004 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -731,14 +731,21 @@ message PartitionColumn { message FileSinkConfig { reserved 6; // writer_mode + reserved 8; // was `overwrite` which has been superseded by `insert_op` string object_store_url = 1; repeated PartitionedFile file_groups = 2; repeated string table_paths = 3; datafusion_common.Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; - bool overwrite = 8; bool keep_partition_by_columns = 9; + InsertOp insert_op = 10; +} + +enum InsertOp { + Append = 0; + Overwrite = 1; + Replace = 2; } message JsonSink { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0614e33b7a4b..004798b3ba93 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5832,10 +5832,10 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { len += 1; } - if self.overwrite { + if self.keep_partition_by_columns { len += 1; } - if self.keep_partition_by_columns { + if self.insert_op != 0 { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; @@ -5854,12 +5854,14 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; } - if self.overwrite { - struct_ser.serialize_field("overwrite", &self.overwrite)?; - } if self.keep_partition_by_columns { struct_ser.serialize_field("keepPartitionByColumns", &self.keep_partition_by_columns)?; } + if self.insert_op != 0 { + let v = InsertOp::try_from(self.insert_op) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.insert_op)))?; + struct_ser.serialize_field("insertOp", &v)?; + } struct_ser.end() } } @@ -5880,9 +5882,10 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "outputSchema", "table_partition_cols", "tablePartitionCols", - "overwrite", "keep_partition_by_columns", "keepPartitionByColumns", + "insert_op", + "insertOp", ]; #[allow(clippy::enum_variant_names)] @@ -5892,8 +5895,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { TablePaths, OutputSchema, TablePartitionCols, - Overwrite, KeepPartitionByColumns, + InsertOp, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5920,8 +5923,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "overwrite" => Ok(GeneratedField::Overwrite), "keepPartitionByColumns" | "keep_partition_by_columns" => Ok(GeneratedField::KeepPartitionByColumns), + "insertOp" | "insert_op" => Ok(GeneratedField::InsertOp), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5946,8 +5949,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut table_paths__ = None; let mut output_schema__ = None; let mut table_partition_cols__ = None; - let mut overwrite__ = None; let mut keep_partition_by_columns__ = None; + let mut insert_op__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::ObjectStoreUrl => { @@ -5980,18 +5983,18 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } table_partition_cols__ = Some(map_.next_value()?); } - GeneratedField::Overwrite => { - if overwrite__.is_some() { - return Err(serde::de::Error::duplicate_field("overwrite")); - } - overwrite__ = Some(map_.next_value()?); - } GeneratedField::KeepPartitionByColumns => { if keep_partition_by_columns__.is_some() { return Err(serde::de::Error::duplicate_field("keepPartitionByColumns")); } keep_partition_by_columns__ = Some(map_.next_value()?); } + GeneratedField::InsertOp => { + if insert_op__.is_some() { + return Err(serde::de::Error::duplicate_field("insertOp")); + } + insert_op__ = Some(map_.next_value::()? as i32); + } } } Ok(FileSinkConfig { @@ -6000,8 +6003,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { table_paths: table_paths__.unwrap_or_default(), output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), - overwrite: overwrite__.unwrap_or_default(), keep_partition_by_columns: keep_partition_by_columns__.unwrap_or_default(), + insert_op: insert_op__.unwrap_or_default(), }) } } @@ -7198,6 +7201,80 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InsertOp { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for InsertOp { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Append", + "Overwrite", + "Replace", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InsertOp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Append" => Ok(InsertOp::Append), + "Overwrite" => Ok(InsertOp::Overwrite), + "Replace" => Ok(InsertOp::Replace), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for InterleaveExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 21d88e565e80..436347330d92 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1067,10 +1067,10 @@ pub struct FileSinkConfig { pub output_schema: ::core::option::Option, #[prost(message, repeated, tag = "5")] pub table_partition_cols: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "8")] - pub overwrite: bool, #[prost(bool, tag = "9")] pub keep_partition_by_columns: bool, + #[prost(enumeration = "InsertOp", tag = "10")] + pub insert_op: i32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct JsonSink { @@ -1954,6 +1954,35 @@ impl DateUnit { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum InsertOp { + Append = 0, + Overwrite = 1, + Replace = 2, +} +impl InsertOp { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Append" => Some(Self::Append), + "Overwrite" => Some(Self::Overwrite), + "Replace" => Some(Self::Replace), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index b2f92f4b2ee4..20ec5eeaeaf8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; use chrono::{TimeZone, Utc}; +use datafusion_expr::dml::InsertOp; use object_store::path::Path; use object_store::ObjectMeta; @@ -640,13 +641,18 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { Ok((name.clone(), data_type)) }) .collect::>>()?; + let insert_op = match conf.insert_op() { + protobuf::InsertOp::Append => InsertOp::Append, + protobuf::InsertOp::Overwrite => InsertOp::Overwrite, + protobuf::InsertOp::Replace => InsertOp::Replace, + }; Ok(Self { object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, file_groups, table_paths, output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, - overwrite: conf.overwrite, + insert_op, keep_partition_by_columns: conf.keep_partition_by_columns, }) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6981c77228a8..6f6065a1c284 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -642,8 +642,8 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { table_paths, output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, - overwrite: conf.overwrite, keep_partition_by_columns: conf.keep_partition_by_columns, + insert_op: conf.insert_op as i32, }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 8a94f905812c..cd789e06dc3b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1060,6 +1060,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { expr: exprs.swap_remove(0), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug)] diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index db84a08e5b40..025676f790a8 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -27,6 +27,7 @@ use arrow::csv::WriterBuilder; use arrow::datatypes::{Fields, TimeUnit}; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_expr::dml::InsertOp; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; @@ -1143,7 +1144,7 @@ fn roundtrip_json_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(JsonSink::new( @@ -1179,7 +1180,7 @@ fn roundtrip_csv_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(CsvSink::new( @@ -1238,7 +1239,7 @@ fn roundtrip_parquet_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(ParquetSink::new( diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ddafc4e3a03a..20a772cdd088 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -432,6 +432,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { qualifier: None, options: WildcardOptions::default(), }), + FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { + let qualifier = self.object_name_to_table_reference(object_name)?; + // sanity check on qualifier with schema + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + if qualified_indices.is_empty() { + return plan_err!("Invalid qualifier {qualifier}"); + } + Ok(Expr::Wildcard { + qualifier: Some(qualifier), + options: WildcardOptions::default(), + }) + } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 2df8d89c59bc..6d130647a49f 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -181,7 +181,7 @@ pub(crate) type LexOrdering = Vec; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateExternalTable { /// Table name - pub name: String, + pub name: ObjectName, /// Optional schema pub columns: Vec, /// File type (Parquet, NDJSON, CSV, etc) @@ -813,7 +813,7 @@ impl<'a> DFParser<'a> { } let create = CreateExternalTable { - name: table_name.to_string(), + name: table_name, columns, file_type: builder.file_type.unwrap(), location: builder.location.unwrap(), @@ -915,8 +915,9 @@ mod tests { // positive case let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; let display = None; + let name = ObjectName(vec![Ident::from("t")]); let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -932,7 +933,7 @@ mod tests { // positive case: leading space let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' "; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -949,7 +950,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' ;"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -966,7 +967,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS (format.delimiter '|')"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -986,7 +987,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1, p2) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -1013,7 +1014,7 @@ mod tests { ]; for (sql, compression) in sqls { let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -1033,7 +1034,7 @@ mod tests { // positive case: it is ok for parquet files not to have columns specified let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), @@ -1049,7 +1050,7 @@ mod tests { // positive case: it is ok for parquet files to be other than upper case let sql = "CREATE EXTERNAL TABLE t STORED AS parqueT LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), @@ -1065,7 +1066,7 @@ mod tests { // positive case: it is ok for avro files not to have columns specified let sql = "CREATE EXTERNAL TABLE t STORED AS AVRO LOCATION 'foo.avro'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "AVRO".to_string(), location: "foo.avro".into(), @@ -1082,7 +1083,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE IF NOT EXISTS t STORED AS PARQUET LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), @@ -1099,7 +1100,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), make_column_def("p1", DataType::Int(None)), @@ -1132,7 +1133,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1') LOCATION 'blahblah'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "X".to_string(), location: "blahblah".into(), @@ -1149,7 +1150,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1', k2 v2) LOCATION 'blahblah'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "X".to_string(), location: "blahblah".into(), @@ -1188,7 +1189,7 @@ mod tests { ]; for (sql, (asc, nulls_first)) in sqls.iter().zip(expected.into_iter()) { let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -1214,7 +1215,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 ASC, c2 DESC NULLS FIRST) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(display)), make_column_def("c2", DataType::Int(display)), @@ -1253,7 +1254,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 - c2 ASC) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(display)), make_column_def("c2", DataType::Int(display)), @@ -1297,7 +1298,7 @@ mod tests { 'TRUNCATE' 'NO', 'format.has_header' 'true')"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), make_column_def("c2", DataType::Float(None)), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 5cbe1d7c014a..e8defedddf2c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -197,9 +197,9 @@ impl PlannerContext { /// extends the FROM schema, returning the existing one, if any pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> { - self.outer_from_schema = match self.outer_from_schema.as_ref() { - Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)), - None => Some(Arc::clone(schema)), + match self.outer_from_schema.as_mut() { + Some(from_schema) => Arc::make_mut(from_schema).merge(schema), + None => self.outer_from_schema = Some(Arc::clone(schema)), }; Ok(()) } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 29dfe25993f1..656d72d07ba2 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -37,7 +37,7 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -53,7 +53,7 @@ use datafusion_expr::{ TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; -use sqlparser::ast; +use sqlparser::ast::{self, SqliteOnConflict}; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, @@ -665,12 +665,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, ignore, table_alias, - replace_into, + mut replace_into, priority, insert_alias, }) => { - if or.is_some() { - plan_err!("Inserts with or clauses not supported")?; + if let Some(or) = or { + match or { + SqliteOnConflict::Replace => replace_into = true, + _ => plan_err!("Inserts with {or} clause is not supported")?, + } } if partitioned.is_some() { plan_err!("Partitioned inserts not yet supported")?; @@ -698,9 +701,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "Inserts with a table alias not supported: {table_alias:?}" )? }; - if replace_into { - plan_err!("Inserts with a `REPLACE INTO` clause not supported")? - }; if let Some(priority) = priority { plan_err!( "Inserts with a `PRIORITY` clause not supported: {priority:?}" @@ -710,7 +710,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Inserts with an alias not supported")?; } let _ = into; // optional keyword doesn't change behavior - self.insert_to_plan(table_name, columns, source, overwrite) + self.insert_to_plan(table_name, columns, source, overwrite, replace_into) } Statement::Update { table, @@ -1239,8 +1239,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let ordered_exprs = self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; - // External tables do not support schemas at the moment, so the name is just a table name - let name = TableReference::bare(name); + let name = self.object_name_to_table_reference(name)?; let constraints = Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( @@ -1605,6 +1604,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { columns: Vec, source: Box, overwrite: bool, + replace_into: bool, ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; @@ -1707,16 +1707,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; let source = project(source, exprs)?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto + let insert_op = match (overwrite, replace_into) { + (false, false) => InsertOp::Append, + (true, false) => InsertOp::Overwrite, + (false, true) => InsertOp::Replace, + (true, true) => plan_err!("Conflicting insert operations: `overwrite` and `replace_into` cannot both be true")?, }; let plan = LogicalPlan::Dml(DmlStatement::new( table_name, Arc::new(table_schema), - op, + WriteOp::Insert(insert_op), Arc::new(source), )); Ok(plan) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 5c9655a55606..44b591fedef8 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1913,6 +1913,13 @@ fn create_external_table_with_pk() { quick_test(sql, expected); } +#[test] +fn create_external_table_wih_schema() { + let sql = "CREATE EXTERNAL TABLE staging.foo STORED AS CSV LOCATION 'foo.csv'"; + let expected = "CreateExternalTable: Partial { schema: \"staging\", table: \"foo\" }"; + quick_test(sql, expected); +} + #[test] fn create_schema_with_quoted_name() { let sql = "CREATE SCHEMA \"quoted_schema_name\""; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 46327534e7de..a78ade81eeba 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1124,6 +1124,14 @@ SELECT COUNT(*) FROM aggregate_test_100 ---- 100 +query I +SELECT COUNT(aggregate_test_100.*) FROM aggregate_test_100 +---- +100 + +query error Error during planning: Invalid qualifier foo +SELECT COUNT(foo.*) FROM aggregate_test_100 + # csv_query_count_literal query I SELECT COUNT(2) FROM aggregate_test_100 diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 12b097c3d5d1..9ac2ecdce7cc 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -275,3 +275,15 @@ DROP TABLE t; # query should fail with bad column statement error DataFusion error: Error during planning: Column foo is not in schema CREATE EXTERNAL TABLE t STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet' WITH ORDER (foo); + +# Create external table with qualified name should belong to the schema +statement ok +CREATE SCHEMA staging; + +statement ok +CREATE EXTERNAL TABLE staging.foo STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; + +# Create external table with qualified name, but no schema should error +statement error DataFusion error: Error during planning: failed to resolve schema: release +CREATE EXTERNAL TABLE release.bar STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; + diff --git a/datafusion/sqllogictest/test_files/dynamic_file.slt b/datafusion/sqllogictest/test_files/dynamic_file.slt index e177fd3de243..69f9a43ad407 100644 --- a/datafusion/sqllogictest/test_files/dynamic_file.slt +++ b/datafusion/sqllogictest/test_files/dynamic_file.slt @@ -25,9 +25,170 @@ SELECT * FROM '../core/tests/data/partitioned_table_arrow/part=123' ORDER BY f0; 1 foo true 2 bar false -# dynamic file query doesn't support partitioned table -statement error DataFusion error: Error during planning: table 'datafusion.public.../core/tests/data/partitioned_table_arrow' not found -SELECT * FROM '../core/tests/data/partitioned_table_arrow' ORDER BY f0; +# Read partitioned file +statement ok +CREATE TABLE src_table_1 ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + partition_col INT +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 1), +(3, 'ccc', 300, 1), +(4, 'ddd', 400, 1); + +statement ok +CREATE TABLE src_table_2 ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + partition_col INT +) AS VALUES +(5, 'eee', 500, 2), +(6, 'fff', 600, 2), +(7, 'ggg', 700, 2), +(8, 'hhh', 800, 2); + +# Read partitioned csv file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/csv_partitions' +STORED AS CSV +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/csv_partitions' +STORED AS CSV +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/csv_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned json file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/json_partitions' +STORED AS JSON +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/json_partitions' +STORED AS JSON +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/json_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned arrow file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/arrow_partitions' +STORED AS ARROW +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/arrow_partitions' +STORED AS ARROW +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/arrow_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned parquet file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/parquet_partitions' +STORED AS PARQUET +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/parquet_partitions' +STORED AS PARQUET +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +select * from 'test_files/scratch/dynamic_file/parquet_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned parquet file with multiple partition columns + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/nested_partition' +STORED AS PARQUET +PARTITIONED BY (partition_col, string_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/nested_partition' +STORED AS PARQUET +PARTITIONED BY (partition_col, string_col); +---- +4 + +query IITT rowsort +select * from 'test_files/scratch/dynamic_file/nested_partition'; +---- +1 100 1 aaa +2 200 1 bbb +3 300 1 ccc +4 400 1 ddd +5 500 2 eee +6 600 2 fff +7 700 2 ggg +8 800 2 hhh # read avro file query IT diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index e887b1934e04..7d41c26ba012 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -18,46 +18,6 @@ # unicode expressions -query I -SELECT char_length('') ----- -0 - -query I -SELECT char_length('chars') ----- -5 - -query I -SELECT char_length('josé') ----- -4 - -query I -SELECT char_length(NULL) ----- -NULL - -query I -SELECT character_length('') ----- -0 - -query I -SELECT character_length('chars') ----- -5 - -query I -SELECT character_length('josé') ----- -4 - -query I -SELECT character_length(NULL) ----- -NULL - query T SELECT left('abcde', -2) ---- @@ -133,152 +93,6 @@ SELECT length(NULL) ---- NULL -query T -SELECT lpad('hi', -1, 'xy') ----- -(empty) - -query T -SELECT lpad('hi', 5, 'xy') ----- -xyxhi - -query T -SELECT lpad('hi', -1) ----- -(empty) - -query T -SELECT lpad('hi', 0) ----- -(empty) - -query T -SELECT lpad('hi', 21, 'abcdef') ----- -abcdefabcdefabcdefahi - -query T -SELECT lpad('hi', 5, 'xy') ----- -xyxhi - -query T -SELECT lpad('hi', 5, NULL) ----- -NULL - -query T -SELECT lpad('hi', 5) ----- - hi - -query T -SELECT lpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5) ----- - hi - -query T -SELECT lpad('hi', CAST(NULL AS INT), 'xy') ----- -NULL - -query T -SELECT lpad('hi', CAST(NULL AS INT)) ----- -NULL - -query T -SELECT lpad('xyxhi', 3) ----- -xyx - -query T -SELECT lpad(NULL, 0) ----- -NULL - -query T -SELECT lpad(NULL, 5, 'xy') ----- -NULL - -# test largeutf8, utf8view for lpad -query T -SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') ----- -xyxhi - -query T -SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') ----- -xyxhi - -query T -SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) ----- -xyxhi - -query T -SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) ----- -xyxhi - -query T -SELECT lpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') ----- -NULL - -query T -SELECT reverse('abcde') ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'LargeUtf8')) ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'Utf8View')) ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) ----- -edcba - -query T -SELECT reverse('loẅks') ----- -sk̈wol - -query T -SELECT reverse(arrow_cast('loẅks', 'LargeUtf8')) ----- -sk̈wol - -query T -SELECT reverse(arrow_cast('loẅks', 'Utf8View')) ----- -sk̈wol - -query T -SELECT reverse(NULL) ----- -NULL - -query T -SELECT reverse(arrow_cast(NULL, 'LargeUtf8')) ----- -NULL - -query T -SELECT reverse(arrow_cast(NULL, 'Utf8View')) ----- -NULL - query T SELECT right('abcde', -2) ---- @@ -324,124 +138,6 @@ SELECT right(NULL, CAST(NULL AS INT)) ---- NULL - -query T -SELECT rpad('hi', -1, 'xy') ----- -(empty) - -query T -SELECT rpad('hi', 5, 'xy') ----- -hixyx - -query T -SELECT rpad('hi', -1) ----- -(empty) - -query T -SELECT rpad('hi', 0) ----- -(empty) - -query T -SELECT rpad('hi', 21, 'abcdef') ----- -hiabcdefabcdefabcdefa - -query T -SELECT rpad('hi', 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad('hi', 5, NULL) ----- -NULL - -query T -SELECT rpad('hi', 5) ----- -hi - -query T -SELECT rpad('hi', CAST(NULL AS INT), 'xy') ----- -NULL - -query T -SELECT rpad('hi', CAST(NULL AS INT)) ----- -NULL - -query T -SELECT rpad('xyxhi', 3) ----- -xyx - -# test for rpad with largeutf8 and utf8View - -query T -SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) ----- -hixyx - -query T -SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') ----- -NULL - -query I -SELECT strpos('abc', 'c') ----- -3 - -query I -SELECT strpos('josé', 'é') ----- -4 - -query I -SELECT strpos('joséésoj', 'so') ----- -6 - -query I -SELECT strpos('joséésoj', 'abc') ----- -0 - -query I -SELECT strpos(NULL, 'abc') ----- -NULL - -query I -SELECT strpos('joséésoj', NULL) ----- -NULL - query T SELECT substr('alphabet', -3) ---- @@ -796,45 +492,6 @@ SELECT md5(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) ---- acbd18db4cc2f85cedef654fccc4a4d8 -query T -SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT repeat('foo', 3) ----- -foofoofoo - -query T -SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) ----- -foofoofoo - -query T -SELECT replace('foobar', 'bar', 'hello') ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'Utf8View'), arrow_cast('bar', 'Utf8View'), arrow_cast('hello', 'Utf8View')) ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'LargeUtf8'), arrow_cast('bar', 'LargeUtf8'), arrow_cast('hello', 'LargeUtf8')) ----- -foohello query T SELECT rtrim(' foo ') @@ -846,68 +503,6 @@ SELECT rtrim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) ---- foo -query T -SELECT split_part('foo_bar', '_', 2) ----- -bar - -query T -SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) ----- -bar - -# test largeutf8, utf8view for split_part -query T -SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3) ----- -large - -query T -SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3); ----- -view - -query T -SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2) ----- -_split - -query T -SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2) ----- -large - -query T -SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2) ----- -_banana - -query T -SELECT split_part(NULL, '_', 2) ----- -NULL - - -query B -SELECT starts_with('foobar', 'foo') ----- -true - -query B -SELECT starts_with('foobar', 'bar') ----- -false - -query B -SELECT ends_with('foobar', 'bar') ----- -true - -query B -SELECT ends_with('foobar', 'foo') ----- -false - query T SELECT trim(' foo ') ---- @@ -1064,279 +659,6 @@ NULL Thomxas NULL -query I -SELECT levenshtein('kitten', 'sitting') ----- -3 - -query I -SELECT levenshtein('kitten', NULL) ----- -NULL - -query I -SELECT levenshtein(NULL, 'sitting') ----- -NULL - -query I -SELECT levenshtein(NULL, NULL) ----- -NULL - -# Test substring_index using '.' as delimiter -# This query is compatible with MySQL(8.0.19 or later), convenient for comparing results -query TIT -SELECT str, n, substring_index(str, '.', n) AS c FROM - (VALUES - ROW('arrow.apache.org'), - ROW('.'), - ROW('...'), - ROW(NULL) - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(3), - ROW(100), - ROW(-1), - ROW(-2), - ROW(-3), - ROW(-100) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -NULL -100 NULL -NULL -3 NULL -NULL -2 NULL -NULL -1 NULL -NULL 1 NULL -NULL 2 NULL -NULL 3 NULL -NULL 100 NULL -arrow.apache.org -100 arrow.apache.org -arrow.apache.org -3 arrow.apache.org -arrow.apache.org -2 apache.org -arrow.apache.org -1 org -arrow.apache.org 1 arrow -arrow.apache.org 2 arrow.apache -arrow.apache.org 3 arrow.apache.org -arrow.apache.org 100 arrow.apache.org -... -100 ... -... -3 .. -... -2 . -... -1 (empty) -... 1 (empty) -... 2 . -... 3 .. -... 100 ... -. -100 . -. -3 . -. -2 . -. -1 (empty) -. 1 (empty) -. 2 . -. 3 . -. 100 . - -query I -SELECT levenshtein(NULL, NULL) ----- -NULL - -# Test substring_index using '.' as delimiter with utf8view -query TIT -SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM - (VALUES - ROW('arrow.apache.org'), - ROW('.'), - ROW('...'), - ROW(NULL) - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(3), - ROW(100), - ROW(-1), - ROW(-2), - ROW(-3), - ROW(-100) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -NULL -100 NULL -NULL -3 NULL -NULL -2 NULL -NULL -1 NULL -NULL 1 NULL -NULL 2 NULL -NULL 3 NULL -NULL 100 NULL -arrow.apache.org -100 arrow.apache.org -arrow.apache.org -3 arrow.apache.org -arrow.apache.org -2 apache.org -arrow.apache.org -1 org -arrow.apache.org 1 arrow -arrow.apache.org 2 arrow.apache -arrow.apache.org 3 arrow.apache.org -arrow.apache.org 100 arrow.apache.org -... -100 ... -... -3 .. -... -2 . -... -1 (empty) -... 1 (empty) -... 2 . -... 3 .. -... 100 ... -. -100 . -. -3 . -. -2 . -. -1 (empty) -. 1 (empty) -. 2 . -. 3 . -. 100 . - -# Test substring_index using 'ac' as delimiter -query TIT -SELECT str, n, substring_index(str, 'ac', n) AS c FROM - (VALUES - -- input string does not contain the delimiter - ROW('arrow'), - -- input string contains the delimiter - ROW('arrow.apache.org') - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(-1), - ROW(-2) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -arrow.apache.org -2 arrow.apache.org -arrow.apache.org -1 he.org -arrow.apache.org 1 arrow.ap -arrow.apache.org 2 arrow.apache.org -arrow -2 arrow -arrow -1 arrow -arrow 1 arrow -arrow 2 arrow - -# Test substring_index with NULL values -query TTTT -SELECT - substring_index(NULL, '.', 1), - substring_index('arrow.apache.org', NULL, 1), - substring_index('arrow.apache.org', '.', NULL), - substring_index(NULL, NULL, NULL) ----- -NULL NULL NULL NULL - -# Test substring_index with empty strings -query TT -SELECT - -- input string is empty - substring_index('', '.', 1), - -- delimiter is empty - substring_index('arrow.apache.org', '', 1) ----- -(empty) (empty) - -# Test substring_index with 0 occurrence -query T -SELECT substring_index('arrow.apache.org', 'ac', 0) ----- -(empty) - -# Test substring_index with large occurrences -query TT -SELECT - -- i64::MIN - substring_index('arrow.apache.org', '.', -9223372036854775808) as c1, - -- i64::MAX - substring_index('arrow.apache.org', '.', 9223372036854775807) as c2; ----- -arrow.apache.org arrow.apache.org - -# Test substring_index issue https://github.com/apache/datafusion/issues/9472 -query TTT -SELECT - url, - substring_index(url, '.', 1) AS subdomain, - substring_index(url, '.', -1) AS tld -FROM - (VALUES ROW('docs.apache.com'), - ROW('community.influxdata.com'), - ROW('arrow.apache.org') - ) data(url) ----- -docs.apache.com docs com -community.influxdata.com community com -arrow.apache.org arrow org - -# find_in_set tests -query I -SELECT find_in_set('b', 'a,b,c,d') ----- -2 - - -query I -SELECT find_in_set('a', 'a,b,c,d,a') ----- -1 - -query I -SELECT find_in_set('', 'a,b,c,d,a') ----- -0 - -query I -SELECT find_in_set('a', '') ----- -0 - - -query I -SELECT find_in_set('', '') ----- -1 - -query I -SELECT find_in_set(NULL, 'a,b,c,d') ----- -NULL - -query I -SELECT find_in_set('a', NULL) ----- -NULL - - -query I -SELECT find_in_set(NULL, NULL) ----- -NULL - -# find_in_set tests with utf8view -query I -SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d') ----- -2 - - -query I -SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View')) ----- -1 - -query I -SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View')) ----- -0 - # Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 8d801b92c393..519fbb887c7e 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1215,14 +1215,14 @@ statement ok create table t1(v1 int) as values(100); ## Query with Ambiguous column reference -query error DataFusion error: Schema error: Ambiguous reference to unqualified field v1 +query error DataFusion error: Schema error: Schema contains duplicate qualified field name t1\.v1 select count(*) from t1 right outer join t1 on t1.v1 > 0; -query error DataFusion error: Schema error: Ambiguous reference to unqualified field v1 +query error DataFusion error: Schema error: Schema contains duplicate qualified field name t1\.v1 select t1.v1 from t1 join t1 using(v1) cross join (select struct('foo' as v1) as t1); statement ok -drop table t1; \ No newline at end of file +drop table t1; diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 3b2b219244f5..f38281abc5ab 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -58,5 +58,43 @@ WHERE "data"."id" = "samples"."id"; 1 3 + + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select count(distinct name) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select approx_median(distinct id) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +statement ok +select array_agg(distinct id) from table_with_metadata; + +query I +select distinct id from table_with_metadata order by id; +---- +1 +3 +NULL + +query I +select count(id) from table_with_metadata; +---- +2 + +query I +select count(id) cnt from table_with_metadata group by name order by cnt; +---- +0 +1 +1 + + statement ok drop table table_with_metadata; diff --git a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt index 9d2460816709..c181f613ee9a 100644 --- a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt +++ b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt @@ -37,6 +37,22 @@ select arrow_cast(col1, 'Dictionary(Int32, Utf8)') as c1 from test_substr_base; statement ok drop table test_source +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL + # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/string/large_string.slt b/datafusion/sqllogictest/test_files/string/large_string.slt index a2e570073ff6..169c658e5ac1 100644 --- a/datafusion/sqllogictest/test_files/string/large_string.slt +++ b/datafusion/sqllogictest/test_files/string/large_string.slt @@ -43,6 +43,22 @@ Xiangpeng Xiangpeng datafusion数据融合 datafusion数据融合 Raphael R datafusionДатаФусион аФус NULL R NULL 🔥 +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL + # TODO: move it back to `string_query.slt.part` after fixing the issue # https://github.com/apache/datafusion/issues/12618 query BB @@ -56,6 +72,23 @@ false false false true NULL NULL +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12670 +query IIIIII +SELECT + STRPOS(ascii_1, 'e'), + STRPOS(ascii_1, 'ang'), + STRPOS(ascii_1, NULL), + STRPOS(unicode_1, 'и'), + STRPOS(unicode_1, 'ион'), + STRPOS(unicode_1, NULL) +FROM test_basic_operator; +---- +5 0 NULL 0 0 NULL +7 3 NULL 0 0 NULL +6 0 NULL 18 18 NULL +NULL NULL NULL NULL NULL NULL + # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/string/string.slt b/datafusion/sqllogictest/test_files/string/string.slt index bc923d5e12c3..f4e83966f78f 100644 --- a/datafusion/sqllogictest/test_files/string/string.slt +++ b/datafusion/sqllogictest/test_files/string/string.slt @@ -47,6 +47,39 @@ false false false true NULL NULL +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL + +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12670 +query IIIIII +SELECT + STRPOS(ascii_1, 'e'), + STRPOS(ascii_1, 'ang'), + STRPOS(ascii_1, NULL), + STRPOS(unicode_1, 'и'), + STRPOS(unicode_1, 'ион'), + STRPOS(unicode_1, NULL) +FROM test_basic_operator; +---- +5 0 NULL 0 0 NULL +7 3 NULL 0 0 NULL +6 0 NULL 18 18 NULL +NULL NULL NULL NULL NULL NULL + # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 24e03fdb7184..5d847747693d 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -167,3 +167,652 @@ query D select make_date(arrow_cast('2024', 'Utf8View'), arrow_cast('01', 'Utf8View'), arrow_cast('23', 'Utf8View')) ---- 2024-01-23 + +query I +SELECT character_length('') +---- +0 + +query I +SELECT character_length('chars') +---- +5 + +query I +SELECT character_length('josé') +---- +4 + +query I +SELECT character_length(NULL) +---- +NULL + +query B +SELECT ends_with('foobar', 'bar') +---- +true + +query B +SELECT ends_with('foobar', 'foo') +---- +false + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query I +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query I +SELECT levenshtein(NULL, NULL) +---- +NULL + + +query T +SELECT lpad('hi', -1, 'xy') +---- +(empty) + +query T +SELECT lpad('hi', 5, 'xy') +---- +xyxhi + +query T +SELECT lpad('hi', -1) +---- +(empty) + +query T +SELECT lpad('hi', 0) +---- +(empty) + +query T +SELECT lpad('hi', 21, 'abcdef') +---- +abcdefabcdefabcdefahi + +query T +SELECT lpad('hi', 5, 'xy') +---- +xyxhi + +query T +SELECT lpad('hi', 5, NULL) +---- +NULL + +query T +SELECT lpad('hi', 5) +---- + hi + +query T +SELECT lpad('hi', CAST(NULL AS INT), 'xy') +---- +NULL + +query T +SELECT lpad('hi', CAST(NULL AS INT)) +---- +NULL + +query T +SELECT lpad('xyxhi', 3) +---- +xyx + +query T +SELECT lpad(NULL, 0) +---- +NULL + +query T +SELECT lpad(NULL, 5, 'xy') +---- +NULL + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT repeat('foo', 3) +---- +foofoofoo + +query T +SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) +---- +foofoofoo + + +query T +SELECT replace('foobar', 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Utf8View'), arrow_cast('bar', 'Utf8View'), arrow_cast('hello', 'Utf8View')) +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'LargeUtf8'), arrow_cast('bar', 'LargeUtf8'), arrow_cast('hello', 'LargeUtf8')) +---- +foohello + + +query T +SELECT reverse('abcde') +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'LargeUtf8')) +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'Utf8View')) +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) +---- +edcba + +query T +SELECT reverse('loẅks') +---- +sk̈wol + +query T +SELECT reverse(arrow_cast('loẅks', 'LargeUtf8')) +---- +sk̈wol + +query T +SELECT reverse(arrow_cast('loẅks', 'Utf8View')) +---- +sk̈wol + +query T +SELECT reverse(NULL) +---- +NULL + +query T +SELECT reverse(arrow_cast(NULL, 'LargeUtf8')) +---- +NULL + +query T +SELECT reverse(arrow_cast(NULL, 'Utf8View')) +---- +NULL + + +query I +SELECT strpos('abc', 'c') +---- +3 + +query I +SELECT strpos('josé', 'é') +---- +4 + +query I +SELECT strpos('joséésoj', 'so') +---- +6 + +query I +SELECT strpos('joséésoj', 'abc') +---- +0 + +query I +SELECT strpos(NULL, 'abc') +---- +NULL + +query I +SELECT strpos('joséésoj', NULL) +---- +NULL + + + +query T +SELECT rpad('hi', -1, 'xy') +---- +(empty) + +query T +SELECT rpad('hi', 5, 'xy') +---- +hixyx + +query T +SELECT rpad('hi', -1) +---- +(empty) + +query T +SELECT rpad('hi', 0) +---- +(empty) + +query T +SELECT rpad('hi', 21, 'abcdef') +---- +hiabcdefabcdefabcdefa + +query T +SELECT rpad('hi', 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad('hi', 5, NULL) +---- +NULL + +query T +SELECT rpad('hi', 5) +---- +hi + +query T +SELECT rpad('hi', CAST(NULL AS INT), 'xy') +---- +NULL + +query T +SELECT rpad('hi', CAST(NULL AS INT)) +---- +NULL + +query T +SELECT rpad('xyxhi', 3) +---- +xyx + +# test for rpad with largeutf8 and utf8View + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) +---- +hixyx + +query T +SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') +---- +NULL + +query I +SELECT char_length('') +---- +0 + +query I +SELECT char_length('chars') +---- +5 + +query I +SELECT char_length('josé') +---- +4 + +query I +SELECT char_length(NULL) +---- +NULL + +# Test substring_index using '.' as delimiter +# This query is compatible with MySQL(8.0.19 or later), convenient for comparing results +query TIT +SELECT str, n, substring_index(str, '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + +# Test substring_index using '.' as delimiter with utf8view +query TIT +SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + +# Test substring_index using 'ac' as delimiter +query TIT +SELECT str, n, substring_index(str, 'ac', n) AS c FROM + (VALUES + -- input string does not contain the delimiter + ROW('arrow'), + -- input string contains the delimiter + ROW('arrow.apache.org') + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(-1), + ROW(-2) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +arrow.apache.org -2 arrow.apache.org +arrow.apache.org -1 he.org +arrow.apache.org 1 arrow.ap +arrow.apache.org 2 arrow.apache.org +arrow -2 arrow +arrow -1 arrow +arrow 1 arrow +arrow 2 arrow + +# Test substring_index with NULL values +query TTTT +SELECT + substring_index(NULL, '.', 1), + substring_index('arrow.apache.org', NULL, 1), + substring_index('arrow.apache.org', '.', NULL), + substring_index(NULL, NULL, NULL) +---- +NULL NULL NULL NULL + +# Test substring_index with empty strings +query TT +SELECT + -- input string is empty + substring_index('', '.', 1), + -- delimiter is empty + substring_index('arrow.apache.org', '', 1) +---- +(empty) (empty) + +# Test substring_index with 0 occurrence +query T +SELECT substring_index('arrow.apache.org', 'ac', 0) +---- +(empty) + +# Test substring_index with large occurrences +query TT +SELECT + -- i64::MIN + substring_index('arrow.apache.org', '.', -9223372036854775808) as c1, + -- i64::MAX + substring_index('arrow.apache.org', '.', 9223372036854775807) as c2; +---- +arrow.apache.org arrow.apache.org + +# Test substring_index issue https://github.com/apache/datafusion/issues/9472 +query TTT +SELECT + url, + substring_index(url, '.', 1) AS subdomain, + substring_index(url, '.', -1) AS tld +FROM + (VALUES ROW('docs.apache.com'), + ROW('community.influxdata.com'), + ROW('arrow.apache.org') + ) data(url) +---- +docs.apache.com docs com +community.influxdata.com community com +arrow.apache.org arrow org + + +# find_in_set tests +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query I +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query I +SELECT find_in_set(NULL, NULL) +---- +NULL + +# find_in_set tests with utf8view +query I +SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +1 + +query I +SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +0 + + +query T +SELECT split_part('foo_bar', '_', 2) +---- +bar + +query T +SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) +---- +bar + +# test largeutf8, utf8view for split_part +query T +SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3); +---- +view + +query T +SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2) +---- +_split + +query T +SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2) +---- +_banana + +query T +SELECT split_part(NULL, '_', 2) +---- +NULL + +query B +SELECT starts_with('foobar', 'foo') +---- +true + +query B +SELECT starts_with('foobar', 'bar') +---- +false diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index 96d5ddbd992c..3ba2b31bbab2 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -694,3 +694,295 @@ Andrew nice Andrew and X datafusion📊🔥 cool datafusion📊🔥 and 🔥 And Xiangpeng nice Xiangpeng and Xiangpeng datafusion数据融合 cool datafusion数据融合 and datafusion数据融合 Xiangpeng 🔥 datafusion数据融合 Raphael nice Raphael and R datafusionДатаФусион cool datafusionДатаФусион and аФус Raphael 🔥 datafusionДатаФусион NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test LIKE / ILIKE +# -------------------------------------- + +# TODO: StringView has wrong behavior for LIKE/ILIKE. Enable this after fixing the issue +# see issue: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +#query TTBBBB +#select ascii_1, unicode_1, +# ascii_1 like 'An%' as ascii_like, +# unicode_1 like '%ion数据%' as unicode_like, +# ascii_1 ilike 'An%' as ascii_ilike, +# unicode_1 ilike '%ion数据%' as unicode_ilik +#from test_basic_operator; +#---- +#Andrew datafusion📊🔥 true false true false +#Xiangpeng datafusion数据融合 false true false true +#Raphael datafusionДатаФусион false false false false +#NULL NULL NULL NULL NULL NULL + +# Test pattern without wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An' as ascii_like, + unicode_1 like 'ion数据' as unicode_like, + ascii_1 ilike 'An' as ascii_ilike, + unicode_1 ilike 'ion数据' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 false false false false +Xiangpeng datafusion数据融合 false false false false +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test CHARACTER_LENGTH +# -------------------------------------- + +query II +SELECT + CHARACTER_LENGTH(ascii_1), + CHARACTER_LENGTH(unicode_1) +FROM + test_basic_operator +---- +6 12 +9 14 +7 20 +NULL NULL + +# -------------------------------------- +# Test Start_With +# -------------------------------------- + +query BBBB +SELECT + STARTS_WITH(ascii_1, 'And'), + STARTS_WITH(unicode_1, 'data'), + STARTS_WITH(ascii_1, NULL), + STARTS_WITH(unicode_1, NULL) +FROM test_basic_operator; +---- +true true NULL NULL +false true NULL NULL +false true NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test ENDS_WITH +# -------------------------------------- + +query BBBB +SELECT + ENDS_WITH(ascii_1, 'w'), + ENDS_WITH(unicode_1, 'ион'), + ENDS_WITH(ascii_1, NULL), + ENDS_WITH(unicode_1, NULL) +FROM test_basic_operator; +---- +true false NULL NULL +false false NULL NULL +false true NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test LEVENSHTEIN +# -------------------------------------- + +query IIII +SELECT + LEVENSHTEIN(ascii_1, 'Andrew'), + LEVENSHTEIN(unicode_1, 'datafusion数据融合'), + LEVENSHTEIN(ascii_1, NULL), + LEVENSHTEIN(unicode_1, NULL) +FROM test_basic_operator; +---- +0 4 NULL NULL +7 0 NULL NULL +6 10 NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test LPAD +# -------------------------------------- + +query TTTT +SELECT + LPAD(ascii_1, 20, 'x'), + LPAD(ascii_1, 20, NULL), + LPAD(unicode_1, 20, '🔥'), + LPAD(unicode_1, 20, NULL) +FROM test_basic_operator; +---- +xxxxxxxxxxxxxxAndrew NULL 🔥🔥🔥🔥🔥🔥🔥🔥datafusion📊🔥 NULL +xxxxxxxxxxxXiangpeng NULL 🔥🔥🔥🔥🔥🔥datafusion数据融合 NULL +xxxxxxxxxxxxxRaphael NULL datafusionДатаФусион NULL +NULL NULL NULL NULL + +query TT +SELECT + LPAD(ascii_1, 20), + LPAD(unicode_1, 20) +FROM test_basic_operator; +---- + Andrew datafusion📊🔥 + Xiangpeng datafusion数据融合 + Raphael datafusionДатаФусион +NULL NULL + +# -------------------------------------- +# Test RPAD +# -------------------------------------- + +query TTTT +SELECT + RPAD(ascii_1, 20, 'x'), + RPAD(ascii_1, 20, NULL), + RPAD(unicode_1, 20, '🔥'), + RPAD(unicode_1, 20, NULL) +FROM test_basic_operator; +---- +Andrewxxxxxxxxxxxxxx NULL datafusion📊🔥🔥🔥🔥🔥🔥🔥🔥🔥 NULL +Xiangpengxxxxxxxxxxx NULL datafusion数据融合🔥🔥🔥🔥🔥🔥 NULL +Raphaelxxxxxxxxxxxxx NULL datafusionДатаФусион NULL +NULL NULL NULL NULL + +query TT +SELECT + RPAD(ascii_1, 20), + RPAD(unicode_1, 20) +FROM test_basic_operator; +---- +Andrew datafusion📊🔥 +Xiangpeng datafusion数据融合 +Raphael datafusionДатаФусион +NULL NULL + +# -------------------------------------- +# Test REGEXP_LIKE +# -------------------------------------- + +query BBBBBBBB +SELECT + -- without flags + REGEXP_LIKE(ascii_1, 'an'), + REGEXP_LIKE(unicode_1, 'таФ'), + REGEXP_LIKE(ascii_1, NULL), + REGEXP_LIKE(unicode_1, NULL), + -- with flags + REGEXP_LIKE(ascii_1, 'AN', 'i'), + REGEXP_LIKE(unicode_1, 'ТаФ', 'i'), + REGEXP_LIKE(ascii_1, NULL, 'i'), + REGEXP_LIKE(unicode_1, NULL, 'i') + FROM test_basic_operator; +---- +false false NULL NULL true false NULL NULL +true false NULL NULL true false NULL NULL +false true NULL NULL false true NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REGEXP_MATCH +# -------------------------------------- + +query ???????? +SELECT + -- without flags + REGEXP_MATCH(ascii_1, 'an'), + REGEXP_MATCH(unicode_1, 'ТаФ'), + REGEXP_MATCH(ascii_1, NULL), + REGEXP_MATCH(unicode_1, NULL), + -- with flags + REGEXP_MATCH(ascii_1, 'AN', 'i'), + REGEXP_MATCH(unicode_1, 'таФ', 'i'), + REGEXP_MATCH(ascii_1, NULL, 'i'), + REGEXP_MATCH(unicode_1, NULL, 'i') +FROM test_basic_operator; +---- +NULL NULL NULL NULL [An] NULL NULL NULL +[an] NULL NULL NULL [an] NULL NULL NULL +NULL NULL NULL NULL NULL [таФ] NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REPEAT +# -------------------------------------- + +query TT +SELECT + REPEAT(ascii_1, 3), + REPEAT(unicode_1, 3) +FROM test_basic_operator; +---- +AndrewAndrewAndrew datafusion📊🔥datafusion📊🔥datafusion📊🔥 +XiangpengXiangpengXiangpeng datafusion数据融合datafusion数据融合datafusion数据融合 +RaphaelRaphaelRaphael datafusionДатаФусионdatafusionДатаФусионdatafusionДатаФусион +NULL NULL + +# -------------------------------------- +# Test SPLIT_PART +# -------------------------------------- + +query TTTTTT +SELECT + SPLIT_PART(ascii_1, 'e', 1), + SPLIT_PART(ascii_1, 'e', 2), + SPLIT_PART(ascii_1, NULL, 1), + SPLIT_PART(unicode_1, 'и', 1), + SPLIT_PART(unicode_1, 'и', 2), + SPLIT_PART(unicode_1, NULL, 1) +FROM test_basic_operator; +---- +Andr w NULL datafusion📊🔥 (empty) NULL +Xiangp ng NULL datafusion数据融合 (empty) NULL +Rapha l NULL datafusionДатаФус он NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REVERSE +# -------------------------------------- + +query TT +SELECT + REVERSE(ascii_1), + REVERSE(unicode_1) +FROM test_basic_operator; +---- +werdnA 🔥📊noisufatad +gnepgnaiX 合融据数noisufatad +leahpaR ноисуФатаДnoisufatad +NULL NULL + +# -------------------------------------- +# Test STRPOS +# -------------------------------------- + +# TODO: DictionaryString does not support STRPOS. Enable this after fixing the issue +# see issue: https://github.com/apache/datafusion/issues/12670 +#query IIIIII +#SELECT +# STRPOS(ascii_1, 'e'), +# STRPOS(ascii_1, 'ang'), +# STRPOS(ascii_1, NULL), +# STRPOS(unicode_1, 'и'), +# STRPOS(unicode_1, 'ион'), +# STRPOS(unicode_1, NULL) +#FROM test_basic_operator; +#---- +#5 0 NULL 0 0 NULL +#7 3 NULL 0 0 NULL +#6 0 NULL 18 18 NULL +#NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test SUBSTR_INDEX +# -------------------------------------- + +query TTTTTT +SELECT + SUBSTR_INDEX(ascii_1, 'e', 1), + SUBSTR_INDEX(ascii_1, 'ang', 1), + SUBSTR_INDEX(ascii_1, NULL, 1), + SUBSTR_INDEX(unicode_1, 'и', 1), + SUBSTR_INDEX(unicode_1, '据融', 1), + SUBSTR_INDEX(unicode_1, NULL, 1) +FROM test_basic_operator; +---- +Andr Andrew NULL datafusion📊🔥 datafusion📊🔥 NULL +Xiangp Xi NULL datafusion数据融合 datafusion数 NULL +Rapha Raphael NULL datafusionДатаФус datafusionДатаФусион NULL +NULL NULL NULL NULL NULL NULL diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index e7b55c9c1c8c..4e7857ad804b 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -50,6 +50,23 @@ false false false true NULL NULL +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12670 +query IIIIII +SELECT + STRPOS(ascii_1, 'e'), + STRPOS(ascii_1, 'ang'), + STRPOS(ascii_1, NULL), + STRPOS(unicode_1, 'и'), + STRPOS(unicode_1, 'ион'), + STRPOS(unicode_1, NULL) +FROM test_basic_operator; +---- +5 0 NULL 0 0 NULL +7 3 NULL 0 0 NULL +6 0 NULL 18 18 NULL +NULL NULL NULL NULL NULL NULL + # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7fee84f9bcd9..cb6c6a5ace76 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4894,3 +4894,42 @@ NULL a4 5 statement ok drop table t + +## test handle NULL and 0 value of nth_value +statement ok +CREATE TABLE t(v1 int, v2 int); + +statement ok +INSERT INTO t VALUES (1,1), (1,2),(1,3),(2,1),(2,2); + +query II +SELECT v1, NTH_VALUE(v2, null) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query II +SELECT v1, NTH_VALUE(v2, v2*null) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query II +SELECT v1, NTH_VALUE(v2, 0) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +statement ok +DROP TABLE t; + +## end test handle NULL and 0 of NTH_VALUE \ No newline at end of file diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f7686bec5435..3b7d0fd29610 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -149,6 +149,10 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { fn dyn_ord(&self, _: &dyn UserDefinedLogicalNode) -> Option { unimplemented!() } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } impl MockUserDefinedLogicalPlan { diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 8f8983061eb6..7c975055d152 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -96,6 +96,7 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine +- [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) @@ -124,7 +125,6 @@ Here are some active projects using DataFusion: Here are some less active projects that used DataFusion: - [bdt](https://github.com/datafusion-contrib/bdt) Boring Data Tool -- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core - [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion - [Flock](https://github.com/flock-lab/flock) diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 0e974550a84d..18c95cdea70e 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -97,7 +97,7 @@ select arrow_cast(now(), 'Timestamp(Second, None)'); | `BYTEA` | `Binary` | You can create binary literals using a hex string literal such as -`X'1234` to create a `Binary` value of two bytes, `0x12` and `0x34`. +`X'1234'` to create a `Binary` value of two bytes, `0x12` and `0x34`. ## Unsupported SQL Types