diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c023faa9b168..a743d0e8fd07 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -41,7 +41,7 @@ on: jobs: # Check license header license-header-check: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest name: Check License Header steps: - uses: actions/checkout@v4 diff --git a/Cargo.lock b/Cargo.lock index 0a7407b50398..3d95cb688382 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1733,9 +1733,11 @@ dependencies = [ "datafusion-catalog-listing", "datafusion-common", "datafusion-common-runtime", + "datafusion-datasource", "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-nested", @@ -1823,32 +1825,20 @@ name = "datafusion-catalog-listing" version = "45.0.0" dependencies = [ "arrow", - "async-compression", "async-trait", - "bytes", - "bzip2 0.5.1", - "chrono", "datafusion-catalog", "datafusion-common", - "datafusion-common-runtime", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", - "flate2", "futures", - "glob", - "itertools 0.14.0", "log", "object_store", - "rand 0.8.5", "tempfile", "tokio", - "tokio-util", - "url", - "xz2", - "zstd", ] [[package]] @@ -1912,6 +1902,37 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-datasource" +version = "45.0.0" +dependencies = [ + "arrow", + "async-compression", + "async-trait", + "bytes", + "bzip2 0.5.1", + "chrono", + "datafusion-catalog", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", + "flate2", + "futures", + "glob", + "itertools 0.14.0", + "log", + "object_store", + "rand 0.8.5", + "tempfile", + "tokio", + "tokio-util", + "url", + "xz2", + "zstd", +] + [[package]] name = "datafusion-doc" version = "45.0.0" @@ -2551,7 +2572,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -3448,7 +3469,7 @@ checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -4617,7 +4638,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -5039,7 +5060,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -5602,7 +5623,7 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -5828,16 +5849,16 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" +checksum = "a40f762a77d2afa88c2d919489e390a12bdd261ed568e60cfa7e48d4e20f0d33" dependencies = [ "cfg-if", "fastrand", "getrandom 0.3.1", "once_cell", "rustix", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -6712,7 +6733,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1e35b7f42027..099e5f22972c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,9 +99,10 @@ ctor = "0.2.9" dashmap = "6.0.1" datafusion = { path = "datafusion/core", version = "45.0.0", default-features = false } datafusion-catalog = { path = "datafusion/catalog", version = "45.0.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "45.0.0", default-features = false } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "45.0.0" } datafusion-common = { path = "datafusion/common", version = "45.0.0", default-features = false } datafusion-common-runtime = { path = "datafusion/common-runtime", version = "45.0.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "45.0.0", default-features = false } datafusion-doc = { path = "datafusion/doc", version = "45.0.0" } datafusion-execution = { path = "datafusion/execution", version = "45.0.0" } datafusion-expr = { path = "datafusion/expr", version = "45.0.0" } diff --git a/README.md b/README.md index 2c2febab09cc..158033d40599 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ![Commit Activity][commit-activity-badge] [![Open Issues][open-issues-badge]][open-issues-url] [![Discord chat][discord-badge]][discord-url] +[![Linkedin][linkedin-badge]][linkedin-url] [crates-badge]: https://img.shields.io/crates/v/datafusion.svg [crates-url]: https://crates.io/crates/datafusion @@ -32,11 +33,13 @@ [license-url]: https://github.com/apache/datafusion/blob/main/LICENSE.txt [actions-badge]: https://github.com/apache/datafusion/actions/workflows/rust.yml/badge.svg [actions-url]: https://github.com/apache/datafusion/actions?query=branch%3Amain -[discord-badge]: https://img.shields.io/discord/885562378132000778.svg?logo=discord&style=flat-square +[discord-badge]: https://img.shields.io/badge/Chat-Discord-purple [discord-url]: https://discord.com/invite/Qw5gKqHxUM [commit-activity-badge]: https://img.shields.io/github/commit-activity/m/apache/datafusion [open-issues-badge]: https://img.shields.io/github/issues-raw/apache/datafusion [open-issues-url]: https://github.com/apache/datafusion/issues +[linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue +[linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ [Website](https://datafusion.apache.org/) | [API Docs](https://docs.rs/datafusion/latest/datafusion/) | diff --git a/benchmarks/lineprotocol.py b/benchmarks/lineprotocol.py new file mode 100644 index 000000000000..75e09b662e3e --- /dev/null +++ b/benchmarks/lineprotocol.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Converts a given json to LineProtocol format that can be +visualised by grafana/other systems that support LineProtocol. + +Usage example: +$ python3 lineprotocol.py sort.json +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=0,row_count=10838832,elapsed_ms=85626006 1691105678000000000 +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=1,row_count=10838832,elapsed_ms=68694468 1691105678000000000 +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=2,row_count=10838832,elapsed_ms=63392883 1691105678000000000 +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=3,row_count=10838832,elapsed_ms=66388367 1691105678000000000 +""" + +# sort.json +""" +{ + "queries": [ + { + "iterations": [ + { + "elapsed": 85626.006132, + "row_count": 10838832 + }, + { + "elapsed": 68694.467851, + "row_count": 10838832 + }, + { + "elapsed": 63392.883406, + "row_count": 10838832 + }, + { + "elapsed": 66388.367387, + "row_count": 10838832 + }, + ], + "query": "sort utf8", + "start_time": 1691105678 + }, + ], + "context": { + "arguments": [ + "sort", + "--path", + "benchmarks/data", + "--scale-factor", + "1.0", + "--iterations", + "4", + "-o", + "sort.json" + ], + "benchmark_version": "28.0.0", + "datafusion_version": "28.0.0", + "num_cpus": 8, + "start_time": 1691105678 + } +} +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Dict, List, Any +from pathlib import Path +from argparse import ArgumentParser +import sys +print = sys.stdout.write + + +@dataclass +class QueryResult: + elapsed: float + row_count: int + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> QueryResult: + return cls(elapsed=data["elapsed"], row_count=data["row_count"]) + + +@dataclass +class QueryRun: + query: int + iterations: List[QueryResult] + start_time: int + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> QueryRun: + return cls( + query=data["query"], + iterations=[QueryResult(**iteration) for iteration in data["iterations"]], + start_time=data["start_time"], + ) + + @property + def execution_time(self) -> float: + assert len(self.iterations) >= 1 + + # Use minimum execution time to account for variations / other + # things the system was doing + return min(iteration.elapsed for iteration in self.iterations) + + +@dataclass +class Context: + benchmark_version: str + datafusion_version: str + num_cpus: int + start_time: int + arguments: List[str] + name: str + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> Context: + return cls( + benchmark_version=data["benchmark_version"], + datafusion_version=data["datafusion_version"], + num_cpus=data["num_cpus"], + start_time=data["start_time"], + arguments=data["arguments"], + name=data["arguments"][0] + ) + + +@dataclass +class BenchmarkRun: + context: Context + queries: List[QueryRun] + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> BenchmarkRun: + return cls( + context=Context.load_from(data["context"]), + queries=[QueryRun.load_from(result) for result in data["queries"]], + ) + + @classmethod + def load_from_file(cls, path: Path) -> BenchmarkRun: + with open(path, "r") as f: + return cls.load_from(json.load(f)) + + +def lineformat( + baseline: Path, +) -> None: + baseline = BenchmarkRun.load_from_file(baseline) + context = baseline.context + benchamrk_str = f"benchmark,name={context.name},version={context.benchmark_version},datafusion_version={context.datafusion_version},num_cpus={context.num_cpus}" + for query in baseline.queries: + query_str = f"query=\"{query.query}\"" + timestamp = f"{query.start_time*10**9}" + for iter_num, result in enumerate(query.iterations): + print(f"{benchamrk_str} {query_str},iteration={iter_num},row_count={result.row_count},elapsed_ms={result.elapsed*1000:.0f} {timestamp}\n") + +def main() -> None: + parser = ArgumentParser() + parser.add_argument( + "path", + type=Path, + help="Path to the benchmark file.", + ) + options = parser.parse_args() + + lineformat(options.baseline_path) + + + +if __name__ == "__main__": + main() diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index a2fb75dd1941..578f71f8275d 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -17,6 +17,8 @@ //! external_aggr binary entrypoint +use datafusion::execution::memory_pool::GreedyMemoryPool; +use datafusion::execution::memory_pool::MemoryPool; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -41,7 +43,7 @@ use datafusion::prelude::*; use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_datafusion_err, exec_err, DEFAULT_PARQUET_EXTENSION}; +use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; #[derive(Debug, StructOpt)] #[structopt( @@ -58,10 +60,6 @@ struct ExternalAggrConfig { #[structopt(short, long)] query: Option, - /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query. - #[structopt(long)] - memory_limit: Option, - /// Common options #[structopt(flatten)] common: CommonOpt, @@ -129,10 +127,8 @@ impl ExternalAggrConfig { pub async fn run(&self) -> Result<()> { let mut benchmark_run = BenchmarkRun::new(); - let memory_limit = match &self.memory_limit { - Some(limit) => Some(Self::parse_memory_limit(limit)?), - None => None, - }; + let memory_limit = self.common.memory_limit.map(|limit| limit as u64); + let mem_pool_type = self.common.mem_pool_type.as_str(); let query_range = match self.query { Some(query_id) => query_id..=query_id, @@ -171,7 +167,9 @@ impl ExternalAggrConfig { human_readable_size(mem_limit as usize) )); - let query_results = self.benchmark_query(query_id, mem_limit).await?; + let query_results = self + .benchmark_query(query_id, mem_limit, mem_pool_type) + .await?; for iter in query_results { benchmark_run.write_iter(iter.elapsed, iter.row_count); } @@ -187,12 +185,20 @@ impl ExternalAggrConfig { &self, query_id: usize, mem_limit: u64, + mem_pool_type: &str, ) -> Result> { let query_name = format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); let config = self.common.config(); + let memory_pool: Arc = match mem_pool_type { + "fair" => Arc::new(FairSpillPool::new(mem_limit as usize)), + "greedy" => Arc::new(GreedyMemoryPool::new(mem_limit as usize)), + _ => { + return exec_err!("Invalid memory pool type: {}", mem_pool_type); + } + }; let runtime_env = RuntimeEnvBuilder::new() - .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit as usize))) + .with_memory_pool(memory_pool) .build_arc()?; let state = SessionStateBuilder::new() .with_config(config) @@ -331,22 +337,6 @@ impl ExternalAggrConfig { .partitions .unwrap_or(get_available_parallelism()) } - - /// Parse memory limit from string to number of bytes - /// e.g. '1.5G', '100M' -> 1572864 - fn parse_memory_limit(limit: &str) -> Result { - let (number, unit) = limit.split_at(limit.len() - 1); - let number: f64 = number.parse().map_err(|_| { - exec_datafusion_err!("Failed to parse number from memory limit '{}'", limit) - })?; - - match unit { - "K" => Ok((number * 1024.0) as u64), - "M" => Ok((number * 1024.0 * 1024.0) as u64), - "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as u64), - _ => exec_err!("Unsupported unit '{}' in memory limit '{}'", unit, limit), - } - } } #[tokio::main] @@ -359,31 +349,3 @@ pub async fn main() -> Result<()> { Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_memory_limit_all() { - // Test valid inputs - assert_eq!( - ExternalAggrConfig::parse_memory_limit("100K").unwrap(), - 102400 - ); - assert_eq!( - ExternalAggrConfig::parse_memory_limit("1.5M").unwrap(), - 1572864 - ); - assert_eq!( - ExternalAggrConfig::parse_memory_limit("2G").unwrap(), - 2147483648 - ); - - // Test invalid unit - assert!(ExternalAggrConfig::parse_memory_limit("500X").is_err()); - - // Test invalid number - assert!(ExternalAggrConfig::parse_memory_limit("abcM").is_err()); - } -} diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 6b7c75ed4bab..a9750d9b4b84 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -124,7 +124,8 @@ impl RunOpt { parquet_options.binary_as_string = true; } - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); self.register_hits(&ctx).await?; let iterations = self.common.iterations; diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index 53a516ceb56d..eae7f67f1d62 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -68,7 +68,8 @@ impl RunOpt { }; let config = self.common.config(); - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // Register data self.register_data(&ctx).await?; diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 8d2317c62ef1..d7d7a56d0540 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -306,8 +306,8 @@ impl RunOpt { .config() .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // register tables self.register_tables(&ctx).await?; @@ -515,6 +515,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { @@ -548,6 +551,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 566a5ea62c2d..b1997b40e09e 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -188,8 +188,10 @@ impl RunOpt { /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { let config = self.common.config(); + let rt_builder = self.common.runtime_env_builder()?; let state = SessionStateBuilder::new() .with_config(config) + .with_runtime_env(rt_builder.build_arc()?) .with_default_features() .build(); let ctx = SessionContext::from(state); diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index de3ee3d67db2..eb9db821db02 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -121,7 +121,8 @@ impl RunOpt { .config() .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // register tables self.register_tables(&ctx).await?; @@ -342,6 +343,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { @@ -375,6 +379,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index b1570a1d1bc1..a1cf31525dd9 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -15,8 +15,17 @@ // specific language governing permissions and limitations // under the License. -use datafusion::prelude::SessionConfig; -use datafusion_common::utils::get_available_parallelism; +use std::{num::NonZeroUsize, sync::Arc}; + +use datafusion::{ + execution::{ + disk_manager::DiskManagerConfig, + memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, + runtime_env::RuntimeEnvBuilder, + }, + prelude::SessionConfig, +}; +use datafusion_common::{utils::get_available_parallelism, DataFusionError, Result}; use structopt::StructOpt; // Common benchmark options (don't use doc comments otherwise this doc @@ -35,6 +44,20 @@ pub struct CommonOpt { #[structopt(short = "s", long = "batch-size", default_value = "8192")] pub batch_size: usize, + /// The memory pool type to use, should be one of "fair" or "greedy" + #[structopt(long = "mem-pool-type", default_value = "fair")] + pub mem_pool_type: String, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query + /// if there's any, otherwise run with no memory limit. + #[structopt(long = "memory-limit", parse(try_from_str = parse_memory_limit))] + pub memory_limit: Option, + + /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used + /// if not specified. + #[structopt(long = "sort-spill-reservation-bytes", parse(try_from_str = parse_memory_limit))] + pub sort_spill_reservation_bytes: Option, + /// Activate debug mode to see more details #[structopt(short, long)] pub debug: bool, @@ -48,10 +71,81 @@ impl CommonOpt { /// Modify the existing config appropriately pub fn update_config(&self, config: SessionConfig) -> SessionConfig { - config + let mut config = config .with_target_partitions( self.partitions.unwrap_or(get_available_parallelism()), ) - .with_batch_size(self.batch_size) + .with_batch_size(self.batch_size); + if let Some(sort_spill_reservation_bytes) = self.sort_spill_reservation_bytes { + config = + config.with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); + } + config + } + + /// Return an appropriately configured `RuntimeEnvBuilder` + pub fn runtime_env_builder(&self) -> Result { + let mut rt_builder = RuntimeEnvBuilder::new(); + const NUM_TRACKED_CONSUMERS: usize = 5; + if let Some(memory_limit) = self.memory_limit { + let pool: Arc = match self.mem_pool_type.as_str() { + "fair" => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_limit), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + "greedy" => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_limit), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + _ => { + return Err(DataFusionError::Configuration(format!( + "Invalid memory pool type: {}", + self.mem_pool_type + ))) + } + }; + rt_builder = rt_builder + .with_memory_pool(pool) + .with_disk_manager(DiskManagerConfig::NewOs); + } + Ok(rt_builder) + } +} + +/// Parse memory limit from string to number of bytes +/// e.g. '1.5G', '100M' -> 1572864 +fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number + .parse() + .map_err(|_| format!("Failed to parse number from memory limit '{}'", limit))?; + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => Err(format!( + "Unsupported unit '{}' in memory limit '{}'", + unit, limit + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!(parse_memory_limit("100K").unwrap(), 102400); + assert_eq!(parse_memory_limit("1.5M").unwrap(), 1572864); + assert_eq!(parse_memory_limit("2G").unwrap(), 2147483648); + + // Test invalid unit + assert!(parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(parse_memory_limit("abcM").is_err()); } } diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index 43dc592b997e..bb1cf3c8f78d 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -504,7 +504,7 @@ impl TableProvider for IndexTableProvider { .with_file(partitioned_file); // Finally, put it all together into a DataSourceExec - Ok(file_scan_config.new_exec()) + Ok(file_scan_config.build()) } /// Tell DataFusion to push filters down to the scan method diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 0206c7cd157e..63f17484809e 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -20,8 +20,8 @@ use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{assert_batches_eq, Result, ScalarValue}; use datafusion::logical_expr::{ - BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion::optimizer::ApplyOrder; use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; @@ -205,11 +205,7 @@ impl ScalarUDFImpl for MyEq { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { // this example simply returns "true" which is not what a real // implementation would do. Ok(ColumnarValue::Scalar(ScalarValue::from(true))) diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index f465699abed2..3851dca2a775 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -258,7 +258,7 @@ impl TableProvider for IndexTableProvider { file_size, )); } - Ok(file_scan_config.new_exec()) + Ok(file_scan_config.build()) } /// Tell DataFusion to push filters down to the scan method diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 0aa2083ebca9..68d0ca3a149f 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -27,43 +27,21 @@ repository.workspace = true rust-version.workspace = true version.workspace = true -[features] -compression = ["async-compression", "xz2", "bzip2", "flate2", "zstd", "tokio-util"] -default = ["compression"] - [dependencies] arrow = { workspace = true } -async-compression = { version = "0.4.0", features = [ - "bzip2", - "gzip", - "xz", - "zstd", - "tokio", -], optional = true } async-trait = { workspace = true } -bytes = { workspace = true } -bzip2 = { version = "0.5.1", optional = true } -chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } -datafusion-common-runtime = { workspace = true } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } -flate2 = { version = "1.0.24", optional = true } futures = { workspace = true } -glob = "0.3.0" -itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -rand = { workspace = true } tokio = { workspace = true } -tokio-util = { version = "0.7.4", features = ["io"], optional = true } -url = { workspace = true } -xz2 = { version = "0.1", optional = true, features = ["static"] } -zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] tempfile = { workspace = true } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index ceacde2494e2..cf475263535a 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -20,11 +20,11 @@ use std::mem; use std::sync::Arc; -use super::ListingTableUrl; -use super::PartitionedFile; use datafusion_catalog::Session; use datafusion_common::internal_err; use datafusion_common::{HashMap, Result, ScalarValue}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::PartitionedFile; use datafusion_expr::{BinaryExpr, Operator}; use arrow::{ diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 9eb79ec07ac8..b98790e86455 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,270 +15,4 @@ // specific language governing permissions and limitations // under the License. -//! A table that uses the `ObjectStore` listing capability -//! to get the list of files to process. - -pub mod file_compression_type; -pub mod file_groups; -pub mod file_meta; -pub mod file_scan_config; -pub mod file_sink_config; -pub mod file_stream; pub mod helpers; -pub mod url; -pub mod write; -use chrono::TimeZone; -use datafusion_common::Result; -use datafusion_common::{ScalarValue, Statistics}; -use futures::Stream; -use object_store::{path::Path, ObjectMeta}; -use std::pin::Pin; -use std::sync::Arc; - -pub use self::url::ListingTableUrl; - -/// Stream of files get listed from object store -pub type PartitionedFileStream = - Pin> + Send + Sync + 'static>>; - -/// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" -/// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping -/// sections of a Parquet file in parallel. -#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] -pub struct FileRange { - /// Range start - pub start: i64, - /// Range end - pub end: i64, -} - -impl FileRange { - /// returns true if this file range contains the specified offset - pub fn contains(&self, offset: i64) -> bool { - offset >= self.start && offset < self.end - } -} - -#[derive(Debug, Clone)] -/// A single file or part of a file that should be read, along with its schema, statistics -/// and partition column values that need to be appended to each row. -pub struct PartitionedFile { - /// Path for the file (e.g. URL, filesystem path, etc) - pub object_meta: ObjectMeta, - /// Values of partition columns to be appended to each row. - /// - /// These MUST have the same count, order, and type than the [`table_partition_cols`]. - /// - /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. - /// - /// - /// [`wrap_partition_type_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L55 - /// [`wrap_partition_value_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L62 - /// [`table_partition_cols`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/file_format/options.rs#L190 - pub partition_values: Vec, - /// An optional file range for a more fine-grained parallel execution - pub range: Option, - /// Optional statistics that describe the data in this file if known. - /// - /// DataFusion relies on these statistics for planning (in particular to sort file groups), - /// so if they are incorrect, incorrect answers may result. - pub statistics: Option, - /// An optional field for user defined per object metadata - pub extensions: Option>, - /// The estimated size of the parquet metadata, in bytes - pub metadata_size_hint: Option, -} - -impl PartitionedFile { - /// Create a simple file without metadata or partition - pub fn new(path: impl Into, size: u64) -> Self { - Self { - object_meta: ObjectMeta { - location: Path::from(path.into()), - last_modified: chrono::Utc.timestamp_nanos(0), - size: size as usize, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - } - - /// Create a file range without metadata or partition - pub fn new_with_range(path: String, size: u64, start: i64, end: i64) -> Self { - Self { - object_meta: ObjectMeta { - location: Path::from(path), - last_modified: chrono::Utc.timestamp_nanos(0), - size: size as usize, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: Some(FileRange { start, end }), - statistics: None, - extensions: None, - metadata_size_hint: None, - } - .with_range(start, end) - } - - /// Provide a hint to the size of the file metadata. If a hint is provided - /// the reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. - /// Without an appropriate hint, two read may be required to fetch the metadata. - pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { - self.metadata_size_hint = Some(metadata_size_hint); - self - } - - /// Return a file reference from the given path - pub fn from_path(path: String) -> Result { - let size = std::fs::metadata(path.clone())?.len(); - Ok(Self::new(path, size)) - } - - /// Return the path of this partitioned file - pub fn path(&self) -> &Path { - &self.object_meta.location - } - - /// Update the file to only scan the specified range (in bytes) - pub fn with_range(mut self, start: i64, end: i64) -> Self { - self.range = Some(FileRange { start, end }); - self - } - - /// Update the user defined extensions for this file. - /// - /// This can be used to pass reader specific information. - pub fn with_extensions( - mut self, - extensions: Arc, - ) -> Self { - self.extensions = Some(extensions); - self - } -} - -impl From for PartitionedFile { - fn from(object_meta: ObjectMeta) -> Self { - PartitionedFile { - object_meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - } -} - -#[cfg(test)] -mod tests { - use super::ListingTableUrl; - use datafusion_execution::object_store::{ - DefaultObjectStoreRegistry, ObjectStoreRegistry, - }; - use object_store::{local::LocalFileSystem, path::Path}; - use std::{ops::Not, sync::Arc}; - use url::Url; - - #[test] - fn test_object_store_listing_url() { - let listing = ListingTableUrl::parse("file:///").unwrap(); - let store = listing.object_store(); - assert_eq!(store.as_str(), "file:///"); - - let listing = ListingTableUrl::parse("s3://bucket/").unwrap(); - let store = listing.object_store(); - assert_eq!(store.as_str(), "s3://bucket/"); - } - - #[test] - fn test_get_store_hdfs() { - let sut = DefaultObjectStoreRegistry::default(); - let url = Url::parse("hdfs://localhost:8020").unwrap(); - sut.register_store(&url, Arc::new(LocalFileSystem::new())); - let url = ListingTableUrl::parse("hdfs://localhost:8020/key").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_get_store_s3() { - let sut = DefaultObjectStoreRegistry::default(); - let url = Url::parse("s3://bucket/key").unwrap(); - sut.register_store(&url, Arc::new(LocalFileSystem::new())); - let url = ListingTableUrl::parse("s3://bucket/key").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_get_store_file() { - let sut = DefaultObjectStoreRegistry::default(); - let url = ListingTableUrl::parse("file:///bucket/key").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_get_store_local() { - let sut = DefaultObjectStoreRegistry::default(); - let url = ListingTableUrl::parse("../").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_url_contains() { - let url = ListingTableUrl::parse("file:///var/data/mytable/").unwrap(); - - // standard case with default config - assert!(url.contains( - &Path::parse("/var/data/mytable/data.parquet").unwrap(), - true - )); - - // standard case with `ignore_subdirectory` set to false - assert!(url.contains( - &Path::parse("/var/data/mytable/data.parquet").unwrap(), - false - )); - - // as per documentation, when `ignore_subdirectory` is true, we should ignore files that aren't - // a direct child of the `url` - assert!(url - .contains( - &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), - true - ) - .not()); - - // when we set `ignore_subdirectory` to false, we should not ignore the file - assert!(url.contains( - &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), - false - )); - - // as above, `ignore_subdirectory` is false, so we include the file - assert!(url.contains( - &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), - false - )); - - // in this case, we include the file even when `ignore_subdirectory` is true because the - // path segment is a hive partition which doesn't count as a subdirectory for the purposes - // of `Url::contains` - assert!(url.contains( - &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), - true - )); - - // testing an empty path with default config - assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), true)); - - // testing an empty path with `ignore_subdirectory` set to false - assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), false)); - } -} diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 88d2d8bde51e..ecc792f73d30 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -33,18 +33,19 @@ use datafusion_expr::{ }; use datafusion_physical_plan::ExecutionPlan; -/// A named table which can be queried. +/// A table which can be queried and modified. /// /// Please see [`CatalogProvider`] for details of implementing a custom catalog. /// /// [`TableProvider`] represents a source of data which can provide data as -/// Apache Arrow `RecordBatch`es. Implementations of this trait provide +/// Apache Arrow [`RecordBatch`]es. Implementations of this trait provide /// important information for planning such as: /// /// 1. [`Self::schema`]: The schema (columns and their types) of the table /// 2. [`Self::supports_filters_pushdown`]: Should filters be pushed into this scan /// 2. [`Self::scan`]: An [`ExecutionPlan`] that can read data /// +/// [`RecordBatch`]: https://docs.rs/arrow/latest/arrow/record_batch/struct.RecordBatch.html /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index bc37e59c9b92..28202c6684b5 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -20,7 +20,7 @@ //! but provide an error message rather than a panic, as the corresponding //! kernels in arrow-rs such as `as_boolean_array` do. -use crate::{downcast_value, DataFusionError, Result}; +use crate::{downcast_value, Result}; use arrow::array::{ BinaryViewArray, Float16Array, Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array, diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index c9900204b97f..5e8317c081d9 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1241,35 +1241,72 @@ macro_rules! extensions_options { Box::new(self.clone()) } - fn set(&mut self, key: &str, value: &str) -> $crate::Result<()> { - match key { - $( - stringify!($field_name) => { - self.$field_name = value.parse().map_err(|e| { - $crate::DataFusionError::Context( - format!(concat!("Error parsing {} as ", stringify!($t),), value), - Box::new($crate::DataFusionError::External(Box::new(e))), - ) - })?; - Ok(()) - } - )* - _ => Err($crate::DataFusionError::Configuration( - format!(concat!("Config value \"{}\" not found on ", stringify!($struct_name)), key) - )) - } + fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> { + $crate::config::ConfigField::set(self, key, value) } fn entries(&self) -> Vec<$crate::config::ConfigEntry> { - vec![ + struct Visitor(Vec<$crate::config::ConfigEntry>); + + impl $crate::config::Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push($crate::config::ConfigEntry { + key: key.to_string(), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push($crate::config::ConfigEntry { + key: key.to_string(), + value: None, + description, + }) + } + } + + let mut v = Visitor(vec![]); + // The prefix is not used for extensions. + // The description is generated in ConfigField::visit. + // We can just pass empty strings here. + $crate::config::ConfigField::visit(self, &mut v, "", ""); + v.0 + } + } + + impl $crate::config::ConfigField for $struct_name { + fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { $( - $crate::config::ConfigEntry { - key: stringify!($field_name).to_owned(), - value: (self.$field_name != $default).then(|| self.$field_name.to_string()), - description: concat!($($d),*).trim(), + stringify!($field_name) => { + // Safely apply deprecated attribute if present + // $(#[allow(deprecated)])? + { + #[allow(deprecated)] + self.$field_name.set(rem, value.as_ref()) + } }, )* - ] + _ => return $crate::error::_config_err!( + "Config value \"{}\" not found on {}", key, stringify!($struct_name) + ) + } + } + + fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { + $( + let key = stringify!($field_name).to_string(); + let desc = concat!($($d),*).trim(); + #[allow(deprecated)] + self.$field_name.visit(v, key.as_str(), desc); + )* } } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 7e9025dee1f4..99fb179c76a3 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -1002,12 +1002,14 @@ pub trait SchemaExt { /// It works the same as [`DFSchema::equivalent_names_and_types`]. fn equivalent_names_and_types(&self, other: &Self) -> bool; - /// Returns true if the two schemas have the same qualified named - /// fields with logically equivalent data types. Returns false otherwise. + /// Returns nothing if the two schemas have the same qualified named + /// fields with logically equivalent data types. Returns internal error otherwise. /// /// Use [DFSchema]::equivalent_names_and_types for stricter semantic type /// equivalence checking. - fn logically_equivalent_names_and_types(&self, other: &Self) -> bool; + /// + /// It is only used by insert into cases. + fn logically_equivalent_names_and_types(&self, other: &Self) -> Result<()>; } impl SchemaExt for Schema { @@ -1028,21 +1030,36 @@ impl SchemaExt for Schema { }) } - fn logically_equivalent_names_and_types(&self, other: &Self) -> bool { + // It is only used by insert into cases. + fn logically_equivalent_names_and_types(&self, other: &Self) -> Result<()> { + // case 1 : schema length mismatch if self.fields().len() != other.fields().len() { - return false; + _plan_err!( + "Inserting query must have the same schema length as the table. \ + Expected table schema length: {}, got: {}", + self.fields().len(), + other.fields().len() + ) + } else { + // case 2 : schema length match, but fields mismatch + // check if the fields name are the same and have the same data types + self.fields() + .iter() + .zip(other.fields().iter()) + .try_for_each(|(f1, f2)| { + if f1.name() != f2.name() || !DFSchema::datatype_is_logically_equal(f1.data_type(), f2.data_type()) { + _plan_err!( + "Inserting query schema mismatch: Expected table field '{}' with type {:?}, \ + but got '{}' with type {:?}.", + f1.name(), + f1.data_type(), + f2.name(), + f2.data_type()) + } else { + Ok(()) + } + }) } - - self.fields() - .iter() - .zip(other.fields().iter()) - .all(|(f1, f2)| { - f1.name() == f2.name() - && DFSchema::datatype_is_logically_equal( - f1.data_type(), - f2.data_type(), - ) - }) } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 1ad2a5f0cec3..df1ae100f581 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -104,21 +104,84 @@ pub type HashSet = hashbrown::HashSet; #[macro_export] macro_rules! downcast_value { ($Value: expr, $Type: ident) => {{ - use std::any::type_name; - $Value.as_any().downcast_ref::<$Type>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$Type>() - )) - })? + use $crate::__private::DowncastArrayHelper; + $Value.downcast_array_helper::<$Type>()? }}; ($Value: expr, $Type: ident, $T: tt) => {{ - use std::any::type_name; - $Value.as_any().downcast_ref::<$Type<$T>>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$Type<$T>>() - )) - })? + use $crate::__private::DowncastArrayHelper; + $Value.downcast_array_helper::<$Type<$T>>()? }}; } + +// Not public API. +#[doc(hidden)] +pub mod __private { + use crate::error::_internal_datafusion_err; + use crate::Result; + use arrow::array::Array; + use std::any::{type_name, Any}; + + #[doc(hidden)] + pub trait DowncastArrayHelper { + fn downcast_array_helper(&self) -> Result<&U>; + } + + impl DowncastArrayHelper for T { + fn downcast_array_helper(&self) -> Result<&U> { + self.as_any().downcast_ref().ok_or_else(|| { + _internal_datafusion_err!( + "could not cast array of type {} to {}", + self.data_type(), + type_name::() + ) + }) + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{ArrayRef, Int32Array, UInt64Array}; + use std::any::{type_name, type_name_of_val}; + use std::sync::Arc; + + #[test] + fn test_downcast_value() -> crate::Result<()> { + let boxed: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let array = downcast_value!(&boxed, Int32Array); + assert_eq!(type_name_of_val(&array), type_name::<&Int32Array>()); + + let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert_eq!(array, &expected); + Ok(()) + } + + #[test] + fn test_downcast_value_err_message() { + let boxed: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let error: crate::DataFusionError = (|| { + downcast_value!(&boxed, UInt64Array); + Ok(()) + })() + .err() + .unwrap(); + + assert_starts_with( + error.to_string(), + "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray" + ); + } + + // `err.to_string()` depends on backtrace being present (may have backtrace appended) + // `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped) + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{}' to start with '{}'", + actual, + expected_prefix + ); + } +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index cb77cc8e79b1..ff9cdedab8b1 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -590,6 +590,13 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// Information about how to coerce lists. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ListCoercion { + /// [`DataType::FixedSizeList`] should be coerced to [`DataType::List`]. + FixedSizedListToList, +} + /// A helper function to coerce base type in List. /// /// Example @@ -600,16 +607,22 @@ pub fn base_type(data_type: &DataType) -> DataType { /// /// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; -/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, + array_coercion: Option<&ListCoercion>, ) -> DataType { - match data_type { - DataType::List(field) | DataType::FixedSizeList(field, _) => { - let field_type = - coerced_type_with_base_type_only(field.data_type(), base_type); + match (data_type, array_coercion) { + (DataType::List(field), _) + | (DataType::FixedSizeList(field, _), Some(ListCoercion::FixedSizedListToList)) => + { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); DataType::List(Arc::new(Field::new( field.name(), @@ -617,9 +630,24 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } - DataType::LargeList(field) => { - let field_type = - coerced_type_with_base_type_only(field.data_type(), base_type); + (DataType::FixedSizeList(field, len), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::FixedSizeList( + Arc::new(Field::new(field.name(), field_type, field.is_nullable())), + *len, + ) + } + (DataType::LargeList(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); DataType::LargeList(Arc::new(Field::new( field.name(), diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 6492e828e60c..784b2a89aae9 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -43,7 +43,7 @@ array_expressions = ["nested_expressions"] # Used to enable the avro format avro = ["apache-avro", "num-traits", "datafusion-common/avro"] backtrace = ["datafusion-common/backtrace"] -compression = ["xz2", "bzip2", "flate2", "zstd", "datafusion-catalog-listing/compression"] +compression = ["xz2", "bzip2", "flate2", "zstd", "datafusion-datasource/compression"] crypto_expressions = ["datafusion-functions/crypto_expressions"] datetime_expressions = ["datafusion-functions/datetime_expressions"] default = [ @@ -95,8 +95,10 @@ datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true, optional = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b1eb2a19e31d..4e5ccab14f93 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -25,7 +25,9 @@ use crate::arrow::util::pretty; use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::format_as_file_type; use crate::datasource::file_format::json::JsonFormatFactory; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::datasource::{ + provider_as_source, DefaultTableSource, MemTable, TableProvider, +}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; @@ -62,6 +64,7 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_sql::TableReference; /// Contains options that control how data is /// written out from a DataFrame @@ -1526,8 +1529,6 @@ impl DataFrame { table_name: &str, write_options: DataFrameWriteOptions, ) -> Result, DataFusionError> { - let arrow_schema = Schema::from(self.schema()); - let plan = if write_options.sort_by.is_empty() { self.plan } else { @@ -1536,10 +1537,19 @@ impl DataFrame { .build()? }; + let table_ref: TableReference = table_name.into(); + let table_schema = self.session_state.schema_for_ref(table_ref.clone())?; + let target = match table_schema.table(table_ref.table()).await? { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => plan_err!("No table named '{table_name}'"), + }?; + + let target = Arc::new(DefaultTableSource::new(target)); + let plan = LogicalPlanBuilder::insert_into( plan, table_name.to_owned(), - &arrow_schema, + target, write_options.insert_op, )? .build()?; @@ -1801,7 +1811,8 @@ impl DataFrame { .iter() .map(|(qualifier, field)| { if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { - col(Column::from((qualifier, field))).alias(new_name) + col(Column::from((qualifier, field))) + .alias_qualified(qualifier.cloned(), new_name) } else { col(Column::from((qualifier, field))) } diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index 91c1e0ac97fc..541e0b6dfa91 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -26,12 +26,15 @@ use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, Constraints}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType}; -/// DataFusion default table source, wrapping TableProvider. +/// Implements [`TableSource`] for a [`TableProvider`] /// -/// This structure adapts a `TableProvider` (physical plan trait) to the `TableSource` -/// (logical plan trait) and is necessary because the logical plan is contained in -/// the `datafusion_expr` crate, and is not aware of table providers, which exist in -/// the core `datafusion` crate. +/// This structure adapts a [`TableProvider`] (a physical plan trait) to the +/// [`TableSource`] (logical plan trait). +/// +/// It is used so logical plans in the `datafusion_expr` crate do not have a +/// direct dependency on physical plans, such as [`TableProvider`]s. +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html pub struct DefaultTableSource { /// table provider pub table_provider: Arc, diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index dd56b4c137ed..09121eba6702 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -171,11 +171,10 @@ impl FileFormat for ArrowFormat { async fn create_physical_plan( &self, _state: &dyn Session, - mut conf: FileScanConfig, + conf: FileScanConfig, _filters: Option<&Arc>, ) -> Result> { - conf = conf.with_source(Arc::new(ArrowSource::default())); - Ok(conf.new_exec()) + Ok(conf.with_source(Arc::new(ArrowSource::default())).build()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 100aa4fd51e2..c0c8f25722c2 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -148,11 +148,10 @@ impl FileFormat for AvroFormat { async fn create_physical_plan( &self, _state: &dyn Session, - mut conf: FileScanConfig, + conf: FileScanConfig, _filters: Option<&Arc>, ) -> Result> { - conf = conf.with_source(self.file_source()); - Ok(conf.new_exec()) + Ok(conf.with_source(self.file_source()).build()) } fn file_source(&self) -> Arc { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 7d06648d7ba8..4991a96dc3d3 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -434,9 +434,7 @@ impl FileFormat for CsvFormat { .with_terminator(self.options.terminator) .with_comment(self.options.comment), ); - conf = conf.with_source(source); - - Ok(conf.new_exec()) + Ok(conf.with_source(source).build()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 32a527bc5876..94e74b144499 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -254,9 +254,7 @@ impl FileFormat for JsonFormat { ) -> Result> { let source = Arc::new(JsonSource::new()); conf.file_compression_type = FileCompressionType::from(self.options.compression); - conf = conf.with_source(source); - - Ok(conf.new_exec()) + Ok(conf.with_source(source).build()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index dd48a9537187..657fe6ca5511 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -28,8 +28,8 @@ pub mod json; pub mod options; #[cfg(feature = "parquet")] pub mod parquet; -pub use datafusion_catalog_listing::file_compression_type; -pub use datafusion_catalog_listing::write; +pub use datafusion_datasource::file_compression_type; +pub use datafusion_datasource::write; use std::any::Any; use std::collections::{HashMap, VecDeque}; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 9774792133cd..7dbc510eca09 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -398,7 +398,7 @@ impl FileFormat for ParquetFormat { async fn create_physical_plan( &self, _state: &dyn Session, - mut conf: FileScanConfig, + conf: FileScanConfig, filters: Option<&Arc>, ) -> Result> { let mut predicate = None; @@ -424,8 +424,7 @@ impl FileFormat for ParquetFormat { if let Some(metadata_size_hint) = metadata_size_hint { source = source.with_metadata_size_hint(metadata_size_hint) } - conf = conf.with_source(Arc::new(source)); - Ok(conf.new_exec()) + Ok(conf.with_source(Arc::new(source)).build()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 39323b993d45..a58db55bccb6 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -19,5 +19,8 @@ //! to get the list of files to process. mod table; -pub use datafusion_catalog_listing::*; +pub use datafusion_catalog_listing::helpers; +pub use datafusion_datasource::{ + FileRange, ListingTableUrl, PartitionedFile, PartitionedFileStream, +}; pub use table::{ListingOptions, ListingTable, ListingTableConfig}; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 642ec93f3671..3be8af59ea2a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -996,27 +996,8 @@ impl TableProvider for ListingTable { insert_op: InsertOp, ) -> Result> { // Check that the schema of the plan matches the schema of this table. - if !self - .schema() - .logically_equivalent_names_and_types(&input.schema()) - { - // Return an error if schema of the input query does not match with the table schema. - return plan_err!( - "Inserting query must have the same schema with the table. \ - Expected: {:?}, got: {:?}", - self.schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>(), - input - .schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>() - ); - } + self.schema() + .logically_equivalent_names_and_types(&input.schema())?; let table_path = &self.table_paths()[0]; if !table_path.is_collection() { @@ -1195,7 +1176,7 @@ mod tests { use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; - use crate::datasource::{provider_as_source, MemTable}; + use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::prelude::*; use crate::{ @@ -2065,6 +2046,8 @@ mod tests { session_ctx.register_table("source", source_table.clone())?; // Convert the source table into a provider so that it can be used in a query let source = provider_as_source(source_table); + let target = session_ctx.table_provider("t").await?; + let target = Arc::new(DefaultTableSource::new(target)); // Create a table scan logical plan to read from the source table let scan_plan = LogicalPlanBuilder::scan("source", source, None)? .filter(filter_predicate)? @@ -2073,7 +2056,7 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)? .build()?; // Create a physical plan from the insert plan let plan = session_ctx diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index a996990105b3..94c6e45804e8 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -278,26 +278,9 @@ impl TableProvider for MemTable { // Create a physical plan from the logical plan. // Check that the schema of the plan matches the schema of this table. - if !self - .schema() - .logically_equivalent_names_and_types(&input.schema()) - { - return plan_err!( - "Inserting query must have the same schema with the table. \ - Expected: {:?}, got: {:?}", - self.schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>(), - input - .schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>() - ); - } + self.schema() + .logically_equivalent_names_and_types(&input.schema())?; + if insert_op != InsertOp::Append { return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } @@ -390,7 +373,7 @@ impl DataSink for MemSink { mod tests { use super::*; - use crate::datasource::provider_as_source; + use crate::datasource::{provider_as_source, DefaultTableSource}; use crate::physical_plan::collect; use crate::prelude::SessionContext; @@ -640,6 +623,7 @@ mod tests { // Create and register the initial table with the provided schema and data let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?); session_ctx.register_table("t", initial_table.clone())?; + let target = Arc::new(DefaultTableSource::new(initial_table.clone())); // Create and register the source table with the provided schema and inserted data let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?); session_ctx.register_table("source", source_table.clone())?; @@ -649,7 +633,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)? .build()?; // Create a physical plan from the insert plan let plan = session_ctx diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 6aa330caffab..b0a1d8c8c9e2 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -399,7 +399,7 @@ mod tests { .with_file(meta.into()) .with_projection(Some(vec![0, 1, 2])); - let source_exec = conf.new_exec(); + let source_exec = conf.build(); assert_eq!( source_exec .properties() @@ -472,7 +472,7 @@ mod tests { .with_file(meta.into()) .with_projection(projection); - let source_exec = conf.new_exec(); + let source_exec = conf.build(); assert_eq!( source_exec .properties() @@ -546,7 +546,7 @@ mod tests { .with_file(partitioned_file) .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]); - let source_exec = conf.new_exec(); + let source_exec = conf.build(); assert_eq!( source_exec diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 5e017b992581..c0952229b5e0 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -425,7 +425,7 @@ impl ExecutionPlan for CsvExec { /// let file_scan_config = FileScanConfig::new(object_store_url, file_schema, source) /// .with_file(PartitionedFile::new("file1.csv", 100*1024*1024)) /// .with_newlines_in_values(true); // The file contains newlines in values; -/// let exec = file_scan_config.new_exec(); +/// let exec = file_scan_config.build(); /// ``` #[derive(Debug, Clone, Default)] pub struct CsvSource { @@ -836,14 +836,14 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_file_compression_type(file_compression_type) - .with_newlines_in_values(false); - config.projection = Some(vec![0, 2, 4]); - - let csv = config.new_exec(); + .with_newlines_in_values(false) + .with_projection(Some(vec![0, 2, 4])); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); + assert_eq!(3, csv.schema().fields().len()); let mut stream = csv.execute(0, task_ctx)?; @@ -901,12 +901,12 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()); - config.projection = Some(vec![4, 0, 2]); - let csv = config.new_exec(); + .with_file_compression_type(file_compression_type.to_owned()) + .with_projection(Some(vec![4, 0, 2])); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(3, csv.schema().fields().len()); let mut stream = csv.execute(0, task_ctx)?; @@ -964,12 +964,12 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()); - config.limit = Some(5); - let csv = config.new_exec(); + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(13, csv.schema().fields().len()); let mut it = csv.execute(0, task_ctx)?; @@ -1024,12 +1024,12 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()); - config.limit = Some(5); - let csv = config.new_exec(); + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)); assert_eq!(14, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(14, csv.schema().fields().len()); // errors due to https://github.com/apache/datafusion/issues/4918 @@ -1089,8 +1089,8 @@ mod tests { // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - let csv = config.new_exec(); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(2, csv.schema().fields().len()); let mut it = csv.execute(0, task_ctx)?; @@ -1179,7 +1179,7 @@ mod tests { let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) .with_file_compression_type(file_compression_type.to_owned()); - let csv = config.new_exec(); + let csv = config.build(); let it = csv.execute(0, task_ctx).unwrap(); let batches: Vec<_> = it.try_collect().await.unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 3708fe6abd5e..123ecc2f9582 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -35,7 +35,7 @@ use datafusion_common::{ColumnStatistics, Constraints, Statistics}; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, Partitioning}; use crate::datasource::data_source::FileSource; -pub use datafusion_catalog_listing::file_scan_config::*; +pub use datafusion_datasource::file_scan_config::*; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_plan::display::{display_orderings, ProjectSchemaDisplay}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; @@ -68,21 +68,30 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) } -/// The base configurations to provide when creating a physical plan for +/// The base configurations for a [`DataSourceExec`], the a physical plan for /// any given file format. /// +/// Use [`Self::build`] to create a [`DataSourceExec`] from a ``FileScanConfig`. +/// /// # Example /// ``` /// # use std::sync::Arc; -/// # use arrow::datatypes::Schema; +/// # use arrow::datatypes::{Field, Fields, DataType, Schema}; /// # use datafusion::datasource::listing::PartitionedFile; /// # use datafusion::datasource::physical_plan::FileScanConfig; /// # use datafusion_execution::object_store::ObjectStoreUrl; /// # use datafusion::datasource::physical_plan::ArrowSource; -/// # let file_schema = Arc::new(Schema::empty()); -/// // create FileScan config for reading data from file:// +/// # use datafusion_physical_plan::ExecutionPlan; +/// # let file_schema = Arc::new(Schema::new(vec![ +/// # Field::new("c1", DataType::Int32, false), +/// # Field::new("c2", DataType::Int32, false), +/// # Field::new("c3", DataType::Int32, false), +/// # Field::new("c4", DataType::Int32, false), +/// # ])); +/// // create FileScan config for reading arrow files from file:// /// let object_store_url = ObjectStoreUrl::local_filesystem(); -/// let config = FileScanConfig::new(object_store_url, file_schema, Arc::new(ArrowSource::default())) +/// let file_source = Arc::new(ArrowSource::default()); +/// let config = FileScanConfig::new(object_store_url, file_schema, file_source) /// .with_limit(Some(1000)) // read only the first 1000 records /// .with_projection(Some(vec![2, 3])) // project columns 2 and 3 /// // Read /tmp/file1.parquet with known size of 1234 bytes in a single group @@ -93,6 +102,8 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { /// PartitionedFile::new("file2.parquet", 56), /// PartitionedFile::new("file3.parquet", 78), /// ]); +/// // create an execution plan from the config +/// let plan: Arc = config.build(); /// ``` #[derive(Clone)] pub struct FileScanConfig { @@ -252,19 +263,20 @@ impl DataSource for FileScanConfig { // If there is any non-column or alias-carrier expression, Projection should not be removed. // This process can be moved into CsvExec, but it would be an overlap of their responsibility. Ok(all_alias_free_columns(projection.expr()).then(|| { - let mut file_scan = self.clone(); + let file_scan = self.clone(); let source = Arc::clone(&file_scan.source); let new_projections = new_projections_for_columns( projection, &file_scan .projection + .clone() .unwrap_or((0..self.file_schema.fields().len()).collect()), ); - file_scan.projection = Some(new_projections); - // Assign projected statistics to source - file_scan = file_scan.with_source(source); - - file_scan.new_exec() as _ + file_scan + // Assign projected statistics to source + .with_projection(Some(new_projections)) + .with_source(source) + .build() as _ })) } } @@ -574,9 +586,9 @@ impl FileScanConfig { } // TODO: This function should be moved into DataSourceExec once FileScanConfig moved out of datafusion/core - /// Returns a new [`DataSourceExec`] from file configurations - pub fn new_exec(&self) -> Arc { - Arc::new(DataSourceExec::new(Arc::new(self.clone()))) + /// Returns a new [`DataSourceExec`] to scan the files specified by this config + pub fn build(self) -> Arc { + Arc::new(DataSourceExec::new(Arc::new(self))) } /// Write the data_type based on file_source diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index c88d4c4458a5..7944d6fa9020 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -37,11 +37,9 @@ use crate::physical_plan::RecordBatchStream; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; -pub use datafusion_catalog_listing::file_stream::{FileOpenFuture, FileOpener, OnError}; -use datafusion_catalog_listing::file_stream::{ - FileStreamMetrics, FileStreamState, NextOpen, -}; use datafusion_common::ScalarValue; +pub use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener, OnError}; +use datafusion_datasource::file_stream::{FileStreamMetrics, FileStreamState, NextOpen}; use futures::{ready, FutureExt, Stream, StreamExt}; diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 51e0a46d942e..590b1cb88dcd 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -589,7 +589,7 @@ mod tests { .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); // TODO: this is not where schema inference should be tested @@ -660,7 +660,7 @@ mod tests { .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); let mut it = exec.execute(0, task_ctx)?; let batch = it.next().await.unwrap()?; @@ -700,7 +700,7 @@ mod tests { .with_file_groups(file_groups) .with_projection(Some(vec![0, 2])) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); let inferred_schema = exec.schema(); assert_eq!(inferred_schema.fields().len(), 2); @@ -745,7 +745,7 @@ mod tests { .with_file_groups(file_groups) .with_projection(Some(vec![3, 0, 2])) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); let inferred_schema = exec.schema(); assert_eq!(inferred_schema.fields().len(), 3); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 18174bd54e4f..953c99322e16 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -50,9 +50,9 @@ pub use avro::AvroSource; #[allow(deprecated)] pub use csv::{CsvExec, CsvExecBuilder}; pub use csv::{CsvOpener, CsvSource}; -pub use datafusion_catalog_listing::file_groups::FileGroupPartitioner; -pub use datafusion_catalog_listing::file_meta::FileMeta; -pub use datafusion_catalog_listing::file_sink_config::*; +pub use datafusion_datasource::file_groups::FileGroupPartitioner; +pub use datafusion_datasource::file_meta::FileMeta; +pub use datafusion_datasource::file_sink_config::*; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, }; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index a1c2bb4207ef..4bd43cd1aaca 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -708,7 +708,7 @@ mod tests { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let parquet_exec = base_config.new_exec(); + let parquet_exec = base_config.clone().build(); RoundTripResult { batches: collect(parquet_exec.clone(), task_ctx).await, parquet_exec, @@ -1354,7 +1354,7 @@ mod tests { Arc::new(ParquetSource::default()), ) .with_file_groups(file_groups) - .new_exec(); + .build(); assert_eq!( parquet_exec .properties() @@ -1468,7 +1468,7 @@ mod tests { false, ), ]) - .new_exec(); + .build(); let partition_count = parquet_exec .source() .output_partitioning() @@ -1531,7 +1531,7 @@ mod tests { Arc::new(ParquetSource::default()), ) .with_file(partitioned_file) - .new_exec(); + .build(); let mut results = parquet_exec.execute(0, state.task_ctx())?; let batch = results.next().await.unwrap(); @@ -2188,7 +2188,7 @@ mod tests { extensions: None, metadata_size_hint: None, }) - .new_exec(); + .build(); let res = collect(exec, ctx.task_ctx()).await.unwrap(); assert_eq!(res.len(), 2); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/source.rs b/datafusion/core/src/datasource/physical_plan/parquet/source.rs index a98524b0bead..21881112075d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/source.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/source.rs @@ -94,7 +94,7 @@ use object_store::ObjectStore; /// // Create a DataSourceExec for reading `file1.parquet` with a file size of 100MB /// let file_scan_config = FileScanConfig::new(object_store_url, file_schema, source) /// .with_file(PartitionedFile::new("file1.parquet", 100*1024*1024)); -/// let exec = file_scan_config.new_exec(); +/// let exec = file_scan_config.build(); /// ``` /// /// # Features @@ -176,7 +176,7 @@ use object_store::ObjectStore; /// .clone() /// .with_file_groups(vec![file_group.clone()]); /// -/// new_config.new_exec() +/// new_config.build() /// }) /// .collect::>(); /// ``` @@ -219,7 +219,7 @@ use object_store::ObjectStore; /// .with_file(partitioned_file); /// // this parquet DataSourceExec will not even try to read row groups 2 and 4. Additional /// // pruning based on predicates may also happen -/// let exec = file_scan_config.new_exec(); +/// let exec = file_scan_config.build(); /// ``` /// /// For a complete example, see the [`advanced_parquet_index` example]). diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index e59d7b669ce0..41e375cf81f8 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -507,7 +507,7 @@ mod tests { FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema, source) .with_file(partitioned_file); - let parquet_exec = base_conf.new_exec(); + let parquet_exec = base_conf.build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 48ee8e46bc0f..f4aa366500ef 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -229,9 +229,9 @@ //! 1. The query string is parsed to an Abstract Syntax Tree (AST) //! [`Statement`] using [sqlparser]. //! -//! 2. The AST is converted to a [`LogicalPlan`] and logical -//! expressions [`Expr`]s to compute the desired result by the -//! [`SqlToRel`] planner. +//! 2. The AST is converted to a [`LogicalPlan`] and logical expressions +//! [`Expr`]s to compute the desired result by [`SqlToRel`]. This phase +//! also includes name and type resolution ("binding"). //! //! [`Statement`]: https://docs.rs/sqlparser/latest/sqlparser/ast/enum.Statement.html //! @@ -737,6 +737,11 @@ pub mod logical_expr { pub use datafusion_expr::*; } +/// re-export of [`datafusion_expr_common`] crate +pub mod logical_expr_common { + pub use datafusion_expr_common::*; +} + /// re-export of [`datafusion_optimizer`] crate pub mod optimizer { pub use datafusion_optimizer::*; @@ -920,6 +925,12 @@ doc_comment::doctest!( user_guide_cli_usage ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/features.md", + user_guide_features +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/sql/aggregate_functions.md", @@ -962,6 +973,12 @@ doc_comment::doctest!( user_guide_sql_operators ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/prepared_statements.md", + user_guide_prepared_statements +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/sql/scalar_functions.md", @@ -980,12 +997,6 @@ doc_comment::doctest!( user_guide_sql_special_functions ); -#[cfg(doctest)] -doc_comment::doctest!( - "../../../docs/source/user-guide/sql/sql_status.md", - user_guide_sql_status -); - #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/sql/subqueries.md", diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 47dee391c751..2303574e88af 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::source_as_provider; +use crate::datasource::{source_as_provider, DefaultTableSource}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; @@ -541,19 +541,22 @@ impl DefaultPhysicalPlanner { .await? } LogicalPlan::Dml(DmlStatement { - table_name, + target, op: WriteOp::Insert(insert_op), .. }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { + if let Some(provider) = + target.as_any().downcast_ref::() + { let input_exec = children.one()?; provider + .table_provider .insert_into(session_state, input_exec, *insert_op) .await? } else { - return exec_err!("Table '{table_name}' does not exist"); + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); } } LogicalPlan::Window(Window { window_expr, .. }) => { diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 5b7a9d8a16eb..ba85f9afb6da 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -93,7 +93,7 @@ pub fn scan_partitioned_csv( let source = Arc::new(CsvSource::new(true, b'"', b'"')); let config = partitioned_csv_config(schema, file_groups, source) .with_file_compression_type(FileCompressionType::UNCOMPRESSED); - Ok(config.new_exec()) + Ok(config.build()) } /// Returns file groups [`Vec>`] for scanning `partitions` of `filename` diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 67e0e1726917..0e0090ef028e 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -156,7 +156,7 @@ impl TestParquetFile { ) -> Result> { let parquet_options = ctx.copied_table_options().parquet; let source = Arc::new(ParquetSource::new(parquet_options.clone())); - let mut scan_config = FileScanConfig::new( + let scan_config = FileScanConfig::new( self.object_store_url.clone(), Arc::clone(&self.schema), source, @@ -185,13 +185,12 @@ impl TestParquetFile { Arc::clone(&scan_config.file_schema), Arc::clone(&physical_filter_expr), )); - scan_config = scan_config.with_source(source); - let parquet_exec = scan_config.new_exec(); + let parquet_exec = scan_config.with_source(source).build(); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?); Ok(exec) } else { - Ok(scan_config.new_exec()) + Ok(scan_config.build()) } } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 8155fd6a2ff9..d545157607c7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -63,7 +63,7 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_catalog::TableProvider; use datafusion_common::{ assert_contains, Constraint, Constraints, DataFusionError, ParamValues, ScalarValue, - UnnestOptions, + TableReference, UnnestOptions, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; @@ -1617,9 +1617,25 @@ async fn with_column_renamed() -> Result<()> { // accepts table qualifier .with_column_renamed("aggregate_test_100.c2", "two")? // no-op for missing column - .with_column_renamed("c4", "boom")? - .collect() - .await?; + .with_column_renamed("c4", "boom")?; + + let references: Vec<_> = df_sum_renamed + .schema() + .iter() + .map(|(a, _)| a.cloned()) + .collect(); + + assert_eq!( + references, + vec![ + Some(TableReference::bare("aggregate_test_100")), // table name is preserved + Some(TableReference::bare("aggregate_test_100")), + Some(TableReference::bare("aggregate_test_100")), + None // total column + ] + ); + + let batches = &df_sum_renamed.collect().await?; assert_batches_sorted_eq!( [ @@ -1629,7 +1645,7 @@ async fn with_column_renamed() -> Result<()> { "| a | 3 | -72 | -69 |", "+-----+-----+-----+-------+", ], - &df_sum_renamed + batches ); Ok(()) @@ -5274,3 +5290,55 @@ async fn register_non_parquet_file() { "1.json' does not match the expected extension '.parquet'" ); } + +// Test inserting into checking. +#[tokio::test] +async fn test_insert_into_checking() -> Result<()> { + // Create a new schema with one field called "a" of type Int64, and setting nullable to false + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let session_ctx = SessionContext::new(); + + // Create and register the initial table with the provided schema and data + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("t", initial_table.clone())?; + + // There are two cases we need to check + // 1. The len of the schema of the plan and the schema of the table should be the same + // 2. The datatype of the schema of the plan and the schema of the table should be the same + + // Test case 1: + let write_df = session_ctx.sql("values (1, 2), (3, 4)").await.unwrap(); + + let e = write_df + .write_table("t", DataFrameWriteOptions::new()) + .await + .unwrap_err(); + + assert_contains!( + e.to_string(), + "Inserting query must have the same schema length as the table." + ); + + // Setting nullable to true + // Make sure the nullable check go through + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + + let session_ctx = SessionContext::new(); + + // Create and register the initial table with the provided schema and data + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("t", initial_table.clone())?; + + // Test case 2: + let write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + + let e = write_df + .write_table("t", DataFrameWriteOptions::new()) + .await + .unwrap_err(); + + assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8"); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index aa6bba8083a1..d4b41b686631 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -29,7 +29,9 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -581,12 +583,8 @@ impl ScalarUDFImpl for TestScalarUDF { Ok(input[0].sort_properties) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new({ diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index e11d472b9b8a..c6876d4a7e96 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -110,6 +110,13 @@ async fn test_utf8_not_like_prefix() { .await; } +#[tokio::test] +async fn test_utf8_not_like_ecsape() { + Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{}%", value)))) + .run() + .await; +} + #[tokio::test] async fn test_utf8_not_like_suffix() { Utf8Test::new(|value| col("a").not_like(lit(format!("{}%", value)))) @@ -117,6 +124,13 @@ async fn test_utf8_not_like_suffix() { .await; } +#[tokio::test] +async fn test_utf8_not_like_suffix_one() { + Utf8Test::new(|value| col("a").not_like(lit(format!("{}_", value)))) + .run() + .await; +} + /// Fuzz testing for UTF8 predicate pruning /// The basic idea is that query results should always be the same with or without stats/pruning /// If we get this right we at least guarantee that there are no incorrect results @@ -321,7 +335,7 @@ async fn execute_with_predicate( }) .collect(), ); - let exec = scan.new_exec(); + let exec = scan.build(); let exec = Arc::new(FilterExec::try_new(predicate, exec).unwrap()) as Arc; diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 928b650e0300..b12b3be2d435 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -90,7 +90,7 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { ) .with_file_group(file_group); - let parquet_exec = base_config.new_exec(); + let parquet_exec = base_config.build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index cbf6580b7e4b..1eacbe42c525 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -351,7 +351,7 @@ impl TestFull { let config = FileScanConfig::new(object_store_url, schema.clone(), source) .with_file(partitioned_file); - let plan: Arc = config.new_exec(); + let plan: Arc = config.build(); // run the DataSourceExec and collect the results let results = diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 90793028f209..4cbbcf12f32b 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -64,7 +64,7 @@ async fn multi_parquet_coercion() { FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema, source) .with_file_group(file_group); - let parquet_exec = conf.new_exec(); + let parquet_exec = conf.build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -121,7 +121,7 @@ async fn multi_parquet_coercion_projection() { ) .with_file_group(file_group) .with_projection(Some(vec![1, 0, 2])) - .new_exec(); + .build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 855550dc748a..50c67f09c704 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -183,7 +183,7 @@ fn parquet_exec_multiple_sorted( vec![PartitionedFile::new("y".to_string(), 100)], ]) .with_output_ordering(output_ordering) - .new_exec() + .build() } fn csv_exec() -> Arc { @@ -198,7 +198,7 @@ fn csv_exec_with_sort(output_ordering: Vec) -> Arc ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) - .new_exec() + .build() } fn csv_exec_multiple() -> Arc { @@ -217,7 +217,7 @@ fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc Result<()> { ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_file_compression_type(compression_type) - .new_exec(), + .build(), vec![("a".to_string(), "a".to_string())], ); assert_optimized!(expected, plan, true, false, 2, true, 10, false); diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 4b358e47361b..3412b962d859 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -69,7 +69,7 @@ fn csv_exec_ordered( ) .with_file(PartitionedFile::new("file_path".to_string(), 100)) .with_output_ordering(vec![sort_exprs]) - .new_exec() + .build() } /// Created a sorted parquet exec @@ -87,7 +87,7 @@ pub fn parquet_exec_sorted( ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(vec![sort_exprs]) - .new_exec() + .build() } /// Create a sorted Csv exec @@ -104,7 +104,7 @@ fn csv_exec_sorted( ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(vec![sort_exprs]) - .new_exec() + .build() } /// Runs the sort enforcement optimizer and asserts the plan diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index c9eadf009130..89bd97881e3a 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -27,9 +27,7 @@ use datafusion_common::Result; use datafusion_common::{JoinSide, JoinType, ScalarValue}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{ - ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, -}; +use datafusion_expr::{Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility}; use datafusion_physical_expr::expressions::{ binary, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -91,14 +89,6 @@ impl ScalarUDFImpl for DummyUDF { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Int32) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("DummyUDF::invoke") - } } #[test] @@ -382,7 +372,7 @@ fn create_simple_csv_exec() -> Arc { ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_projection(Some(vec![0, 1, 2, 3, 4])) - .new_exec() + .build() } fn create_projecting_csv_exec() -> Arc { @@ -399,7 +389,7 @@ fn create_projecting_csv_exec() -> Arc { ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_projection(Some(vec![3, 2, 1])) - .new_exec() + .build() } fn create_projecting_memory_exec() -> Arc { diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 5e486a715b41..162f93facc90 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -75,7 +75,7 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) - .new_exec() + .build() } /// Create a single parquet file that is sorted @@ -89,7 +89,7 @@ pub(crate) fn parquet_exec_with_sort( ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) - .new_exec() + .build() } pub fn schema() -> SchemaRef { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 2c3577a137ad..43e7ec9e45e4 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -31,15 +31,16 @@ use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::utils::take_function_args; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, + plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, + OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -207,11 +208,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) } } @@ -518,16 +515,13 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - number_rows: usize, - ) -> Result { - let answer = match &args[0] { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [arg] = take_function_args(self.name(), &args.args)?; + let answer = match arg { // When called with static arguments, the result is returned as an array. ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { let mut answer = vec![]; - for index in 1..=number_rows { + for index in 1..=args.number_rows { // When calling a function with immutable arguments, the result is returned with ")". // Example: SELECT add_index_to_string('const_value') FROM table; answer.push(index.to_string() + ") " + value); @@ -713,14 +707,6 @@ impl ScalarUDFImpl for CastToI64UDF { // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("Function should have been simplified prior to evaluation") - } } #[tokio::test] @@ -850,17 +836,14 @@ impl ScalarUDFImpl for TakeUDF { } // The actual implementation - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let take_idx = match &args[2] { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [_arg0, _arg1, arg2] = take_function_args(self.name(), &args.args)?; + let take_idx = match arg2 { ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0, ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1, _ => unreachable!(), }; - match &args[take_idx] { + match &args.args[take_idx] { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), ColumnarValue::Scalar(_) => unimplemented!(), } @@ -963,14 +946,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - internal_err!("This function should not get invoked!") - } - fn simplify( &self, args: Vec, diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml new file mode 100644 index 000000000000..caf1c60a785d --- /dev/null +++ b/datafusion/datasource/Cargo.toml @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-datasource" +description = "datafusion-datasource" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[features] +compression = ["async-compression", "xz2", "bzip2", "flate2", "zstd", "tokio-util"] +default = ["compression"] + +[dependencies] +arrow = { workspace = true } +async-compression = { version = "0.4.0", features = [ + "bzip2", + "gzip", + "xz", + "zstd", + "tokio", +], optional = true } +async-trait = { workspace = true } +bytes = { workspace = true } +bzip2 = { version = "0.5.1", optional = true } +chrono = { workspace = true } +datafusion-catalog = { workspace = true } +datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-common-runtime = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } +flate2 = { version = "1.0.24", optional = true } +futures = { workspace = true } +glob = "0.3.0" +itertools = { workspace = true } +log = { workspace = true } +object_store = { workspace = true } +rand = { workspace = true } +tokio = { workspace = true } +tokio-util = { version = "0.7.4", features = ["io"], optional = true } +url = { workspace = true } +xz2 = { version = "0.1", optional = true, features = ["static"] } +zstd = { version = "0.13", optional = true, default-features = false } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_datasource" +path = "src/mod.rs" diff --git a/datafusion/datasource/LICENSE.txt b/datafusion/datasource/LICENSE.txt new file mode 120000 index 000000000000..1ef648f64b34 --- /dev/null +++ b/datafusion/datasource/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/datasource/NOTICE.txt b/datafusion/datasource/NOTICE.txt new file mode 120000 index 000000000000..fb051c92b10b --- /dev/null +++ b/datafusion/datasource/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/datasource/README.md b/datafusion/datasource/README.md new file mode 100644 index 000000000000..2479a28ae68d --- /dev/null +++ b/datafusion/datasource/README.md @@ -0,0 +1,24 @@ + + +# DataFusion datasource + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that defines common DataSource related components like FileScanConfig, FileCompression etc. diff --git a/datafusion/catalog-listing/src/file_compression_type.rs b/datafusion/datasource/src/file_compression_type.rs similarity index 100% rename from datafusion/catalog-listing/src/file_compression_type.rs rename to datafusion/datasource/src/file_compression_type.rs diff --git a/datafusion/catalog-listing/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs similarity index 100% rename from datafusion/catalog-listing/src/file_groups.rs rename to datafusion/datasource/src/file_groups.rs diff --git a/datafusion/catalog-listing/src/file_meta.rs b/datafusion/datasource/src/file_meta.rs similarity index 100% rename from datafusion/catalog-listing/src/file_meta.rs rename to datafusion/datasource/src/file_meta.rs diff --git a/datafusion/catalog-listing/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs similarity index 100% rename from datafusion/catalog-listing/src/file_scan_config.rs rename to datafusion/datasource/src/file_scan_config.rs diff --git a/datafusion/catalog-listing/src/file_sink_config.rs b/datafusion/datasource/src/file_sink_config.rs similarity index 100% rename from datafusion/catalog-listing/src/file_sink_config.rs rename to datafusion/datasource/src/file_sink_config.rs diff --git a/datafusion/catalog-listing/src/file_stream.rs b/datafusion/datasource/src/file_stream.rs similarity index 100% rename from datafusion/catalog-listing/src/file_stream.rs rename to datafusion/datasource/src/file_stream.rs diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs new file mode 100644 index 000000000000..c735c3108b3d --- /dev/null +++ b/datafusion/datasource/src/mod.rs @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses the `ObjectStore` listing capability +//! to get the list of files to process. + +pub mod file_compression_type; +pub mod file_groups; +pub mod file_meta; +pub mod file_scan_config; +pub mod file_sink_config; +pub mod file_stream; +pub mod url; +pub mod write; +use chrono::TimeZone; +use datafusion_common::Result; +use datafusion_common::{ScalarValue, Statistics}; +use futures::Stream; +use object_store::{path::Path, ObjectMeta}; +use std::pin::Pin; +use std::sync::Arc; + +pub use self::url::ListingTableUrl; + +/// Stream of files get listed from object store +pub type PartitionedFileStream = + Pin> + Send + Sync + 'static>>; + +/// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" +/// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping +/// sections of a Parquet file in parallel. +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +pub struct FileRange { + /// Range start + pub start: i64, + /// Range end + pub end: i64, +} + +impl FileRange { + /// returns true if this file range contains the specified offset + pub fn contains(&self, offset: i64) -> bool { + offset >= self.start && offset < self.end + } +} + +#[derive(Debug, Clone)] +/// A single file or part of a file that should be read, along with its schema, statistics +/// and partition column values that need to be appended to each row. +pub struct PartitionedFile { + /// Path for the file (e.g. URL, filesystem path, etc) + pub object_meta: ObjectMeta, + /// Values of partition columns to be appended to each row. + /// + /// These MUST have the same count, order, and type than the [`table_partition_cols`]. + /// + /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. + /// + /// + /// [`wrap_partition_type_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L55 + /// [`wrap_partition_value_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L62 + /// [`table_partition_cols`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/file_format/options.rs#L190 + pub partition_values: Vec, + /// An optional file range for a more fine-grained parallel execution + pub range: Option, + /// Optional statistics that describe the data in this file if known. + /// + /// DataFusion relies on these statistics for planning (in particular to sort file groups), + /// so if they are incorrect, incorrect answers may result. + pub statistics: Option, + /// An optional field for user defined per object metadata + pub extensions: Option>, + /// The estimated size of the parquet metadata, in bytes + pub metadata_size_hint: Option, +} + +impl PartitionedFile { + /// Create a simple file without metadata or partition + pub fn new(path: impl Into, size: u64) -> Self { + Self { + object_meta: ObjectMeta { + location: Path::from(path.into()), + last_modified: chrono::Utc.timestamp_nanos(0), + size: size as usize, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + } + } + + /// Create a file range without metadata or partition + pub fn new_with_range(path: String, size: u64, start: i64, end: i64) -> Self { + Self { + object_meta: ObjectMeta { + location: Path::from(path), + last_modified: chrono::Utc.timestamp_nanos(0), + size: size as usize, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: Some(FileRange { start, end }), + statistics: None, + extensions: None, + metadata_size_hint: None, + } + .with_range(start, end) + } + + /// Provide a hint to the size of the file metadata. If a hint is provided + /// the reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. + /// Without an appropriate hint, two read may be required to fetch the metadata. + pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { + self.metadata_size_hint = Some(metadata_size_hint); + self + } + + /// Return a file reference from the given path + pub fn from_path(path: String) -> Result { + let size = std::fs::metadata(path.clone())?.len(); + Ok(Self::new(path, size)) + } + + /// Return the path of this partitioned file + pub fn path(&self) -> &Path { + &self.object_meta.location + } + + /// Update the file to only scan the specified range (in bytes) + pub fn with_range(mut self, start: i64, end: i64) -> Self { + self.range = Some(FileRange { start, end }); + self + } + + /// Update the user defined extensions for this file. + /// + /// This can be used to pass reader specific information. + pub fn with_extensions( + mut self, + extensions: Arc, + ) -> Self { + self.extensions = Some(extensions); + self + } +} + +impl From for PartitionedFile { + fn from(object_meta: ObjectMeta) -> Self { + PartitionedFile { + object_meta, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::ListingTableUrl; + use datafusion_execution::object_store::{ + DefaultObjectStoreRegistry, ObjectStoreRegistry, + }; + use object_store::{local::LocalFileSystem, path::Path}; + use std::{ops::Not, sync::Arc}; + use url::Url; + + #[test] + fn test_object_store_listing_url() { + let listing = ListingTableUrl::parse("file:///").unwrap(); + let store = listing.object_store(); + assert_eq!(store.as_str(), "file:///"); + + let listing = ListingTableUrl::parse("s3://bucket/").unwrap(); + let store = listing.object_store(); + assert_eq!(store.as_str(), "s3://bucket/"); + } + + #[test] + fn test_get_store_hdfs() { + let sut = DefaultObjectStoreRegistry::default(); + let url = Url::parse("hdfs://localhost:8020").unwrap(); + sut.register_store(&url, Arc::new(LocalFileSystem::new())); + let url = ListingTableUrl::parse("hdfs://localhost:8020/key").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_get_store_s3() { + let sut = DefaultObjectStoreRegistry::default(); + let url = Url::parse("s3://bucket/key").unwrap(); + sut.register_store(&url, Arc::new(LocalFileSystem::new())); + let url = ListingTableUrl::parse("s3://bucket/key").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_get_store_file() { + let sut = DefaultObjectStoreRegistry::default(); + let url = ListingTableUrl::parse("file:///bucket/key").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_get_store_local() { + let sut = DefaultObjectStoreRegistry::default(); + let url = ListingTableUrl::parse("../").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_url_contains() { + let url = ListingTableUrl::parse("file:///var/data/mytable/").unwrap(); + + // standard case with default config + assert!(url.contains( + &Path::parse("/var/data/mytable/data.parquet").unwrap(), + true + )); + + // standard case with `ignore_subdirectory` set to false + assert!(url.contains( + &Path::parse("/var/data/mytable/data.parquet").unwrap(), + false + )); + + // as per documentation, when `ignore_subdirectory` is true, we should ignore files that aren't + // a direct child of the `url` + assert!(url + .contains( + &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), + true + ) + .not()); + + // when we set `ignore_subdirectory` to false, we should not ignore the file + assert!(url.contains( + &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), + false + )); + + // as above, `ignore_subdirectory` is false, so we include the file + assert!(url.contains( + &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), + false + )); + + // in this case, we include the file even when `ignore_subdirectory` is true because the + // path segment is a hive partition which doesn't count as a subdirectory for the purposes + // of `Url::contains` + assert!(url.contains( + &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), + true + )); + + // testing an empty path with default config + assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), true)); + + // testing an empty path with `ignore_subdirectory` set to false + assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), false)); + } +} diff --git a/datafusion/catalog-listing/src/url.rs b/datafusion/datasource/src/url.rs similarity index 99% rename from datafusion/catalog-listing/src/url.rs rename to datafusion/datasource/src/url.rs index 2e6415ba3b2b..89e73a8a2b26 100644 --- a/datafusion/catalog-listing/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -193,7 +193,7 @@ impl ListingTableUrl { /// /// Examples: /// ```rust - /// use datafusion_catalog_listing::ListingTableUrl; + /// use datafusion_datasource::ListingTableUrl; /// let url = ListingTableUrl::parse("file:///foo/bar.csv").unwrap(); /// assert_eq!(url.file_extension(), Some("csv")); /// let url = ListingTableUrl::parse("file:///foo/bar").unwrap(); diff --git a/datafusion/catalog-listing/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs similarity index 100% rename from datafusion/catalog-listing/src/write/demux.rs rename to datafusion/datasource/src/write/demux.rs diff --git a/datafusion/catalog-listing/src/write/mod.rs b/datafusion/datasource/src/write/mod.rs similarity index 100% rename from datafusion/catalog-listing/src/write/mod.rs rename to datafusion/datasource/src/write/mod.rs diff --git a/datafusion/catalog-listing/src/write/orchestration.rs b/datafusion/datasource/src/write/orchestration.rs similarity index 100% rename from datafusion/catalog-listing/src/write/orchestration.rs rename to datafusion/datasource/src/write/orchestration.rs diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 7cdb53c90d0e..b11596c4a30f 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -214,6 +214,7 @@ mod tests { extensions_options! { struct TestExtension { value: usize, default = 42 + option_value: Option, default = None } } @@ -229,6 +230,7 @@ mod tests { let mut config = ConfigOptions::new().with_extensions(extensions); config.set("test.value", "24")?; + config.set("test.option_value", "42")?; let session_config = SessionConfig::from(config); let task_context = TaskContext::new( @@ -249,6 +251,39 @@ mod tests { assert!(test.is_some()); assert_eq!(test.unwrap().value, 24); + assert_eq!(test.unwrap().option_value, Some(42)); + + Ok(()) + } + + #[test] + fn task_context_extensions_default() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let mut extensions = Extensions::new(); + extensions.insert(TestExtension::default()); + + let config = ConfigOptions::new().with_extensions(extensions); + let session_config = SessionConfig::from(config); + + let task_context = TaskContext::new( + Some("task_id".to_string()), + "session_id".to_string(), + session_config, + HashMap::default(), + HashMap::default(), + HashMap::default(), + runtime, + ); + + let test = task_context + .session_config() + .options() + .extensions + .get::(); + assert!(test.is_some()); + + assert_eq!(test.unwrap().value, 42); + assert_eq!(test.unwrap().option_value, None); Ok(()) } diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 1bfae28af840..4ca4961d7b63 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,11 +19,11 @@ //! and return types of functions in DataFusion. use std::fmt::Display; -use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use datafusion_common::types::{LogicalTypeRef, NativeType}; +use datafusion_common::utils::ListCoercion; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -227,25 +227,13 @@ impl Display for TypeSignatureClass { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { - /// Specialized Signature for ArrayAppend and similar functions - /// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list. - /// The second argument's list dimension should be one dimension less than the first argument's list dimension. - /// List dimension of the List/LargeList is equivalent to the number of List. - /// List dimension of the non-list is 0. - ArrayAndElement, - /// Specialized Signature for ArrayPrepend and similar functions - /// The first argument should be non-list or list, and the second argument should be List/LargeList. - /// The first argument's list dimension should be one dimension less than the second argument's list dimension. - ElementAndArray, - /// Specialized Signature for Array functions of the form (List/LargeList, Index+) - /// The first argument should be List/LargeList/FixedSizedList, and the next n arguments should be Int64. - ArrayAndIndexes(NonZeroUsize), - /// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index) - ArrayAndElementAndOptionalIndex, - /// Specialized Signature for ArrayEmpty and similar functions - /// The function takes a single argument that must be a List/LargeList/FixedSizeList - /// or something that can be coerced to one of those types. - Array, + /// A function takes at least one List/LargeList/FixedSizeList argument. + Array { + /// A full list of the arguments accepted by this function. + arguments: Vec, + /// Additional information about how array arguments should be coerced. + array_coercion: Option, + }, /// A function takes a single argument that must be a List/LargeList/FixedSizeList /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. RecursiveArray, @@ -257,25 +245,15 @@ pub enum ArrayFunctionSignature { impl Display for ArrayFunctionSignature { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ArrayFunctionSignature::ArrayAndElement => { - write!(f, "array, element") - } - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { - write!(f, "array, element, [index]") - } - ArrayFunctionSignature::ElementAndArray => { - write!(f, "element, array") - } - ArrayFunctionSignature::ArrayAndIndexes(count) => { - write!(f, "array")?; - for _ in 0..count.get() { - write!(f, ", index")?; + ArrayFunctionSignature::Array { arguments, .. } => { + for (idx, argument) in arguments.iter().enumerate() { + write!(f, "{argument}")?; + if idx != arguments.len() - 1 { + write!(f, ", ")?; + } } Ok(()) } - ArrayFunctionSignature::Array => { - write!(f, "array") - } ArrayFunctionSignature::RecursiveArray => { write!(f, "recursive_array") } @@ -286,6 +264,34 @@ impl Display for ArrayFunctionSignature { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ArrayFunctionArgument { + /// A non-list or list argument. The list dimensions should be one less than the Array's list + /// dimensions. + Element, + /// An Int64 index argument. + Index, + /// An argument of type List/LargeList/FixedSizeList. All Array arguments must be coercible + /// to the same type. + Array, +} + +impl Display for ArrayFunctionArgument { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArrayFunctionArgument::Element => { + write!(f, "element") + } + ArrayFunctionArgument::Index => { + write!(f, "index") + } + ArrayFunctionArgument::Array => { + write!(f, "array") + } + } + } +} + impl TypeSignature { pub fn to_string_repr(&self) -> Vec { match self { @@ -580,7 +586,13 @@ impl Signature { pub fn array_and_element(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndElement, + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, ), volatility, } @@ -588,30 +600,38 @@ impl Signature { /// Specialized Signature for Array functions with an optional index pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex, - ), - volatility, - } - } - /// Specialized Signature for ArrayPrepend and similar functions - pub fn element_and_array(volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ElementAndArray, - ), + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + ]), volatility, } } + /// Specialized Signature for ArrayElement and similar functions pub fn array_and_index(volatility: Volatility) -> Self { - Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is non-zero")) - } - /// Specialized Signature for ArraySlice and similar functions - pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes(count), + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }, ), volatility, } @@ -619,7 +639,12 @@ impl Signature { /// Specialized Signature for ArrayEmpty and similar functions pub fn array(volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }, + ), volatility, } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a2de5e7b259f..eb5f98930a00 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -27,7 +27,7 @@ use crate::function::{ }; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ @@ -477,12 +477,8 @@ impl ScalarUDFImpl for SimpleScalarUDF { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - (self.fun)(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + (self.fun)(&args.args) } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index aaa65c676a42..2f04f234eb1d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,8 +71,8 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::signature::{ - ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass, Volatility, - TIMEZONE_WILDCARD, + ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature, + TypeSignatureClass, Volatility, TIMEZONE_WILDCARD, }; pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4fdfb84aea42..45889e96b57f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -384,14 +384,12 @@ impl LogicalPlanBuilder { pub fn insert_into( input: LogicalPlan, table_name: impl Into, - table_schema: &Schema, + target: Arc, insert_op: InsertOp, ) -> Result { - let table_schema = table_schema.clone().to_dfschema_ref()?; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), - table_schema, + target, WriteOp::Insert(insert_op), Arc::new(input), )))) @@ -722,6 +720,21 @@ impl LogicalPlanBuilder { union(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) } + /// Apply a union by name, preserving duplicate rows + pub fn union_by_name(self, plan: LogicalPlan) -> Result { + union_by_name(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) + } + + /// Apply a union by name, removing duplicate rows + pub fn union_by_name_distinct(self, plan: LogicalPlan) -> Result { + let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); + let right_plan: LogicalPlan = plan; + + Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( + union_by_name(left_plan, right_plan)?, + ))))) + } + /// Apply a union, removing duplicate rows pub fn union_distinct(self, plan: LogicalPlan) -> Result { let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); @@ -834,10 +847,16 @@ impl LogicalPlanBuilder { plan: &LogicalPlan, column: impl Into, ) -> Result { + let column = column.into(); + if column.relation.is_some() { + // column is already normalized + return Ok(column); + } + let schema = plan.schema(); let fallback_schemas = plan.fallback_normalize_schemas(); let using_columns = plan.using_columns()?; - column.into().normalize_with_schemas_and_ambiguity_check( + column.normalize_with_schemas_and_ambiguity_check( &[&[schema], &fallback_schemas], &using_columns, ) @@ -1540,6 +1559,18 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result Result { + Ok(LogicalPlan::Union(Union::try_new_by_name(vec![ + Arc::new(left_plan), + Arc::new(right_plan), + ])?)) +} + /// Create Projection /// # Errors /// This function errors under any of the following conditions: diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 669bc8e8a7d3..d4d50ac4eae4 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; -use crate::LogicalPlan; +use crate::{LogicalPlan, TableSource}; /// Operator that copies the contents of a database to file(s) #[derive(Clone)] @@ -91,12 +91,12 @@ impl Hash for CopyTo { /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone)] pub struct DmlStatement { /// The table name pub table_name: TableReference, - /// The schema of the table (must align with Rel input) - pub table_schema: DFSchemaRef, + /// this is target table to insert into + pub target: Arc, /// The type of operation to perform pub op: WriteOp, /// The relation that determines the tuples to add/remove/modify the schema must match with table_schema @@ -104,18 +104,51 @@ pub struct DmlStatement { /// The schema of the output relation pub output_schema: DFSchemaRef, } +impl Eq for DmlStatement {} +impl Hash for DmlStatement { + fn hash(&self, state: &mut H) { + self.table_name.hash(state); + self.target.schema().hash(state); + self.op.hash(state); + self.input.hash(state); + self.output_schema.hash(state); + } +} + +impl PartialEq for DmlStatement { + fn eq(&self, other: &Self) -> bool { + self.table_name == other.table_name + && self.target.schema() == other.target.schema() + && self.op == other.op + && self.input == other.input + && self.output_schema == other.output_schema + } +} + +impl Debug for DmlStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("DmlStatement") + .field("table_name", &self.table_name) + .field("target", &"...") + .field("target_schema", &self.target.schema()) + .field("op", &self.op) + .field("input", &self.input) + .field("output_schema", &self.output_schema) + .finish() + } +} impl DmlStatement { /// Creates a new DML statement with the output schema set to a single `count` column. pub fn new( table_name: TableReference, - table_schema: DFSchemaRef, + target: Arc, op: WriteOp, input: Arc, ) -> Self { Self { table_name, - table_schema, + target, op, input, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index daf1a1375eac..a07da8adde78 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -18,7 +18,7 @@ //! Logical plan types use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::{Arc, LazyLock}; @@ -705,6 +705,13 @@ impl LogicalPlan { // If inputs are not pruned do not change schema Ok(LogicalPlan::Union(Union { inputs, schema })) } else { + // A note on `Union`s constructed via `try_new_by_name`: + // + // At this point, the schema for each input should have + // the same width. Thus, we do not need to save whether a + // `Union` was created `BY NAME`, and can safely rely on the + // `try_new` initializer to derive the new schema based on + // column positions. Ok(LogicalPlan::Union(Union::try_new(inputs)?)) } } @@ -784,7 +791,7 @@ impl LogicalPlan { } LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, .. }) => { @@ -792,7 +799,7 @@ impl LogicalPlan { let input = self.only_input(inputs)?; Ok(LogicalPlan::Dml(DmlStatement::new( table_name.clone(), - Arc::clone(table_schema), + Arc::clone(target), op.clone(), Arc::new(input), ))) @@ -2648,7 +2655,7 @@ pub struct Union { impl Union { /// Constructs new Union instance deriving schema from inputs. fn try_new(inputs: Vec>) -> Result { - let schema = Self::derive_schema_from_inputs(&inputs, false)?; + let schema = Self::derive_schema_from_inputs(&inputs, false, false)?; Ok(Union { inputs, schema }) } @@ -2657,21 +2664,143 @@ impl Union { /// take type from the first input. // TODO (https://github.com/apache/datafusion/issues/14380): Avoid creating uncoerced union at all. pub fn try_new_with_loose_types(inputs: Vec>) -> Result { - let schema = Self::derive_schema_from_inputs(&inputs, true)?; + let schema = Self::derive_schema_from_inputs(&inputs, true, false)?; Ok(Union { inputs, schema }) } + /// Constructs a new Union instance that combines rows from different tables by name, + /// instead of by position. This means that the specified inputs need not have schemas + /// that are all the same width. + pub fn try_new_by_name(inputs: Vec>) -> Result { + let schema = Self::derive_schema_from_inputs(&inputs, true, true)?; + let inputs = Self::rewrite_inputs_from_schema(&schema, inputs)?; + + Ok(Union { inputs, schema }) + } + + /// When constructing a `UNION BY NAME`, we may need to wrap inputs + /// in an additional `Projection` to account for absence of columns + /// in input schemas. + fn rewrite_inputs_from_schema( + schema: &DFSchema, + inputs: Vec>, + ) -> Result>> { + let schema_width = schema.iter().count(); + let mut wrapped_inputs = Vec::with_capacity(inputs.len()); + for input in inputs { + // If the input plan's schema contains the same number of fields + // as the derived schema, then it does not to be wrapped in an + // additional `Projection`. + if input.schema().iter().count() == schema_width { + wrapped_inputs.push(input); + continue; + } + + // Any columns that exist within the derived schema but do not exist + // within an input's schema should be replaced with `NULL` aliased + // to the appropriate column in the derived schema. + let mut expr = Vec::with_capacity(schema_width); + for column in schema.columns() { + if input + .schema() + .has_column_with_unqualified_name(column.name()) + { + expr.push(Expr::Column(column)); + } else { + expr.push(Expr::Literal(ScalarValue::Null).alias(column.name())); + } + } + wrapped_inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new( + expr, input, + )?))); + } + + Ok(wrapped_inputs) + } + /// Constructs new Union instance deriving schema from inputs. /// - /// `loose_types` if true, inputs do not have to have matching types and produced schema will - /// take type from the first input. TODO () this is not necessarily reasonable behavior. + /// If `loose_types` is true, inputs do not need to have matching types and + /// the produced schema will use the type from the first input. + /// TODO (): This is not necessarily reasonable behavior. + /// + /// If `by_name` is `true`, input schemas need not be the same width. That is, + /// the constructed schema follows `UNION BY NAME` semantics. fn derive_schema_from_inputs( inputs: &[Arc], loose_types: bool, + by_name: bool, ) -> Result { if inputs.len() < 2 { return plan_err!("UNION requires at least two inputs"); } + + if by_name { + Self::derive_schema_from_inputs_by_name(inputs, loose_types) + } else { + Self::derive_schema_from_inputs_by_position(inputs, loose_types) + } + } + + fn derive_schema_from_inputs_by_name( + inputs: &[Arc], + loose_types: bool, + ) -> Result { + type FieldData<'a> = (&'a DataType, bool, Vec<&'a HashMap>); + // Prefer `BTreeMap` as it produces items in order by key when iterated over + let mut cols: BTreeMap<&str, FieldData> = BTreeMap::new(); + for input in inputs.iter() { + for field in input.schema().fields() { + match cols.entry(field.name()) { + std::collections::btree_map::Entry::Occupied(mut occupied) => { + let (data_type, is_nullable, metadata) = occupied.get_mut(); + if !loose_types && *data_type != field.data_type() { + return plan_err!( + "Found different types for field {}", + field.name() + ); + } + + metadata.push(field.metadata()); + // If the field is nullable in any one of the inputs, + // then the field in the final schema is also nullable. + *is_nullable |= field.is_nullable(); + } + std::collections::btree_map::Entry::Vacant(vacant) => { + vacant.insert(( + field.data_type(), + field.is_nullable(), + vec![field.metadata()], + )); + } + } + } + } + + let union_fields = cols + .into_iter() + .map(|(name, (data_type, is_nullable, unmerged_metadata))| { + let mut field = Field::new(name, data_type.clone(), is_nullable); + field.set_metadata(intersect_maps(unmerged_metadata)); + + (None, Arc::new(field)) + }) + .collect::, _)>>(); + + let union_schema_metadata = + intersect_maps(inputs.iter().map(|input| input.schema().metadata())); + + // Functional Dependencies are not preserved after UNION operation + let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; + let schema = Arc::new(schema); + + Ok(schema) + } + + fn derive_schema_from_inputs_by_position( + inputs: &[Arc], + loose_types: bool, + ) -> Result { let first_schema = inputs[0].schema(); let fields_count = first_schema.fields().len(); for input in inputs.iter().skip(1) { @@ -2727,7 +2856,7 @@ impl Union { let union_schema_metadata = intersect_maps(inputs.iter().map(|input| input.schema().metadata())); - // Functional Dependencies doesn't preserve after UNION operation + // Functional Dependencies are not preserved after UNION operation let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; let schema = Arc::new(schema); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 9a6103afd4b4..dfc18c74c70a 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -228,14 +228,14 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, output_schema, }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, output_schema, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 42047e8e6caa..04cc26c910cb 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -29,13 +29,18 @@ use sqlparser::ast; use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; -/// Provides the `SQL` query planner meta-data about tables and -/// functions referenced in SQL statements, without a direct dependency on other -/// DataFusion structures +/// Provides the `SQL` query planner meta-data about tables and +/// functions referenced in SQL statements, without a direct dependency on the +/// `datafusion` Catalog structures such as [`TableProvider`] +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html pub trait ContextProvider { - /// Getter for a datasource + /// Returns a table by reference, if it exists fn get_table_source(&self, name: TableReference) -> Result>; + /// Return the type of a file based on its extension (e.g. `.parquet`) + /// + /// This is used to plan `COPY` statements fn get_file_type(&self, _ext: &str) -> Result> { not_impl_err!("Registered file types are not supported") } @@ -49,11 +54,20 @@ pub trait ContextProvider { not_impl_err!("Table Functions are not supported") } - /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) - /// We don't directly implement this in the logical plan's ['SqlToRel`] - /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency - /// of the sql crate (namely, the `CteWorktable`). + /// Provides an intermediate table that is used to store the results of a CTE during execution + /// + /// CTE stands for "Common Table Expression" + /// + /// # Notes + /// We don't directly implement this in [`SqlToRel`] as implementing this function + /// often requires access to a table that contains + /// execution-related types that can't be a direct dependency + /// of the sql crate (for example [`CteWorkTable`]). + /// /// The [`ContextProvider`] provides a way to "hide" this dependency. + /// + /// [`SqlToRel`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html + /// [`CteWorkTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/cte_worktable/struct.CteWorkTable.html fn create_cte_work_table( &self, _name: &str, @@ -62,39 +76,44 @@ pub trait ContextProvider { not_impl_err!("Recursive CTE is not implemented") } - /// Getter for expr planners + /// Return [`ExprPlanner`] extensions for planning expressions fn get_expr_planners(&self) -> &[Arc] { &[] } - /// Getter for the data type planner + /// Return [`TypePlanner`] extensions for planning data types fn get_type_planner(&self) -> Option> { None } - /// Getter for a UDF description + /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; - /// Getter for a UDAF description + + /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; - /// Getter for a UDWF + + /// Return the window function with a given name, if any fn get_window_meta(&self, name: &str) -> Option>; - /// Getter for system/user-defined variable type + + /// Return the system/user-defined variable type, if any + /// + /// A user defined variable is typically accessed via `@var_name` fn get_variable_type(&self, variable_names: &[String]) -> Option; - /// Get configuration options + /// Return overall configuration options fn options(&self) -> &ConfigOptions; - /// Get all user defined scalar function names + /// Return all scalar function names fn udf_names(&self) -> Vec; - /// Get all user defined aggregate function names + /// Return all aggregate function names fn udaf_names(&self) -> Vec; - /// Get all user defined window function names + /// Return all window function names fn udwf_names(&self) -> Vec; } -/// This trait allows users to customize the behavior of the SQL planner +/// Customize planning of SQL AST expressions to [`Expr`]s pub trait ExprPlanner: Debug + Send + Sync { /// Plan the binary operation between two expressions, returns original /// BinaryExpr if not possible @@ -106,9 +125,9 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan the field access expression + /// Plan the field access expression, such as `foo.bar` /// - /// returns original FieldAccessExpr if not possible + /// returns original [`RawFieldAccessExpr`] if not possible fn plan_field_access( &self, expr: RawFieldAccessExpr, @@ -117,7 +136,7 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan the array literal, returns OriginalArray if not possible + /// Plan an array literal, such as `[1, 2, 3]` /// /// Returns origin expression arguments if not possible fn plan_array_literal( @@ -128,13 +147,14 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(exprs)) } - // Plan the POSITION expression, e.g., POSITION( in ) - // returns origin expression arguments if not possible + /// Plan a `POSITION` expression, such as `POSITION( in )` + /// + /// returns origin expression arguments if not possible fn plan_position(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plan the dictionary literal `{ key: value, ...}` + /// Plan a dictionary literal, such as `{ key: value, ...}` /// /// Returns origin expression arguments if not possible fn plan_dictionary_literal( @@ -145,27 +165,26 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan an extract expression, e.g., `EXTRACT(month FROM foo)` + /// Plan an extract expression, such as`EXTRACT(month FROM foo)` /// /// Returns origin expression arguments if not possible fn plan_extract(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plan an substring expression, e.g., `SUBSTRING( [FROM ] [FOR ])` + /// Plan an substring expression, such as `SUBSTRING( [FROM ] [FOR ])` /// /// Returns origin expression arguments if not possible fn plan_substring(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plans a struct `struct(expression1[, ..., expression_n])` - /// literal based on the given input expressions. - /// This function takes a vector of expressions and a boolean flag indicating whether - /// the struct uses the optional name + /// Plans a struct literal, such as `{'field1' : expr1, 'field2' : expr2, ...}` + /// + /// This function takes a vector of expressions and a boolean flag + /// indicating whether the struct uses the optional name /// - /// Returns a `PlannerResult` containing either the planned struct expressions or the original - /// input expressions if planning is not possible. + /// Returns the original input expressions if planning is not possible. fn plan_struct_literal( &self, args: Vec, @@ -174,26 +193,26 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(args)) } - /// Plans an overlay expression eg `overlay(str PLACING substr FROM pos [FOR count])` + /// Plans an overlay expression, such as `overlay(str PLACING substr FROM pos [FOR count])` /// /// Returns origin expression arguments if not possible fn plan_overlay(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plan a make_map expression, e.g., `make_map(key1, value1, key2, value2, ...)` + /// Plans a `make_map` expression, such as `make_map(key1, value1, key2, value2, ...)` /// /// Returns origin expression arguments if not possible fn plan_make_map(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plans compound identifier eg `db.schema.table` for non-empty nested names + /// Plans compound identifier such as `db.schema.table` for non-empty nested names /// - /// Note: + /// # Note: /// Currently compound identifier for outer query schema is not supported. /// - /// Returns planned expression + /// Returns original expression if not possible fn plan_compound_identifier( &self, _field: &Field, @@ -205,7 +224,7 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, e.g., `expr = ANY(array_expr)` + /// Plans `ANY` expression, such as `expr = ANY(array_expr)` /// /// Returns origin binary expression if not possible fn plan_any(&self, expr: RawBinaryExpr) -> Result> { @@ -256,9 +275,9 @@ pub enum PlannerResult { Original(T), } -/// This trait allows users to customize the behavior of the data type planning +/// Customize planning SQL types to DataFusion (Arrow) types. pub trait TypePlanner: Debug + Send + Sync { - /// Plan SQL type to DataFusion data type + /// Plan SQL [`ast::DataType`] to DataFusion [`DataType`] /// /// Returns None if not possible fn plan_type(&self, _sql_type: &ast::DataType) -> Result> { diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index d62484153f53..d6155cfb5dc0 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -71,24 +71,33 @@ impl std::fmt::Display for TableType { } } -/// Access schema information and filter push-down capabilities. +/// Planning time information about a table. /// -/// The TableSource trait is used during logical query planning and -/// optimizations and provides a subset of the functionality of the -/// `TableProvider` trait in the (core) `datafusion` crate. The `TableProvider` -/// trait provides additional capabilities needed for physical query execution -/// (such as the ability to perform a scan). +/// This trait is used during logical query planning and optimizations, and +/// provides a subset of the [`TableProvider`] trait, such as schema information +/// and filter push-down capabilities. The [`TableProvider`] trait provides +/// additional information needed for physical query execution, such as the +/// ability to perform a scan or insert data. +/// +/// # See Also: +/// +/// [`DefaultTableSource`] to go from [`TableProvider`], to `TableSource` +/// +/// # Rationale /// /// The reason for having two separate traits is to avoid having the logical /// plan code be dependent on the DataFusion execution engine. Some projects use /// DataFusion's logical plans and have their own execution engine. +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html +/// [`DefaultTableSource`]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html pub trait TableSource: Sync + Send { fn as_any(&self) -> &dyn Any; /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; - /// Get primary key indices, if one exists. + /// Get primary key indices, if any fn constraints(&self) -> Option<&Constraints> { None } @@ -110,6 +119,8 @@ pub trait TableSource: Sync + Send { } /// Get the Logical plan of this table provider, if available. + /// + /// For example, a view may have a logical plan, but a CSV file does not. fn get_logical_plan(&self) -> Option> { None } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7ac836ef3aeb..7fda92862be9 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,13 +21,14 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, types::{LogicalType, NativeType}, utils::list_ndims, Result, }; +use datafusion_expr_common::signature::ArrayFunctionArgument; use datafusion_expr_common::{ signature::{ ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, @@ -357,88 +358,81 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { - fn array_element_and_optional_index( + fn array_valid_types( function_name: &str, current_types: &[DataType], + arguments: &[ArrayFunctionArgument], + array_coercion: Option<&ListCoercion>, ) -> Result>> { - // make sure there's 2 or 3 arguments - if !(current_types.len() == 2 || current_types.len() == 3) { + if current_types.len() != arguments.len() { return Ok(vec![vec![]]); } - let first_two_types = ¤t_types[0..2]; - let mut valid_types = - array_append_or_prepend_valid_types(function_name, first_two_types, true)?; - - // Early return if there are only 2 arguments - if current_types.len() == 2 { - return Ok(valid_types); - } - - let valid_types_with_index = valid_types - .iter() - .map(|t| { - let mut t = t.clone(); - t.push(DataType::Int64); - t - }) - .collect::>(); - - valid_types.extend(valid_types_with_index); - - Ok(valid_types) - } - - fn array_append_or_prepend_valid_types( - function_name: &str, - current_types: &[DataType], - is_append: bool, - ) -> Result>> { - if current_types.len() != 2 { - return Ok(vec![vec![]]); - } - - let (array_type, elem_type) = if is_append { - (¤t_types[0], ¤t_types[1]) - } else { - (¤t_types[1], ¤t_types[0]) + let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { + if *arg == ArrayFunctionArgument::Array { + Some(idx) + } else { + None + } + }); + let Some(array_idx) = array_idx else { + return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); }; - - // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { + let Some(array_type) = array(¤t_types[array_idx]) else { return Ok(vec![vec![]]); - } + }; // We need to find the coerced base type, mainly for cases like: // `array_append(List(null), i64)` -> `List(i64)` - let array_base_type = datafusion_common::utils::base_type(array_type); - let elem_base_type = datafusion_common::utils::base_type(elem_type); - let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - - let new_base_type = new_base_type.ok_or_else(|| { - internal_datafusion_err!( - "Function '{function_name}' does not support coercion from {array_base_type:?} to {elem_base_type:?}" - ) - })?; - + let mut new_base_type = datafusion_common::utils::base_type(&array_type); + for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { + match argument_type { + ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { + new_base_type = + coerce_array_types(function_name, current_type, &new_base_type)?; + } + ArrayFunctionArgument::Index => {} + } + } let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, + &array_type, &new_base_type, + array_coercion, ); - match new_array_type { + let new_elem_type = match new_array_type { DataType::List(ref field) | DataType::LargeList(ref field) - | DataType::FixedSizeList(ref field, _) => { - let new_elem_type = field.data_type(); - if is_append { - Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]]) - } else { - Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]]) + | DataType::FixedSizeList(ref field, _) => field.data_type(), + _ => return Ok(vec![vec![]]), + }; + + let mut valid_types = Vec::with_capacity(arguments.len()); + for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { + let valid_type = match argument_type { + ArrayFunctionArgument::Element => new_elem_type.clone(), + ArrayFunctionArgument::Index => DataType::Int64, + ArrayFunctionArgument::Array => { + let Some(current_type) = array(current_type) else { + return Ok(vec![vec![]]); + }; + let new_type = + datafusion_common::utils::coerced_type_with_base_type_only( + ¤t_type, + &new_base_type, + array_coercion, + ); + // All array arguments must be coercible to the same type + if new_type != new_array_type { + return Ok(vec![vec![]]); + } + new_type } - } - _ => Ok(vec![vec![]]), + }; + valid_types.push(valid_type); } + + Ok(vec![valid_types]) } fn array(array_type: &DataType) -> Option { @@ -449,6 +443,20 @@ fn get_valid_types( } } + fn coerce_array_types( + function_name: &str, + current_type: &DataType, + base_type: &DataType, + ) -> Result { + let current_base_type = datafusion_common::utils::base_type(current_type); + let new_base_type = comparison_coercion(base_type, ¤t_base_type); + new_base_type.ok_or_else(|| { + internal_datafusion_err!( + "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" + ) + }) + } + fn recursive_array(array_type: &DataType) -> Option { match array_type { DataType::List(_) @@ -693,40 +701,9 @@ fn get_valid_types( vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], - TypeSignature::ArraySignature(ref function_signature) => match function_signature - { - ArrayFunctionSignature::ArrayAndElement => { - array_append_or_prepend_valid_types(function_name, current_types, true)? - } - ArrayFunctionSignature::ElementAndArray => { - array_append_or_prepend_valid_types(function_name, current_types, false)? - } - ArrayFunctionSignature::ArrayAndIndexes(count) => { - if current_types.len() != count.get() + 1 { - return Ok(vec![vec![]]); - } - array(¤t_types[0]).map_or_else( - || vec![vec![]], - |array_type| { - let mut inner = Vec::with_capacity(count.get() + 1); - inner.push(array_type); - for _ in 0..count.get() { - inner.push(DataType::Int64); - } - vec![inner] - }, - ) - } - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { - array_element_and_optional_index(function_name, current_types)? - } - ArrayFunctionSignature::Array => { - if current_types.len() != 1 { - return Ok(vec![vec![]]); - } - - array(¤t_types[0]) - .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + TypeSignature::ArraySignature(ref function_signature) => match function_signature { + ArrayFunctionSignature::Array { arguments, array_coercion, } => { + array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())? } ArrayFunctionSignature::RecursiveArray => { if current_types.len() != 1 { diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b41d97520362..74c3c2775c1c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -899,13 +899,8 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_args(args) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - number_rows: usize, - ) -> Result { - #[allow(deprecated)] - self.inner.invoke_batch(args, number_rows) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.inner.invoke_with_args(args) } fn simplify( @@ -980,6 +975,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -996,6 +992,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -1070,4 +1067,10 @@ The following regular expression functions are supported:"#, label: "Other Functions", description: None, }; + + pub const DOC_SECTION_UNION: DocSection = DocSection { + include: true, + label: "Union Functions", + description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"), + }; } diff --git a/datafusion/expr/src/var_provider.rs b/datafusion/expr/src/var_provider.rs index e00cf7407237..708cd576c3ff 100644 --- a/datafusion/expr/src/var_provider.rs +++ b/datafusion/expr/src/var_provider.rs @@ -38,6 +38,12 @@ pub trait VarProvider: std::fmt::Debug { fn get_type(&self, var_names: &[String]) -> Option; } +/// Returns true if the specified string is a "system" variable such as +/// `@@version` +/// +/// See [`SessionContext::register_variable`] for more details +/// +/// [`SessionContext::register_variable`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.register_variable pub fn is_system_variables(variable_names: &[String]) -> bool { !variable_names.is_empty() && variable_names[0].get(0..2) == Some("@@") } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 237a4f8de6a7..1fad5f73703c 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -32,7 +32,7 @@ use arrow::{ use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, - DataFusionError, Result, ScalarValue, + Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 29dfc68e0576..1b33a7900c00 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -29,7 +29,7 @@ use arrow::datatypes::Field; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index fa04e1aca2c9..cb59042ef468 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -44,7 +44,7 @@ use arrow::{ buffer::BooleanBuffer, }; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, + downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 8aa7a40ce320..53e3e0cc56cd 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -27,9 +27,7 @@ use arrow::{ use std::mem::{size_of, size_of_val}; use std::{fmt::Debug, sync::Arc}; -use datafusion_common::{ - downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index ad30c0b540af..886709779917 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -30,8 +30,8 @@ use datafusion_common::utils::take_function_args; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -50,7 +50,10 @@ impl Cardinality { Self { signature: Signature::one_of( vec![ - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), ], Volatility::Immutable, diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index 14d4b958867f..17fb1a3731df 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -26,6 +26,7 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::ListCoercion; use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, @@ -33,7 +34,8 @@ use datafusion_common::{ utils::{list_ndims, take_function_args}, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -165,7 +167,18 @@ impl Default for ArrayPrepend { impl ArrayPrepend { pub fn new() -> Self { Self { - signature: Signature::element_and_array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Array, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![ String::from("list_prepend"), String::from("array_push_front"), @@ -455,8 +468,8 @@ where }; let res = match list_array.value_type() { - DataType::List(_) => concat_internal::(args)?, - DataType::LargeList(_) => concat_internal::(args)?, + DataType::List(_) => concat_internal::(args)?, + DataType::LargeList(_) => concat_internal::(args)?, data_type => { return generic_append_and_prepend::( list_array, diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 697c868fdea1..6bf4d16db636 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -30,17 +30,19 @@ use arrow::datatypes::{ use datafusion_common::cast::as_int64_array; use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; +use datafusion_common::utils::ListCoercion; use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, utils::take_function_args, DataFusionError, Result, }; -use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, Expr, TypeSignature, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; -use std::num::NonZeroUsize; use std::sync::Arc; use crate::utils::make_scalar_function; @@ -330,16 +332,23 @@ impl ArraySlice { Self { signature: Signature::one_of( vec![ - TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes( - NonZeroUsize::new(2).expect("2 is non-zero"), - ), - ), - TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes( - NonZeroUsize::new(3).expect("3 is non-zero"), - ), - ), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), ], Volatility::Immutable, ), @@ -665,7 +674,15 @@ pub(super) struct ArrayPopFront { impl ArrayPopFront { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_pop_front")], } } @@ -765,7 +782,15 @@ pub(super) struct ArrayPopBack { impl ArrayPopBack { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_pop_back")], } } diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 53f43de4108d..6d84e64cba4d 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -25,9 +25,11 @@ use arrow::datatypes::{DataType, Field}; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::as_int64_array; +use datafusion_common::utils::ListCoercion; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -91,7 +93,19 @@ impl Default for ArrayReplace { impl ArrayReplace { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace")], } } @@ -160,7 +174,20 @@ pub(super) struct ArrayReplaceN { impl ArrayReplaceN { pub fn new() -> Self { Self { - signature: Signature::any(4, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace_n")], } } @@ -228,7 +255,19 @@ pub(super) struct ArrayReplaceAll { impl ArrayReplaceAll { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace_all")], } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index c4186c39317c..c77e58f0c022 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -108,6 +108,16 @@ harness = false name = "encoding" required-features = ["encoding_expressions"] +[[bench]] +harness = false +name = "uuid" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "to_hex" +required-features = ["string_expressions"] + [[bench]] harness = false name = "regx" diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs new file mode 100644 index 000000000000..ce3767cc4839 --- /dev/null +++ b/datafusion/functions/benches/to_hex.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Int32Type, Int64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let hex = string::to_hex(); + let size = 1024; + let i32_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = i32_array.len(); + let i32_args = vec![ColumnarValue::Array(i32_array)]; + c.bench_function(&format!("to_hex i32 array: {}", size), |b| { + b.iter(|| black_box(hex.invoke_batch(&i32_args, batch_len).unwrap())) + }); + let i64_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = i64_array.len(); + let i64_args = vec![ColumnarValue::Array(i64_array)]; + c.bench_function(&format!("to_hex i64 array: {}", size), |b| { + b.iter(|| black_box(hex.invoke_batch(&i64_args, batch_len).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs new file mode 100644 index 000000000000..95cf77de3190 --- /dev/null +++ b/datafusion/functions/benches/uuid.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::string; + +fn criterion_benchmark(c: &mut Criterion) { + let uuid = string::uuid(); + c.bench_function("uuid", |b| { + b.iter(|| black_box(uuid.invoke_batch(&[], 1024))) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 76fb4bbe5b47..425ce78decbe 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -34,6 +34,7 @@ pub mod nvl; pub mod nvl2; pub mod planner; pub mod r#struct; +pub mod union_extract; pub mod version; // create UDFs @@ -48,6 +49,7 @@ make_udf_function!(getfield::GetFieldFunc, get_field); make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); +make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -99,6 +101,11 @@ pub mod expr_fn { pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr { super::get_field().call(vec![arg1, arg2.lit()]) } + + #[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"] + pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr { + super::union_extract().call(vec![arg1, arg2.lit()]) + } } /// Returns all DataFusion functions defined in this package @@ -121,6 +128,7 @@ pub fn functions() -> Vec> { coalesce(), greatest(), least(), + union_extract(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs new file mode 100644 index 000000000000..d54627f73598 --- /dev/null +++ b/datafusion/functions/src/core/union_extract.rs @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use datafusion_common::cast::as_union_array; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, +}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the value of the given field in the union when selected, or NULL otherwise.", + syntax_example = "union_extract(union, field_name)", + sql_example = r#"```sql +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union"), + argument( + name = "field_name", + description = "String expression to operate on. Must be a constant." + ) +)] +#[derive(Debug)] +pub struct UnionExtractFun { + signature: Signature, +} + +impl Default for UnionExtractFun { + fn default() -> Self { + Self::new() + } +} + +impl UnionExtractFun { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionExtractFun { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default implementation + internal_err!("union_extract should return type from exprs") + } + + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + if args.arg_types.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.arg_types.len() + ); + } + + let DataType::Union(fields, _) = &args.arg_types[0] else { + return exec_err!( + "union_extract first argument must be a union, got {} instead", + args.arg_types[0] + ); + }; + + let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { + return exec_err!( + "union_extract second argument must be a non-null string literal, got {} instead", + args.arg_types[1] + ); + }; + + let field = find_field(fields, field_name)?.1; + + Ok(ReturnInfo::new_nullable(field.data_type().clone())) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; + + if args.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.len() + ); + } + + let target_name = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), + _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()), + }; + + match &args[0] { + ColumnarValue::Array(array) => { + let union_array = as_union_array(&array).map_err(|_| { + exec_datafusion_err!( + "union_extract first argument must be a union, got {} instead", + array.data_type() + ) + })?; + + Ok(ColumnarValue::Array( + arrow::compute::kernels::union_extract::union_extract( + union_array, + target_name?, + )?, + )) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { + let target_name = target_name?; + let (target_type_id, target) = find_field(fields, target_name)?; + + let result = match value { + Some((type_id, value)) if target_type_id == *type_id => { + *value.clone() + } + _ => ScalarValue::try_from(target.data_type())?, + }; + + Ok(ColumnarValue::Scalar(result)) + } + other => exec_err!( + "union_extract first argument must be a union, got {} instead", + other.data_type() + ), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> { + fields + .iter() + .find(|field| field.1.name() == name) + .ok_or_else(|| exec_datafusion_err!("field {name} not found on union")) +} + +#[cfg(test)] +mod tests { + + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use super::UnionExtractFun; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn test_scalar_value() -> Result<()> { + let fun = UnionExtractFun::new(); + + let fields = UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::new_utf8("42")); + + Ok(()) + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 48eff4fcd423..d2849c3abba0 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -164,7 +164,8 @@ macro_rules! make_math_unary_udf { use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, }; #[derive(Debug)] @@ -218,12 +219,11 @@ macro_rules! make_math_unary_udf { $EVALUATE_BOUNDS(inputs) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new( args[0] @@ -278,7 +278,8 @@ macro_rules! make_math_binary_udf { use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, }; #[derive(Debug)] @@ -330,12 +331,11 @@ macro_rules! make_math_binary_udf { $OUTPUT_ORDERING(input) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => { let y = args[0].as_primitive::(); diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index ff6a82113262..0c686a59016a 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -32,7 +32,8 @@ use datafusion_common::{ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -168,12 +169,8 @@ impl ScalarUDFImpl for AbsFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let [input] = take_function_args(self.name(), args)?; let input_data_type = input.data_type(); diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 8b4f9317fe5f..4e56212ddbee 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use crate::utils::make_scalar_function; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -87,12 +87,8 @@ impl ScalarUDFImpl for CotFunc { self.doc() } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(cot, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(cot, vec![])(&args.args) } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 18f10863a01b..c2ac21b78f21 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -30,7 +30,8 @@ use datafusion_common::{ arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -76,12 +77,8 @@ impl ScalarUDFImpl for FactorialFunc { Ok(Int64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(factorial, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(factorial, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 14503701f661..911e00308ab7 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -29,7 +29,8 @@ use datafusion_common::{ arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,12 +78,8 @@ impl ScalarUDFImpl for GcdFunc { Ok(Int64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(gcd, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(gcd, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index 8e72ee285518..bc12dfb7898e 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -25,7 +25,8 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,12 +78,8 @@ impl ScalarUDFImpl for IsZeroFunc { Ok(Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(iszero, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(iszero, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index c2c72c89841d..fc6bf9461f28 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -27,7 +27,8 @@ use datafusion_common::{ arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -78,12 +79,8 @@ impl ScalarUDFImpl for LcmFunc { Ok(Int64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(lcm, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(lcm, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 88a624806874..fd135f4c5ec0 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -31,7 +31,8 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - lit, ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature::*, + lit, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + TypeSignature::*, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -117,12 +118,8 @@ impl ScalarUDFImpl for LogFunc { } // Support overloaded log(base, x) and log(x) which defaults to log(10, x) - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); @@ -267,34 +264,44 @@ mod tests { #[test] #[should_panic] fn test_log_invalid_base_type() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 10.0, 100.0, 1000.0, 10000.0, - ]))), // num - ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let _ = LogFunc::new().invoke_batch(&args, 4); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), + ], + number_rows: 4, + return_type: &DataType::Float64, + }; + let _ = LogFunc::new().invoke_with_args(args); } #[test] fn test_log_invalid_value() { - let args = [ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let result = LogFunc::new().invoke_batch(&args, 1); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ], + number_rows: 1, + return_type: &DataType::Float64, + }; + + let result = LogFunc::new().invoke_with_args(args); result.expect_err("expected error"); } #[test] fn test_log_scalar_f32_unary() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -313,12 +320,15 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -337,13 +347,16 @@ mod tests { #[test] fn test_log_scalar_f32() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num - ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -362,13 +375,16 @@ mod tests { #[test] fn test_log_scalar_f64() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num - ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -387,14 +403,17 @@ mod tests { #[test] fn test_log_f64_unary() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 10.0, 100.0, 1000.0, 10000.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -416,14 +435,17 @@ mod tests { #[test] fn test_log_f32_unary() { - let args = [ - ColumnarValue::Array(Arc::new(Float32Array::from(vec![ - 10.0, 100.0, 1000.0, 10000.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -445,15 +467,20 @@ mod tests { #[test] fn test_log_f64() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 2.0, 2.0, 3.0, 5.0, + ]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 8.0, 4.0, 81.0, 625.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -475,15 +502,20 @@ mod tests { #[test] fn test_log_f32() { - let args = [ - ColumnarValue::Array(Arc::new(Float32Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base - ColumnarValue::Array(Arc::new(Float32Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 2.0, 2.0, 3.0, 5.0, + ]))), // base + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 8.0, 4.0, 81.0, 625.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 30c920c29a21..34a5c2a1c16b 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -19,7 +19,7 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, TypeSignature}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; @@ -75,12 +75,8 @@ impl ScalarUDFImpl for IsNanFunc { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new(BooleanArray::from_unary( diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 33823acce751..9effb82896ee 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -26,7 +26,8 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -87,12 +88,8 @@ impl ScalarUDFImpl for NanvlFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(nanvl, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(nanvl, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index 06f7a01544f8..5339a9b14a28 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -22,7 +22,8 @@ use arrow::datatypes::DataType::Float64; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -67,12 +68,8 @@ impl ScalarUDFImpl for PiFunc { Ok(Float64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - if !args.is_empty() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 7fab858d34a0..028ec2fef793 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -29,7 +29,9 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, TypeSignature, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -91,12 +93,8 @@ impl ScalarUDFImpl for PowerFunc { &self.aliases } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => { @@ -195,13 +193,20 @@ mod tests { #[test] fn test_power_f64() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 2.0, 2.0, 3.0, 5.0, + ]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 3.0, 2.0, 4.0, 4.0, + ]))), // exponent + ], + number_rows: 4, + return_type: &DataType::Float64, + }; let result = PowerFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function power"); match result { @@ -222,13 +227,16 @@ mod tests { #[test] fn test_power_i64() { - let args = [ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base + ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent + ], + number_rows: 4, + return_type: &DataType::Int64, + }; let result = PowerFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function power"); match result { diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 197d065ea408..607f9fb09f2a 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -24,7 +24,7 @@ use arrow::datatypes::DataType::Float64; use rand::{thread_rng, Rng}; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -70,16 +70,12 @@ impl ScalarUDFImpl for RandomFunc { Ok(Float64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - num_rows: usize, - ) -> Result { - if !args.is_empty() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } let mut rng = thread_rng(); - let mut values = vec![0.0; num_rows]; + let mut values = vec![0.0; args.number_rows]; // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient rng.fill(&mut values[..]); let array = Float64Array::from(values); diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index b3442c321c99..fc87b7e63a62 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -28,7 +28,8 @@ use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -90,12 +91,8 @@ impl ScalarUDFImpl for RoundFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(round, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(round, vec![])(&args.args) } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index f68834db375e..ba5422afa768 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -25,7 +25,8 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -88,12 +89,8 @@ impl ScalarUDFImpl for SignumFunc { Ok(input[0].sort_properties) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(signum, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(signum, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -140,10 +137,10 @@ pub fn signum(args: &[ArrayRef]) -> Result { mod test { use std::sync::Arc; - use arrow::array::{Float32Array, Float64Array}; - + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::DataType; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::math::signum::SignumFunc; @@ -160,10 +157,13 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); - let batch_size = array.len(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + number_rows: array.len(), + return_type: &DataType::Float32, + }; let result = SignumFunc::new() - .invoke_batch(&[ColumnarValue::Array(array)], batch_size) + .invoke_with_args(args) .expect("failed to initialize function signum"); match result { @@ -201,10 +201,13 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); - let batch_size = array.len(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + number_rows: array.len(), + return_type: &DataType::Float64, + }; let result = SignumFunc::new() - .invoke_batch(&[ColumnarValue::Array(array)], batch_size) + .invoke_with_args(args) .expect("failed to initialize function signum"); match result { diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 8d791370d7f8..2ac291204a0b 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -28,7 +28,8 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -99,12 +100,8 @@ impl ScalarUDFImpl for TruncFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(trunc, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(trunc, vec![])(&args.args) } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 64654ef6ef10..5c7c92cc34ed 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -16,9 +16,10 @@ // under the License. use std::any::Any; +use std::fmt::Write; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; @@ -40,22 +41,30 @@ where { let integer_array = as_primitive_array::(&args[0])?; - let result = integer_array - .iter() - .map(|integer| { - if let Some(value) = integer { - if let Some(value_usize) = value.to_usize() { - Ok(Some(format!("{value_usize:x}"))) - } else if let Some(value_isize) = value.to_isize() { - Ok(Some(format!("{value_isize:x}"))) - } else { - exec_err!("Unsupported data type {integer:?} for function to_hex") - } + let mut result = GenericStringBuilder::::with_capacity( + integer_array.len(), + // * 8 to convert to bits, / 4 bits per hex char + integer_array.len() * (T::Native::get_byte_width() * 8 / 4), + ); + + for integer in integer_array { + if let Some(value) = integer { + if let Some(value_usize) = value.to_usize() { + write!(result, "{value_usize:x}")?; + } else if let Some(value_isize) = value.to_isize() { + write!(result, "{value_isize:x}")?; } else { - Ok(None) + return exec_err!( + "Unsupported data type {integer:?} for function to_hex" + ); } - }) - .collect::>>()?; + result.append_value(""); + } else { + result.append_null(); + } + } + + let result = result.finish(); Ok(Arc::new(result) as ArrayRef) } diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index f6d6a941068d..64065c26b7d4 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -18,9 +18,10 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::GenericStringArray; +use arrow::array::GenericStringBuilder; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Utf8; +use rand::Rng; use uuid::Uuid; use datafusion_common::{internal_err, Result}; @@ -87,9 +88,25 @@ impl ScalarUDFImpl for UuidFunc { if !args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } - let values = std::iter::repeat_with(|| Uuid::new_v4().to_string()).take(num_rows); - let array = GenericStringArray::::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) + + // Generate random u128 values + let mut rng = rand::thread_rng(); + let mut randoms = vec![0u128; num_rows]; + rng.fill(&mut randoms[..]); + + let mut builder = + GenericStringBuilder::::with_capacity(num_rows, num_rows * 36); + + let mut buffer = [0u8; 36]; + for x in &mut randoms { + // From Uuid::new_v4(): Mask out the version and variant bits + *x = *x & 0xFFFFFFFFFFFF4FFFBFFFFFFFFFFFFFFF | 0x40008000000000000000; + let uuid = Uuid::from_u128(*x); + let fmt = uuid::fmt::Hyphenated::from_uuid(uuid); + builder.append_value(fmt.encode_lower(&mut buffer)); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 85fc9b31bcdd..f7dc4befb189 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1047,8 +1047,8 @@ mod test { use datafusion_expr::{ cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, - Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, - Volatility, + Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + SimpleAggregateUDF, Subquery, Volatility, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -1266,11 +1266,7 @@ mod test { Ok(Utf8) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4b9a83fd3e4c..bfa53a5ce852 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1703,13 +1703,5 @@ mod test { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!() - } } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7cb0e7c2f1f7..1dda1c4c0ea1 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1386,8 +1386,9 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension, - LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, - UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, + LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + TableSource, TableType, UserDefinedLogicalNodeCore, Volatility, + WindowFunctionDefinition, }; use crate::optimizer::Optimizer; @@ -3615,11 +3616,7 @@ Projection: a, b Ok(DataType::Int32) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::from(1))) } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 71c925378218..7e4c7f0e10ba 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -262,7 +262,9 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{exec_err, DataFusionError, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + }; use petgraph::visit::Bfs; @@ -309,12 +311,8 @@ pub(crate) mod tests { Ok(input[0].sort_properties) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new({ diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 6bfa02adf6dc..420c080f09c2 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -72,8 +72,8 @@ use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrde use itertools::izip; -/// This rule inspects [`SortExec`]'s in the given physical plan and removes the -/// ones it can prove unnecessary. +/// This rule inspects [`SortExec`]'s in the given physical plan in order to +/// remove unnecessary sorts, and optimize sort performance across the plan. #[derive(Default, Debug)] pub struct EnforceSorting {} @@ -84,33 +84,43 @@ impl EnforceSorting { } } -/// This object is used within the [`EnforceSorting`] rule to track the closest +/// This context object is used within the [`EnforceSorting`] rule to track the closest /// [`SortExec`] descendant(s) for every child of a plan. The data attribute /// stores whether the plan is a `SortExec` or is connected to a `SortExec` /// via its children. pub type PlanWithCorrespondingSort = PlanContext; -fn update_sort_ctx_children( - mut node: PlanWithCorrespondingSort, +/// For a given node, update the [`PlanContext.data`] attribute. +/// +/// If the node is a `SortExec`, or any of the node's children are a `SortExec`, +/// then set the attribute to true. +/// +/// This requires a bottom-up traversal was previously performed, updating the +/// children previously. +fn update_sort_ctx_children_data( + mut node_and_ctx: PlanWithCorrespondingSort, data: bool, ) -> Result { - for child_node in node.children.iter_mut() { - let plan = &child_node.plan; - child_node.data = if is_sort(plan) { - // Initiate connection: + // Update `child.data` for all children. + for child_node in node_and_ctx.children.iter_mut() { + let child_plan = &child_node.plan; + child_node.data = if is_sort(child_plan) { + // child is sort true - } else if is_limit(plan) { + } else if is_limit(child_plan) { // There is no sort linkage for this path, it starts at a limit. false } else { - let is_spm = is_sort_preserving_merge(plan); - let required_orderings = plan.required_input_ordering(); - let flags = plan.maintains_input_order(); + // If a descendent is a sort, and the child maintains the sort. + let is_spm = is_sort_preserving_merge(child_plan); + let required_orderings = child_plan.required_input_ordering(); + let flags = child_plan.maintains_input_order(); // Add parent node to the tree if there is at least one child with // a sort connection: izip!(flags, required_orderings).any(|(maintains, required_ordering)| { let propagates_ordering = (maintains && required_ordering.is_none()) || is_spm; + // `connected_to_sort` only returns the correct answer with bottom-up traversal let connected_to_sort = child_node.children.iter().any(|child| child.data); propagates_ordering && connected_to_sort @@ -118,8 +128,10 @@ fn update_sort_ctx_children( } } - node.data = data; - node.update_plan_from_children() + // set data attribute on current node + node_and_ctx.data = data; + + Ok(node_and_ctx) } /// This object is used within the [`EnforceSorting`] rule to track the closest @@ -151,10 +163,15 @@ fn update_coalesce_ctx_children( }; } -/// The boolean flag `repartition_sorts` defined in the config indicates -/// whether we elect to transform [`CoalescePartitionsExec`] + [`SortExec`] cascades -/// into [`SortExec`] + [`SortPreservingMergeExec`] cascades, which enables us to -/// perform sorting in parallel. +/// Performs optimizations based upon a series of subrules. +/// +/// Refer to each subrule for detailed descriptions of the optimizations performed: +/// [`ensure_sorting`], [`parallelize_sorts`], [`replace_with_order_preserving_variants()`], +/// and [`pushdown_sorts`]. +/// +/// Subrule application is ordering dependent. +/// +/// The subrule `parallelize_sorts` is only applied if `repartition_sorts` is enabled. impl PhysicalOptimizerRule for EnforceSorting { fn optimize( &self, @@ -243,20 +260,66 @@ fn replace_with_partial_sort( Ok(plan) } -/// This function turns plans of the form +/// Transform [`CoalescePartitionsExec`] + [`SortExec`] into +/// [`SortExec`] + [`SortPreservingMergeExec`] as illustrated below: +/// +/// The [`CoalescePartitionsExec`] + [`SortExec`] cascades +/// combine the partitions first, and then sort: +/// ```text +/// ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ +/// ││B│A│D│... ├──┐ +/// └─┴─┴─┘ │ +/// └ ─ ─ ─ ─ ─ ┘ │ ┌────────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ +/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ +/// ├──▶(no ordering guarantees)│──▶││B│E│A│D│C│...───▶ Sort ├───▶││A│B│C│D│E│... │ +/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ +/// ┌ ─ ─ ─ ─ ─ ┐ │ └────────────────────────┘ └ ─ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ┌─┬─┐ │ Partition Partition +/// ││E│C│ ... ├──┘ +/// └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 +/// ``` +/// +/// +/// The [`SortExec`] + [`SortPreservingMergeExec`] cascades +/// sorts each partition first, then merge partitions while retaining the sort: +/// ```text +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ +/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ +/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ │ ┌─────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ +/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ +/// ├──▶ SortPreservingMerge ├───▶││A│B│C│D│E│... │ +/// │ │ │ └─┴─┴─┴─┴─┘ +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ │ └─────────────────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition +/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ +/// └─┴─┘ │ │ └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 Partition 2 +/// ``` +/// +/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs the +/// sort first on a per-partition basis, thereby parallelizing the sort. +/// +/// +/// The outcome is that plans of the form /// ```text /// "SortExec: expr=\[a@0 ASC\]", -/// " CoalescePartitionsExec", -/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// " ...nodes..." +/// " CoalescePartitionsExec", +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` -/// to +/// are transformed into /// ```text /// "SortPreservingMergeExec: \[a@0 ASC\]", -/// " SortExec: expr=\[a@0 ASC\]", -/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// " ...nodes..." +/// " SortExec: expr=\[a@0 ASC\]", +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` -/// by following connections from [`CoalescePartitionsExec`]s to [`SortExec`]s. -/// By performing sorting in parallel, we can increase performance in some scenarios. pub fn parallelize_sorts( mut requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { @@ -318,11 +381,14 @@ pub fn parallelize_sorts( } /// This function enforces sorting requirements and makes optimizations without -/// violating these requirements whenever possible. +/// violating these requirements whenever possible. Requires a bottom-up traversal. pub fn ensure_sorting( mut requirements: PlanWithCorrespondingSort, ) -> Result> { - requirements = update_sort_ctx_children(requirements, false)?; + // Before starting, making requirements' children's ExecutionPlan be same as the requirements' plan's children's ExecutionPlan. + // It should be guaranteed by previous code, but we need to make sure to avoid any potential missing. + requirements = requirements.update_plan_from_children()?; + requirements = update_sort_ctx_children_data(requirements, false)?; // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.children.is_empty() { @@ -353,7 +419,8 @@ pub fn ensure_sorting( child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; } child = add_sort_above(child, required, None); - child = update_sort_ctx_children(child, true)?; + child = child.update_plan_from_children()?; + child = update_sort_ctx_children_data(child, true)?; } } else if physical_ordering.is_none() || !plan.maintains_input_order()[idx] @@ -383,9 +450,10 @@ pub fn ensure_sorting( Arc::new(LocalLimitExec::new(Arc::clone(&child_node.plan), fetch)); } return Ok(Transformed::yes(child_node)); + } else { + requirements = requirements.update_plan_from_children()?; } - - update_sort_ctx_children(requirements, false).map(Transformed::yes) + update_sort_ctx_children_data(requirements, false).map(Transformed::yes) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input @@ -609,8 +677,9 @@ fn remove_corresponding_sort_from_sub_plan( } }) .collect::>()?; + node = node.update_plan_from_children()?; if any_connection || node.children.is_empty() { - node = update_sort_ctx_children(node, false)?; + node = update_sort_ctx_children_data(node, false)?; } // Replace with variants that do not preserve order. @@ -643,7 +712,8 @@ fn remove_corresponding_sort_from_sub_plan( Arc::new(CoalescePartitionsExec::new(plan)) as _ }; node = PlanWithCorrespondingSort::new(plan, false, vec![node]); - node = update_sort_ctx_children(node, false)?; + node = node.update_plan_from_children()?; + node = update_sort_ctx_children_data(node, false)?; } Ok(node) } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index c542f9261a24..2c5c0d4d510e 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -45,7 +45,7 @@ use itertools::izip; pub type OrderPreservationContext = PlanContext; /// Updates order-preservation data for all children of the given node. -pub fn update_children(opc: &mut OrderPreservationContext) { +pub fn update_order_preservation_ctx_children_data(opc: &mut OrderPreservationContext) { for PlanContext { plan, children, @@ -244,7 +244,7 @@ pub fn replace_with_order_preserving_variants( is_spm_better: bool, config: &ConfigOptions, ) -> Result> { - update_children(&mut requirements); + update_order_preservation_ctx_children_data(&mut requirements); if !(is_sort(&requirements.plan) && requirements.children[0].data) { return Ok(Transformed::no(requirements)); } diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 8bf0ffbd3c32..2004aeafb893 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -1590,6 +1590,7 @@ fn build_statistics_expr( )), )) } + Operator::NotLikeMatch => build_not_like_match(expr_builder)?, Operator::LikeMatch => build_like_match(expr_builder).ok_or_else(|| { plan_datafusion_err!( "LIKE expression with wildcard at the beginning is not supported" @@ -1638,6 +1639,19 @@ fn build_statistics_expr( Ok(statistics_expr) } +/// returns the string literal of the scalar value if it is a string +fn unpack_string(s: &ScalarValue) -> Option<&str> { + s.try_as_str().flatten() +} + +fn extract_string_literal(expr: &Arc) -> Option<&str> { + if let Some(lit) = expr.as_any().downcast_ref::() { + let s = unpack_string(lit.value())?; + return Some(s); + } + None +} + /// Convert `column LIKE literal` where P is a constant prefix of the literal /// to a range check on the column: `P <= column && column < P'`, where P' is the /// lowest string after all P* strings. @@ -1650,19 +1664,6 @@ fn build_like_match( // column LIKE '%foo%' => min <= '' && '' <= max => true // column LIKE 'foo' => min <= 'foo' && 'foo' <= max - /// returns the string literal of the scalar value if it is a string - fn unpack_string(s: &ScalarValue) -> Option<&str> { - s.try_as_str().flatten() - } - - fn extract_string_literal(expr: &Arc) -> Option<&str> { - if let Some(lit) = expr.as_any().downcast_ref::() { - let s = unpack_string(lit.value())?; - return Some(s); - } - None - } - // TODO Handle ILIKE perhaps by making the min lowercase and max uppercase // this may involve building the physical expressions that call lower() and upper() let min_column_expr = expr_builder.min_column_expr().ok()?; @@ -1710,6 +1711,80 @@ fn build_like_match( Some(combined) } +// For predicate `col NOT LIKE 'const_prefix%'`, we rewrite it as `(col_min NOT LIKE 'const_prefix%' OR col_max NOT LIKE 'const_prefix%')`. +// +// The intuition is that if both `col_min` and `col_max` begin with `const_prefix` that means +// **all** data in this row group begins with `const_prefix` as well (and therefore the predicate +// looking for rows that don't begin with `const_prefix` can never be true) +fn build_not_like_match( + expr_builder: &mut PruningExpressionBuilder<'_>, +) -> Result> { + // col NOT LIKE 'const_prefix%' -> !(col_min LIKE 'const_prefix%' && col_max LIKE 'const_prefix%') -> (col_min NOT LIKE 'const_prefix%' || col_max NOT LIKE 'const_prefix%') + + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + + let scalar_expr = expr_builder.scalar_expr(); + + let pattern = extract_string_literal(scalar_expr).ok_or_else(|| { + plan_datafusion_err!("cannot extract literal from NOT LIKE expression") + })?; + + let (const_prefix, remaining) = split_constant_prefix(pattern); + if const_prefix.is_empty() || remaining != "%" { + // we can not handle `%` at the beginning or in the middle of the pattern + // Example: For pattern "foo%bar", the row group might include values like + // ["foobar", "food", "foodbar"], making it unsafe to prune. + // Even if the min/max values in the group (e.g., "foobar" and "foodbar") + // match the pattern, intermediate values like "food" may not + // match the full pattern "foo%bar", making pruning unsafe. + // (truncate foo%bar to foo% have same problem) + + // we can not handle pattern containing `_` + // Example: For pattern "foo_", row groups might contain ["fooa", "fooaa", "foob"], + // which means not every row is guaranteed to match the pattern. + return Err(plan_datafusion_err!( + "NOT LIKE expressions only support constant_prefix+wildcard`%`" + )); + } + + let min_col_not_like_epxr = Arc::new(phys_expr::LikeExpr::new( + true, + false, + Arc::clone(&min_column_expr), + Arc::clone(scalar_expr), + )); + + let max_col_not_like_expr = Arc::new(phys_expr::LikeExpr::new( + true, + false, + Arc::clone(&max_column_expr), + Arc::clone(scalar_expr), + )); + + Ok(Arc::new(phys_expr::BinaryExpr::new( + min_col_not_like_epxr, + Operator::Or, + max_col_not_like_expr, + ))) +} + +/// Returns unescaped constant prefix of a LIKE pattern (possibly empty) and the remaining pattern (possibly empty) +fn split_constant_prefix(pattern: &str) -> (&str, &str) { + let char_indices = pattern.char_indices().collect::>(); + for i in 0..char_indices.len() { + let (idx, char) = char_indices[i]; + if char == '%' || char == '_' { + if i != 0 && char_indices[i - 1].1 == '\\' { + // ecsaped by `\` + continue; + } + return (&pattern[..idx], &pattern[idx..]); + } + } + (pattern, "") +} + /// Increment a UTF8 string by one, returning `None` if it can't be incremented. /// This makes it so that the returned string will always compare greater than the input string /// or any other string with the same prefix. @@ -4061,6 +4136,132 @@ mod tests { prune_with_expr(expr, &schema, &statistics, expected_ret); } + #[test] + fn prune_utf8_not_like_one() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").not_like(lit("A\u{10ffff}_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // orignal (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + + #[test] + fn prune_utf8_not_like_many() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").not_like(lit("A\u{10ffff}%")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_like(lit("A\u{10ffff}%\u{10ffff}")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_like(lit("A\u{10ffff}%\u{10ffff}_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_like(lit("A\\%%")); + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![Some("A%a"), Some("A")], + vec![Some("A%c"), Some("A")], + ), + ); + let expected_ret = &[false, true]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + #[test] fn test_rewrite_expr_to_prunable() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index b84243b1b56b..f0afdaa2de3d 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -78,3 +78,7 @@ tokio = { workspace = true, features = [ [[bench]] harness = false name = "spm" + +[[bench]] +harness = false +name = "partial_ordering" diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs new file mode 100644 index 000000000000..422826abcc8b --- /dev/null +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow_schema::{DataType, Field, Schema, SortOptions}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::aggregates::order::GroupOrderingPartial; + +const BATCH_SIZE: usize = 8192; + +fn create_test_arrays(num_columns: usize) -> Vec { + (0..num_columns) + .map(|i| { + Arc::new(Int32Array::from_iter_values( + (0..BATCH_SIZE as i32).map(|x| x * (i + 1) as i32), + )) as ArrayRef + }) + .collect() +} +fn bench_new_groups(c: &mut Criterion) { + let mut group = c.benchmark_group("group_ordering_partial"); + + // Test with 1, 2, 4, and 8 order indices + for num_columns in [1, 2, 4, 8] { + let fields: Vec = (0..num_columns) + .map(|i| Field::new(format!("col{}", i), DataType::Int32, false)) + .collect(); + let schema = Schema::new(fields); + + let order_indices: Vec = (0..num_columns).collect(); + let ordering = LexOrdering::new( + (0..num_columns) + .map(|i| { + PhysicalSortExpr::new( + col(&format!("col{}", i), &schema).unwrap(), + SortOptions::default(), + ) + }) + .collect(), + ); + + group.bench_function(format!("order_indices_{}", num_columns), |b| { + let batch_group_values = create_test_arrays(num_columns); + let group_indices: Vec = (0..BATCH_SIZE).collect(); + + b.iter(|| { + let mut ordering = + GroupOrderingPartial::try_new(&schema, &order_indices, &ordering) + .unwrap(); + ordering + .new_groups(&batch_group_values, &group_indices, BATCH_SIZE) + .unwrap(); + }); + }); + } + group.finish(); +} + +criterion_group!(benches, bench_new_groups); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index f19aefbcf47e..aff69277a4ce 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -16,12 +16,15 @@ // under the License. use arrow::array::ArrayRef; +use arrow::compute::SortOptions; use arrow::datatypes::Schema; -use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; -use datafusion_common::Result; +use arrow_ord::partition::partition; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::cmp::Ordering; use std::mem::size_of; use std::sync::Arc; @@ -69,13 +72,9 @@ pub struct GroupOrderingPartial { /// For example if grouping by `id, state` and ordered by `state` /// this would be `[1]`. order_indices: Vec, - - /// Converter for the sort key (used on the group columns - /// specified in `order_indexes`) - row_converter: RowConverter, } -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq)] enum State { /// The ordering was temporarily taken. `Self::Taken` is left /// when state must be temporarily taken to satisfy the borrow @@ -93,7 +92,7 @@ enum State { /// Smallest group index with the sort_key current_sort: usize, /// The sort key of group_index `current_sort` - sort_key: OwnedRow, + sort_key: Vec, /// index of the current group for which values are being /// generated current: usize, @@ -103,47 +102,47 @@ enum State { Complete, } +impl State { + fn size(&self) -> usize { + match self { + State::Taken => 0, + State::Start => 0, + State::InProgress { sort_key, .. } => sort_key + .iter() + .map(|scalar_value| scalar_value.size()) + .sum(), + State::Complete => 0, + } + } +} + impl GroupOrderingPartial { + /// TODO: Remove unnecessary `input_schema` parameter. pub fn try_new( - input_schema: &Schema, + _input_schema: &Schema, order_indices: &[usize], ordering: &LexOrdering, ) -> Result { assert!(!order_indices.is_empty()); assert!(order_indices.len() <= ordering.len()); - // get only the section of ordering, that consist of group by expressions. - let fields = ordering[0..order_indices.len()] - .iter() - .map(|sort_expr| { - Ok(SortField::new_with_options( - sort_expr.expr.data_type(input_schema)?, - sort_expr.options, - )) - }) - .collect::>>()?; - Ok(Self { state: State::Start, order_indices: order_indices.to_vec(), - row_converter: RowConverter::new(fields)?, }) } - /// Creates sort keys from the group values + /// Select sort keys from the group values /// /// For example, if group_values had `A, B, C` but the input was /// only sorted on `B` and `C` this should return rows for (`B`, /// `C`) - fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Result { + fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec { // Take only the columns that are in the sort key - let sort_values: Vec<_> = self - .order_indices + self.order_indices .iter() .map(|&idx| Arc::clone(&group_values[idx])) - .collect(); - - Ok(self.row_converter.convert_columns(&sort_values)?) + .collect() } /// How many groups be emitted, or None if no data can be emitted @@ -194,6 +193,23 @@ impl GroupOrderingPartial { }; } + fn updated_sort_key( + current_sort: usize, + sort_key: Option>, + range_current_sort: usize, + range_sort_key: Vec, + ) -> Result<(usize, Vec)> { + if let Some(sort_key) = sort_key { + let sort_options = vec![SortOptions::new(false, false); sort_key.len()]; + let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?; + if ordering == Ordering::Equal { + return Ok((current_sort, sort_key)); + } + } + + Ok((range_current_sort, range_sort_key)) + } + /// Called when new groups are added in a batch. See documentation /// on [`super::GroupOrdering::new_groups`] pub fn new_groups( @@ -207,37 +223,46 @@ impl GroupOrderingPartial { let max_group_index = total_num_groups - 1; - // compute the sort key values for each group - let sort_keys = self.compute_sort_keys(batch_group_values)?; - - let old_state = std::mem::take(&mut self.state); - let (mut current_sort, mut sort_key) = match &old_state { + let (current_sort, sort_key) = match std::mem::take(&mut self.state) { State::Taken => unreachable!("State previously taken"), - State::Start => (0, sort_keys.row(0)), + State::Start => (0, None), State::InProgress { current_sort, sort_key, .. - } => (*current_sort, sort_key.row()), + } => (current_sort, Some(sort_key)), State::Complete => { panic!("Saw new group after the end of input"); } }; - // Find latest sort key - let iter = group_indices.iter().zip(sort_keys.iter()); - for (&group_index, group_sort_key) in iter { - // Does this group have seen a new sort_key? - if sort_key != group_sort_key { - current_sort = group_index; - sort_key = group_sort_key; - } - } + // Select the sort key columns + let sort_keys = self.compute_sort_keys(batch_group_values); + + // Check if the sort keys indicate a boundary inside the batch + let ranges = partition(&sort_keys)?.ranges(); + let last_range = ranges.last().unwrap(); + + let range_current_sort = group_indices[last_range.start]; + let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?; + + let (current_sort, sort_key) = if last_range.start == 0 { + // There was no boundary in the batch. Compare with the previous sort_key (if present) + // to check if there was a boundary between the current batch and the previous one. + Self::updated_sort_key( + current_sort, + sort_key, + range_current_sort, + range_sort_key, + )? + } else { + (range_current_sort, range_sort_key) + }; self.state = State::InProgress { current_sort, - sort_key: sort_key.owned(), current: max_group_index, + sort_key, }; Ok(()) @@ -245,8 +270,104 @@ impl GroupOrderingPartial { /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { - size_of::() - + self.order_indices.allocated_size() - + self.row_converter.size() + size_of::() + self.order_indices.allocated_size() + self.state.size() + } +} + +#[cfg(test)] +mod tests { + use arrow::array::Int32Array; + use arrow_schema::{DataType, Field}; + use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; + + use super::*; + + #[test] + fn test_group_ordering_partial() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + // Ordered on column a + let order_indices = vec![0]; + + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &schema)?, + SortOptions::default(), + )]); + + let mut group_ordering = + GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?; + + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![2, 1, 3])), + ]; + + let group_indices = vec![0, 1, 2]; + let total_num_groups = 3; + + group_ordering.new_groups( + &batch_group_values, + &group_indices, + total_num_groups, + )?; + + assert_eq!( + group_ordering.state, + State::InProgress { + current_sort: 2, + sort_key: vec![ScalarValue::Int32(Some(3))], + current: 2 + } + ); + + // push without a boundary + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(vec![3, 3, 3])), + Arc::new(Int32Array::from(vec![2, 1, 7])), + ]; + let group_indices = vec![3, 4, 5]; + let total_num_groups = 6; + + group_ordering.new_groups( + &batch_group_values, + &group_indices, + total_num_groups, + )?; + + assert_eq!( + group_ordering.state, + State::InProgress { + current_sort: 2, + sort_key: vec![ScalarValue::Int32(Some(3))], + current: 5 + } + ); + + // push with only a boundary to previous batch + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(vec![4, 4, 4])), + Arc::new(Int32Array::from(vec![1, 1, 1])), + ]; + let group_indices = vec![6, 7, 8]; + let total_num_groups = 9; + + group_ordering.new_groups( + &batch_group_values, + &group_indices, + total_num_groups, + )?; + assert_eq!( + group_ordering.state, + State::InProgress { + current_sort: 6, + sort_key: vec![ScalarValue::Int32(Some(4))], + current: 8 + } + ); + + Ok(()) } } diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 96bd0de3d37c..69b0a165315e 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -39,9 +39,17 @@ impl DynTreeNode for dyn ExecutionPlan { } } -/// A node object beneficial for writing optimizer rules, encapsulating an [`ExecutionPlan`] node with a payload. -/// Since there are two ways to access child plans—directly from the plan and through child nodes—it's recommended +/// A node context object beneficial for writing optimizer rules. +/// This context encapsulating an [`ExecutionPlan`] node with a payload. +/// +/// Since each wrapped node has it's children within both the [`PlanContext.plan.children()`], +/// as well as separately within the [`PlanContext.children`] (which are child nodes wrapped in the context), +/// it's important to keep these child plans in sync when performing mutations. +/// +/// Since there are two ways to access child plans directly -— it's recommended /// to perform mutable operations via [`Self::update_plan_from_children`]. +/// After mutating the `PlanContext.children`, or after creating the `PlanContext`, +/// call `update_plan_from_children` to sync. #[derive(Debug)] pub struct PlanContext { /// The execution plan associated with this context. @@ -61,6 +69,8 @@ impl PlanContext { } } + /// Update the [`PlanContext.plan.children()`] from the [`PlanContext.children`], + /// if the `PlanContext.children` have been changed. pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| Arc::clone(&c.plan)).collect(); self.plan = with_new_children_if_necessary(self.plan, children_plans)?; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3bc884257dab..1cdfe6d216e3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -278,7 +278,7 @@ message DmlNode{ Type dml_type = 1; LogicalPlanNode input = 2; TableReference table_name = 3; - datafusion_common.DfSchema schema = 4; + LogicalPlanNode target = 5; } message UnnestNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index add72e4f777e..6e09e9a797ea 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4764,7 +4764,7 @@ impl serde::Serialize for DmlNode { if self.table_name.is_some() { len += 1; } - if self.schema.is_some() { + if self.target.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.DmlNode", len)?; @@ -4779,8 +4779,8 @@ impl serde::Serialize for DmlNode { if let Some(v) = self.table_name.as_ref() { struct_ser.serialize_field("tableName", v)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if let Some(v) = self.target.as_ref() { + struct_ser.serialize_field("target", v)?; } struct_ser.end() } @@ -4797,7 +4797,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { "input", "table_name", "tableName", - "schema", + "target", ]; #[allow(clippy::enum_variant_names)] @@ -4805,7 +4805,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { DmlType, Input, TableName, - Schema, + Target, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4830,7 +4830,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { "dmlType" | "dml_type" => Ok(GeneratedField::DmlType), "input" => Ok(GeneratedField::Input), "tableName" | "table_name" => Ok(GeneratedField::TableName), - "schema" => Ok(GeneratedField::Schema), + "target" => Ok(GeneratedField::Target), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4853,7 +4853,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { let mut dml_type__ = None; let mut input__ = None; let mut table_name__ = None; - let mut schema__ = None; + let mut target__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::DmlType => { @@ -4874,11 +4874,11 @@ impl<'de> serde::Deserialize<'de> for DmlNode { } table_name__ = map_.next_value()?; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Target => { + if target__.is_some() { + return Err(serde::de::Error::duplicate_field("target")); } - schema__ = map_.next_value()?; + target__ = map_.next_value()?; } } } @@ -4886,7 +4886,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { dml_type: dml_type__.unwrap_or_default(), input: input__, table_name: table_name__, - schema: schema__, + target: target__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index df32c1a70d61..f5ec45da48f2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -409,8 +409,8 @@ pub struct DmlNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "3")] pub table_name: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub schema: ::core::option::Option, + #[prost(message, optional, boxed, tag = "5")] + pub target: ::core::option::Option<::prost::alloc::boxed::Box>, } /// Nested message and enum types in `DmlNode`. pub mod dml_node { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 53b683bac66a..641dfe7b5fb8 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -55,8 +55,8 @@ use datafusion::{ }; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, - Result, TableReference, + context, internal_datafusion_err, internal_err, not_impl_err, plan_err, + DataFusionError, Result, TableReference, ToDFSchema, }; use datafusion_expr::{ dml, @@ -71,7 +71,7 @@ use datafusion_expr::{ }; use datafusion_expr::{ AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - Unnest, + TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -236,6 +236,45 @@ fn from_table_reference( Ok(table_ref.clone().try_into()?) } +/// Converts [LogicalPlan::TableScan] to [TableSource] +/// method to be used to deserialize nodes +/// serialized by [from_table_source] +fn to_table_source( + node: &Option>, + ctx: &SessionContext, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result> { + if let Some(node) = node { + match node.try_into_logical_plan(ctx, extension_codec)? { + LogicalPlan::TableScan(TableScan { source, .. }) => Ok(source), + _ => plan_err!("expected TableScan node"), + } + } else { + plan_err!("LogicalPlanNode should be provided") + } +} + +/// converts [TableSource] to [LogicalPlan::TableScan] +/// using [LogicalPlan::TableScan] was the best approach to +/// serialize [TableSource] to [LogicalPlan::TableScan] +fn from_table_source( + table_name: TableReference, + target: Arc, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result { + let projected_schema = target.schema().to_dfschema_ref()?; + let r = LogicalPlan::TableScan(TableScan { + table_name, + source: target, + projection: None, + projected_schema, + filters: vec![], + fetch: None, + }); + + LogicalPlanNode::try_from_logical_plan(&r, extension_codec) +} + impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -454,7 +493,7 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } - CustomScan(scan) => { + LogicalPlanType::CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); let mut projection = None; @@ -942,7 +981,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Dml(dml_node) => Ok(LogicalPlan::Dml( datafusion::logical_expr::DmlStatement::new( from_table_reference(dml_node.table_name.as_ref(), "DML ")?, - Arc::new(convert_required!(dml_node.schema)?), + to_table_source(&dml_node.target, ctx, extension_codec)?, dml_node.dml_type().into(), Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), @@ -1658,7 +1697,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, .. @@ -1669,7 +1708,11 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Dml(Box::new(DmlNode { input: Some(Box::new(input)), - schema: Some(table_schema.try_into()?), + target: Some(Box::new(from_table_source( + table_name.clone(), + Arc::clone(target), + extension_codec, + )?)), table_name: Some(table_name.clone().into()), dml_type: dml_type.into(), }))), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 84b952965958..a575a42d0b6c 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -243,7 +243,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { )? .with_newlines_in_values(scan.newlines_in_values) .with_file_compression_type(FileCompressionType::UNCOMPRESSED); - Ok(conf.new_exec()) + Ok(conf.build()) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetScan(scan) => { @@ -280,7 +280,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { extension_codec, Arc::new(source), )?; - Ok(base_config.new_exec()) + Ok(base_config.build()) } #[cfg(not(feature = "parquet"))] panic!("Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled") @@ -292,7 +292,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { extension_codec, Arc::new(AvroSource::new()), )?; - Ok(conf.new_exec()) + Ok(conf.build()) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => { let input: Arc = into_physical_plan( diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index f36b7178313a..25efa2690268 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -22,8 +22,8 @@ use std::fmt::Debug; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl, - Signature, Volatility, WindowUDFImpl, + Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarUDFImpl, Signature, + Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -69,13 +69,6 @@ impl ScalarUDFImpl for MyRegexUdf { plan_err!("regex_udf only accepts Utf8 arguments") } } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> datafusion_common::Result { - unimplemented!() - } fn aliases(&self) -> &[String] { &self.aliases } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7418184fcac1..a8ecb2d0749e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -741,7 +741,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { source, }; - roundtrip_test(scan_config.new_exec()) + roundtrip_test(scan_config.build()) } #[tokio::test] @@ -772,7 +772,7 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { source, }; - roundtrip_test(scan_config.new_exec()) + roundtrip_test(scan_config.build()) } #[test] @@ -918,7 +918,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } } - let exec_plan = scan_config.new_exec(); + let exec_plan = scan_config.build(); let ctx = SessionContext::new(); roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index a3ac831a1f78..5fb6ef913d8c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -224,7 +224,24 @@ impl PlannerContext { } } -/// SQL query planner +/// SQL query planner and binder +/// +/// This struct is used to convert a SQL AST into a [`LogicalPlan`]. +/// +/// You can control the behavior of the planner by providing [`ParserOptions`]. +/// +/// It performs the following tasks: +/// +/// 1. Name and type resolution (called "binding" in other systems). This +/// phase looks up table and column names using the [`ContextProvider`]. +/// 2. Mechanical translation of the AST into a [`LogicalPlan`]. +/// +/// It does not perform type coercion, or perform optimization, which are done +/// by subsequent passes. +/// +/// Key interfaces are: +/// * [`Self::sql_statement_to_plan`]: Convert a statement (e.g. `SELECT ...`) into a [`LogicalPlan`] +/// * [`Self::sql_to_expr`]: Convert an expression (e.g. `1 + 2`) into an [`Expr`] pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 2579f2397228..a55b3b039087 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -54,15 +54,18 @@ impl SqlToRel<'_, S> { return Err(err); } }; - self.validate_set_expr_num_of_columns( - op, - left_span, - right_span, - &left_plan, - &right_plan, - set_expr_span, - )?; - + if !(set_quantifier == SetQuantifier::ByName + || set_quantifier == SetQuantifier::AllByName) + { + self.validate_set_expr_num_of_columns( + op, + left_span, + right_span, + &left_plan, + &right_plan, + set_expr_span, + )?; + } self.set_operation_to_plan(op, left_plan, right_plan, set_quantifier) } SetExpr::Query(q) => self.query_to_plan(*q, planner_context), @@ -72,17 +75,11 @@ impl SqlToRel<'_, S> { pub(super) fn is_union_all(set_quantifier: SetQuantifier) -> Result { match set_quantifier { - SetQuantifier::All => Ok(true), - SetQuantifier::Distinct | SetQuantifier::None => Ok(false), - SetQuantifier::ByName => { - not_impl_err!("UNION BY NAME not implemented") - } - SetQuantifier::AllByName => { - not_impl_err!("UNION ALL BY NAME not implemented") - } - SetQuantifier::DistinctByName => { - not_impl_err!("UNION DISTINCT BY NAME not implemented") - } + SetQuantifier::All | SetQuantifier::AllByName => Ok(true), + SetQuantifier::Distinct + | SetQuantifier::ByName + | SetQuantifier::DistinctByName + | SetQuantifier::None => Ok(false), } } @@ -127,28 +124,42 @@ impl SqlToRel<'_, S> { right_plan: LogicalPlan, set_quantifier: SetQuantifier, ) -> Result { - let all = Self::is_union_all(set_quantifier)?; - match (op, all) { - (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) - .union(right_plan)? - .build(), - (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan) - .union_distinct(right_plan)? + match (op, set_quantifier) { + (SetOperator::Union, SetQuantifier::All) => { + LogicalPlanBuilder::from(left_plan) + .union(right_plan)? + .build() + } + (SetOperator::Union, SetQuantifier::AllByName) => { + LogicalPlanBuilder::from(left_plan) + .union_by_name(right_plan)? + .build() + } + (SetOperator::Union, SetQuantifier::Distinct | SetQuantifier::None) => { + LogicalPlanBuilder::from(left_plan) + .union_distinct(right_plan)? + .build() + } + ( + SetOperator::Union, + SetQuantifier::ByName | SetQuantifier::DistinctByName, + ) => LogicalPlanBuilder::from(left_plan) + .union_by_name_distinct(right_plan)? .build(), - (SetOperator::Intersect, true) => { + (SetOperator::Intersect, SetQuantifier::All) => { LogicalPlanBuilder::intersect(left_plan, right_plan, true) } - (SetOperator::Intersect, false) => { + (SetOperator::Intersect, SetQuantifier::Distinct | SetQuantifier::None) => { LogicalPlanBuilder::intersect(left_plan, right_plan, false) } - (SetOperator::Except, true) => { + (SetOperator::Except, SetQuantifier::All) => { LogicalPlanBuilder::except(left_plan, right_plan, true) } - (SetOperator::Except, false) => { + (SetOperator::Except, SetQuantifier::Distinct | SetQuantifier::None) => { LogicalPlanBuilder::except(left_plan, right_plan, false) } - (SetOperator::Minus, _) => { - not_impl_err!("MINUS Set Operator not implemented") + (op, quantifier) => { + not_impl_err!("{op} {quantifier} not implemented") } } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 209e9cc787d1..74055d979145 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1709,10 +1709,10 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = (*table_source.schema()).clone(); - let schema = DFSchema::try_from(schema)?; + let schema = table_source.schema().to_dfschema_ref()?; let scan = - LogicalPlanBuilder::scan(table_ref.clone(), table_source, None)?.build()?; + LogicalPlanBuilder::scan(table_ref.clone(), Arc::clone(&table_source), None)? + .build()?; let mut planner_context = PlannerContext::new(); let source = match predicate_expr { @@ -1720,7 +1720,7 @@ impl SqlToRel<'_, S> { Some(predicate_expr) => { let filter_expr = self.sql_to_expr(predicate_expr, &schema, &mut planner_context)?; - let schema = Arc::new(schema.clone()); + let schema = Arc::new(schema); let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( @@ -1734,7 +1734,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_ref, - schema.into(), + table_source, WriteOp::Delete, Arc::new(source), )); @@ -1847,7 +1847,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - table_schema, + table_source, WriteOp::Update, Arc::new(source), )); @@ -1976,7 +1976,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - Arc::new(table_schema), + Arc::clone(&table_source), WriteOp::Insert(insert_op), Arc::new(source), )); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 72618c2b6ab4..7c1bcbd5ac41 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1655,8 +1655,8 @@ mod tests { use datafusion_expr::{ case, cast, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, - table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, WindowFrame, WindowFunctionDefinition, + table_scan, try_cast, when, wildcard, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions::expr_fn::{get_field, named_struct}; @@ -1705,14 +1705,6 @@ mod tests { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Int32) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("DummyUDF::invoke") - } } // See sql::tests for E2E tests. diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b92bc0fd60e7..1df18302687e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -30,8 +30,8 @@ use datafusion_expr::{ col, logical_plan::{LogicalPlan, Prepare}, test::function_stub::sum_udaf, - ColumnarValue, CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature, - Statement, Volatility, + CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature, Statement, + Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ @@ -2113,6 +2113,33 @@ fn union() { quick_test(sql, expected); } +#[test] +fn union_by_name_different_columns() { + let sql = "SELECT order_id from orders UNION BY NAME SELECT order_id, 1 FROM orders"; + let expected = "\ + Distinct:\ + \n Union\ + \n Projection: NULL AS Int64(1), order_id\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id, Int64(1)\ + \n TableScan: orders"; + quick_test(sql, expected); +} + +#[test] +fn union_by_name_same_column_names() { + let sql = "SELECT order_id from orders UNION SELECT order_id FROM orders"; + let expected = "\ + Distinct:\ + \n Union\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id\ + \n TableScan: orders"; + quick_test(sql, expected); +} + #[test] fn union_all() { let sql = "SELECT order_id from orders UNION ALL SELECT order_id FROM orders"; @@ -2124,6 +2151,31 @@ fn union_all() { quick_test(sql, expected); } +#[test] +fn union_all_by_name_different_columns() { + let sql = + "SELECT order_id from orders UNION ALL BY NAME SELECT order_id, 1 FROM orders"; + let expected = "\ + Union\ + \n Projection: NULL AS Int64(1), order_id\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id, Int64(1)\ + \n TableScan: orders"; + quick_test(sql, expected); +} + +#[test] +fn union_all_by_name_same_column_names() { + let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id FROM orders"; + let expected = "Union\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id\ + \n TableScan: orders"; + quick_test(sql, expected); +} + #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; @@ -2646,14 +2698,6 @@ impl ScalarUDFImpl for DummyUDF { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(self.return_type.clone()) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("DummyUDF::invoke") - } } /// Create logical plan, write with formatter, compare to expected output diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 31b1e1e8a194..ce819f186454 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -22,10 +22,11 @@ use std::path::Path; use std::sync::Arc; use arrow::array::{ - ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, + Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, + LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, @@ -113,6 +114,10 @@ impl TestContext { info!("Registering metadata table tables"); register_metadata_tables(test_ctx.session_ctx()).await; } + "union_function.slt" => { + info!("Registering table with union column"); + register_union_table(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } @@ -402,3 +407,24 @@ fn create_example_udf() -> ScalarUDF { adder, ) } + +fn register_union_table(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), + ScalarBuffer::from(vec![3, 3]), + None, + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .unwrap(); + + let schema = Schema::new(vec![Field::new( + "union_column", + union.data_type().clone(), + false, + )]); + + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap(); + + ctx.register_batch("union_table", batch).unwrap(); +} diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index a2e51cffacf7..3a4d641abf68 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -220,6 +220,7 @@ set datafusion.execution.batch_size = 4; # Inserting into nullable table with batch_size specified above # to prevent creation on single in-memory batch + statement ok CREATE TABLE aggregate_test_100_null ( c2 TINYINT NOT NULL, @@ -506,7 +507,7 @@ SELECT avg(c11) FILTER (WHERE c2 != 5) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; ---- -a 2.5 0.449071887467 +a 2.5 0.449071887467 b 2.642857142857 0.445486298629 c 2.421052631579 0.422882117723 d 2.125 0.518706191331 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 8f23bfe5ea65..6b5b246aee51 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2656,6 +2656,28 @@ select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array_prepend scalar function #7 (element is fixed size list) +query ??? +select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'), make_array(arrow_cast(make_array(2), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(3), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(4), 'FixedSizeList(1, Int64)'))), + array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'), make_array(arrow_cast([2.0], 'FixedSizeList(1, Float64)'), arrow_cast([3.0], 'FixedSizeList(1, Float64)'), arrow_cast([4.0], 'FixedSizeList(1, Float64)'))), + array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'), make_array(arrow_cast(['e'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'), arrow_cast(['o'], 'FixedSizeList(1, Utf8)'))); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +query ??? +select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(FixedSizeList(1, Int64))')), + array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(FixedSizeList(1, Float64))')), + array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(FixedSizeList(1, Utf8))')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +query ??? +select array_prepend(arrow_cast([1], 'FixedSizeList(1, Int64)'), arrow_cast([[1], [2], [3]], 'FixedSizeList(3, FixedSizeList(1, Int64))')), + array_prepend(arrow_cast([1.0], 'FixedSizeList(1, Float64)'), arrow_cast([[2.0], [3.0], [4.0]], 'FixedSizeList(3, FixedSizeList(1, Float64))')), + array_prepend(arrow_cast(['h'], 'FixedSizeList(1, Utf8)'), arrow_cast([['e'], ['l'], ['l'], ['o']], 'FixedSizeList(4, FixedSizeList(1, Utf8))')); +---- +[[1], [1], [2], [3]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -3563,6 +3585,17 @@ select list_replace( ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +# array_replace scalar function #4 (null input) +query ? +select array_replace(make_array(1, 2, 3, 4, 5), NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace scalar function with columns #1 query ? select array_replace(column1, column2, column3) from arrays_with_repeating_elements; @@ -3728,6 +3761,17 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +# array_replace_n scalar function #4 (null input) +query ? +select array_replace_n(make_array(1, 2, 3, 4, 5), NULL, NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace_n(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace_n scalar function with columns #1 query ? select @@ -3904,6 +3948,17 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +# array_replace_all scalar function #4 (null input) +query ? +select array_replace_all(make_array(1, 2, 3, 4, 5), NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace_all(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace_all scalar function with columns #1 query ? select diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 196e2f30518e..aefc2672b539 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -637,46 +637,6 @@ select * from table_without_values; statement ok set datafusion.catalog.information_schema = true; -statement ok -CREATE OR REPLACE TABLE TABLE_WITH_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT); - -# Check table name is in lowercase -query TTTT -show create table table_with_normalization ----- -datafusion public table_with_normalization NULL - -# Check column name is in uppercase -query TTT -describe table_with_normalization ----- -field1 Int64 YES -field2 Int64 YES - -# Disable ident normalization -statement ok -set datafusion.sql_parser.enable_ident_normalization = false; - -statement ok -CREATE TABLE TABLE_WITHOUT_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT) AS VALUES (1,2); - -# Check table name is in uppercase -query TTTT -show create table TABLE_WITHOUT_NORMALIZATION ----- -datafusion public TABLE_WITHOUT_NORMALIZATION NULL - -# Check column name is in uppercase -query TTT -describe TABLE_WITHOUT_NORMALIZATION ----- -FIELD1 Int64 YES -FIELD2 Int64 YES - -statement ok -set datafusion.sql_parser.enable_ident_normalization = true; - - statement ok create table foo(x int); diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index b9699dfd5c06..de1dbf74c29b 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -720,6 +720,14 @@ select count(distinct u) from uuid_table; ---- 2 +# must be valid uuidv4 format +query B +SELECT REGEXP_LIKE(uuid(), + '^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$') + AS is_valid; +---- +true + statement ok drop table uuid_table diff --git a/datafusion/sqllogictest/test_files/ident_normalization.slt b/datafusion/sqllogictest/test_files/ident_normalization.slt new file mode 100644 index 000000000000..996093c3ad9c --- /dev/null +++ b/datafusion/sqllogictest/test_files/ident_normalization.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Enable information_schema, so we can execute show create table +statement ok +set datafusion.catalog.information_schema = true; + +# Check ident normalization is enabled by default + +statement ok +CREATE OR REPLACE TABLE TABLE_WITH_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT); + +# Check table name is in lowercase +query TTTT +show create table table_with_normalization +---- +datafusion public table_with_normalization NULL + +# Check column name is in uppercase +query TTT +describe table_with_normalization +---- +field1 Int64 YES +field2 Int64 YES + +# Disable ident normalization +statement ok +set datafusion.sql_parser.enable_ident_normalization = false; + +statement ok +CREATE TABLE TABLE_WITHOUT_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT) AS VALUES (1,2); + +# Check table name is in uppercase +query TTTT +show create table TABLE_WITHOUT_NORMALIZATION +---- +datafusion public TABLE_WITHOUT_NORMALIZATION NULL + +# Check column name is in uppercase +query TTT +describe TABLE_WITHOUT_NORMALIZATION +---- +FIELD1 Int64 YES +FIELD2 Int64 YES + +statement ok +DROP TABLE TABLE_WITHOUT_NORMALIZATION + +############ +## Column Name Normalization +############ + +# Table x (lowercase) with a column named "A" (uppercase) +statement ok +create table x as select 1 "A" + +query TTT +describe x +---- +A Int64 NO + +# Expect error as 'a' is not a column -- "A" is and the identifiers +# are not normalized +query error DataFusion error: Schema error: No field named a\. Valid fields are x\."A"\. +select a from x; + +# should work (note the uppercase 'A') +query I +select A from x; +---- +1 + +statement ok +drop table x; + +############ +## Table Name Normalization +############ + +# Table Y (uppercase) with a column named a (lower case) +statement ok +create table Y as select 1 a; + +query TTT +describe Y +---- +a Int64 NO + +# Expect error as y is not a a table -- "Y" is +query error DataFusion error: Error during planning: table 'datafusion\.public\.y' not found +select * from y; + +# should work (note the uppercase 'Y') +query I +select * from Y; +---- +1 + +statement ok +drop table Y; + +############ +## Function Name Normalization +############ + +## Check function names are still normalized even though column names are not +query I +SELECT length('str'); +---- +3 + +query I +SELECT LENGTH('str'); +---- +3 + +query T +SELECT CONCAT('Hello', 'World') +---- +HelloWorld diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index cbc989841ab3..ee76ee1c5511 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -296,8 +296,11 @@ insert into table_without_values(field1) values(3); 1 # insert NULL values for the missing column (field1), but column is non-nullable -statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +statement error insert into table_without_values(field2) values(300); +---- +DataFusion error: Execution error: Invalid batch column at '0' has null but schema specifies non-nullable + statement error Invalid argument error: Column 'column1' is declared as non-nullable but contains null values insert into table_without_values values(NULL, 300); @@ -358,7 +361,7 @@ statement ok create table test_column_defaults( a int, b int not null default null, - c int default 100*2+300, + c int default 100*2+300, d text default lower('DEFAULT_TEXT'), e timestamp default now() ) @@ -368,8 +371,11 @@ insert into test_column_defaults values(1, 10, 100, 'ABC', now()) ---- 1 -statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +statement error insert into test_column_defaults(a) values(2) +---- +DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable + query I insert into test_column_defaults(b) values(20) @@ -412,7 +418,7 @@ statement ok create table test_column_defaults( a int, b int not null default null, - c int default 100*2+300, + c int default 100*2+300, d text default lower('DEFAULT_TEXT'), e timestamp default now() ) as values(1, 10, 100, 'ABC', now()) diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index c5fa2b4e1a51..ee1d67c5e26d 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -60,6 +60,7 @@ STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b); +#query error here because PARTITIONED BY (b) will make the b nullable to false query I insert into dictionary_encoded_parquet_partitioned select * from dictionary_encoded_values @@ -81,6 +82,7 @@ STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' PARTITIONED BY (b); +#query error here because PARTITIONED BY (b) will make the b nullable to false query I insert into dictionary_encoded_arrow_partitioned select * from dictionary_encoded_values @@ -543,8 +545,11 @@ insert into table_without_values(field1) values(3); 1 # insert NULL values for the missing column (field1), but column is non-nullable -statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +statement error insert into table_without_values(field2) values(300); +---- +DataFusion error: Execution error: Invalid batch column at '0' has null but schema specifies non-nullable + statement error Invalid argument error: Column 'column1' is declared as non-nullable but contains null values insert into table_without_values values(NULL, 300); @@ -581,8 +586,11 @@ insert into test_column_defaults values(1, 10, 100, 'ABC', now()) ---- 1 -statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +statement error insert into test_column_defaults(a) values(2) +---- +DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable + query I insert into test_column_defaults(b) values(20) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a6826a6ef108..66413775b393 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1807,24 +1807,6 @@ SELECT acos(); statement error SELECT isnan(); -# turn off enable_ident_normalization -statement ok -set datafusion.sql_parser.enable_ident_normalization = false; - -query I -SELECT LENGTH('str'); ----- -3 - -query T -SELECT CONCAT('Hello', 'World') ----- -HelloWorld - -# turn on enable_ident_normalization -statement ok -set datafusion.sql_parser.enable_ident_normalization = true; - query I SELECT LENGTH('str'); ---- diff --git a/datafusion/sqllogictest/test_files/union_by_name.slt b/datafusion/sqllogictest/test_files/union_by_name.slt new file mode 100644 index 000000000000..0ba4c32ee5be --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_by_name.slt @@ -0,0 +1,288 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Portions of this file are derived from DuckDB and are licensed +# under the MIT License (see below). + +# Copyright 2018-2025 Stichting DuckDB Foundation + +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +statement ok +CREATE TABLE t1 (x INT, y INT); + +statement ok +INSERT INTO t1 VALUES (3, 3), (3, 3), (1, 1); + +statement ok +CREATE TABLE t2 (y INT, z INT); + +statement ok +INSERT INTO t2 VALUES (2, 2), (4, 4); + + +# Test binding +query I +SELECT t1.x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +3 + +query I +SELECT t1.x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +1 +3 +3 +3 +3 + +query I +SELECT x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +3 + +query I +SELECT x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +1 +3 +3 +3 +3 + +query II +(SELECT x FROM t1 UNION ALL SELECT x FROM t1) UNION BY NAME SELECT 5 ORDER BY x; +---- +NULL 1 +NULL 3 +5 NULL + +# TODO: This should pass, but the sanity checker isn't allowing it. +# Commenting out the ordering check in the sanity checker produces the correct result. +query error +(SELECT x FROM t1 UNION ALL SELECT x FROM t1) UNION ALL BY NAME SELECT 5 ORDER BY x; +---- +DataFusion error: SanityCheckPlan +caused by +Error during planning: Plan: ["SortPreservingMergeExec: [x@1 ASC NULLS LAST]", " UnionExec", " SortExec: expr=[x@1 ASC NULLS LAST], preserve_partitioning=[true]", " ProjectionExec: expr=[NULL as Int64(5), x@0 as x]", " UnionExec", " DataSourceExec: partitions=1, partition_sizes=[1]", " DataSourceExec: partitions=1, partition_sizes=[1]", " ProjectionExec: expr=[5 as Int64(5), NULL as x]", " PlaceholderRowExec"] does not satisfy order requirements: [x@1 ASC NULLS LAST]. Child-0 order: [] + + +query II +(SELECT x FROM t1 UNION ALL SELECT y FROM t1) UNION BY NAME SELECT 5 ORDER BY x; +---- +NULL 1 +NULL 3 +5 NULL + +# TODO: This should pass, but the sanity checker isn't allowing it. +# Commenting out the ordering check in the sanity checker produces the correct result. +query error +(SELECT x FROM t1 UNION ALL SELECT y FROM t1) UNION ALL BY NAME SELECT 5 ORDER BY x; +---- +DataFusion error: SanityCheckPlan +caused by +Error during planning: Plan: ["SortPreservingMergeExec: [x@1 ASC NULLS LAST]", " UnionExec", " SortExec: expr=[x@1 ASC NULLS LAST], preserve_partitioning=[true]", " ProjectionExec: expr=[NULL as Int64(5), x@0 as x]", " UnionExec", " DataSourceExec: partitions=1, partition_sizes=[1]", " ProjectionExec: expr=[y@0 as x]", " DataSourceExec: partitions=1, partition_sizes=[1]", " ProjectionExec: expr=[5 as Int64(5), NULL as x]", " PlaceholderRowExec"] does not satisfy order requirements: [x@1 ASC NULLS LAST]. Child-0 order: [] + + + +# Ambiguous name + +statement error DataFusion error: Schema error: No field named t1.x. Valid fields are a, b. +SELECT x AS a FROM t1 UNION BY NAME SELECT x AS b FROM t1 ORDER BY t1.x; + +query II +(SELECT y FROM t1 UNION ALL SELECT x FROM t1) UNION BY NAME (SELECT z FROM t2 UNION ALL SELECT y FROM t2) ORDER BY y, z; +---- +1 NULL +3 NULL +NULL 2 +NULL 4 + +query II +(SELECT y FROM t1 UNION ALL SELECT x FROM t1) UNION ALL BY NAME (SELECT z FROM t2 UNION ALL SELECT y FROM t2) ORDER BY y, z; +---- +1 NULL +1 NULL +3 NULL +3 NULL +3 NULL +3 NULL +NULL 2 +NULL 2 +NULL 4 +NULL 4 + +# Limit + +query III +SELECT 1 UNION BY NAME SELECT * FROM unnest(range(2, 100)) UNION BY NAME SELECT 999 ORDER BY 3, 1 LIMIT 5; +---- +NULL NULL 2 +NULL NULL 3 +NULL NULL 4 +NULL NULL 5 +NULL NULL 6 + +# TODO: This should pass, but the sanity checker isn't allowing it. +# Commenting out the ordering check in the sanity checker produces the correct result. +query error +SELECT 1 UNION ALL BY NAME SELECT * FROM unnest(range(2, 100)) UNION ALL BY NAME SELECT 999 ORDER BY 3, 1 LIMIT 5; +---- +DataFusion error: SanityCheckPlan +caused by +Error during planning: Plan: ["SortPreservingMergeExec: [UNNEST(range(Int64(2),Int64(100)))@2 ASC NULLS LAST, Int64(1)@0 ASC NULLS LAST], fetch=5", " UnionExec", " SortExec: TopK(fetch=5), expr=[UNNEST(range(Int64(2),Int64(100)))@2 ASC NULLS LAST], preserve_partitioning=[true]", " ProjectionExec: expr=[Int64(1)@0 as Int64(1), NULL as Int64(999), UNNEST(range(Int64(2),Int64(100)))@1 as UNNEST(range(Int64(2),Int64(100)))]", " UnionExec", " ProjectionExec: expr=[1 as Int64(1), NULL as UNNEST(range(Int64(2),Int64(100)))]", " PlaceholderRowExec", " ProjectionExec: expr=[NULL as Int64(1), __unnest_placeholder(range(Int64(2),Int64(100)),depth=1)@0 as UNNEST(range(Int64(2),Int64(100)))]", " UnnestExec", " ProjectionExec: expr=[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] as __unnest_placeholder(range(Int64(2),Int64(100)))]", " PlaceholderRowExec", " ProjectionExec: expr=[NULL as Int64(1), 999 as Int64(999), NULL as UNNEST(range(Int64(2),Int64(100)))]", " PlaceholderRowExec"] does not satisfy order requirements: [UNNEST(range(Int64(2),Int64(100)))@2 ASC NULLS LAST, Int64(1)@0 ASC NULLS LAST]. Child-0 order: [] + + +# Order by + +query III +SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY y; +---- +1 1 NULL +NULL 2 2 +3 3 NULL +NULL 4 4 + +query III +SELECT x, y FROM t1 UNION ALL BY NAME SELECT y, z FROM t2 ORDER BY y; +---- +1 1 NULL +NULL 2 2 +3 3 NULL +3 3 NULL +NULL 4 4 + +query III +SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY 3, 1; +---- +NULL 2 2 +NULL 4 4 +1 1 NULL +3 3 NULL + +query III +SELECT x, y FROM t1 UNION ALL BY NAME SELECT y, z FROM t2 ORDER BY 3, 1; +---- +NULL 2 2 +NULL 4 4 +1 1 NULL +3 3 NULL +3 3 NULL + +statement error +SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY 4; +---- +DataFusion error: Error during planning: Order by column out of bounds, specified: 4, max: 3 + + +statement error +SELECT x, y FROM t1 UNION ALL BY NAME SELECT y, z FROM t2 ORDER BY 4; +---- +DataFusion error: Error during planning: Order by column out of bounds, specified: 4, max: 3 + + +# Multi set operations + +query IIII rowsort +(SELECT 1 UNION BY NAME SELECT x, y FROM t1) UNION BY NAME SELECT y, z FROM t2; +---- +1 NULL NULL NULL +NULL 1 1 NULL +NULL 3 3 NULL +NULL NULL 2 2 +NULL NULL 4 4 + +query IIII rowsort +(SELECT 1 UNION ALL BY NAME SELECT x, y FROM t1) UNION ALL BY NAME SELECT y, z FROM t2; +---- +1 NULL NULL NULL +NULL 1 1 NULL +NULL 3 3 NULL +NULL 3 3 NULL +NULL NULL 2 2 +NULL NULL 4 4 + +query III +SELECT x, y FROM t1 UNION BY NAME (SELECT y, z FROM t2 INTERSECT SELECT 2, 2 as two FROM t1 ORDER BY 1) ORDER BY 1; +---- +1 1 NULL +3 3 NULL +NULL 2 2 + +query III +SELECT x, y FROM t1 UNION ALL BY NAME (SELECT y, z FROM t2 INTERSECT SELECT 2, 2 as two FROM t1 ORDER BY 1) ORDER BY 1; +---- +1 1 NULL +3 3 NULL +3 3 NULL +NULL 2 2 + +query III +(SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY 1) EXCEPT SELECT NULL, 2, 2 as two FROM t1 ORDER BY 1; +---- +1 1 NULL +3 3 NULL +NULL 4 4 + +# Alias in select list + +query II +SELECT x as a FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY 1, 2; +---- +1 NULL +3 NULL +NULL 1 +NULL 3 + +query II +SELECT x as a FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY 1, 2; +---- +1 NULL +3 NULL +3 NULL +NULL 1 +NULL 3 +NULL 3 + +# Different types + +query T rowsort +SELECT '0' as c UNION ALL BY NAME SELECT 0 as c; +---- +0 +0 diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt new file mode 100644 index 000000000000..9c70b1011f58 --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## UNION DataType Tests +########## + +query ?I +select union_column, union_extract(union_column, 'int') from union_table; +---- +{int=1} 1 +{int=2} 2 + +query error DataFusion error: Execution error: field bool not found on union +select union_extract(union_column, 'bool') from union_table; + +query error DataFusion error: Error during planning: 'union_extract' does not support zero arguments +select union_extract() from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 +select union_extract(union_column) from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 +select union_extract('a') from union_table; + +query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead +select union_extract('a', union_column) from union_table; + +query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead +select union_extract(union_column, 1) from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 +select union_extract(union_column, 'a', 'b') from union_table; diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index ce056ddac664..7bbdfc2a5d94 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -152,7 +152,7 @@ pub async fn from_substrait_rel( } } - Ok(base_config.new_exec() as Arc) + Ok(base_config.build() as Arc) } _ => not_impl_err!( "Only LocalFile reads are supported when parsing physical" diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index 04c5e8ada758..f1284db2ad46 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -49,7 +49,7 @@ async fn parquet_exec() -> Result<()> { 123, )], ]); - let parquet_exec: Arc = scan_config.new_exec(); + let parquet_exec: Arc = scan_config.build(); let mut extension_info: ( Vec, diff --git a/docs/source/index.rst b/docs/source/index.rst index 03561be3893c..45c4ffafe7f2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,6 +103,7 @@ To get started, see user-guide/introduction user-guide/example-usage + user-guide/features user-guide/concepts-readings-events user-guide/crate-configuration user-guide/cli/index diff --git a/docs/source/user-guide/sql/sql_status.md b/docs/source/user-guide/features.md similarity index 65% rename from docs/source/user-guide/sql/sql_status.md rename to docs/source/user-guide/features.md index cb9bc0bb67b3..1f73ce7eac11 100644 --- a/docs/source/user-guide/sql/sql_status.md +++ b/docs/source/user-guide/features.md @@ -17,23 +17,28 @@ under the License. --> -# Status +# Features ## General - [x] SQL Parser - [x] SQL Query Planner +- [x] DataFrame API +- [x] Parallel query execution +- [x] Streaming Execution + +## Optimizations + - [x] Query Optimizer - [x] Constant folding - [x] Join Reordering - [x] Limit Pushdown - [x] Projection push down - [x] Predicate push down -- [x] Type coercion -- [x] Parallel query execution ## SQL Support +- [x] Type coercion - [x] Projection (`SELECT`) - [x] Filter (`WHERE`) - [x] Filter post-aggregate (`HAVING`) @@ -42,23 +47,23 @@ - [x] Aggregate (`GROUP BY`) - [x] cast /try_cast - [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) -- [x] [String Functions](./scalar_functions.md#string-functions) -- [x] [Conditional Functions](./scalar_functions.md#conditional-functions) -- [x] [Time and Date Functions](./scalar_functions.md#time-and-date-functions) -- [x] [Math Functions](./scalar_functions.md#math-functions) -- [x] [Aggregate Functions](./aggregate_functions.md) (`SUM`, `MEDIAN`, and many more) +- [x] [String Functions](./sql/scalar_functions.md#string-functions) +- [x] [Conditional Functions](./sql/scalar_functions.md#conditional-functions) +- [x] [Time and Date Functions](./sql/scalar_functions.md#time-and-date-functions) +- [x] [Math Functions](./sql/scalar_functions.md#math-functions) +- [x] [Aggregate Functions](./sql/aggregate_functions.md) (`SUM`, `MEDIAN`, and many more) - [x] Schema Queries - [x] `SHOW TABLES` - [x] `SHOW COLUMNS FROM ` - [x] `SHOW CREATE TABLE ` - - [x] Basic SQL [Information Schema](./information_schema.md) (`TABLES`, `VIEWS`, `COLUMNS`) - - [ ] Full SQL [Information Schema](./information_schema.md) support -- [ ] Support for nested types (`ARRAY`/`LIST` and `STRUCT`. See [#2326](https://github.com/apache/datafusion/issues/2326) for details) + - [x] Basic SQL [Information Schema](./sql/information_schema.md) (`TABLES`, `VIEWS`, `COLUMNS`) + - [ ] Full SQL [Information Schema](./sql/information_schema.md) support +- [x] Support for nested types (`ARRAY`/`LIST` and `STRUCT`. - [x] Read support - [x] Write support - [x] Field access (`col['field']` and [`col[1]`]) - - [x] [Array Functions](./scalar_functions.md#array-functions) - - [ ] [Struct Functions](./scalar_functions.md#struct-functions) + - [x] [Array Functions](./sql/scalar_functions.md#array-functions) + - [x] [Struct Functions](./sql/scalar_functions.md#struct-functions) - [x] `struct` - [ ] [Postgres JSON operators](https://github.com/apache/datafusion/issues/6631) (`->`, `->>`, etc.) - [x] Subqueries @@ -73,12 +78,12 @@ - [x] Catalogs - [x] Schemas (`CREATE / DROP SCHEMA`) - [x] Tables (`CREATE / DROP TABLE`, `CREATE TABLE AS SELECT`) -- [ ] Data Insert +- [x] Data Insert - [x] `INSERT INTO` - - [ ] `COPY .. INTO ..` + - [x] `COPY .. INTO ..` - [x] CSV - - [ ] JSON - - [ ] Parquet + - [x] JSON + - [x] Parquet - [ ] Avro ## Runtime @@ -87,16 +92,22 @@ - [x] Streaming Window Evaluation - [x] Memory limits enforced - [x] Spilling (to disk) Sort -- [ ] Spilling (to disk) Grouping +- [x] Spilling (to disk) Grouping - [ ] Spilling (to disk) Joins ## Data Sources -In addition to allowing arbitrary datasources via the `TableProvider` +In addition to allowing arbitrary datasources via the [`TableProvider`] trait, DataFusion includes built in support for the following formats: - [x] CSV -- [x] Parquet (for all primitive and nested types) +- [x] Parquet + - [x] Primitive and Nested Types + - [x] Row Group and Data Page pruning on min/max statistics + - [x] Row Group pruning on Bloom Filters + - [x] Predicate push down (late materialization) [not by default](https://github.com/apache/datafusion/issues/3463) - [x] JSON - [x] Avro - [x] Arrow + +[`tableprovider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 0508fa12f0f3..8e3f51bf8b0b 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -33,5 +33,5 @@ SQL Reference window_functions scalar_functions special_functions - sql_status write_options + prepared_statements diff --git a/docs/source/user-guide/sql/prepared_statements.md b/docs/source/user-guide/sql/prepared_statements.md new file mode 100644 index 000000000000..6677b212fdf2 --- /dev/null +++ b/docs/source/user-guide/sql/prepared_statements.md @@ -0,0 +1,139 @@ + + +# Prepared Statements + +The `PREPARE` statement allows for the creation and storage of a SQL statement with placeholder arguments. + +The prepared statements can then be executed repeatedly in an efficient manner. + +**SQL Example** + +Create a prepared statement `greater_than` that selects all records where column "a" is greater than the parameter: + +```sql +PREPARE greater_than(INT) AS SELECT * FROM example WHERE a > $1; +``` + +The prepared statement can then be executed with parameters as needed: + +```sql +EXECUTE greater_than(20); +``` + +**Rust Example** + +```rust +use datafusion::prelude::*; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + // Register the table + let ctx = SessionContext::new(); + ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + + // Create the prepared statement `greater_than` + let prepare_sql = "PREPARE greater_than(INT) AS SELECT * FROM example WHERE a > $1"; + ctx.sql(prepare_sql).await?; + + // Execute the prepared statement `greater_than` + let execute_sql = "EXECUTE greater_than(20)"; + let df = ctx.sql(execute_sql).await?; + + // Execute and print results + df.show().await?; + Ok(()) +} +``` + +## Inferred Types + +If the parameter type is not specified, it can be inferred at execution time: + +**SQL Example** + +Create the prepared statement `greater_than` + +```sql +PREPARE greater_than AS SELECT * FROM example WHERE a > $1; +``` + +Execute the prepared statement `greater_than` + +```sql +EXECUTE greater_than(20); +``` + +**Rust Example** + +```rust +# use datafusion::prelude::*; +# #[tokio::main] +# async fn main() -> datafusion::error::Result<()> { +# let ctx = SessionContext::new(); +# ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; +# + // Create the prepared statement `greater_than` + let prepare_sql = "PREPARE greater_than AS SELECT * FROM example WHERE a > $1"; + ctx.sql(prepare_sql).await?; + + // Execute the prepared statement `greater_than` + let execute_sql = "EXECUTE greater_than(20)"; + let df = ctx.sql(execute_sql).await?; +# +# Ok(()) +# } +``` + +## Positional Arguments + +In the case of multiple parameters, prepared statements can use positional arguments: + +**SQL Example** + +Create the prepared statement `greater_than` + +```sql +PREPARE greater_than(INT, DOUBLE) AS SELECT * FROM example WHERE a > $1 AND b > $2; +``` + +Execute the prepared statement `greater_than` + +```sql +EXECUTE greater_than(20, 23.3); +``` + +**Rust Example** + +```rust +# use datafusion::prelude::*; +# #[tokio::main] +# async fn main() -> datafusion::error::Result<()> { +# let ctx = SessionContext::new(); +# ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + // Create the prepared statement `greater_than` + let prepare_sql = "PREPARE greater_than(INT, DOUBLE) AS SELECT * FROM example WHERE a > $1 AND b > $2"; + ctx.sql(prepare_sql).await?; + + // Execute the prepared statement `greater_than` + let execute_sql = "EXECUTE greater_than(20, 23.3)"; + let df = ctx.sql(execute_sql).await?; +# Ok(()) +# } +``` diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index b14bf5b2cc91..fb4043c33efc 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4339,6 +4339,40 @@ sha512(expression) +-------------------------------------------+ ``` +## Union Functions + +Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator + +- [union_extract](#union_extract) + +### `union_extract` + +Returns the value of the given field in the union when selected, or NULL otherwise. + +```sql +union_extract(union, field_name) +``` + +#### Arguments + +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **field_name**: String expression to operate on. Must be a constant. + +#### Example + +```sql +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ +``` + ## Other Functions - [arrow_cast](#arrow_cast)