From 4773c6ac5889298d4efced7a3584d463b7c459c5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:06:32 -0600 Subject: [PATCH 01/19] update dependency version --- native/Cargo.lock | 133 ++++++++++++++++++++-------------------------- native/Cargo.toml | 16 +++--- 2 files changed, 66 insertions(+), 83 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index a3f6f6d30..44400e24b 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -346,13 +346,13 @@ checksum = "0c24e9d990669fbd16806bff449e4ac644fd9b1fca014760087732fe4102f131" [[package]] name = "async-trait" -version = "0.1.81" +version = "0.1.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" +checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -463,9 +463,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.17.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fd4c6dcc3b0aea2f5c0b4b82c2b15fe39ddbc76041a310848f4706edf76bb31" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" [[package]] name = "byteorder" @@ -487,9 +487,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.14" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d2eb3cd3d1bf4529e31c215ee6f93ec5a3d536d9f578f93d9d33ee19562932" +checksum = "e9d013ecb737093c0e86b151a7b837993cf9ec6c502946cfb44bedc392421e0b" dependencies = [ "jobserver", "libc", @@ -659,9 +659,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "core-foundation-sys" @@ -671,9 +671,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpp_demangle" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8227005286ec39567949b33df9896bcadfa6051bccca2488129f108ca23119" +checksum = "96e58d342ad113c2b878f16d5d034c03be492ae460cdbc02b7f0f2284d310c7d" dependencies = [ "cfg-if", ] @@ -811,7 +811,7 @@ dependencies = [ [[package]] name = "datafusion" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -833,7 +833,6 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", - "datafusion-physical-expr-functions-aggregate", "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", @@ -860,7 +859,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow-schema", "async-trait", @@ -956,7 +955,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -978,7 +977,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "log", "tokio", @@ -987,7 +986,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow", "chrono", @@ -1007,7 +1006,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -1028,7 +1027,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow", "datafusion-common", @@ -1038,7 +1037,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow", "arrow-buffer", @@ -1064,7 +1063,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -1084,7 +1083,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -1097,7 +1096,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow", "arrow-array", @@ -1119,7 +1118,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "datafusion-common", "datafusion-expr", @@ -1130,7 +1129,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow", "async-trait", @@ -1149,7 +1148,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -1180,7 +1179,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -1190,25 +1189,10 @@ dependencies = [ "rand", ] -[[package]] -name = "datafusion-physical-expr-functions-aggregate" -version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" -dependencies = [ - "ahash", - "arrow", - "datafusion-common", - "datafusion-expr", - "datafusion-expr-common", - "datafusion-functions-aggregate-common", - "datafusion-physical-expr-common", - "rand", -] - [[package]] name = "datafusion-physical-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "datafusion-common", "datafusion-execution", @@ -1220,7 +1204,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "ahash", "arrow", @@ -1238,7 +1222,6 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", - "datafusion-physical-expr-functions-aggregate", "futures", "half", "hashbrown", @@ -1255,7 +1238,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" dependencies = [ "arrow", "arrow-array", @@ -1448,7 +1431,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -1624,9 +1607,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown", @@ -2128,9 +2111,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.3" +version = "0.36.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" +checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" dependencies = [ "memchr", ] @@ -2435,7 +2418,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2556,9 +2539,9 @@ checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "rgb" -version = "0.8.48" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f86ae463694029097b846d8f99fd5536740602ae00022c0c50c5600720b2f71" +checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" dependencies = [ "bytemuck", ] @@ -2571,18 +2554,18 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f" dependencies = [ "bitflags 2.6.0", "errno", @@ -2657,7 +2640,7 @@ checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2775,7 +2758,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2815,7 +2798,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2826,9 +2809,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbolic-common" -version = "12.10.0" +version = "12.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16629323a4ec5268ad23a575110a724ad4544aae623451de600c747bf87b36cf" +checksum = "9c1db5ac243c7d7f8439eb3b8f0357888b37cf3732957e91383b0ad61756374e" dependencies = [ "debugid", "memmap2", @@ -2838,9 +2821,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.10.0" +version = "12.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c043a45f08f41187414592b3ceb53fb0687da57209cc77401767fb69d5b596" +checksum = "ea26e430c27d4a8a5dea4c4b81440606c7c1a415bd611451ef6af8c81416afc3" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -2860,9 +2843,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.76" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -2899,7 +2882,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2959,9 +2942,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.3" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", @@ -2977,7 +2960,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2999,7 +2982,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -3149,7 +3132,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", "wasm-bindgen-shared", ] @@ -3171,7 +3154,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3410,7 +3393,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] diff --git a/native/Cargo.toml b/native/Cargo.toml index 33711b1df..92707b141 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -39,14 +39,14 @@ arrow-buffer = { version = "52.2.0" } arrow-data = { version = "52.2.0" } arrow-schema = { version = "52.2.0" } parquet = { version = "52.2.0", default-features = false, features = ["experimental"] } -datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "dff590b" } -datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "dff590b", features = ["unicode_expressions", "crypto_expressions"] } -datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", features = ["crypto_expressions"] } -datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5" } +datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", features = ["unicode_expressions", "crypto_expressions"] } +datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", features = ["crypto_expressions"] } +datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } +datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } +datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } +datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } datafusion-comet-spark-expr = { path = "spark-expr", version = "0.3.0" } datafusion-comet-proto = { path = "proto", version = "0.3.0" } chrono = { version = "0.4", default-features = false, features = ["clock"] } From 4098e971b15d7075319550fa1a7c7fe93c2b677a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:16:00 -0600 Subject: [PATCH 02/19] update avg --- .../execution/datafusion/expressions/avg.rs | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/avg.rs b/native/core/src/execution/datafusion/expressions/avg.rs index 5e7b555c8..792de4b30 100644 --- a/native/core/src/execution/datafusion/expressions/avg.rs +++ b/native/core/src/execution/datafusion/expressions/avg.rs @@ -25,20 +25,23 @@ use arrow_array::{ }; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{ - type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, + type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, Signature, }; use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; - +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::Volatility::Immutable; use DataType::*; /// AVG aggregate expression #[derive(Debug, Clone)] pub struct Avg { name: String, + signature: Signature, expr: Arc, input_data_type: DataType, result_data_type: DataType, @@ -51,6 +54,7 @@ impl Avg { Self { name: name.into(), + signature: Signature::user_defined(Immutable), expr, input_data_type: data_type, result_data_type, @@ -58,17 +62,13 @@ impl Avg { } } -impl AggregateExpr for Avg { +impl AggregateUDFImpl for Avg { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.result_data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::::default()), @@ -80,7 +80,7 @@ impl AggregateExpr for Avg { } } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "sum"), @@ -95,19 +95,18 @@ impl AggregateExpr for Avg { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - fn name(&self) -> &str { &self.name } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( @@ -126,6 +125,14 @@ impl AggregateExpr for Avg { fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Float64(None)) } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } } impl PartialEq for Avg { From fbeaf97875c62dbecf6cfb71e05f053d73800fbd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:20:02 -0600 Subject: [PATCH 03/19] update avg_decimal --- .../execution/datafusion/expressions/avg.rs | 6 ++- .../datafusion/expressions/avg_decimal.rs | 44 +++++++++++-------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/avg.rs b/native/core/src/execution/datafusion/expressions/avg.rs index 792de4b30..96332077f 100644 --- a/native/core/src/execution/datafusion/expressions/avg.rs +++ b/native/core/src/execution/datafusion/expressions/avg.rs @@ -33,8 +33,8 @@ use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::AggregateUDFImpl; use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; use DataType::*; /// AVG aggregate expression @@ -99,6 +99,10 @@ impl AggregateUDFImpl for Avg { &self.name } + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index b035d3244..40439c26a 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -24,16 +24,19 @@ use arrow_array::{ Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray, }; use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator, Signature}; use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; use arrow_data::decimal::{ validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; - +use datafusion::logical_expr::Volatility::Immutable; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::avg_return_type; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; use num::{integer::div_ceil, Integer}; use DataType::*; @@ -41,6 +44,7 @@ use DataType::*; #[derive(Debug, Clone)] pub struct AvgDecimal { name: String, + signature: Signature, expr: Arc, sum_data_type: DataType, result_data_type: DataType, @@ -56,6 +60,7 @@ impl AvgDecimal { ) -> Self { Self { name: name.into(), + signature: Signature::user_defined(Immutable), expr, result_data_type: result_type, sum_data_type: sum_type, @@ -63,17 +68,13 @@ impl AvgDecimal { } } -impl AggregateExpr for AvgDecimal { +impl AggregateUDFImpl for AvgDecimal { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.result_data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { match (&self.sum_data_type, &self.result_data_type) { (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => { Ok(Box::new(AvgDecimalAccumulator::new( @@ -91,7 +92,7 @@ impl AggregateExpr for AvgDecimal { } } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "sum"), @@ -106,23 +107,22 @@ impl AggregateExpr for AvgDecimal { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - fn name(&self) -> &str { &self.name } - fn reverse_expr(&self) -> Option> { - None + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { // instantiate specialized accumulator based for the type match (&self.sum_data_type, &self.result_data_type) { (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => { @@ -154,6 +154,14 @@ impl AggregateExpr for AvgDecimal { ), } } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } } impl PartialEq for AvgDecimal { From 1fa346df37c4f549b30da3be7af66cecaa265664 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:26:05 -0600 Subject: [PATCH 04/19] update sum_decimal --- .../expressions/bloom_filter_might_contain.rs | 3 +- .../datafusion/expressions/sum_decimal.rs | 48 +++++++++++-------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs index b66ca5b2c..462a22247 100644 --- a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs +++ b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -21,9 +21,10 @@ use crate::{ use arrow::record_batch::RecordBatch; use arrow_array::cast::as_primitive_array; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::physical_plan::ColumnarValue; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr}; +use datafusion_physical_expr::PhysicalExpr; use std::{ any::Any, fmt::Display, diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 37030b67a..ae2d384da 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::unlikely; use arrow::{ array::BooleanBufferBuilder, buffer::{BooleanBuffer, NullBuffer}, @@ -26,14 +27,17 @@ use arrow_data::decimal::validate_decimal_precision; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; use datafusion_common::{Result as DFResult, ScalarValue}; -use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; +use datafusion_physical_expr::PhysicalExpr; use std::{any::Any, ops::BitAnd, sync::Arc}; - -use crate::unlikely; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; #[derive(Debug)] pub struct SumDecimal { name: String, + signature: Signature, expr: Arc, /// The data type of the SUM result @@ -56,6 +60,7 @@ impl SumDecimal { }; Self { name: name.into(), + signature: Signature::user_defined(Immutable), expr, result_type: data_type, precision, @@ -65,27 +70,19 @@ impl SumDecimal { } } -impl AggregateExpr for SumDecimal { +impl AggregateUDFImpl for SumDecimal { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> DFResult { - Ok(Field::new( - &self.name, - self.result_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> DFResult> { + fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { Ok(Box::new(SumDecimalAccumulator::new( self.precision, self.scale, ))) } - fn state_fields(&self) -> DFResult> { + fn state_fields(&self, args: StateFieldsArgs) -> DFResult> { let fields = vec![ Field::new(&self.name, self.result_type.clone(), self.nullable), Field::new("is_empty", DataType::Boolean, false), @@ -93,19 +90,26 @@ impl AggregateExpr for SumDecimal { Ok(fields) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - fn name(&self) -> &str { &self.name } - fn groups_accumulator_supported(&self) -> bool { + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DFResult { + Ok(self.result_type.clone()) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> DFResult> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { Ok(Box::new(SumDecimalGroupsAccumulator::new( self.result_type.clone(), self.precision, @@ -119,6 +123,10 @@ impl AggregateExpr for SumDecimal { &DataType::Decimal128(self.precision, self.scale), ) } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } } impl PartialEq for SumDecimal { From a946ce4f0b3c9a8c407c3cafc6127b7fd589d30f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:34:12 -0600 Subject: [PATCH 05/19] variance --- .../datafusion/expressions/negative.rs | 2 +- .../datafusion/expressions/sum_decimal.rs | 2 +- .../datafusion/expressions/variance.rs | 50 +++++++++++++------ 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/negative.rs b/native/core/src/execution/datafusion/expressions/negative.rs index fbcd194f0..8dfe71742 100644 --- a/native/core/src/execution/datafusion/expressions/negative.rs +++ b/native/core/src/execution/datafusion/expressions/negative.rs @@ -21,6 +21,7 @@ use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeT use arrow_array::RecordBatch; use arrow_buffer::IntervalDayTime; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::{ logical_expr::{interval_arithmetic::Interval, ColumnarValue}, physical_expr::PhysicalExpr, @@ -28,7 +29,6 @@ use datafusion::{ use datafusion_comet_spark_expr::SparkError; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::sort_properties::ExprProperties; -use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::{ any::Any, hash::{Hash, Hasher}, diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index ae2d384da..40547a446 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -26,13 +26,13 @@ use arrow_array::{ use arrow_data::decimal::validate_decimal_precision; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; use datafusion_physical_expr::PhysicalExpr; use std::{any::Any, ops::BitAnd, sync::Arc}; -use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; #[derive(Debug)] pub struct SumDecimal { diff --git a/native/core/src/execution/datafusion/expressions/variance.rs b/native/core/src/execution/datafusion/expressions/variance.rs index 5cfbf2947..692c5c795 100644 --- a/native/core/src/execution/datafusion/expressions/variance.rs +++ b/native/core/src/execution/datafusion/expressions/variance.rs @@ -24,7 +24,10 @@ use arrow::{ }; use datafusion::logical_expr::Accumulator; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, Signature}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// VAR_SAMP and VAR_POP aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -34,6 +37,7 @@ use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, Ph #[derive(Debug)] pub struct Variance { name: String, + signature: Signature, expr: Arc, stats_type: StatsType, null_on_divide_by_zero: bool, @@ -52,6 +56,7 @@ impl Variance { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::numeric(1, Immutable), expr, stats_type, null_on_divide_by_zero, @@ -59,31 +64,54 @@ impl Variance { } } -impl AggregateExpr for Variance { +impl AggregateUDFImpl for Variance { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature } - fn create_accumulator(&self) -> Result> { + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(VarianceAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn create_sliding_accumulator(&self) -> Result> { + /* + fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { + !acc_args.is_distinct + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(VarianceGroupsAccumulator::new( + StatsType::Population, + ))) + } + */ + + fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result> { Ok(Box::new(VarianceAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -99,14 +127,6 @@ impl AggregateExpr for Variance { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } - fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Float64(None)) } From 5c674a641028e6a0496ca1b4f5d1f1cc9461f04e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:37:20 -0600 Subject: [PATCH 06/19] stddev --- .../datafusion/expressions/stddev.rs | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index bc96a5680..9ed78645c 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -26,7 +26,9 @@ use arrow::{ }; use datafusion::logical_expr::Accumulator; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature, Volatility}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -36,6 +38,7 @@ use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, Ph #[derive(Debug)] pub struct Stddev { name: String, + signature: Signature, expr: Arc, stats_type: StatsType, null_on_divide_by_zero: bool, @@ -54,6 +57,10 @@ impl Stddev { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), expr, stats_type, null_on_divide_by_zero, @@ -61,31 +68,52 @@ impl Stddev { } } -impl AggregateExpr for Stddev { +impl AggregateUDFImpl for Stddev { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature } - fn create_accumulator(&self) -> Result> { + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(StddevAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn create_sliding_accumulator(&self) -> Result> { + /* + fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { + !acc_args.is_distinct + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample))) + } + */ + + fn create_sliding_accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(StddevAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -101,14 +129,6 @@ impl AggregateExpr for Stddev { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } - fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Float64(None)) } From 9474f2d38b7168304d8da78c694183c93a636208 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:39:52 -0600 Subject: [PATCH 07/19] covariance --- .../datafusion/expressions/covariance.rs | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/covariance.rs b/native/core/src/execution/datafusion/expressions/covariance.rs index 11b345a22..0be28def1 100644 --- a/native/core/src/execution/datafusion/expressions/covariance.rs +++ b/native/core/src/execution/datafusion/expressions/covariance.rs @@ -26,11 +26,15 @@ use arrow::{ datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_physical_expr::{ - aggregate::utils::down_cast_any_ref, expressions::format_state_name, AggregateExpr, + expressions::format_state_name, PhysicalExpr, }; @@ -41,6 +45,7 @@ use datafusion_physical_expr::{ #[derive(Debug, Clone)] pub struct Covariance { name: String, + signature: Signature, expr1: Arc, expr2: Arc, stats_type: StatsType, @@ -61,6 +66,7 @@ impl Covariance { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), expr1, expr2, stats_type, @@ -69,24 +75,35 @@ impl Covariance { } } -impl AggregateExpr for Covariance { +impl AggregateUDFImpl for Covariance { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(CovarianceAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -111,17 +128,7 @@ impl AggregateExpr for Covariance { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr1), Arc::clone(&self.expr2)] - } - - fn name(&self) -> &str { - &self.name - } - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Float64(None)) - } } impl PartialEq for Covariance { From be6b032c9db12309a59335305f8f71d08b5a12a4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:42:13 -0600 Subject: [PATCH 08/19] correlation --- .../datafusion/expressions/correlation.rs | 40 ++++++++++--------- .../datafusion/expressions/covariance.rs | 9 +---- .../datafusion/expressions/stddev.rs | 12 +++--- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/correlation.rs b/native/core/src/execution/datafusion/expressions/correlation.rs index 3dcf6cca8..e95e664a6 100644 --- a/native/core/src/execution/datafusion/expressions/correlation.rs +++ b/native/core/src/execution/datafusion/expressions/correlation.rs @@ -29,7 +29,10 @@ use arrow::{ }; use datafusion::logical_expr::Accumulator; use datafusion_common::{Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// CORR aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -39,6 +42,7 @@ use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, Ph #[derive(Debug)] pub struct Correlation { name: String, + signature: Signature, expr1: Arc, expr2: Arc, null_on_divide_by_zero: bool, @@ -56,6 +60,7 @@ impl Correlation { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), expr1, expr2, null_on_divide_by_zero, @@ -63,23 +68,34 @@ impl Correlation { } } -impl AggregateExpr for Correlation { +impl AggregateUDFImpl for Correlation { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(CorrelationAccumulator::try_new( self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -113,18 +129,6 @@ impl AggregateExpr for Correlation { ), ]) } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr1), Arc::clone(&self.expr2)] - } - - fn name(&self) -> &str { - &self.name - } - - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Float64(None)) - } } impl PartialEq for Correlation { diff --git a/native/core/src/execution/datafusion/expressions/covariance.rs b/native/core/src/execution/datafusion/expressions/covariance.rs index 0be28def1..0b9fbe7b1 100644 --- a/native/core/src/execution/datafusion/expressions/covariance.rs +++ b/native/core/src/execution/datafusion/expressions/covariance.rs @@ -30,13 +30,10 @@ use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_physical_expr::{ - expressions::format_state_name, - PhysicalExpr, -}; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// COVAR_SAMP and COVAR_POP aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -127,8 +124,6 @@ impl AggregateUDFImpl for Covariance { ), ]) } - - } impl PartialEq for Covariance { diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index 9ed78645c..d8a11f78d 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -26,8 +26,8 @@ use arrow::{ }; use datafusion::logical_expr::Accumulator; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature, Volatility}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature, Volatility}; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression @@ -57,10 +57,7 @@ impl Stddev { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), - signature: Signature::coercible( - vec![DataType::Float64], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![DataType::Float64], Volatility::Immutable), expr, stats_type, null_on_divide_by_zero, @@ -106,7 +103,10 @@ impl AggregateUDFImpl for Stddev { } */ - fn create_sliding_accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + fn create_sliding_accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(StddevAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, From cb0d86ee408c090a51a457e46d85ddbb7ece8f57 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:54:56 -0600 Subject: [PATCH 09/19] save progress --- .../datafusion/expressions/correlation.rs | 2 +- .../datafusion/expressions/covariance.rs | 2 +- .../datafusion/expressions/stddev.rs | 2 +- .../datafusion/expressions/sum_decimal.rs | 4 +- .../core/src/execution/datafusion/planner.rs | 168 ++++++++++++------ native/core/src/execution/jni_api.rs | 2 +- 6 files changed, 124 insertions(+), 56 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/correlation.rs b/native/core/src/execution/datafusion/expressions/correlation.rs index e95e664a6..1b7eadbf4 100644 --- a/native/core/src/execution/datafusion/expressions/correlation.rs +++ b/native/core/src/execution/datafusion/expressions/correlation.rs @@ -82,7 +82,7 @@ impl AggregateUDFImpl for Correlation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } fn default_value(&self, _data_type: &DataType) -> Result { diff --git a/native/core/src/execution/datafusion/expressions/covariance.rs b/native/core/src/execution/datafusion/expressions/covariance.rs index 0b9fbe7b1..7f13e357a 100644 --- a/native/core/src/execution/datafusion/expressions/covariance.rs +++ b/native/core/src/execution/datafusion/expressions/covariance.rs @@ -86,7 +86,7 @@ impl AggregateUDFImpl for Covariance { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } fn default_value(&self, _data_type: &DataType) -> Result { diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index d8a11f78d..34b86a527 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -27,7 +27,7 @@ use arrow::{ use datafusion::logical_expr::Accumulator; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature, Volatility}; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 40547a446..e957bd25e 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -82,7 +82,7 @@ impl AggregateUDFImpl for SumDecimal { ))) } - fn state_fields(&self, args: StateFieldsArgs) -> DFResult> { + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { let fields = vec![ Field::new(&self.name, self.result_type.clone(), self.nullable), Field::new("is_empty", DataType::Boolean, false), @@ -98,7 +98,7 @@ impl AggregateUDFImpl for SumDecimal { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> DFResult { + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { Ok(self.result_type.clone()) } diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b9b882824..d60c8cc44 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -55,7 +55,6 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; -use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; use datafusion::{ @@ -70,7 +69,7 @@ use datafusion::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal as DataFusionLiteral, NotExpr, }, - AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, + PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, physical_optimizer::join_selection::swap_hash_join, physical_plan::{ @@ -83,6 +82,8 @@ use datafusion::{ }, prelude::SessionContext, }; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; + use datafusion_comet_proto::{ spark_expression::{ self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, @@ -116,7 +117,7 @@ use std::{collections::HashMap, sync::Arc}; // For clippy error on type_complexity. type ExecResult = Result; -type PhyAggResult = Result>, ExecutionError>; +type PhyAggResult = Result>, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; @@ -1275,7 +1276,7 @@ impl PhysicalPlanner { &self, spark_expr: &AggExpr, schema: SchemaRef, - ) -> Result, ExecutionError> { + ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { assert!(!expr.children.is_empty()); @@ -1448,22 +1449,42 @@ impl PhysicalPlanner { self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { - 0 => Ok(Arc::new(Covariance::new( - child1, - child2, - "covariance", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))), - 1 => Ok(Arc::new(Covariance::new( - child1, - child2, - "covariance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))), + 0 => { + let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( + child1, + child2, + "covariance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child1, child2]) + .schema(schema) + .alias("covariance") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } + 1 => { + let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( + child1, + child2, + "covariance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child1, child2]) + .schema(schema) + .alias("covariance_pop") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", stats_type @@ -1474,20 +1495,40 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { - 0 => Ok(Arc::new(Variance::new( - child, - "variance", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))), - 1 => Ok(Arc::new(Variance::new( - child, - "variance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))), + 0 => { + let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( + child, + "variance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("variance") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } + 1 => { + let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( + child, + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("variance_pop") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", stats_type @@ -1498,20 +1539,40 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { - 0 => Ok(Arc::new(Stddev::new( - child, - "stddev", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))), - 1 => Ok(Arc::new(Stddev::new( - child, - "stddev_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))), + 0 => { + let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( + child, + "stddev", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("stddev") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } + 1 => { + let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( + child, + "stddev_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("stddev_pop") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for stddev", stats_type @@ -1524,13 +1585,20 @@ impl PhysicalPlanner { let child2 = self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(Correlation::new( + let func = datafusion_expr::AggregateUDF::new_from_impl(Correlation::new( child1, child2, "correlation", datatype, expr.null_on_divide_by_zero, - ))) + )); + AggregateExprBuilder::new(Arc::new(func), vec![child1, child2]) + .schema(schema) + .alias("correlation") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) } } } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1cb790efc..2d99a854d 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -240,7 +240,7 @@ fn prepare_datafusion_session_context( // this is the threshold of number of groups / number of rows and the // maximum value is 1.0, so we set the threshold a little higher just // to be safe - ScalarValue::Float64(Some(1.1)), + &ScalarValue::Float64(Some(1.1)), ); for (key, value) in conf.iter().filter(|(k, _)| k.starts_with("datafusion.")) { From f2ae56d32d900513ef974ced93270d97afc8292c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 08:59:58 -0600 Subject: [PATCH 10/19] code compiles --- .../core/src/execution/datafusion/planner.rs | 71 +++++++++++++------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index d60c8cc44..9515f781a 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1348,12 +1348,23 @@ impl PhysicalPlanner { match datatype { DataType::Decimal128(_, _) => { - Ok(Arc::new(SumDecimal::new("sum", child, datatype))) + let func = datafusion_expr::AggregateUDF::new_from_impl(SumDecimal::new( + "sum", + child.clone(), + datatype, + )); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side - let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); + let child = Arc::new(CastExpr::new(child.clone(), datatype.clone(), None)); AggregateExprBuilder::new(sum_udaf(), vec![child]) .schema(schema) @@ -1370,18 +1381,38 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); match datatype { - DataType::Decimal128(_, _) => Ok(Arc::new(AvgDecimal::new( - child, - "avg", - datatype, - input_datatype, - ))), + DataType::Decimal128(_, _) => { + let func = datafusion_expr::AggregateUDF::new_from_impl(AvgDecimal::new( + child.clone(), + "avg", + datatype, + input_datatype, + )); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("avg") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } _ => { // cast to the result data type of AVG if the result data type is different // from the input type, e.g. AVG(Int32). We should not expect a cast // failure since it should have already been checked at Spark side. - let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); - Ok(Arc::new(Avg::new(child, "avg", datatype))) + let child = Arc::new(CastExpr::new(child.clone(), datatype.clone(), None)); + let func = datafusion_expr::AggregateUDF::new_from_impl(Avg::new( + child.clone(), + "avg", + datatype, + )); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("avg") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } } } @@ -1451,8 +1482,8 @@ impl PhysicalPlanner { match expr.stats_type { 0 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( - child1, - child2, + child1.clone(), + child2.clone(), "covariance", datatype, StatsType::Sample, @@ -1469,8 +1500,8 @@ impl PhysicalPlanner { } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( - child1, - child2, + child1.clone(), + child2.clone(), "covariance_pop", datatype, StatsType::Population, @@ -1497,7 +1528,7 @@ impl PhysicalPlanner { match expr.stats_type { 0 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( - child, + child.clone(), "variance", datatype, StatsType::Sample, @@ -1514,7 +1545,7 @@ impl PhysicalPlanner { } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( - child, + child.clone(), "variance_pop", datatype, StatsType::Population, @@ -1541,7 +1572,7 @@ impl PhysicalPlanner { match expr.stats_type { 0 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( - child, + child.clone(), "stddev", datatype, StatsType::Sample, @@ -1558,7 +1589,7 @@ impl PhysicalPlanner { } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( - child, + child.clone(), "stddev_pop", datatype, StatsType::Population, @@ -1586,8 +1617,8 @@ impl PhysicalPlanner { self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let func = datafusion_expr::AggregateUDF::new_from_impl(Correlation::new( - child1, - child2, + child1.clone(), + child2.clone(), "correlation", datatype, expr.null_on_divide_by_zero, From 942930b40f5f98b2fc912fa79918da721c59c8a3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 09:14:17 -0600 Subject: [PATCH 11/19] clippy --- .../core/src/execution/datafusion/planner.rs | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 9515f781a..f446698f8 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1350,7 +1350,7 @@ impl PhysicalPlanner { DataType::Decimal128(_, _) => { let func = datafusion_expr::AggregateUDF::new_from_impl(SumDecimal::new( "sum", - child.clone(), + Arc::clone(&child), datatype, )); AggregateExprBuilder::new(Arc::new(func), vec![child]) @@ -1364,7 +1364,8 @@ impl PhysicalPlanner { _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side - let child = Arc::new(CastExpr::new(child.clone(), datatype.clone(), None)); + let child = + Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); AggregateExprBuilder::new(sum_udaf(), vec![child]) .schema(schema) @@ -1383,7 +1384,7 @@ impl PhysicalPlanner { match datatype { DataType::Decimal128(_, _) => { let func = datafusion_expr::AggregateUDF::new_from_impl(AvgDecimal::new( - child.clone(), + Arc::clone(&child), "avg", datatype, input_datatype, @@ -1400,9 +1401,10 @@ impl PhysicalPlanner { // cast to the result data type of AVG if the result data type is different // from the input type, e.g. AVG(Int32). We should not expect a cast // failure since it should have already been checked at Spark side. - let child = Arc::new(CastExpr::new(child.clone(), datatype.clone(), None)); + let child: Arc = + Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); let func = datafusion_expr::AggregateUDF::new_from_impl(Avg::new( - child.clone(), + Arc::clone(&child), "avg", datatype, )); @@ -1482,8 +1484,8 @@ impl PhysicalPlanner { match expr.stats_type { 0 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( - child1.clone(), - child2.clone(), + Arc::clone(&child1), + Arc::clone(&child2), "covariance", datatype, StatsType::Sample, @@ -1500,8 +1502,8 @@ impl PhysicalPlanner { } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( - child1.clone(), - child2.clone(), + Arc::clone(&child1), + Arc::clone(&child2), "covariance_pop", datatype, StatsType::Population, @@ -1528,7 +1530,7 @@ impl PhysicalPlanner { match expr.stats_type { 0 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( - child.clone(), + Arc::clone(&child), "variance", datatype, StatsType::Sample, @@ -1545,7 +1547,7 @@ impl PhysicalPlanner { } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( - child.clone(), + Arc::clone(&child), "variance_pop", datatype, StatsType::Population, @@ -1572,7 +1574,7 @@ impl PhysicalPlanner { match expr.stats_type { 0 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( - child.clone(), + Arc::clone(&child), "stddev", datatype, StatsType::Sample, @@ -1589,7 +1591,7 @@ impl PhysicalPlanner { } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( - child.clone(), + Arc::clone(&child), "stddev_pop", datatype, StatsType::Population, @@ -1617,8 +1619,8 @@ impl PhysicalPlanner { self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let func = datafusion_expr::AggregateUDF::new_from_impl(Correlation::new( - child1.clone(), - child2.clone(), + Arc::clone(&child1), + Arc::clone(&child2), "correlation", datatype, expr.null_on_divide_by_zero, From 2ace729c0490fa054d29ec871bac2cffd999e719 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 09:29:07 -0600 Subject: [PATCH 12/19] remove duplicate of down_cast_any_ref function --- .../execution/datafusion/expressions/stats.rs | 1 + .../execution/datafusion/expressions/utils.rs | 2 +- native/spark-expr/src/cast.rs | 3 ++- native/spark-expr/src/if_expr.rs | 3 +-- native/spark-expr/src/regexp.rs | 2 +- native/spark-expr/src/structs.rs | 3 +-- native/spark-expr/src/temporal.rs | 3 ++- native/spark-expr/src/utils.rs | 18 ------------------ 8 files changed, 9 insertions(+), 26 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/stats.rs b/native/core/src/execution/datafusion/expressions/stats.rs index 1f4e64d0b..4f405c373 100644 --- a/native/core/src/execution/datafusion/expressions/stats.rs +++ b/native/core/src/execution/datafusion/expressions/stats.rs @@ -17,6 +17,7 @@ * under the License. */ +// TODO remove this copy once https://github.com/apache/datafusion/pull/12327 is resolved /// Enum used for differentiating population and sample for statistical functions #[derive(PartialEq, Eq, Debug, Clone, Copy)] pub enum StatsType { diff --git a/native/core/src/execution/datafusion/expressions/utils.rs b/native/core/src/execution/datafusion/expressions/utils.rs index 540fca86b..6670e332c 100644 --- a/native/core/src/execution/datafusion/expressions/utils.rs +++ b/native/core/src/execution/datafusion/expressions/utils.rs @@ -16,4 +16,4 @@ // under the License. // re-export for legacy reasons -pub use datafusion_comet_spark_expr::utils::down_cast_any_ref; +pub use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 12dbfdcdc..6a3974fe1 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -50,6 +50,7 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use num::{ cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive, @@ -57,7 +58,7 @@ use num::{ use regex::Regex; use crate::timezone; -use crate::utils::{array_with_timezone, down_cast_any_ref}; +use crate::utils::array_with_timezone; use crate::{EvalMode, SparkError, SparkResult}; diff --git a/native/spark-expr/src/if_expr.rs b/native/spark-expr/src/if_expr.rs index 9a90b727c..193a90fb5 100644 --- a/native/spark-expr/src/if_expr.rs +++ b/native/spark-expr/src/if_expr.rs @@ -26,11 +26,10 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::Result; use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr}; -use crate::utils::down_cast_any_ref; - /// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to /// `CASE WHEN a THEN b ELSE c END`. #[derive(Debug, Hash)] diff --git a/native/spark-expr/src/regexp.rs b/native/spark-expr/src/regexp.rs index 221fd1f04..c7626285a 100644 --- a/native/spark-expr/src/regexp.rs +++ b/native/spark-expr/src/regexp.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::down_cast_any_ref; use crate::SparkError; use arrow::compute::take; use arrow_array::builder::BooleanBuilder; use arrow_array::types::Int32Type; use arrow_array::{Array, BooleanArray, DictionaryArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; diff --git a/native/spark-expr/src/structs.rs b/native/spark-expr/src/structs.rs index 49017b671..cda8246d9 100644 --- a/native/spark-expr/src/structs.rs +++ b/native/spark-expr/src/structs.rs @@ -19,6 +19,7 @@ use arrow::record_batch::RecordBatch; use arrow_array::{Array, StructArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use std::{ @@ -28,8 +29,6 @@ use std::{ sync::Arc, }; -use crate::utils::down_cast_any_ref; - #[derive(Debug, Hash)] pub struct CreateNamedStruct { values: Vec>, diff --git a/native/spark-expr/src/temporal.rs b/native/spark-expr/src/temporal.rs index 415db6070..91953dd60 100644 --- a/native/spark-expr/src/temporal.rs +++ b/native/spark-expr/src/temporal.rs @@ -28,10 +28,11 @@ use arrow::{ }; use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue::Utf8}; use datafusion_physical_expr::PhysicalExpr; -use crate::utils::{array_with_timezone, down_cast_any_ref}; +use crate::utils::array_with_timezone; use crate::kernels::temporal::{ date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, diff --git a/native/spark-expr/src/utils.rs b/native/spark-expr/src/utils.rs index 6945e82b3..db4ad1956 100644 --- a/native/spark-expr/src/utils.rs +++ b/native/spark-expr/src/utils.rs @@ -20,7 +20,6 @@ use arrow_array::{ types::{Int32Type, TimestampMicrosecondType}, }; use arrow_schema::{ArrowError, DataType}; -use std::any::Any; use std::sync::Arc; use crate::timezone::Tz; @@ -30,23 +29,6 @@ use arrow::{ }; use chrono::{DateTime, Offset, TimeZone}; -use datafusion_physical_plan::PhysicalExpr; - -/// A utility function from DataFusion. It is not exposed by DataFusion. -pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else { - any - } -} - /// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or /// to apply timezone offset. // From 81ddd567f141337ef57fefd569591c464428ea79 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 09:34:20 -0600 Subject: [PATCH 13/19] remove duplicate of down_cast_any_ref function --- .../execution/datafusion/expressions/avg.rs | 2 +- .../datafusion/expressions/avg_decimal.rs | 2 +- .../datafusion/expressions/bitwise_not.rs | 3 +-- .../datafusion/expressions/checkoverflow.rs | 3 +-- .../datafusion/expressions/correlation.rs | 2 +- .../execution/datafusion/expressions/mod.rs | 1 - .../datafusion/expressions/normalize_nan.rs | 3 +-- .../datafusion/expressions/stddev.rs | 5 ++--- .../datafusion/expressions/strings.rs | 6 ++---- .../datafusion/expressions/subquery.rs | 10 +++++----- .../datafusion/expressions/unbound.rs | 2 +- .../execution/datafusion/expressions/utils.rs | 19 ------------------- .../datafusion/expressions/variance.rs | 3 ++- 13 files changed, 18 insertions(+), 43 deletions(-) delete mode 100644 native/core/src/execution/datafusion/expressions/utils.rs diff --git a/native/core/src/execution/datafusion/expressions/avg.rs b/native/core/src/execution/datafusion/expressions/avg.rs index 96332077f..7820497d4 100644 --- a/native/core/src/execution/datafusion/expressions/avg.rs +++ b/native/core/src/execution/datafusion/expressions/avg.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use arrow::compute::sum; use arrow_array::{ builder::PrimitiveBuilder, @@ -32,6 +31,7 @@ use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index 40439c26a..0462f2d3d 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum}; use arrow_array::{ builder::PrimitiveBuilder, @@ -34,6 +33,7 @@ use arrow_data::decimal::{ validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use datafusion::logical_expr::Volatility::Immutable; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::avg_return_type; use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; diff --git a/native/core/src/execution/datafusion/expressions/bitwise_not.rs b/native/core/src/execution/datafusion/expressions/bitwise_not.rs index c7b2bc067..a2b9ebe5b 100644 --- a/native/core/src/execution/datafusion/expressions/bitwise_not.rs +++ b/native/core/src/execution/datafusion/expressions/bitwise_not.rs @@ -26,12 +26,11 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; use datafusion_common::{Result, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; - macro_rules! compute_op { ($OPERAND:expr, $DT:ident) => {{ let operand = $OPERAND diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index ff2cffd42..e922171bd 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -29,11 +29,10 @@ use arrow::{ }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; - /// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals /// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds /// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the diff --git a/native/core/src/execution/datafusion/expressions/correlation.rs b/native/core/src/execution/datafusion/expressions/correlation.rs index 1b7eadbf4..a8a25ddb3 100644 --- a/native/core/src/execution/datafusion/expressions/correlation.rs +++ b/native/core/src/execution/datafusion/expressions/correlation.rs @@ -21,13 +21,13 @@ use std::{any::Any, sync::Arc}; use crate::execution::datafusion::expressions::{ covariance::CovarianceAccumulator, stats::StatsType, stddev::StddevAccumulator, - utils::down_cast_any_ref, }; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index 4435f6b69..2848b7a37 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -36,7 +36,6 @@ pub mod strings; pub mod subquery; pub mod sum_decimal; pub mod unbound; -mod utils; pub mod variance; pub use datafusion_comet_spark_expr::{EvalMode, SparkError}; diff --git a/native/core/src/execution/datafusion/expressions/normalize_nan.rs b/native/core/src/execution/datafusion/expressions/normalize_nan.rs index d2192feef..c5331ad7b 100644 --- a/native/core/src/execution/datafusion/expressions/normalize_nan.rs +++ b/native/core/src/execution/datafusion/expressions/normalize_nan.rs @@ -29,10 +29,9 @@ use arrow::{ }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; - #[derive(Debug, Hash)] pub struct NormalizeNaNAndZero { pub data_type: DataType, diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index 34b86a527..f417c95d9 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -17,14 +17,13 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::{ - stats::StatsType, utils::down_cast_any_ref, variance::VarianceAccumulator, -}; +use crate::execution::datafusion::expressions::{stats::StatsType, variance::VarianceAccumulator}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; diff --git a/native/core/src/execution/datafusion/expressions/strings.rs b/native/core/src/execution/datafusion/expressions/strings.rs index 96e45eae2..200b4ec5a 100644 --- a/native/core/src/execution/datafusion/expressions/strings.rs +++ b/native/core/src/execution/datafusion/expressions/strings.rs @@ -17,10 +17,7 @@ #![allow(deprecated)] -use crate::execution::{ - datafusion::expressions::utils::down_cast_any_ref, - kernels::strings::{string_space, substring}, -}; +use crate::execution::kernels::strings::{string_space, substring}; use arrow::{ compute::{ contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn, @@ -30,6 +27,7 @@ use arrow::{ }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue::Utf8}; use datafusion_physical_expr::PhysicalExpr; use std::{ diff --git a/native/core/src/execution/datafusion/expressions/subquery.rs b/native/core/src/execution/datafusion/expressions/subquery.rs index cf6f8d846..3eeb29c16 100644 --- a/native/core/src/execution/datafusion/expressions/subquery.rs +++ b/native/core/src/execution/datafusion/expressions/subquery.rs @@ -15,9 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::{ + execution::utils::bytes_to_i128, + jvm_bridge::{jni_static_call, BinaryWrapper, JVMClasses, StringWrapper}, +}; use arrow_array::RecordBatch; use arrow_schema::{DataType, Schema, TimeUnit}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use jni::{ @@ -31,11 +36,6 @@ use std::{ sync::Arc, }; -use crate::{ - execution::{datafusion::expressions::utils::down_cast_any_ref, utils::bytes_to_i128}, - jvm_bridge::{jni_static_call, BinaryWrapper, JVMClasses, StringWrapper}, -}; - #[derive(Debug, Hash)] pub struct Subquery { /// The ID of the execution context that owns this subquery. We use this ID to retrieve the diff --git a/native/core/src/execution/datafusion/expressions/unbound.rs b/native/core/src/execution/datafusion/expressions/unbound.rs index 95f9912c9..a6babd0f7 100644 --- a/native/core/src/execution/datafusion/expressions/unbound.rs +++ b/native/core/src/execution/datafusion/expressions/unbound.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use arrow_array::RecordBatch; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::physical_plan::ColumnarValue; use datafusion_common::{internal_err, Result}; use datafusion_physical_expr::PhysicalExpr; diff --git a/native/core/src/execution/datafusion/expressions/utils.rs b/native/core/src/execution/datafusion/expressions/utils.rs deleted file mode 100644 index 6670e332c..000000000 --- a/native/core/src/execution/datafusion/expressions/utils.rs +++ /dev/null @@ -1,19 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// re-export for legacy reasons -pub use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; diff --git a/native/core/src/execution/datafusion/expressions/variance.rs b/native/core/src/execution/datafusion/expressions/variance.rs index 692c5c795..468234493 100644 --- a/native/core/src/execution/datafusion/expressions/variance.rs +++ b/native/core/src/execution/datafusion/expressions/variance.rs @@ -17,12 +17,13 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::{stats::StatsType, utils::down_cast_any_ref}; +use crate::execution::datafusion::expressions::stats::StatsType; use arrow::{ array::{ArrayRef, Float64Array}, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; From 7ff01bf99bcc8d31a4cb8a6006869c364ff07970 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 10:18:46 -0600 Subject: [PATCH 14/19] machete --- native/Cargo.lock | 1 - native/spark-expr/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 44400e24b..a53b7c2c9 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -944,7 +944,6 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "datafusion-physical-plan", "num", "rand", "regex", diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 0a371a6e6..a5d156912 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -35,7 +35,6 @@ datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } -datafusion-physical-plan = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } From f0eacda580dffd893c15155c89548e98ee036864 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 11:15:59 -0600 Subject: [PATCH 15/19] bump DF version again and use StatsType from DataFusion --- native/Cargo.lock | 36 +++++++++---------- native/Cargo.toml | 16 ++++----- .../datafusion/expressions/correlation.rs | 3 +- .../datafusion/expressions/covariance.rs | 2 +- .../execution/datafusion/expressions/mod.rs | 1 - .../execution/datafusion/expressions/stats.rs | 28 --------------- .../datafusion/expressions/stddev.rs | 3 +- .../datafusion/expressions/variance.rs | 2 +- .../core/src/execution/datafusion/planner.rs | 3 +- 9 files changed, 33 insertions(+), 61 deletions(-) delete mode 100644 native/core/src/execution/datafusion/expressions/stats.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index a53b7c2c9..3692f0488 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -811,7 +811,7 @@ dependencies = [ [[package]] name = "datafusion" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -859,7 +859,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow-schema", "async-trait", @@ -954,7 +954,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -976,7 +976,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "log", "tokio", @@ -985,7 +985,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "chrono", @@ -1005,7 +1005,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1026,7 +1026,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "datafusion-common", @@ -1036,7 +1036,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "arrow-buffer", @@ -1062,7 +1062,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1082,7 +1082,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1095,7 +1095,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "arrow-array", @@ -1117,7 +1117,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "datafusion-common", "datafusion-expr", @@ -1128,7 +1128,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "async-trait", @@ -1147,7 +1147,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1178,7 +1178,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1191,7 +1191,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "datafusion-common", "datafusion-execution", @@ -1203,7 +1203,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1237,7 +1237,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=e5a6cd5#e5a6cd5f6784f1d1b3d559d5356a6154a73e077c" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "arrow-array", diff --git a/native/Cargo.toml b/native/Cargo.toml index 92707b141..54e568943 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -39,14 +39,14 @@ arrow-buffer = { version = "52.2.0" } arrow-data = { version = "52.2.0" } arrow-schema = { version = "52.2.0" } parquet = { version = "52.2.0", default-features = false, features = ["experimental"] } -datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5" } -datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", features = ["unicode_expressions", "crypto_expressions"] } -datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", features = ["crypto_expressions"] } -datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } -datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } -datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } -datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } -datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "e5a6cd5", default-features = false } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b" } +datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", features = ["unicode_expressions", "crypto_expressions"] } +datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", features = ["crypto_expressions"] } +datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } datafusion-comet-spark-expr = { path = "spark-expr", version = "0.3.0" } datafusion-comet-proto = { path = "proto", version = "0.3.0" } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/native/core/src/execution/datafusion/expressions/correlation.rs b/native/core/src/execution/datafusion/expressions/correlation.rs index a8a25ddb3..6bf35e711 100644 --- a/native/core/src/execution/datafusion/expressions/correlation.rs +++ b/native/core/src/execution/datafusion/expressions/correlation.rs @@ -20,7 +20,7 @@ use arrow::compute::{and, filter, is_not_null}; use std::{any::Any, sync::Arc}; use crate::execution::datafusion::expressions::{ - covariance::CovarianceAccumulator, stats::StatsType, stddev::StddevAccumulator, + covariance::CovarianceAccumulator, stddev::StddevAccumulator, }; use arrow::{ array::ArrayRef, @@ -32,6 +32,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// CORR aggregate expression diff --git a/native/core/src/execution/datafusion/expressions/covariance.rs b/native/core/src/execution/datafusion/expressions/covariance.rs index 7f13e357a..9166e3976 100644 --- a/native/core/src/execution/datafusion/expressions/covariance.rs +++ b/native/core/src/execution/datafusion/expressions/covariance.rs @@ -19,7 +19,6 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::stats::StatsType; use arrow::{ array::{ArrayRef, Float64Array}, compute::cast, @@ -33,6 +32,7 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// COVAR_SAMP and COVAR_POP aggregate expression diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index 2848b7a37..10c9d3092 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -30,7 +30,6 @@ pub mod comet_scalar_funcs; pub mod correlation; pub mod covariance; pub mod negative; -pub mod stats; pub mod stddev; pub mod strings; pub mod subquery; diff --git a/native/core/src/execution/datafusion/expressions/stats.rs b/native/core/src/execution/datafusion/expressions/stats.rs deleted file mode 100644 index 4f405c373..000000000 --- a/native/core/src/execution/datafusion/expressions/stats.rs +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// TODO remove this copy once https://github.com/apache/datafusion/pull/12327 is resolved -/// Enum used for differentiating population and sample for statistical functions -#[derive(PartialEq, Eq, Debug, Clone, Copy)] -pub enum StatsType { - /// Population - Population, - /// Sample - Sample, -} diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index f417c95d9..0cb668f3a 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -17,7 +17,7 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::{stats::StatsType, variance::VarianceAccumulator}; +use crate::execution::datafusion::expressions::variance::VarianceAccumulator; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, @@ -27,6 +27,7 @@ use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression diff --git a/native/core/src/execution/datafusion/expressions/variance.rs b/native/core/src/execution/datafusion/expressions/variance.rs index 468234493..3ad776692 100644 --- a/native/core/src/execution/datafusion/expressions/variance.rs +++ b/native/core/src/execution/datafusion/expressions/variance.rs @@ -17,7 +17,6 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::stats::StatsType; use arrow::{ array::{ArrayRef, Float64Array}, datatypes::{DataType, Field}, @@ -28,6 +27,7 @@ use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{AggregateUDFImpl, Signature}; +use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// VAR_SAMP and VAR_POP aggregate expression diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index f446698f8..d8d4bd349 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -33,7 +33,6 @@ use crate::{ correlation::Correlation, covariance::Covariance, negative, - stats::StatsType, stddev::Stddev, strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExpr, SubstringExpr}, subquery::Subquery, @@ -107,7 +106,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::find_df_window_func; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use itertools::Itertools; use jni::objects::GlobalRef; From b1ab6db9722239de74b13cba1fd196ba75997fff Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 11:19:06 -0600 Subject: [PATCH 16/19] implement groups accumulator for stddev and variance --- native/core/src/execution/datafusion/expressions/stddev.rs | 7 ++++--- .../core/src/execution/datafusion/expressions/variance.rs | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index 0cb668f3a..6cfd509de 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -22,11 +22,12 @@ use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use datafusion::functions_aggregate::stddev::StddevGroupsAccumulator; use datafusion::logical_expr::Accumulator; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature, Volatility}; use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; @@ -90,7 +91,6 @@ impl AggregateUDFImpl for Stddev { )?)) } - /* fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { !acc_args.is_distinct } @@ -99,9 +99,10 @@ impl AggregateUDFImpl for Stddev { &self, _args: AccumulatorArgs, ) -> Result> { + // TODO is it safe to use DataFusion's version of StddevGroupsAccumulator + // which uses u64 for counts or do we need to fork this and use f64? Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample))) } - */ fn create_sliding_accumulator( &self, diff --git a/native/core/src/execution/datafusion/expressions/variance.rs b/native/core/src/execution/datafusion/expressions/variance.rs index 3ad776692..523fae21e 100644 --- a/native/core/src/execution/datafusion/expressions/variance.rs +++ b/native/core/src/execution/datafusion/expressions/variance.rs @@ -21,12 +21,13 @@ use arrow::{ array::{ArrayRef, Float64Array}, datatypes::{DataType, Field}, }; +use datafusion::functions_aggregate::variance::VarianceGroupsAccumulator; use datafusion::logical_expr::Accumulator; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{AggregateUDFImpl, Signature}; +use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature}; use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; @@ -90,7 +91,6 @@ impl AggregateUDFImpl for Variance { )?)) } - /* fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { !acc_args.is_distinct } @@ -99,11 +99,12 @@ impl AggregateUDFImpl for Variance { &self, _args: AccumulatorArgs, ) -> Result> { + // it is safe to use DataFusion's implementation of VarianceGroupsAccumulator + // because it already uses f64 for count Ok(Box::new(VarianceGroupsAccumulator::new( StatsType::Population, ))) } - */ fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result> { Ok(Box::new(VarianceAccumulator::try_new( From 0625ad5601193b8745fd3574280ea3c418c118d6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 11:24:54 -0600 Subject: [PATCH 17/19] refactor --- .../core/src/execution/datafusion/planner.rs | 85 ++++++++----------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index d8d4bd349..f336b7961 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -105,7 +105,9 @@ use datafusion_common::{ JoinType as DFJoinType, ScalarValue, }; use datafusion_expr::expr::find_df_window_func; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; +use datafusion_expr::{ + AggregateUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, +}; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use itertools::Itertools; @@ -1491,13 +1493,12 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child1, child2]) - .schema(schema) - .alias("covariance") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr( + "covariance", + schema, + vec![child1, child2], + func, + ) } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( @@ -1509,13 +1510,12 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child1, child2]) - .schema(schema) - .alias("covariance_pop") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr( + "covariance_pop", + schema, + vec![child1, child2], + func, + ) } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", @@ -1536,13 +1536,7 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("variance") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr("variance", schema, vec![child], func) } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( @@ -1553,13 +1547,7 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("variance_pop") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr("variance_pop", schema, vec![child], func) } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", @@ -1580,13 +1568,7 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("stddev") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr("stddev", schema, vec![child], func) } 1 => { let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( @@ -1597,13 +1579,7 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("stddev_pop") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr("stddev_pop", schema, vec![child], func) } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for stddev", @@ -1624,13 +1600,7 @@ impl PhysicalPlanner { datatype, expr.null_on_divide_by_zero, )); - AggregateExprBuilder::new(Arc::new(func), vec![child1, child2]) - .schema(schema) - .alias("correlation") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| e.into()) + Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func) } } } @@ -1891,6 +1861,21 @@ impl PhysicalPlanner { Ok(scalar_expr) } + + fn create_aggr_func_expr( + name: &str, + schema: SchemaRef, + children: Vec>, + func: AggregateUDF, + ) -> Result, ExecutionError> { + AggregateExprBuilder::new(Arc::new(func), children) + .schema(schema) + .alias(name) + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } } impl From for ExecutionError { From 23fc1c31642809477b6621048e4a98d52179fca5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 11:25:36 -0600 Subject: [PATCH 18/19] fmt --- .../core/src/execution/datafusion/planner.rs | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index f336b7961..be27495ab 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1349,7 +1349,7 @@ impl PhysicalPlanner { match datatype { DataType::Decimal128(_, _) => { - let func = datafusion_expr::AggregateUDF::new_from_impl(SumDecimal::new( + let func = AggregateUDF::new_from_impl(SumDecimal::new( "sum", Arc::clone(&child), datatype, @@ -1384,7 +1384,7 @@ impl PhysicalPlanner { let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); match datatype { DataType::Decimal128(_, _) => { - let func = datafusion_expr::AggregateUDF::new_from_impl(AvgDecimal::new( + let func = AggregateUDF::new_from_impl(AvgDecimal::new( Arc::clone(&child), "avg", datatype, @@ -1404,7 +1404,7 @@ impl PhysicalPlanner { // failure since it should have already been checked at Spark side. let child: Arc = Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); - let func = datafusion_expr::AggregateUDF::new_from_impl(Avg::new( + let func = AggregateUDF::new_from_impl(Avg::new( Arc::clone(&child), "avg", datatype, @@ -1421,7 +1421,7 @@ impl PhysicalPlanner { } AggExprStruct::First(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); + let func = AggregateUDF::new_from_impl(FirstValue::new()); AggregateExprBuilder::new(Arc::new(func), vec![child]) .schema(schema) @@ -1433,7 +1433,7 @@ impl PhysicalPlanner { } AggExprStruct::Last(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); + let func = AggregateUDF::new_from_impl(LastValue::new()); AggregateExprBuilder::new(Arc::new(func), vec![child]) .schema(schema) @@ -1484,7 +1484,7 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { 0 => { - let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( + let func = AggregateUDF::new_from_impl(Covariance::new( Arc::clone(&child1), Arc::clone(&child2), "covariance", @@ -1501,7 +1501,7 @@ impl PhysicalPlanner { ) } 1 => { - let func = datafusion_expr::AggregateUDF::new_from_impl(Covariance::new( + let func = AggregateUDF::new_from_impl(Covariance::new( Arc::clone(&child1), Arc::clone(&child2), "covariance_pop", @@ -1528,7 +1528,7 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { 0 => { - let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( + let func = AggregateUDF::new_from_impl(Variance::new( Arc::clone(&child), "variance", datatype, @@ -1539,7 +1539,7 @@ impl PhysicalPlanner { Self::create_aggr_func_expr("variance", schema, vec![child], func) } 1 => { - let func = datafusion_expr::AggregateUDF::new_from_impl(Variance::new( + let func = AggregateUDF::new_from_impl(Variance::new( Arc::clone(&child), "variance_pop", datatype, @@ -1560,7 +1560,7 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { 0 => { - let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( + let func = AggregateUDF::new_from_impl(Stddev::new( Arc::clone(&child), "stddev", datatype, @@ -1571,7 +1571,7 @@ impl PhysicalPlanner { Self::create_aggr_func_expr("stddev", schema, vec![child], func) } 1 => { - let func = datafusion_expr::AggregateUDF::new_from_impl(Stddev::new( + let func = AggregateUDF::new_from_impl(Stddev::new( Arc::clone(&child), "stddev_pop", datatype, @@ -1593,7 +1593,7 @@ impl PhysicalPlanner { let child2 = self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let func = datafusion_expr::AggregateUDF::new_from_impl(Correlation::new( + let func = AggregateUDF::new_from_impl(Correlation::new( Arc::clone(&child1), Arc::clone(&child2), "correlation", From 11e0938543c51d96fb6b445b83685d8ca11226ca Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Sep 2024 12:55:15 -0600 Subject: [PATCH 19/19] revert group accumulator --- .../execution/datafusion/expressions/stddev.rs | 16 +--------------- .../datafusion/expressions/variance.rs | 18 +----------------- 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index 6cfd509de..1ba495e21 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -22,12 +22,11 @@ use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; -use datafusion::functions_aggregate::stddev::StddevGroupsAccumulator; use datafusion::logical_expr::Accumulator; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature, Volatility}; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; @@ -91,19 +90,6 @@ impl AggregateUDFImpl for Stddev { )?)) } - fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { - !acc_args.is_distinct - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> Result> { - // TODO is it safe to use DataFusion's version of StddevGroupsAccumulator - // which uses u64 for counts or do we need to fork this and use f64? - Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample))) - } - fn create_sliding_accumulator( &self, _acc_args: AccumulatorArgs, diff --git a/native/core/src/execution/datafusion/expressions/variance.rs b/native/core/src/execution/datafusion/expressions/variance.rs index 523fae21e..2f4d8091c 100644 --- a/native/core/src/execution/datafusion/expressions/variance.rs +++ b/native/core/src/execution/datafusion/expressions/variance.rs @@ -21,13 +21,12 @@ use arrow::{ array::{ArrayRef, Float64Array}, datatypes::{DataType, Field}, }; -use datafusion::functions_aggregate::variance::VarianceGroupsAccumulator; use datafusion::logical_expr::Accumulator; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, Signature}; +use datafusion_expr::{AggregateUDFImpl, Signature}; use datafusion_physical_expr::expressions::StatsType; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; @@ -91,21 +90,6 @@ impl AggregateUDFImpl for Variance { )?)) } - fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { - !acc_args.is_distinct - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> Result> { - // it is safe to use DataFusion's implementation of VarianceGroupsAccumulator - // because it already uses f64 for count - Ok(Box::new(VarianceGroupsAccumulator::new( - StatsType::Population, - ))) - } - fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result> { Ok(Box::new(VarianceAccumulator::try_new( self.stats_type,