From ae70d7206059c37448f430291e8ba35bdcfa6a2f Mon Sep 17 00:00:00 2001 From: Wang Date: Fri, 4 Mar 2022 18:19:29 +0800 Subject: [PATCH] Refactor ExecutionContext and related conf to support multi-tenancy configurations Do not use static singleton RuntimeEnv in both DataFusion and Ballista Add SessionBuilder fn type to SchedulerServer to allow customized SessionContext creation fix client BallistaContext --- .../src/bin/ballista-dataframe.rs | 2 +- ballista-examples/src/bin/ballista-sql.rs | 2 +- ballista/rust/client/README.md | 2 +- ballista/rust/client/src/context.rs | 221 ++-- ballista/rust/client/src/prelude.rs | 1 + ballista/rust/core/proto/ballista.proto | 16 +- ballista/rust/core/src/config.rs | 40 + .../src/execution_plans/distributed_query.rs | 43 +- .../src/execution_plans/shuffle_reader.rs | 13 +- .../src/execution_plans/shuffle_writer.rs | 40 +- .../src/execution_plans/unresolved_shuffle.rs | 21 +- .../rust/core/src/serde/logical_plan/mod.rs | 56 +- ballista/rust/core/src/serde/mod.rs | 43 +- .../src/serde/physical_plan/from_proto.rs | 27 +- .../rust/core/src/serde/physical_plan/mod.rs | 216 +++- ballista/rust/core/src/utils.rs | 38 +- ballista/rust/executor/src/collect.rs | 16 +- ballista/rust/executor/src/execution_loop.rs | 20 +- ballista/rust/executor/src/executor.rs | 16 +- ballista/rust/executor/src/executor_server.rs | 20 +- ballista/rust/executor/src/main.rs | 19 +- ballista/rust/executor/src/standalone.rs | 16 +- ballista/rust/scheduler/src/main.rs | 11 +- ballista/rust/scheduler/src/planner.rs | 22 +- .../scheduler/src/scheduler_server/grpc.rs | 156 ++- .../scheduler/src/scheduler_server/mod.rs | 114 +- .../src/scheduler_server/task_scheduler.rs | 20 +- ballista/rust/scheduler/src/standalone.rs | 3 - ballista/rust/scheduler/src/state/mod.rs | 64 +- ballista/rust/scheduler/src/test_utils.rs | 9 +- benchmarks/src/bin/nyctaxi.rs | 21 +- benchmarks/src/bin/tpch.rs | 95 +- datafusion-cli/src/context.rs | 19 +- datafusion-cli/src/main.rs | 10 +- datafusion-examples/examples/avro_sql.rs | 2 +- datafusion-examples/examples/csv_sql.rs | 4 +- .../examples/custom_datasource.rs | 26 +- datafusion-examples/examples/dataframe.rs | 4 +- .../examples/dataframe_in_memory.rs | 2 +- datafusion-examples/examples/flight_server.rs | 4 +- datafusion-examples/examples/memtable.rs | 6 +- datafusion-examples/examples/parquet_sql.rs | 4 +- .../examples/parquet_sql_multiple_files.rs | 4 +- datafusion-examples/examples/simple_udaf.rs | 6 +- datafusion-examples/examples/simple_udf.rs | 8 +- datafusion/Cargo.toml | 1 + datafusion/benches/aggregate_query_sql.rs | 8 +- datafusion/benches/filter_query_sql.rs | 20 +- datafusion/benches/math_query_sql.rs | 8 +- datafusion/benches/parquet_query_sql.rs | 4 +- datafusion/benches/physical_plan.rs | 3 +- datafusion/benches/sort_limit_query_sql.rs | 13 +- datafusion/benches/window_query_sql.rs | 8 +- datafusion/src/catalog/schema.rs | 4 +- datafusion/src/dataframe.rs | 46 +- datafusion/src/datasource/datasource.rs | 1 + datafusion/src/datasource/empty.rs | 7 +- datafusion/src/datasource/file_format/avro.rs | 77 +- datafusion/src/datasource/file_format/csv.rs | 34 +- datafusion/src/datasource/file_format/json.rs | 34 +- datafusion/src/datasource/file_format/mod.rs | 1 + .../src/datasource/file_format/parquet.rs | 79 +- datafusion/src/datasource/listing/helpers.rs | 4 +- datafusion/src/datasource/listing/table.rs | 15 +- datafusion/src/datasource/memory.rs | 54 +- datafusion/src/datasource/object_store/mod.rs | 2 +- datafusion/src/execution/context.rs | 1063 +++++++++-------- datafusion/src/execution/dataframe_impl.rs | 127 +- datafusion/src/execution/runtime_env.rs | 89 +- datafusion/src/lib.rs | 4 +- datafusion/src/optimizer/filter_push_down.rs | 1 + .../aggregate_statistics.rs | 57 +- .../physical_optimizer/coalesce_batches.rs | 7 +- .../hash_build_probe_order.rs | 13 +- .../src/physical_optimizer/merge_exec.rs | 5 +- .../src/physical_optimizer/optimizer.rs | 4 +- .../src/physical_optimizer/repartition.rs | 12 +- datafusion/src/physical_optimizer/utils.rs | 6 +- datafusion/src/physical_plan/analyze.rs | 22 +- .../src/physical_plan/coalesce_batches.rs | 23 +- .../src/physical_plan/coalesce_partitions.rs | 30 +- datafusion/src/physical_plan/common.rs | 6 +- datafusion/src/physical_plan/cross_join.rs | 12 +- datafusion/src/physical_plan/empty.rs | 45 +- datafusion/src/physical_plan/explain.rs | 12 +- .../src/physical_plan/file_format/avro.rs | 17 +- .../src/physical_plan/file_format/csv.rs | 66 +- .../src/physical_plan/file_format/json.rs | 96 +- .../src/physical_plan/file_format/parquet.rs | 60 +- datafusion/src/physical_plan/filter.rs | 17 +- .../src/physical_plan/hash_aggregate.rs | 73 +- datafusion/src/physical_plan/hash_join.rs | 207 +++- datafusion/src/physical_plan/limit.rs | 25 +- datafusion/src/physical_plan/memory.rs | 34 +- .../src/physical_plan/metrics/tracker.rs | 11 +- datafusion/src/physical_plan/mod.rs | 33 +- datafusion/src/physical_plan/planner.rs | 168 +-- datafusion/src/physical_plan/projection.rs | 18 +- datafusion/src/physical_plan/repartition.rs | 101 +- datafusion/src/physical_plan/sorts/sort.rs | 79 +- .../sorts/sort_preserving_merge.rs | 179 +-- datafusion/src/physical_plan/union.rs | 18 +- datafusion/src/physical_plan/values.rs | 20 +- datafusion/src/physical_plan/windows/mod.rs | 27 +- .../physical_plan/windows/window_agg_exec.rs | 10 +- datafusion/src/prelude.rs | 2 +- datafusion/src/row/mod.rs | 3 +- datafusion/src/test/exec.rs | 67 +- datafusion/tests/custom_sources.rs | 26 +- datafusion/tests/dataframe.rs | 8 +- datafusion/tests/dataframe_functions.rs | 4 +- datafusion/tests/merge_fuzz.rs | 27 +- datafusion/tests/order_spill_fuzz.rs | 15 +- datafusion/tests/parquet_pruning.rs | 18 +- datafusion/tests/path_partition.rs | 31 +- datafusion/tests/provider_filter_pushdown.rs | 16 +- datafusion/tests/sql/aggregates.rs | 224 ++-- datafusion/tests/sql/avro.rs | 23 +- datafusion/tests/sql/create_drop.rs | 22 +- datafusion/tests/sql/errors.rs | 24 +- datafusion/tests/sql/explain.rs | 4 +- datafusion/tests/sql/explain_analyze.rs | 82 +- datafusion/tests/sql/expr.rs | 106 +- datafusion/tests/sql/functions.rs | 32 +- datafusion/tests/sql/group_by.rs | 86 +- datafusion/tests/sql/information_schema.rs | 117 +- datafusion/tests/sql/intersection.rs | 20 +- datafusion/tests/sql/joins.rs | 148 +-- datafusion/tests/sql/limit.rs | 24 +- datafusion/tests/sql/mod.rs | 65 +- datafusion/tests/sql/order.rs | 30 +- datafusion/tests/sql/parquet.rs | 20 +- datafusion/tests/sql/partitioned_csv.rs | 17 +- datafusion/tests/sql/predicates.rs | 94 +- datafusion/tests/sql/projection.rs | 26 +- datafusion/tests/sql/references.rs | 20 +- datafusion/tests/sql/select.rs | 164 +-- datafusion/tests/sql/timestamp.rs | 122 +- datafusion/tests/sql/udf.rs | 16 +- datafusion/tests/sql/unicode.rs | 4 +- datafusion/tests/sql/union.rs | 18 +- datafusion/tests/sql/window.rs | 24 +- datafusion/tests/statistics.rs | 41 +- datafusion/tests/user_defined_plan.rs | 73 +- .../user-guide/distributed/clients/rust.md | 6 +- docs/source/user-guide/example-usage.md | 4 +- 146 files changed, 3932 insertions(+), 2663 deletions(-) diff --git a/ballista-examples/src/bin/ballista-dataframe.rs b/ballista-examples/src/bin/ballista-dataframe.rs index 8399324ad0e2..7f604d2ddc84 100644 --- a/ballista-examples/src/bin/ballista-dataframe.rs +++ b/ballista-examples/src/bin/ballista-dataframe.rs @@ -25,7 +25,7 @@ async fn main() -> Result<()> { let config = BallistaConfig::builder() .set("ballista.shuffle.partitions", "4") .build()?; - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/ballista-examples/src/bin/ballista-sql.rs b/ballista-examples/src/bin/ballista-sql.rs index 3e0df21a73f1..787f41aa41db 100644 --- a/ballista-examples/src/bin/ballista-sql.rs +++ b/ballista-examples/src/bin/ballista-sql.rs @@ -25,7 +25,7 @@ async fn main() -> Result<()> { let config = BallistaConfig::builder() .set("ballista.shuffle.partitions", "4") .build()?; - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; let testdata = datafusion::arrow::util::test_util::arrow_test_data(); diff --git a/ballista/rust/client/README.md b/ballista/rust/client/README.md index c27b83899b83..c190a012e1cf 100644 --- a/ballista/rust/client/README.md +++ b/ballista/rust/client/README.md @@ -106,7 +106,7 @@ async fn main() -> Result<()> { .build()?; // connect to Ballista scheduler - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; // register csv file with the execution context ctx.register_csv( diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index f7362476b60f..d94b9033ac26 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -17,6 +17,7 @@ //! Distributed execution context. +use log::info; use parking_lot::Mutex; use sqlparser::ast::Statement; use std::collections::HashMap; @@ -25,7 +26,8 @@ use std::path::PathBuf; use std::sync::Arc; use ballista_core::config::BallistaConfig; -use ballista_core::serde::protobuf::LogicalPlanNode; +use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; +use ballista_core::serde::protobuf::{ExecuteQueryParams, KeyValuePair, LogicalPlanNode}; use ballista_core::utils::create_df_ctx_with_ballista_query_planner; use datafusion::catalog::TableReference; @@ -35,7 +37,7 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::dataframe_impl::DataFrameImpl; use datafusion::logical_plan::{CreateExternalTable, LogicalPlan, TableScan}; use datafusion::prelude::{ - AvroReadOptions, CsvReadOptions, ExecutionConfig, ExecutionContext, + AvroReadOptions, CsvReadOptions, SessionConfig, SessionContext, }; use datafusion::sql::parser::{DFParser, FileType, Statement as DFStatement}; @@ -64,26 +66,81 @@ impl BallistaContextState { } } + pub fn config(&self) -> &BallistaConfig { + &self.config + } +} + +pub struct BallistaContext { + state: Arc>, + context: Arc, +} + +impl BallistaContext { + /// Create a context for executing queries against a remote Ballista scheduler instance + pub async fn remote(host: &str, port: u16, config: &BallistaConfig) -> Result { + let state = BallistaContextState::new(host.to_owned(), port, config); + let scheduler_url = + format!("http://{}:{}", &state.scheduler_host, state.scheduler_port); + info!( + "Connecting to Ballista scheduler at {}", + scheduler_url.clone() + ); + let mut scheduler = SchedulerGrpcClient::connect(scheduler_url.clone()) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + let remote_session_id = scheduler + .execute_query(ExecuteQueryParams { + query: None, + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(), + optional_session_id: None, + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? + .into_inner() + .session_id; + + info!( + "Server side SessionContext created with Session id: {}", + remote_session_id + ); + + let ctx = { + create_df_ctx_with_ballista_query_planner::( + scheduler_url, + remote_session_id, + state.config(), + ) + }; + + Ok(Self { + state: Arc::new(Mutex::new(state)), + context: Arc::new(ctx), + }) + } + #[cfg(feature = "standalone")] - pub async fn new_standalone( + pub async fn standalone( config: &BallistaConfig, concurrent_tasks: usize, ) -> ballista_core::error::Result { - use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; use ballista_core::serde::protobuf::PhysicalPlanNode; use ballista_core::serde::BallistaCodec; log::info!("Running in local mode. Scheduler will be run in-proc"); let addr = ballista_scheduler::standalone::new_standalone_scheduler().await?; - - let scheduler = loop { - match SchedulerGrpcClient::connect(format!( - "http://localhost:{}", - addr.port() - )) - .await - { + let scheduler_url = format!("http://localhost:{}", addr.port()); + let mut scheduler = loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { Err(_) => { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; log::info!("Attempting to connect to in-proc scheduler..."); @@ -92,6 +149,37 @@ impl BallistaContextState { } }; + let remote_session_id = scheduler + .execute_query(ExecuteQueryParams { + query: None, + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(), + optional_session_id: None, + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? + .into_inner() + .session_id; + + info!( + "Server side SessionContext created with Session id: {}", + remote_session_id + ); + + let ctx = { + create_df_ctx_with_ballista_query_planner::( + scheduler_url, + remote_session_id, + config, + ) + }; + let default_codec: BallistaCodec = BallistaCodec::default(); @@ -102,43 +190,12 @@ impl BallistaContextState { ) .await?; - Ok(Self { - config: config.clone(), - scheduler_host: "localhost".to_string(), - scheduler_port: addr.port(), - tables: HashMap::new(), - }) - } - - pub fn config(&self) -> &BallistaConfig { - &self.config - } -} - -pub struct BallistaContext { - state: Arc>, -} - -impl BallistaContext { - /// Create a context for executing queries against a remote Ballista scheduler instance - pub fn remote(host: &str, port: u16, config: &BallistaConfig) -> Self { - let state = BallistaContextState::new(host.to_owned(), port, config); - - Self { - state: Arc::new(Mutex::new(state)), - } - } - - #[cfg(feature = "standalone")] - pub async fn standalone( - config: &BallistaConfig, - concurrent_tasks: usize, - ) -> ballista_core::error::Result { let state = - BallistaContextState::new_standalone(config, concurrent_tasks).await?; + BallistaContextState::new("localhost".to_string(), addr.port(), config); Ok(Self { state: Arc::new(Mutex::new(state)), + context: Arc::new(ctx), }) } @@ -154,15 +211,10 @@ impl BallistaContext { let path = fs::canonicalize(&path)?; // use local DataFusion context for now but later this might call the scheduler - let mut ctx = { - let guard = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &guard.scheduler_host, - guard.scheduler_port, - guard.config(), - ) - }; - let df = ctx.read_avro(path.to_str().unwrap(), options).await?; + let df = self + .context + .read_avro(path.to_str().unwrap(), options) + .await?; Ok(df) } @@ -174,15 +226,7 @@ impl BallistaContext { let path = fs::canonicalize(&path)?; // use local DataFusion context for now but later this might call the scheduler - let mut ctx = { - let guard = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &guard.scheduler_host, - guard.scheduler_port, - guard.config(), - ) - }; - let df = ctx.read_parquet(path.to_str().unwrap()).await?; + let df = self.context.read_parquet(path.to_str().unwrap()).await?; Ok(df) } @@ -198,15 +242,10 @@ impl BallistaContext { let path = fs::canonicalize(&path)?; // use local DataFusion context for now but later this might call the scheduler - let mut ctx = { - let guard = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &guard.scheduler_host, - guard.scheduler_port, - guard.config(), - ) - }; - let df = ctx.read_csv(path.to_str().unwrap(), options).await?; + let df = self + .context + .read_csv(path.to_str().unwrap(), options) + .await?; Ok(df) } @@ -292,34 +331,30 @@ impl BallistaContext { /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` /// might require the schema to be inferred. pub async fn sql(&self, sql: &str) -> Result> { - let mut ctx = { - let state = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &state.scheduler_host, - state.scheduler_port, - state.config(), - ) - }; - + let mut ctx = self.context.clone(); let is_show = self.is_show_statement(sql).await?; // the show tables、 show columns sql can not run at scheduler because the tables is store at client if is_show { let state = self.state.lock(); - ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema( + ctx = Arc::new(SessionContext::with_config( + SessionConfig::new().with_information_schema( state.config.default_with_information_schema(), ), - ); + )); } // register tables with DataFusion context { let state = self.state.lock(); for (name, prov) in &state.tables { - ctx.register_table( - TableReference::Bare { table: name }, - Arc::clone(prov), - )?; + // ctx is shared between queries, check table exists or not before register + let table_ref = TableReference::Bare { table: name }; + if !ctx.table_exist(table_ref)? { + ctx.register_table( + TableReference::Bare { table: name }, + Arc::clone(prov), + )?; + } } } @@ -342,16 +377,16 @@ impl BallistaContext { .has_header(*has_header), ) .await?; - Ok(Arc::new(DataFrameImpl::new(ctx.state, &plan))) + Ok(Arc::new(DataFrameImpl::new(ctx.state.clone(), &plan))) } FileType::Parquet => { self.register_parquet(name, location).await?; - Ok(Arc::new(DataFrameImpl::new(ctx.state, &plan))) + Ok(Arc::new(DataFrameImpl::new(ctx.state.clone(), &plan))) } FileType::Avro => { self.register_avro(name, location, AvroReadOptions::default()) .await?; - Ok(Arc::new(DataFrameImpl::new(ctx.state, &plan))) + Ok(Arc::new(DataFrameImpl::new(ctx.state.clone(), &plan))) } _ => Err(DataFusionError::NotImplemented(format!( "Unsupported file type {:?}.", @@ -476,7 +511,6 @@ mod tests { use datafusion::arrow::datatypes::Schema; use datafusion::arrow::util::pretty; use datafusion::datasource::file_format::csv::CsvFormat; - use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, }; @@ -484,9 +518,6 @@ mod tests { use ballista_core::config::{ BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, }; - use std::fs::File; - use std::io::Write; - use tempfile::TempDir; let config = BallistaConfigBuilder::default() .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") .build() diff --git a/ballista/rust/client/src/prelude.rs b/ballista/rust/client/src/prelude.rs index d162d0c017bd..15558a357daa 100644 --- a/ballista/rust/client/src/prelude.rs +++ b/ballista/rust/client/src/prelude.rs @@ -19,6 +19,7 @@ pub use crate::context::BallistaContext; pub use ballista_core::config::BallistaConfig; +pub use ballista_core::config::BALLISTA_DEFAULT_BATCH_SIZE; pub use ballista_core::config::BALLISTA_DEFAULT_SHUFFLE_PARTITIONS; pub use ballista_core::error::{BallistaError, Result}; diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index a835229a6057..d01e59515407 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -705,6 +705,8 @@ message TaskDefinition { bytes plan = 2; // Output partition for shuffle writer PhysicalHashRepartition output_partitioning = 3; + string session_id = 4; + repeated KeyValuePair props = 5; } message PollWorkResult { @@ -745,12 +747,23 @@ message UpdateTaskStatusResult { bool success = 1; } +message OpenSession { + repeated KeyValuePair settings = 1; +} + +message CloseSession { + string session_id = 1; +} + message ExecuteQueryParams { oneof query { bytes logical_plan = 1; string sql = 2; } - repeated KeyValuePair settings = 3; + oneof optional_session_id { + string session_id = 3; + } + repeated KeyValuePair settings = 4; } message ExecuteSqlParams { @@ -759,6 +772,7 @@ message ExecuteSqlParams { message ExecuteQueryResult { string job_id = 1; + string session_id = 2; } message GetJobStatusParams { diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs index 8cdaf1f82952..fffe0ead3d75 100644 --- a/ballista/rust/core/src/config.rs +++ b/ballista/rust/core/src/config.rs @@ -28,6 +28,11 @@ use crate::error::{BallistaError, Result}; use datafusion::arrow::datatypes::DataType; pub const BALLISTA_DEFAULT_SHUFFLE_PARTITIONS: &str = "ballista.shuffle.partitions"; +pub const BALLISTA_DEFAULT_BATCH_SIZE: &str = "ballista.batch.size"; +pub const BALLISTA_REPARTITION_JOINS: &str = "ballista.repartition.joins"; +pub const BALLISTA_REPARTITION_AGGREGATIONS: &str = "ballista.repartition.aggregations"; +pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows"; +pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning"; pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; pub type ParseResult = result::Result; @@ -148,6 +153,21 @@ impl BallistaConfig { ConfigEntry::new(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS.to_string(), "Sets the default number of partitions to create when repartitioning query stages".to_string(), DataType::UInt16, Some("2".to_string())), + ConfigEntry::new(BALLISTA_DEFAULT_BATCH_SIZE.to_string(), + "Sets the default batch size".to_string(), + DataType::UInt16, Some("8192".to_string())), + ConfigEntry::new(BALLISTA_REPARTITION_JOINS.to_string(), + "Configuration for repartition joins".to_string(), + DataType::Boolean, Some("true".to_string())), + ConfigEntry::new(BALLISTA_REPARTITION_AGGREGATIONS.to_string(), + "Configuration for repartition aggregations".to_string(), + DataType::Boolean,Some("true".to_string())), + ConfigEntry::new(BALLISTA_REPARTITION_WINDOWS.to_string(), + "Configuration for repartition windows".to_string(), + DataType::Boolean,Some("true".to_string())), + ConfigEntry::new(BALLISTA_PARQUET_PRUNING.to_string(), + "Configuration for parquet prune".to_string(), + DataType::Boolean,Some("true".to_string())), ConfigEntry::new(BALLISTA_WITH_INFORMATION_SCHEMA.to_string(), "Sets whether enable information_schema".to_string(), DataType::Boolean,Some("false".to_string())), @@ -166,6 +186,26 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS) } + pub fn default_batch_size(&self) -> usize { + self.get_usize_setting(BALLISTA_DEFAULT_BATCH_SIZE) + } + + pub fn repartition_joins(&self) -> bool { + self.get_bool_setting(BALLISTA_REPARTITION_JOINS) + } + + pub fn repartition_aggregations(&self) -> bool { + self.get_bool_setting(BALLISTA_REPARTITION_AGGREGATIONS) + } + + pub fn repartition_windows(&self) -> bool { + self.get_bool_setting(BALLISTA_REPARTITION_WINDOWS) + } + + pub fn parquet_pruning(&self) -> bool { + self.get_bool_setting(BALLISTA_PARQUET_PRUNING) + } + pub fn default_with_information_schema(&self) -> bool { self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA) } diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index a226622fe9db..19214ca7899a 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -40,9 +40,10 @@ use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; +use crate::serde::protobuf::execute_query_params::OptionalSessionId; use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use async_trait::async_trait; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use futures::future; use futures::StreamExt; use log::{error, info}; @@ -63,16 +64,24 @@ pub struct DistributedQueryExec { extension_codec: Arc, /// Phantom data for serializable plan message plan_repr: PhantomData, + /// Session id + session_id: String, } impl DistributedQueryExec { - pub fn new(scheduler_url: String, config: BallistaConfig, plan: LogicalPlan) -> Self { + pub fn new( + scheduler_url: String, + config: BallistaConfig, + plan: LogicalPlan, + session_id: String, + ) -> Self { Self { scheduler_url, config, plan, extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), plan_repr: PhantomData, + session_id, } } @@ -81,6 +90,7 @@ impl DistributedQueryExec { config: BallistaConfig, plan: LogicalPlan, extension_codec: Arc, + session_id: String, ) -> Self { Self { scheduler_url, @@ -88,6 +98,7 @@ impl DistributedQueryExec { plan, extension_codec, plan_repr: PhantomData, + session_id, } } @@ -97,6 +108,7 @@ impl DistributedQueryExec { plan: LogicalPlan, extension_codec: Arc, plan_repr: PhantomData, + session_id: String, ) -> Self { Self { scheduler_url, @@ -104,6 +116,7 @@ impl DistributedQueryExec { plan, extension_codec, plan_repr, + session_id, } } } @@ -144,18 +157,19 @@ impl ExecutionPlan for DistributedQueryExec { plan: self.plan.clone(), extension_codec: self.extension_codec.clone(), plan_repr: self.plan_repr, + session_id: self.session_id(), })) } async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { assert_eq!(0, partition); info!("Connecting to Ballista scheduler at {}", self.scheduler_url); - + // TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again let mut scheduler = SchedulerGrpcClient::connect(self.scheduler_url.clone()) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -176,7 +190,7 @@ impl ExecutionPlan for DistributedQueryExec { DataFusionError::Execution(format!("failed to encode logical plan: {:?}", e)) })?; - let job_id = scheduler + let query_result = scheduler .execute_query(ExecuteQueryParams { query: Some(Query::LogicalPlan(buf)), settings: self @@ -188,11 +202,22 @@ impl ExecutionPlan for DistributedQueryExec { value: v.to_owned(), }) .collect::>(), + optional_session_id: Some(OptionalSessionId::SessionId( + self.session_id(), + )), }) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? - .into_inner() - .job_id; + .into_inner(); + + let response_session_id = query_result.session_id; + assert_eq!( + self.session_id(), + response_session_id, + "Session id inconsistent between Client and Server side in DistributedQueryExec." + ); + + let job_id = query_result.job_id; let mut prev_status: Option = None; @@ -272,6 +297,10 @@ impl ExecutionPlan for DistributedQueryExec { // This implies that we cannot infer the statistics at this stage. Statistics::default() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } async fn fetch_partition( diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index 3bebcd12e155..d24df4fdb7e1 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -24,8 +24,6 @@ use crate::serde::scheduler::{PartitionLocation, PartitionStats}; use crate::utils::WrappedStream; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; - -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::metrics::{ ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -39,6 +37,7 @@ use datafusion::{ }; use futures::{future, StreamExt}; +use datafusion::execution::context::TaskContext; use log::info; /// ShuffleReaderExec reads partitions that have already been materialized by a ShuffleWriterExec @@ -50,6 +49,8 @@ pub struct ShuffleReaderExec { pub(crate) schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Session id + session_id: String, } impl ShuffleReaderExec { @@ -57,11 +58,13 @@ impl ShuffleReaderExec { pub fn try_new( partition: Vec>, schema: SchemaRef, + session_id: String, ) -> Result { Ok(Self { partition, schema, metrics: ExecutionPlanMetricsSet::new(), + session_id, }) } } @@ -106,7 +109,7 @@ impl ExecutionPlan for ShuffleReaderExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { info!("ShuffleReaderExec::execute({})", partition); @@ -174,6 +177,10 @@ impl ExecutionPlan for ShuffleReaderExec { .map(|loc| loc.partition_stats), ) } + + fn session_id(&self) -> String { + self.session_id.clone() + } } fn stats_for_partitions( diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index b80fc8492083..7749c8dfe881 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -42,7 +42,6 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::common::IPCWriter; use datafusion::physical_plan::hash_utils::create_hashes; use datafusion::physical_plan::memory::MemoryStream; @@ -55,6 +54,7 @@ use datafusion::physical_plan::{ }; use futures::StreamExt; +use datafusion::execution::context::TaskContext; use log::{debug, info}; /// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and @@ -138,11 +138,11 @@ impl ShuffleWriterExec { pub async fn execute_shuffle_write( &self, input_partition: usize, - runtime: Arc, + context: Arc, ) -> Result> { let now = Instant::now(); - let mut stream = self.plan.execute(input_partition, runtime).await?; + let mut stream = self.plan.execute(input_partition, context).await?; let mut path = PathBuf::from(&self.work_dir); path.push(&self.job_id); @@ -358,9 +358,9 @@ impl ExecutionPlan for ShuffleWriterExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let part_loc = self.execute_shuffle_write(partition, runtime).await?; + let part_loc = self.execute_shuffle_write(partition, context).await?; // build metadata result batch let num_writers = part_loc.len(); @@ -429,6 +429,10 @@ impl ExecutionPlan for ShuffleWriterExec { fn statistics(&self) -> Statistics { self.plan.statistics() } + + fn session_id(&self) -> String { + self.plan.session_id() + } } fn result_schema() -> SchemaRef { @@ -448,13 +452,13 @@ mod tests { use datafusion::physical_plan::expressions::Column; use datafusion::physical_plan::memory::MemoryExec; + use datafusion::prelude::SessionContext; use tempfile::TempDir; #[tokio::test] async fn test() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); - - let input_plan = Arc::new(CoalescePartitionsExec::new(create_input_plan()?)); + let ctx = SessionContext::new(); + let input_plan = Arc::new(CoalescePartitionsExec::new(create_input_plan(&ctx)?)); let work_dir = TempDir::new()?; let query_stage = ShuffleWriterExec::try_new( "jobOne".to_owned(), @@ -463,7 +467,8 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut stream = query_stage.execute(0, task_ctx).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -506,9 +511,8 @@ mod tests { #[tokio::test] async fn test_partitioned() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); - - let input_plan = create_input_plan()?; + let ctx = SessionContext::new(); + let input_plan = create_input_plan(&ctx)?; let work_dir = TempDir::new()?; let query_stage = ShuffleWriterExec::try_new( "jobOne".to_owned(), @@ -517,7 +521,8 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut stream = query_stage.execute(0, task_ctx).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -541,7 +546,7 @@ mod tests { Ok(()) } - fn create_input_plan() -> Result> { + fn create_input_plan(ctx: &SessionContext) -> Result> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, true), Field::new("b", DataType::Utf8, true), @@ -557,6 +562,11 @@ mod tests { )?; let partition = vec![batch.clone(), batch]; let partitions = vec![partition.clone(), partition]; - Ok(Arc::new(MemoryExec::try_new(&partitions, schema, None)?)) + Ok(Arc::new(MemoryExec::try_new( + &partitions, + schema, + None, + ctx.session_id.clone(), + )?)) } } diff --git a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs index 418546aa389c..a0c806bd8264 100644 --- a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -33,17 +33,20 @@ use datafusion::physical_plan::{ /// is used as a signal so the scheduler knows it can't start computation until the dependent shuffle has completed. #[derive(Debug, Clone)] pub struct UnresolvedShuffleExec { - // The query stage ids which needs to be computed + /// The query stage ids which needs to be computed pub stage_id: usize, - // The schema this node will have once it is replaced with a ShuffleReaderExec + /// The schema this node will have once it is replaced with a ShuffleReaderExec pub schema: SchemaRef, - // The number of shuffle writer partition tasks that will produce the partitions + /// The number of shuffle writer partition tasks that will produce the partitions pub input_partition_count: usize, - // The partition count this node will have once it is replaced with a ShuffleReaderExec + /// The partition count this node will have once it is replaced with a ShuffleReaderExec pub output_partition_count: usize, + + /// Session id + pub session_id: String, } impl UnresolvedShuffleExec { @@ -53,12 +56,14 @@ impl UnresolvedShuffleExec { schema: SchemaRef, input_partition_count: usize, output_partition_count: usize, + session_id: String, ) -> Self { Self { stage_id, schema, input_partition_count, output_partition_count, + session_id, } } } @@ -104,7 +109,7 @@ impl ExecutionPlan for UnresolvedShuffleExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Err(DataFusionError::Plan( "Ballista UnresolvedShuffleExec does not support execution".to_owned(), @@ -128,4 +133,8 @@ impl ExecutionPlan for UnresolvedShuffleExec { // that replaces this one once the previous stage is completed. Statistics::default() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 4970cd600a5a..8a8e05b26232 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -36,7 +36,7 @@ use datafusion::logical_plan::{ Column, CreateExternalTable, CrossJoin, Expr, JoinConstraint, Limit, LogicalPlan, LogicalPlanBuilder, Repartition, TableScan, Values, }; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use prost::bytes::BufMut; use prost::Message; @@ -70,7 +70,7 @@ impl AsLogicalPlan for LogicalPlanNode { fn try_into_logical_plan( &self, - ctx: &ExecutionContext, + ctx: Arc, extension_codec: &dyn LogicalExtensionCodec, ) -> Result { let plan = self.logical_plan_type.as_ref().ok_or_else(|| { @@ -109,7 +109,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Projection(projection) => { let input: LogicalPlan = - into_logical_plan!(projection.input, &ctx, extension_codec)?; + into_logical_plan!(projection.input, ctx, extension_codec)?; let x: Vec = projection .expr .iter() @@ -129,7 +129,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Selection(selection) => { let input: LogicalPlan = - into_logical_plan!(selection.input, &ctx, extension_codec)?; + into_logical_plan!(selection.input, ctx, extension_codec)?; let expr: Expr = selection .expr .as_ref() @@ -144,7 +144,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Window(window) => { let input: LogicalPlan = - into_logical_plan!(window.input, &ctx, extension_codec)?; + into_logical_plan!(window.input, ctx, extension_codec)?; let window_expr = window .window_expr .iter() @@ -157,7 +157,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Aggregate(aggregate) => { let input: LogicalPlan = - into_logical_plan!(aggregate.input, &ctx, extension_codec)?; + into_logical_plan!(aggregate.input, ctx, extension_codec)?; let group_expr = aggregate .group_expr .iter() @@ -224,6 +224,9 @@ impl AsLogicalPlan for LogicalPlanNode { }; let object_store = ctx + .state + .lock() + .runtime .object_store(scan.path.as_str()) .map_err(|e| { BallistaError::NotImplemented(format!( @@ -256,7 +259,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = - into_logical_plan!(sort.input, &ctx, extension_codec)?; + into_logical_plan!(sort.input, ctx, extension_codec)?; let sort_expr: Vec = sort .expr .iter() @@ -270,7 +273,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Repartition(repartition) => { use datafusion::logical_plan::Partitioning; let input: LogicalPlan = - into_logical_plan!(repartition.input, &ctx, extension_codec)?; + into_logical_plan!(repartition.input, ctx, extension_codec)?; use protobuf::repartition_node::PartitionMethod; let pb_partition_method = repartition.partition_method.clone().ok_or_else(|| { BallistaError::General(String::from( @@ -324,7 +327,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Analyze(analyze) => { let input: LogicalPlan = - into_logical_plan!(analyze.input, &ctx, extension_codec)?; + into_logical_plan!(analyze.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .explain(analyze.verbose, true)? .build() @@ -332,7 +335,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Explain(explain) => { let input: LogicalPlan = - into_logical_plan!(explain.input, &ctx, extension_codec)?; + into_logical_plan!(explain.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .explain(explain.verbose, false)? .build() @@ -340,7 +343,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Limit(limit) => { let input: LogicalPlan = - into_logical_plan!(limit.input, &ctx, extension_codec)?; + into_logical_plan!(limit.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .limit(limit.limit as usize)? .build() @@ -370,17 +373,17 @@ impl AsLogicalPlan for LogicalPlanNode { let builder = LogicalPlanBuilder::from(into_logical_plan!( join.left, - &ctx, + ctx.clone(), extension_codec )?); let builder = match join_constraint.into() { JoinConstraint::On => builder.join( - &into_logical_plan!(join.right, &ctx, extension_codec)?, + &into_logical_plan!(join.right, ctx, extension_codec)?, join_type.into(), (left_keys, right_keys), )?, JoinConstraint::Using => builder.join_using( - &into_logical_plan!(join.right, &ctx, extension_codec)?, + &into_logical_plan!(join.right, ctx, extension_codec)?, join_type.into(), left_keys, )?, @@ -389,8 +392,9 @@ impl AsLogicalPlan for LogicalPlanNode { builder.build().map_err(|e| e.into()) } LogicalPlanType::CrossJoin(crossjoin) => { - let left = into_logical_plan!(crossjoin.left, &ctx, extension_codec)?; - let right = into_logical_plan!(crossjoin.right, &ctx, extension_codec)?; + let left = + into_logical_plan!(crossjoin.left, ctx.clone(), extension_codec)?; + let right = into_logical_plan!(crossjoin.right, ctx, extension_codec)?; LogicalPlanBuilder::from(left) .cross_join(&right)? @@ -400,7 +404,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Extension(LogicalExtensionNode { node, inputs }) => { let input_plans: Vec = inputs .iter() - .map(|i| i.try_into_logical_plan(ctx, extension_codec)) + .map(|i| i.try_into_logical_plan(ctx.clone(), extension_codec)) .collect::>()?; let extension_node = extension_codec.try_decode(node, &input_plans)?; @@ -838,7 +842,7 @@ impl AsLogicalPlan for LogicalPlanNode { macro_rules! into_logical_plan { ($PB:expr, $CTX:expr, $CODEC:expr) => {{ if let Some(field) = $PB.as_ref() { - field.as_ref().try_into_logical_plan(&$CTX, $CODEC) + field.as_ref().try_into_logical_plan($CTX, $CODEC) } else { Err(proto_error("Missing required field in protobuf")) } @@ -858,6 +862,7 @@ mod roundtrip_tests { FileMetaStream, ListEntryStream, ObjectReader, ObjectStore, SizedFile, }; use datafusion::error::DataFusionError; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::{ arrow::datatypes::{DataType, Field, Schema}, datasource::object_store::local::LocalFileSystem, @@ -920,7 +925,7 @@ mod roundtrip_tests { roundtrip_test!($initial_struct, protobuf::LogicalPlanNode, $struct_type); }; ($initial_struct:ident) => { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let codec: BallistaCodec< protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode, @@ -932,7 +937,7 @@ mod roundtrip_tests { ) .expect("from logical plan"); let round_trip: LogicalPlan = proto - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .try_into_logical_plan(Arc::new(ctx), codec.logical_extension_codec()) .expect("to logical plan"); assert_eq!( @@ -949,7 +954,7 @@ mod roundtrip_tests { protobuf::LogicalPlanNode::try_from_logical_plan(&$initial_struct) .expect("from logical plan"); let round_trip: LogicalPlan = proto - .try_into_logical_plan(&$ctx, codec.logical_extension_codec()) + .try_into_logical_plan($ctx, codec.logical_extension_codec()) .expect("to logical plan"); assert_eq!( @@ -1252,13 +1257,14 @@ mod roundtrip_tests { #[tokio::test] async fn roundtrip_logical_plan_custom_ctx() -> Result<()> { - let ctx = ExecutionContext::new(); + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + let ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime.clone()); let codec: BallistaCodec = BallistaCodec::default(); let custom_object_store = Arc::new(TestObjectStore {}); - ctx.register_object_store("test", custom_object_store.clone()); + runtime.register_object_store("test", custom_object_store.clone()); - let (os, _) = ctx.object_store("test://foo.csv")?; + let (os, _) = runtime.object_store("test://foo.csv")?; println!("Object Store {:?}", os); @@ -1288,7 +1294,7 @@ mod roundtrip_tests { ) .expect("from logical plan"); let round_trip: LogicalPlan = proto - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .try_into_logical_plan(Arc::new(ctx), codec.logical_extension_codec()) .expect("to logical plan"); assert_eq!(format!("{:?}", plan), format!("{:?}", round_trip)); diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index b85a957d43f6..c507417f807c 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -28,9 +28,10 @@ use datafusion::logical_plan::{JoinConstraint, JoinType, LogicalPlan, Operator}; use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::Extension; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use prost::Message; // include the generated protobuf source as a submodule @@ -67,7 +68,7 @@ pub trait AsLogicalPlan: Debug + Send + Sync + Clone { fn try_into_logical_plan( &self, - ctx: &ExecutionContext, + ctx: Arc, extension_codec: &dyn LogicalExtensionCodec, ) -> Result; @@ -130,8 +131,9 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { fn try_into_physical_plan( &self, - ctx: &ExecutionContext, + session_id: String, extension_codec: &dyn PhysicalExtensionCodec, + runtime: Arc, ) -> Result, BallistaError>; fn try_from_physical_plan( @@ -345,8 +347,7 @@ mod tests { use datafusion::arrow::datatypes::SchemaRef; use datafusion::datasource::object_store::local::LocalFileSystem; use datafusion::error::DataFusionError; - use datafusion::execution::context::{ExecutionContextState, QueryPlanner}; - use datafusion::execution::runtime_env::RuntimeEnv; + use datafusion::execution::context::{QueryPlanner, SessionState, TaskContext}; use datafusion::logical_plan::plan::Extension; use datafusion::logical_plan::{ col, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, @@ -357,10 +358,11 @@ mod tests { DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalPlanner, SendableRecordBatchStream, Statistics, }; - use datafusion::prelude::{CsvReadOptions, ExecutionConfig, ExecutionContext}; + use datafusion::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use prost::Message; use std::any::Any; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use std::convert::TryInto; use std::fmt; use std::fmt::{Debug, Formatter}; @@ -512,7 +514,7 @@ mod tests { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> datafusion::error::Result { Err(DataFusionError::NotImplemented( "not implemented".to_string(), @@ -536,6 +538,10 @@ mod tests { // better statistics inference could be provided Statistics::default() } + + fn session_id(&self) -> String { + self.input.session_id() + } } struct TopKPlanner {} @@ -548,7 +554,7 @@ mod tests { node: &dyn UserDefinedLogicalNode, logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> datafusion::error::Result>> { Ok( if let Some(topk_node) = node.as_any().downcast_ref::() { @@ -575,7 +581,7 @@ mod tests { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> datafusion::error::Result> { // Teach the default physical planner how to plan TopK nodes. let physical_planner = @@ -584,7 +590,7 @@ mod tests { )]); // Delegate most work of physical planning to the default physical planner physical_planner - .create_physical_plan(logical_plan, ctx_state) + .create_physical_plan(logical_plan, session_state) .await } } @@ -693,10 +699,10 @@ mod tests { #[tokio::test] async fn test_extension_plan() -> crate::error::Result<()> { let store = Arc::new(LocalFileSystem {}); - let config = - ExecutionConfig::new().with_query_planner(Arc::new(TopKQueryPlanner {})); - - let ctx = ExecutionContext::with_config(config); + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + let state = SessionState::with_config(SessionConfig::new(), runtime.clone()) + .with_query_planner(Arc::new(TopKQueryPlanner {})); + let ctx = Arc::new(SessionContext::with_state(state)); let scan = LogicalPlanBuilder::scan_csv( store, @@ -717,7 +723,8 @@ mod tests { let extension_codec = TopKExtensionCodec {}; let proto = LogicalPlanNode::try_from_logical_plan(&topk_plan, &extension_codec)?; - let logical_round_trip = proto.try_into_logical_plan(&ctx, &extension_codec)?; + let logical_round_trip = + proto.try_into_logical_plan(ctx.clone(), &extension_codec)?; assert_eq!( format!("{:?}", topk_plan), @@ -728,7 +735,11 @@ mod tests { topk_exec.clone(), &extension_codec, )?; - let physical_round_trip = proto.try_into_physical_plan(&ctx, &extension_codec)?; + let physical_round_trip = proto.try_into_physical_plan( + ctx.session_id.clone(), + &extension_codec, + runtime, + )?; assert_eq!( format!("{:?}", topk_exec), diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 8daefc904b7c..dfd16e2d7d36 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -26,19 +26,12 @@ use crate::serde::{from_proto_binary_op, proto_error, protobuf}; use crate::{convert_box_required, convert_required}; use chrono::{TimeZone, Utc}; -use datafusion::catalog::catalog::{CatalogList, MemoryCatalogList}; use datafusion::datasource::object_store::local::LocalFileSystem; -use datafusion::datasource::object_store::{FileMeta, ObjectStoreRegistry, SizedFile}; +use datafusion::datasource::object_store::{FileMeta, SizedFile}; use datafusion::datasource::PartitionedFile; -use datafusion::execution::context::{ - ExecutionConfig, ExecutionContextState, ExecutionProps, -}; -use datafusion::execution::runtime_env::RuntimeEnv; - +use datafusion::execution::context::ExecutionProps; use datafusion::physical_plan::file_format::FileScanConfig; - use datafusion::physical_plan::window_functions::WindowFunction; - use datafusion::physical_plan::{ expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, @@ -157,22 +150,10 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { .map(|x| x.try_into()) .collect::, _>>()?; - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - object_store_registry: Arc::new(ObjectStoreRegistry::new()), - runtime_env: Arc::new(RuntimeEnv::default()), - }; - + let execution_props = ExecutionProps::default(); let fun_expr = functions::create_physical_fun( &(&scalar_function).into(), - &ctx_state.execution_props, + &execution_props, )?; Arc::new(ScalarFunctionExpr::new( diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 83607ae6e555..4334b990d91b 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -35,6 +35,7 @@ use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; use datafusion::datasource::object_store::local::LocalFileSystem; use datafusion::datasource::PartitionedFile; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::window_frames::WindowFrame; use datafusion::physical_plan::aggregates::create_aggregate_expr; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -56,7 +57,6 @@ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, }; -use datafusion::prelude::ExecutionContext; use prost::bytes::BufMut; use prost::Message; use std::convert::TryInto; @@ -87,8 +87,9 @@ impl AsExecutionPlan for PhysicalPlanNode { fn try_into_physical_plan( &self, - ctx: &ExecutionContext, + session_id: String, extension_codec: &dyn PhysicalExtensionCodec, + runtime: Arc, ) -> Result, BallistaError> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( @@ -98,8 +99,12 @@ impl AsExecutionPlan for PhysicalPlanNode { })?; match plan { PhysicalPlanType::Projection(projection) => { - let input: Arc = - into_physical_plan!(projection.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + projection.input, + session_id, + extension_codec, + runtime + )?; let exprs = projection .expr .iter() @@ -110,8 +115,12 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) } PhysicalPlanType::Filter(filter) => { - let input: Arc = - into_physical_plan!(filter.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + filter.input, + session_id, + extension_codec, + runtime + )?; let predicate = filter .expr .as_ref() @@ -125,9 +134,10 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(FilterExec::try_new(predicate, input)?)) } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?, + decode_scan_config(scan.base_conf.as_ref().unwrap(), runtime)?, scan.has_header, str_to_byte(&scan.delimiter)?, + session_id, ))), PhysicalPlanType::ParquetScan(scan) => { let predicate = scan @@ -136,29 +146,43 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.try_into()) .transpose()?; Ok(Arc::new(ParquetExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?, + decode_scan_config(scan.base_conf.as_ref().unwrap(), runtime)?, predicate, + session_id, ))) } PhysicalPlanType::AvroScan(scan) => Ok(Arc::new(AvroExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?, + decode_scan_config(scan.base_conf.as_ref().unwrap(), runtime)?, + session_id, ))), PhysicalPlanType::CoalesceBatches(coalesce_batches) => { - let input: Arc = - into_physical_plan!(coalesce_batches.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + coalesce_batches.input, + session_id, + extension_codec, + runtime + )?; Ok(Arc::new(CoalesceBatchesExec::new( input, coalesce_batches.target_batch_size as usize, ))) } PhysicalPlanType::Merge(merge) => { - let input: Arc = - into_physical_plan!(merge.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + merge.input, + session_id, + extension_codec, + runtime + )?; Ok(Arc::new(CoalescePartitionsExec::new(input))) } PhysicalPlanType::Repartition(repart) => { - let input: Arc = - into_physical_plan!(repart.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + repart.input, + session_id, + extension_codec, + runtime + )?; match repart.partition_method { Some(PartitionMethod::Hash(ref hash_part)) => { let expr = hash_part @@ -197,18 +221,30 @@ impl AsExecutionPlan for PhysicalPlanNode { } } PhysicalPlanType::GlobalLimit(limit) => { - let input: Arc = - into_physical_plan!(limit.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + limit.input, + session_id, + extension_codec, + runtime + )?; Ok(Arc::new(GlobalLimitExec::new(input, limit.limit as usize))) } PhysicalPlanType::LocalLimit(limit) => { - let input: Arc = - into_physical_plan!(limit.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + limit.input, + session_id, + extension_codec, + runtime + )?; Ok(Arc::new(LocalLimitExec::new(input, limit.limit as usize))) } PhysicalPlanType::Window(window_agg) => { - let input: Arc = - into_physical_plan!(window_agg.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + window_agg.input, + session_id, + extension_codec, + runtime + )?; let input_schema = window_agg .input_schema .as_ref() @@ -254,8 +290,12 @@ impl AsExecutionPlan for PhysicalPlanNode { )?)) } PhysicalPlanType::HashAggregate(hash_agg) => { - let input: Arc = - into_physical_plan!(hash_agg.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + hash_agg.input, + session_id, + extension_codec, + runtime + )?; let mode = protobuf::AggregateMode::from_i32(hash_agg.mode).ok_or_else(|| { proto_error(format!( "Received a HashAggregateNode message with unknown AggregateMode {}", @@ -339,10 +379,18 @@ impl AsExecutionPlan for PhysicalPlanNode { )?)) } PhysicalPlanType::HashJoin(hashjoin) => { - let left: Arc = - into_physical_plan!(hashjoin.left, ctx, extension_codec)?; - let right: Arc = - into_physical_plan!(hashjoin.right, ctx, extension_codec)?; + let left: Arc = into_physical_plan!( + hashjoin.left, + session_id.clone(), + extension_codec, + runtime.clone() + )?; + let right: Arc = into_physical_plan!( + hashjoin.right, + session_id, + extension_codec, + runtime + )?; let on: Vec<(Column, Column)> = hashjoin .on .iter() @@ -382,15 +430,27 @@ impl AsExecutionPlan for PhysicalPlanNode { )?)) } PhysicalPlanType::CrossJoin(crossjoin) => { - let left: Arc = - into_physical_plan!(crossjoin.left, ctx, extension_codec)?; - let right: Arc = - into_physical_plan!(crossjoin.right, ctx, extension_codec)?; + let left: Arc = into_physical_plan!( + crossjoin.left, + session_id.clone(), + extension_codec, + runtime.clone() + )?; + let right: Arc = into_physical_plan!( + crossjoin.right, + session_id, + extension_codec, + runtime + )?; Ok(Arc::new(CrossJoinExec::try_new(left, right)?)) } PhysicalPlanType::ShuffleWriter(shuffle_writer) => { - let input: Arc = - into_physical_plan!(shuffle_writer.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + shuffle_writer.input, + session_id, + extension_codec, + runtime + )?; let output_partitioning = parse_protobuf_hash_partitioning( shuffle_writer.output_partitioning.as_ref(), @@ -417,16 +477,24 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .collect::, BallistaError>>()?; let shuffle_reader = - ShuffleReaderExec::try_new(partition_location, schema)?; + ShuffleReaderExec::try_new(partition_location, schema, session_id)?; Ok(Arc::new(shuffle_reader)) } PhysicalPlanType::Empty(empty) => { let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(empty.produce_one_row, schema))) + Ok(Arc::new(EmptyExec::new( + empty.produce_one_row, + schema, + session_id, + ))) } PhysicalPlanType::Sort(sort) => { - let input: Arc = - into_physical_plan!(sort.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + sort.input, + session_id, + extension_codec, + runtime + )?; let exprs = sort .expr .iter() @@ -474,13 +542,20 @@ impl AsExecutionPlan for PhysicalPlanNode { as usize, output_partition_count: unresolved_shuffle.output_partition_count as usize, + session_id, })) } PhysicalPlanType::Extension(extension) => { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| { + i.try_into_physical_plan( + session_id.clone(), + extension_codec, + runtime.clone(), + ) + }) .collect::>()?; let extension_node = @@ -883,7 +958,7 @@ impl AsExecutionPlan for PhysicalPlanNode { fn decode_scan_config( proto: &protobuf::FileScanExecConf, - ctx: &ExecutionContext, + runtime: Arc, ) -> Result { let schema = Arc::new(convert_required!(proto.schema)?); let projection = proto @@ -905,7 +980,7 @@ fn decode_scan_config( .collect::, _>>()?; let object_store = if let Some(file) = file_groups.get(0).and_then(|h| h.get(0)) { - ctx.object_store(file.file_meta.path())?.0 + runtime.object_store(file.file_meta.path())?.0 } else { Arc::new(LocalFileSystem {}) }; @@ -923,9 +998,11 @@ fn decode_scan_config( #[macro_export] macro_rules! into_physical_plan { - ($PB:expr, $CTX:expr, $CODEC:expr) => {{ + ($PB:expr, $SESS_ID:expr, $CODEC:expr, $RUNTIME:expr) => {{ if let Some(field) = $PB.as_ref() { - field.as_ref().try_into_physical_plan(&$CTX, $CODEC) + field + .as_ref() + .try_into_physical_plan($SESS_ID, $CODEC, $RUNTIME) } else { Err(proto_error("Missing required field in protobuf")) } @@ -939,8 +1016,8 @@ mod roundtrip_tests { use crate::serde::{AsExecutionPlan, BallistaCodec}; use datafusion::datasource::object_store::local::LocalFileSystem; use datafusion::datasource::PartitionedFile; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::sorts::sort::SortExec; - use datafusion::prelude::ExecutionContext; use datafusion::{ arrow::{ compute::kernels::sort::SortOptions, @@ -969,7 +1046,7 @@ mod roundtrip_tests { use crate::serde::protobuf::{LogicalPlanNode, PhysicalPlanNode}; fn roundtrip_test(exec_plan: Arc) -> Result<()> { - let ctx = ExecutionContext::new(); + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); let codec: BallistaCodec = BallistaCodec::default(); let proto: protobuf::PhysicalPlanNode = @@ -979,7 +1056,11 @@ mod roundtrip_tests { ) .expect("to proto"); let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, codec.physical_extension_codec()) + .try_into_physical_plan( + "sess_123".to_owned(), + codec.physical_extension_codec(), + runtime, + ) .expect("from proto"); assert_eq!( format!("{:?}", exec_plan), @@ -990,13 +1071,21 @@ mod roundtrip_tests { #[test] fn roundtrip_empty() -> Result<()> { - roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) + roundtrip_test(Arc::new(EmptyExec::new( + false, + Arc::new(Schema::empty()), + "sess_123".to_owned(), + ))) } #[test] fn roundtrip_local_limit() -> Result<()> { roundtrip_test(Arc::new(LocalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new( + false, + Arc::new(Schema::empty()), + "sess_123".to_owned(), + )), 25, ))) } @@ -1004,7 +1093,11 @@ mod roundtrip_tests { #[test] fn roundtrip_global_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new( + false, + Arc::new(Schema::empty()), + "sess_123".to_owned(), + )), 25, ))) } @@ -1033,8 +1126,16 @@ mod roundtrip_tests { &[PartitionMode::Partitioned, PartitionMode::CollectLeft] { roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new( + false, + schema_left.clone(), + "sess_123".to_owned(), + )), + Arc::new(EmptyExec::new( + false, + schema_right.clone(), + "sess_123".to_owned(), + )), on.clone(), join_type, *partition_mode, @@ -1064,7 +1165,7 @@ mod roundtrip_tests { AggregateMode::Final, groups.clone(), aggregates.clone(), - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(false, schema.clone(), "sess_123".to_owned())), schema, )?)) } @@ -1085,9 +1186,10 @@ mod roundtrip_tests { false, )); let and = binary(not, Operator::And, in_list, &schema)?; + let session_id = "sess_123"; roundtrip_test(Arc::new(FilterExec::try_new( and, - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(false, schema.clone(), session_id.to_owned())), )?)) } @@ -1114,7 +1216,7 @@ mod roundtrip_tests { ]; roundtrip_test(Arc::new(SortExec::try_new( sort_exprs, - Arc::new(EmptyExec::new(false, schema)), + Arc::new(EmptyExec::new(false, schema, "sess_123".to_owned())), )?)) } @@ -1127,7 +1229,7 @@ mod roundtrip_tests { roundtrip_test(Arc::new(ShuffleWriterExec::try_new( "job123".to_string(), 123, - Arc::new(EmptyExec::new(false, schema)), + Arc::new(EmptyExec::new(false, schema, "sess_123".to_owned())), "".to_string(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4)), )?)) @@ -1158,6 +1260,10 @@ mod roundtrip_tests { }; let predicate = datafusion::prelude::col("col").eq(datafusion::prelude::lit("1")); - roundtrip_test(Arc::new(ParquetExec::new(scan_config, Some(predicate)))) + roundtrip_test(Arc::new(ParquetExec::new( + scan_config, + Some(predicate), + "sess_123".to_owned(), + ))) } } diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 560d459977dd..960bc68c2562 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -37,16 +37,17 @@ use datafusion::arrow::{ datatypes::SchemaRef, ipc::writer::FileWriter, record_batch::RecordBatch, }; use datafusion::error::DataFusionError; -use datafusion::execution::context::{ - ExecutionConfig, ExecutionContext, ExecutionContextState, QueryPlanner, -}; +use datafusion::execution::context::QueryPlanner; +use datafusion::execution::context::SessionContext as ExecutionContext; +use datafusion::execution::context::SessionState as ExecutionState; +use datafusion::execution::context::{SessionConfig as ExecutionConfig, SessionState}; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::common::batch_byte_size; use datafusion::physical_plan::empty::EmptyExec; - use datafusion::physical_plan::file_format::{CsvExec, ParquetExec}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::HashAggregateExec; @@ -54,6 +55,7 @@ use datafusion::physical_plan::hash_join::HashJoinExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream}; + use futures::{Stream, StreamExt}; /// Stream data to disk in Arrow IPC format @@ -224,21 +226,26 @@ fn build_exec_plan_diagram( Ok(node_id) } -/// Create a DataFusion context that uses the BallistaQueryPlanner to send logical plans +/// Create a client DataFusion Context that uses the BallistaQueryPlanner to send logical plans /// to a Ballista scheduler pub fn create_df_ctx_with_ballista_query_planner( - scheduler_host: &str, - scheduler_port: u16, + scheduler_url: String, + session_id: String, config: &BallistaConfig, ) -> ExecutionContext { - let scheduler_url = format!("http://{}:{}", scheduler_host, scheduler_port); let planner: Arc> = Arc::new(BallistaQueryPlanner::new(scheduler_url, config.clone())); - let config = ExecutionConfig::new() - .with_query_planner(planner) + let session_config = ExecutionConfig::new() .with_target_partitions(config.default_shuffle_partitions()) .with_information_schema(true); - ExecutionContext::with_config(config) + let mut session_state = ExecutionState::with_config( + session_config, + Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()), + ) + .with_query_planner(planner); + session_state.session_id = session_id; + // the ExecutionContext created here is the client side context, but the session_id is from server side. + ExecutionContext::with_state(session_state) } pub struct BallistaQueryPlanner { @@ -291,12 +298,16 @@ impl QueryPlanner for BallistaQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - _ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> std::result::Result, DataFusionError> { match logical_plan { LogicalPlan::CreateExternalTable(_) => { // table state is managed locally in the BallistaContext, not in the scheduler - Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) + Ok(Arc::new(EmptyExec::new( + false, + Arc::new(Schema::empty()), + session_state.session_id.clone(), + ))) } _ => Ok(Arc::new(DistributedQueryExec::with_repr( self.scheduler_url.clone(), @@ -304,6 +315,7 @@ impl QueryPlanner for BallistaQueryPlanner { logical_plan.clone(), self.extension_codec.clone(), self.plan_repr, + session_state.session_id.clone(), ))), } } diff --git a/ballista/rust/executor/src/collect.rs b/ballista/rust/executor/src/collect.rs index 37a7f7bb0d1b..3832da9ed6d5 100644 --- a/ballista/rust/executor/src/collect.rs +++ b/ballista/rust/executor/src/collect.rs @@ -27,7 +27,7 @@ use datafusion::arrow::{ datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, }; use datafusion::error::DataFusionError; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -41,11 +41,13 @@ use futures::Stream; #[derive(Debug, Clone)] pub struct CollectExec { plan: Arc, + /// Session id + session_id: String, } impl CollectExec { - pub fn new(plan: Arc) -> Self { - Self { plan } + pub fn new(plan: Arc, session_id: String) -> Self { + Self { plan, session_id } } } @@ -81,12 +83,12 @@ impl ExecutionPlan for CollectExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { assert_eq!(0, partition); let num_partitions = self.plan.output_partitioning().partition_count(); - let futures = (0..num_partitions).map(|i| self.plan.execute(i, runtime.clone())); + let futures = (0..num_partitions).map(|i| self.plan.execute(i, context.clone())); let streams = futures::future::join_all(futures) .await .into_iter() @@ -114,6 +116,10 @@ impl ExecutionPlan for CollectExec { fn statistics(&self) -> Statistics { self.plan.statistics() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } struct MergedRecordBatchStream { diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs index ddb2c972a70f..587a62e9991c 100644 --- a/ballista/rust/executor/src/execution_loop.rs +++ b/ballista/rust/executor/src/execution_loop.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{Receiver, Sender, TryRecvError}; use std::{sync::Arc, time::Duration}; @@ -34,6 +35,7 @@ use ballista_core::error::BallistaError; use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::execution::context::TaskContext; pub async fn poll_loop( mut scheduler: SchedulerGrpcClient, @@ -122,13 +124,27 @@ async fn run_received_tasks = U::try_decode(task.plan.as_slice()).and_then(|proto| { proto.try_into_physical_plan( - executor.ctx.as_ref(), + session_id, codec.physical_extension_codec(), + runtime, ) })?; @@ -142,11 +158,13 @@ async fn run_received_tasks, + /// Runtime environment for Executor + pub runtime: Arc, } impl Executor { @@ -46,12 +46,12 @@ impl Executor { pub fn new( metadata: ExecutorRegistration, work_dir: &str, - ctx: Arc, + runtime: Arc, ) -> Self { Self { metadata, work_dir: work_dir.to_owned(), - ctx, + runtime, } } } @@ -66,6 +66,7 @@ impl Executor { stage_id: usize, part: usize, plan: Arc, + task_context: Arc, _shuffle_output_partitioning: Option, ) -> Result, BallistaError> { let exec = if let Some(shuffle_writer) = @@ -86,10 +87,7 @@ impl Executor { )) }?; - let config = ExecutionConfig::new().with_temp_file_path(self.work_dir.clone()); - let runtime = Arc::new(RuntimeEnv::new(config.runtime)?); - - let partitions = exec.execute_shuffle_write(part, runtime).await?; + let partitions = exec.execute_shuffle_write(part, task_context).await?; println!( "=== [{}/{}/{}] Physical plan with metrics ===\n{}\n", diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs index ad34634265bb..1f612bb436d2 100644 --- a/ballista/rust/executor/src/executor_server.rs +++ b/ballista/rust/executor/src/executor_server.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; @@ -36,6 +37,7 @@ use ballista_core::serde::protobuf::{ }; use ballista_core::serde::scheduler::ExecutorState; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::ExecutionPlan; use crate::as_task_status; @@ -175,15 +177,28 @@ impl ExecutorServer = U::try_decode(encoded_plan).and_then(|proto| { proto.try_into_physical_plan( - self.executor.ctx.as_ref(), + session_id, self.codec.physical_extension_codec(), + runtime, ) })?; @@ -197,6 +212,7 @@ impl ExecutorServer Result<()> { info!("work_dir: {}", work_dir); info!("concurrent_tasks: {}", opt.concurrent_tasks); + // assign this executor an unique ID + let executor_id = Uuid::new_v4().to_string(); let executor_meta = ExecutorRegistration { - id: Uuid::new_v4().to_string(), // assign this executor a unique ID + id: executor_id.clone(), optional_host: external_host .clone() .map(executor_registration::OptionalHost::Host), @@ -111,11 +114,13 @@ async fn main() -> Result<()> { .into(), ), }; - let executor = Arc::new(Executor::new( - executor_meta, - &work_dir, - Arc::new(ExecutionContext::new()), - )); + + let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone()); + let runtime = Arc::new(RuntimeEnv::new(config).map_err(|_| { + BallistaError::Internal("Failed to init Executor RuntimeEnv".to_owned()) + })?); + + let executor = Arc::new(Executor::new(executor_meta, &work_dir, runtime)); let scheduler = SchedulerGrpcClient::connect(scheduler_url) .await diff --git a/ballista/rust/executor/src/standalone.rs b/ballista/rust/executor/src/standalone.rs index 0bc2503e9dfc..5959d4a7bb83 100644 --- a/ballista/rust/executor/src/standalone.rs +++ b/ballista/rust/executor/src/standalone.rs @@ -27,7 +27,7 @@ use ballista_core::{ serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration}, BALLISTA_VERSION, }; -use datafusion::prelude::ExecutionContext; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use log::info; use tempfile::TempDir; use tokio::net::TcpListener; @@ -52,8 +52,10 @@ pub async fn new_standalone_executor< BALLISTA_VERSION, addr ); + // assign this executor a unique ID + let executor_id = Uuid::new_v4().to_string(); let executor_meta = ExecutorRegistration { - id: Uuid::new_v4().to_string(), // assign this executor a unique ID + id: executor_id, // assign this executor a unique ID optional_host: Some(OptionalHost::Host("localhost".to_string())), port: addr.port() as u32, // TODO Make it configurable @@ -71,8 +73,14 @@ pub async fn new_standalone_executor< .into_string() .unwrap(); info!("work_dir: {}", work_dir); - let ctx = Arc::new(ExecutionContext::new()); - let executor = Arc::new(Executor::new(executor_meta, &work_dir, ctx)); + + let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone()); + + let executor = Arc::new(Executor::new( + executor_meta, + &work_dir, + Arc::new(RuntimeEnv::new(config).unwrap()), + )); let service = BallistaFlightService::new(executor.clone()); let server = FlightServiceServer::new(service); diff --git a/ballista/rust/scheduler/src/main.rs b/ballista/rust/scheduler/src/main.rs index 6646ce32428a..96589491a3dc 100644 --- a/ballista/rust/scheduler/src/main.rs +++ b/ballista/rust/scheduler/src/main.rs @@ -40,13 +40,11 @@ use ballista_scheduler::state::EtcdClient; #[cfg(feature = "sled")] use ballista_scheduler::state::StandaloneClient; -use ballista_scheduler::scheduler_server::SchedulerServer; -use ballista_scheduler::state::{ConfigBackend, ConfigBackendClient}; - use ballista_core::config::TaskSchedulingPolicy; use ballista_core::serde::BallistaCodec; +use ballista_scheduler::scheduler_server::SchedulerServer; +use ballista_scheduler::state::{ConfigBackend, ConfigBackendClient}; use log::info; -use tokio::sync::RwLock; #[macro_use] extern crate configure_me; @@ -62,7 +60,7 @@ mod config { } use config::prelude::*; -use datafusion::prelude::ExecutionContext; +use datafusion::execution::context::default_session_builder; async fn start_server( config_backend: Arc, @@ -85,13 +83,12 @@ async fn start_server( config_backend.clone(), namespace.clone(), policy, - Arc::new(RwLock::new(ExecutionContext::new())), BallistaCodec::default(), + default_session_builder, ), _ => SchedulerServer::new( config_backend.clone(), namespace.clone(), - Arc::new(RwLock::new(ExecutionContext::new())), BallistaCodec::default(), ), }; diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 68f26e9ffa1d..3c4ab0e00c7d 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -119,6 +119,7 @@ impl DistributedPlanner { .unwrap_or_else(|| { shuffle_writer.output_partitioning().partition_count() }), + shuffle_writer.session_id(), )); stages.push(shuffle_writer); Ok(( @@ -146,6 +147,7 @@ impl DistributedPlanner { .unwrap_or_else(|| { shuffle_writer.output_partitioning().partition_count() }), + shuffle_writer.session_id(), )); stages.push(shuffle_writer); Ok((unresolved_shuffle, stages)) @@ -218,6 +220,7 @@ pub fn remove_unresolved_shuffles( new_children.push(Arc::new(ShuffleReaderExec::try_new( relevant_locations, unresolved_shuffle.schema().clone(), + child.session_id(), )?)) } else { new_children.push(remove_unresolved_shuffles( @@ -238,7 +241,7 @@ fn create_shuffle_writer( Ok(Arc::new(ShuffleWriterExec::try_new( job_id.to_owned(), stage_id, - plan, + plan.clone(), "".to_owned(), // executor will decide on the work_dir path partitioning, )?)) @@ -259,9 +262,9 @@ mod test { coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, }; use datafusion::physical_plan::{displayable, ExecutionPlan}; - use datafusion::prelude::ExecutionContext; use ballista_core::serde::protobuf::{LogicalPlanNode, PhysicalPlanNode}; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use std::sync::Arc; use uuid::Uuid; @@ -273,7 +276,7 @@ mod test { #[tokio::test] async fn distributed_hash_aggregate_plan() -> Result<(), BallistaError> { - let mut ctx = datafusion_test_context("testdata").await?; + let ctx = datafusion_test_context("testdata").await?; // simplified form of TPC-H query 1 let df = ctx @@ -360,7 +363,7 @@ mod test { #[tokio::test] async fn distributed_join_plan() -> Result<(), BallistaError> { - let mut ctx = datafusion_test_context("testdata").await?; + let ctx = datafusion_test_context("testdata").await?; // simplified form of TPC-H query 12 let df = ctx @@ -535,7 +538,7 @@ order by #[tokio::test] async fn roundtrip_serde_hash_aggregate() -> Result<(), BallistaError> { - let mut ctx = datafusion_test_context("testdata").await?; + let ctx = datafusion_test_context("testdata").await?; // simplified form of TPC-H query 1 let df = ctx @@ -574,7 +577,7 @@ order by fn roundtrip_operator( plan: Arc, ) -> Result, BallistaError> { - let ctx = ExecutionContext::new(); + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); let codec: BallistaCodec = BallistaCodec::default(); let proto: protobuf::PhysicalPlanNode = @@ -582,8 +585,11 @@ order by plan.clone(), codec.physical_extension_codec(), )?; - let result_exec_plan: Arc = - (&proto).try_into_physical_plan(&ctx, codec.physical_extension_codec())?; + let result_exec_plan: Arc = (&proto).try_into_physical_plan( + plan.session_id(), + codec.physical_extension_codec(), + runtime, + )?; Ok(result_exec_plan) } } diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs b/ballista/rust/scheduler/src/scheduler_server/grpc.rs index 7f7764cbe5c2..c52630e1842f 100644 --- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs @@ -16,18 +16,18 @@ // under the License. use anyhow::Context; -use ballista_core::config::TaskSchedulingPolicy; +use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy}; use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; -use ballista_core::serde::protobuf::execute_query_params::Query; +use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query}; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::protobuf::executor_registration::OptionalHost; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc; use ballista_core::serde::protobuf::{ job_status, ExecuteQueryParams, ExecuteQueryResult, ExecutorHeartbeat, FailedJob, FileType, GetFileMetadataParams, GetFileMetadataResult, GetJobStatusParams, - GetJobStatusResult, HeartBeatParams, HeartBeatResult, JobStatus, PartitionId, - PollWorkParams, PollWorkResult, QueuedJob, RegisterExecutorParams, + GetJobStatusResult, HeartBeatParams, HeartBeatResult, JobStatus, KeyValuePair, + PartitionId, PollWorkParams, PollWorkResult, QueuedJob, RegisterExecutorParams, RegisterExecutorResult, RunningJob, TaskDefinition, TaskStatus, UpdateTaskStatusParams, UpdateTaskStatusResult, }; @@ -39,6 +39,7 @@ use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::object_store::{local::LocalFileSystem, ObjectStore}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; use futures::StreamExt; use log::{debug, error, info, trace, warn}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; @@ -50,7 +51,9 @@ use tonic::{Request, Response, Status}; use crate::planner::DistributedPlanner; use crate::scheduler_server::event_loop::SchedulerServerEvent; -use crate::scheduler_server::SchedulerServer; +use crate::scheduler_server::{ + create_datafusion_session_context, update_datafusion_session_context, SchedulerServer, +}; #[tonic::async_trait] impl SchedulerGrpc @@ -150,6 +153,14 @@ impl SchedulerGrpc plan_clone ))); }; + let session_props = self + .state + .session_context_registry + .lookup_session(plan.session_id().as_str()) + .await + .expect("SessionContext does not exist in SessionContextRegistry.") + .copied_config() + .to_props(); let mut buf: Vec = vec![]; U::try_from_physical_plan( plan, @@ -162,6 +173,14 @@ impl SchedulerGrpc e )) })?; + + let task_props = session_props + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(); Ok(Some(TaskDefinition { plan: buf, task_id: status.task_id, @@ -169,6 +188,8 @@ impl SchedulerGrpc output_partitioning, ) .map_err(|_| Status::internal("TBD".to_string()))?, + session_id: plan_clone.session_id(), + props: task_props, })) } None => Ok(None), @@ -371,31 +392,63 @@ impl SchedulerGrpc &self, request: Request, ) -> std::result::Result, tonic::Status> { + let query_params = request.into_inner(); if let ExecuteQueryParams { query: Some(query), - settings: _, - } = request.into_inner() + settings, + optional_session_id, + } = query_params { - let plan = match query { - Query::LogicalPlan(message) => { - let ctx = self.ctx.read().await; - T::try_decode(message.as_slice()) - .and_then(|m| { - m.try_into_logical_plan( - &ctx, - self.codec.logical_extension_codec(), - ) - }) - .map_err(|e| { - let msg = - format!("Could not parse logical plan protobuf: {}", e); - error!("{}", msg); - tonic::Status::internal(msg) - })? + // parse config + let mut config_builder = BallistaConfig::builder(); + for kv_pair in &settings { + config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); + } + let config = config_builder.build().map_err(|e| { + let msg = format!("Could not parse configs: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + + let df_session = match optional_session_id { + Some(OptionalSessionId::SessionId(session_id)) => { + let session_ctx = self + .state + .session_context_registry + .lookup_session(session_id.as_str()) + .await + .expect( + "SessionContext does not exist in SessionContextRegistry.", + ); + update_datafusion_session_context(session_ctx, &config) + } + _ => { + let df_session = + create_datafusion_session_context(&config, self.session_builder); + let session_id = df_session.session_id.clone(); + self.state + .session_context_registry + .register_session(session_id, df_session.clone()) + .await; + df_session } + }; + + let plan = match query { + Query::LogicalPlan(message) => T::try_decode(message.as_slice()) + .and_then(|m| { + m.try_into_logical_plan( + df_session.clone(), + self.codec.logical_extension_codec(), + ) + }) + .map_err(|e| { + let msg = format!("Could not parse logical plan protobuf: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?, Query::Sql(sql) => { - let mut ctx = self.ctx.write().await; - let df = ctx.sql(&sql).await.map_err(|e| { + let df = df_session.sql(&sql).await.map_err(|e| { let msg = format!("Error parsing SQL: {}", e); error!("{}", msg); tonic::Status::internal(msg) @@ -409,6 +462,8 @@ impl SchedulerGrpc // TODO Maybe the format will be changed in the future let job_id = generate_job_id(); + let session_id = df_session.session_id.clone(); + // Save placeholder job metadata self.state .save_job_metadata( @@ -424,7 +479,7 @@ impl SchedulerGrpc // Create job details for the plan, like stages, tasks, etc // TODO To achieve more throughput, maybe change it to be event-based processing in the future - match create_job(self, job_id.clone(), plan).await { + match create_job(self, df_session.clone(), job_id.clone(), plan).await { Err(error) => { let msg = format!("Job {} failed due to {}", job_id, error); warn!("{}", msg); @@ -441,8 +496,35 @@ impl SchedulerGrpc .unwrap(); return Err(tonic::Status::internal(msg)); } - Ok(_) => Ok(Response::new(ExecuteQueryResult { job_id })), + Ok(_) => Ok(Response::new(ExecuteQueryResult { job_id, session_id })), } + } else if let ExecuteQueryParams { + query: None, + settings, + optional_session_id: None, + } = query_params + { + // parse config for new session + let mut config_builder = BallistaConfig::builder(); + for kv_pair in &settings { + config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); + } + let config = config_builder.build().map_err(|e| { + let msg = format!("Could not parse configs: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + let df_session = + create_datafusion_session_context(&config, self.session_builder); + let session_id = df_session.session_id.clone(); + self.state + .session_context_registry + .register_session(session_id.clone(), df_session.clone()) + .await; + Ok(Response::new(ExecuteQueryResult { + job_id: "NA".to_owned(), + session_id, + })) } else { Err(tonic::Status::internal("Error parsing request")) } @@ -472,6 +554,7 @@ fn generate_job_id() -> String { async fn create_job( scheduler_server: &SchedulerServer, + session_ctx: Arc, job_id: String, plan: LogicalPlan, ) -> Result<(), BallistaError> { @@ -479,8 +562,7 @@ async fn create_job( let plan = async move { let start = Instant::now(); - let ctx = scheduler_server.ctx.read().await.clone(); - let optimized_plan = ctx.optimize(&plan).map_err(|e| { + let optimized_plan = session_ctx.optimize(&plan).map_err(|e| { let msg = format!("Could not create optimized logical plan: {}", e); error!("{}", msg); BallistaError::General(msg) @@ -488,7 +570,7 @@ async fn create_job( debug!("Calculated optimized plan: {:?}", optimized_plan); - let plan = ctx + let plan = session_ctx .create_physical_plan(&optimized_plan) .await .map_err(|e| { @@ -562,6 +644,11 @@ async fn create_job( BallistaError::General(msg) })?; } + info!( + "Adding stage {} with {} pending tasks", + shuffle_writer.stage_id(), + num_partitions + ); } if let Some(event_loop) = scheduler_server.event_loop.as_ref() { @@ -579,7 +666,6 @@ async fn create_job( #[cfg(all(test, feature = "sled"))] mod test { use std::sync::Arc; - use tokio::sync::RwLock; use tonic::Request; @@ -591,7 +677,6 @@ mod test { }; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::BallistaCodec; - use datafusion::prelude::ExecutionContext; use super::{SchedulerGrpc, SchedulerServer}; @@ -603,7 +688,6 @@ mod test { SchedulerServer::new( state_storage.clone(), namespace.to_owned(), - Arc::new(RwLock::new(ExecutionContext::new())), BallistaCodec::default(), ); let exec_meta = ExecutorRegistration { @@ -631,8 +715,7 @@ mod test { namespace.to_string(), BallistaCodec::default(), ); - let ctx = scheduler.ctx.read().await; - state.init(&ctx).await?; + state.init().await?; // executor should be registered assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1); @@ -654,8 +737,7 @@ mod test { namespace.to_string(), BallistaCodec::default(), ); - let ctx = scheduler.ctx.read().await; - state.init(&ctx).await?; + state.init().await?; // executor should be registered assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1); Ok(()) diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs b/ballista/rust/scheduler/src/scheduler_server/mod.rs index 029b80246538..5d987d4be10f 100644 --- a/ballista/rust/scheduler/src/scheduler_server/mod.rs +++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs @@ -28,7 +28,8 @@ use ballista_core::event_loop::EventLoop; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; -use datafusion::prelude::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::{default_session_builder, SessionState}; +use datafusion::prelude::{SessionConfig, SessionContext}; use crate::scheduler_server::event_loop::{ SchedulerServerEvent, SchedulerServerEventAction, @@ -47,6 +48,7 @@ mod grpc; mod task_scheduler; type ExecutorsClient = Arc>>>; +type SessionBuilder = fn(SessionConfig) -> SessionState; #[derive(Clone)] pub struct SchedulerServer { @@ -55,23 +57,38 @@ pub struct SchedulerServer, event_loop: Option>, - ctx: Arc>, codec: BallistaCodec, + /// SessionState Builder + session_builder: SessionBuilder, } impl SchedulerServer { pub fn new( config: Arc, namespace: String, - ctx: Arc>, codec: BallistaCodec, ) -> Self { SchedulerServer::new_with_policy( config, namespace, TaskSchedulingPolicy::PullStaged, - ctx, codec, + default_session_builder, + ) + } + + pub fn new_with_builder( + config: Arc, + namespace: String, + codec: BallistaCodec, + session_builder: SessionBuilder, + ) -> Self { + SchedulerServer::new_with_policy( + config, + namespace, + TaskSchedulingPolicy::PullStaged, + codec, + session_builder, ) } @@ -79,8 +96,8 @@ impl SchedulerServer, namespace: String, policy: TaskSchedulingPolicy, - ctx: Arc>, codec: BallistaCodec, + session_builder: SessionBuilder, ) -> Self { let state = Arc::new(SchedulerState::new(config, namespace, codec.clone())); @@ -107,16 +124,15 @@ impl SchedulerServer Result<()> { { // initialize state - let ctx = self.ctx.read().await; - self.state.init(&ctx).await?; + self.state.init().await?; } { @@ -129,9 +145,81 @@ impl SchedulerServer ExecutionContext { - let config = ExecutionConfig::new() - .with_target_partitions(config.default_shuffle_partitions()); - ExecutionContext::with_config(config) +/// Create a new DataFusion session context from Ballista Configuration +pub fn create_datafusion_session_context( + config: &BallistaConfig, + session_builder: SessionBuilder, +) -> Arc { + let config = SessionConfig::new() + .with_target_partitions(config.default_shuffle_partitions()) + .with_batch_size(config.default_batch_size()) + .with_repartition_joins(config.repartition_joins()) + .with_repartition_aggregations(config.repartition_aggregations()) + .with_repartition_windows(config.repartition_windows()) + .with_parquet_pruning(config.parquet_pruning()); + let session_state = session_builder(config); + Arc::new(SessionContext::with_state(session_state)) +} + +/// Update the existing DataFusion session context with Ballista Configuration +pub fn update_datafusion_session_context( + session_ctx: Arc, + config: &BallistaConfig, +) -> Arc { + let session_config = session_ctx.state.lock().clone().config; + let mut mut_config = session_config.lock(); + mut_config.target_partitions = config.default_shuffle_partitions(); + mut_config.batch_size = config.default_batch_size(); + mut_config.repartition_joins = config.repartition_joins(); + mut_config.repartition_aggregations = config.repartition_aggregations(); + mut_config.repartition_windows = config.repartition_windows(); + mut_config.parquet_pruning = config.parquet_pruning(); + session_ctx +} + +/// A Registry holds all the datafusion session contexts +pub struct SessionContextRegistry { + /// A map from session_id to SessionContext + pub running_sessions: RwLock>>, +} + +impl Default for SessionContextRegistry { + fn default() -> Self { + Self::new() + } +} + +impl SessionContextRegistry { + /// Create the registry that object stores can registered into. + /// ['LocalFileSystem'] store is registered in by default to support read local files natively. + pub fn new() -> Self { + Self { + running_sessions: RwLock::new(HashMap::new()), + } + } + + /// Adds a new session to this registry. + pub async fn register_session( + &self, + session_id: String, + session_ctx: Arc, + ) -> Option> { + let mut sessions = self.running_sessions.write().await; + sessions.insert(session_id, session_ctx) + } + + /// Lookup the session context registered + pub async fn lookup_session(&self, session_id: &str) -> Option> { + let sessions = self.running_sessions.read().await; + sessions.get(session_id).cloned() + } + + /// Remove a session from this registry. + pub async fn unregister_session( + &self, + session_id: &str, + ) -> Option> { + let mut sessions = self.running_sessions.write().await; + sessions.remove(session_id) + } } diff --git a/ballista/rust/scheduler/src/scheduler_server/task_scheduler.rs b/ballista/rust/scheduler/src/scheduler_server/task_scheduler.rs index 3cbf783beb61..2152055ba914 100644 --- a/ballista/rust/scheduler/src/scheduler_server/task_scheduler.rs +++ b/ballista/rust/scheduler/src/scheduler_server/task_scheduler.rs @@ -19,7 +19,7 @@ use crate::state::SchedulerState; use async_trait::async_trait; use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; -use ballista_core::serde::protobuf::TaskDefinition; +use ballista_core::serde::protobuf::{KeyValuePair, TaskDefinition}; use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto; use ballista_core::serde::scheduler::ExecutorData; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan}; @@ -89,6 +89,14 @@ impl TaskScheduler ))); }; + let session_props = self + .session_context_registry + .lookup_session(plan.session_id().as_str()) + .await + .expect("SessionContext does not exist in SessionContextRegistry.") + .copied_config() + .to_props(); + let mut buf: Vec = vec![]; U::try_from_physical_plan( plan, @@ -102,6 +110,14 @@ impl TaskScheduler )) })?; + let task_props = session_props + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(); + ret[idx].push(TaskDefinition { plan: buf, task_id: status.task_id, @@ -109,6 +125,8 @@ impl TaskScheduler output_partitioning, ) .map_err(|_| Status::internal("TBD".to_string()))?, + session_id: plan_clone.session_id(), + props: task_props, }); executor.available_task_slots -= 1; num_tasks += 1; diff --git a/ballista/rust/scheduler/src/standalone.rs b/ballista/rust/scheduler/src/standalone.rs index 52e30cd096f3..6ab40ac68f85 100644 --- a/ballista/rust/scheduler/src/standalone.rs +++ b/ballista/rust/scheduler/src/standalone.rs @@ -21,11 +21,9 @@ use ballista_core::{ error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, BALLISTA_VERSION, }; -use datafusion::prelude::ExecutionContext; use log::info; use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; -use tokio::sync::RwLock; use tonic::transport::Server; use crate::{scheduler_server::SchedulerServer, state::StandaloneClient}; @@ -37,7 +35,6 @@ pub async fn new_standalone_scheduler() -> Result { SchedulerServer::new( Arc::new(client), "ballista".to_string(), - Arc::new(RwLock::new(ExecutionContext::new())), BallistaCodec::default(), ); scheduler_server.init().await?; diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index e30075e827eb..037a17296a80 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -37,7 +37,7 @@ use ballista_core::serde::scheduler::{ ExecutorData, ExecutorMetadata, PartitionId, PartitionStats, }; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; -use datafusion::prelude::ExecutionContext; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use super::planner::remove_unresolved_shuffles; @@ -46,6 +46,7 @@ mod etcd; #[cfg(feature = "sled")] mod standalone; +use crate::scheduler_server::SessionContextRegistry; use clap::ArgEnum; #[cfg(feature = "etcd")] pub use etcd::EtcdClient; @@ -263,20 +264,23 @@ impl VolatileSchedulerState { } } +// StageKey is job_id + stage_id type StageKey = (String, u32); #[derive(Clone)] struct StableSchedulerState { - // for db + /// for db config_client: Arc, namespace: String, codec: BallistaCodec, - // for in-memory cache + /// for in-memory cache executors_metadata: Arc>>, jobs: Arc>>, stages: Arc>>>, + /// Runtime environment for Scheduler + runtime: Arc, } impl @@ -294,14 +298,15 @@ impl executors_metadata: Arc::new(RwLock::new(HashMap::new())), jobs: Arc::new(RwLock::new(HashMap::new())), stages: Arc::new(RwLock::new(HashMap::new())), + runtime: Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()), } } /// Load the state stored in storage into memory - async fn init(&self, ctx: &ExecutionContext) -> Result<()> { + async fn init(&self) -> Result<()> { self.init_executors_metadata_from_storage().await?; self.init_jobs_from_storage().await?; - self.init_stages_from_storage(ctx).await?; + self.init_stages_from_storage().await?; Ok(()) } @@ -339,7 +344,7 @@ impl Ok(()) } - async fn init_stages_from_storage(&self, ctx: &ExecutionContext) -> Result<()> { + async fn init_stages_from_storage(&self) -> Result<()> { let entries = self .config_client .get_from_prefix(&get_stage_prefix(&self.namespace)) @@ -347,10 +352,14 @@ impl let mut stages = self.stages.write(); for (key, entry) in entries { - let (job_id, stage_id) = extract_stage_id_from_stage_key(&key).unwrap(); + let (session_id, job_id, stage_id) = + extract_stage_id_from_stage_key(&key).unwrap(); let value = U::try_decode(&entry)?; - let plan = value - .try_into_physical_plan(ctx, self.codec.physical_extension_codec())?; + let plan = value.try_into_physical_plan( + session_id.clone(), + self.codec.physical_extension_codec(), + self.runtime.clone(), + )?; stages.insert((job_id, stage_id), plan); } @@ -422,8 +431,14 @@ impl plan: Arc, ) -> Result<()> { { + let session_id = plan.session_id(); // Save in db - let key = get_stage_plan_key(&self.namespace, job_id, stage_id as u32); + let key = get_stage_plan_key( + &self.namespace, + session_id.as_str(), + job_id, + stage_id as u32, + ); let value = { let mut buf: Vec = vec![]; let proto = U::try_from_physical_plan( @@ -485,8 +500,19 @@ fn get_stage_prefix(namespace: &str) -> String { format!("/ballista/{}/stages", namespace,) } -fn get_stage_plan_key(namespace: &str, job_id: &str, stage_id: u32) -> String { - format!("{}/{}/{}", get_stage_prefix(namespace), job_id, stage_id,) +fn get_stage_plan_key( + namespace: &str, + session_id: &str, + job_id: &str, + stage_id: u32, +) -> String { + format!( + "{}/{}/{}/{}", + get_stage_prefix(namespace), + session_id, + job_id, + stage_id + ) } fn extract_job_id_from_job_key(job_key: &str) -> Result<&str> { @@ -495,9 +521,9 @@ fn extract_job_id_from_job_key(job_key: &str) -> Result<&str> { }) } -fn extract_stage_id_from_stage_key(stage_key: &str) -> Result { +fn extract_stage_id_from_stage_key(stage_key: &str) -> Result<(String, String, u32)> { let splits: Vec<&str> = stage_key.split('/').collect(); - if splits.len() < 4 { + if splits.len() < 5 { Err(BallistaError::Internal(format!( "Unexpected stage key: {}", stage_key @@ -505,7 +531,8 @@ fn extract_stage_id_from_stage_key(stage_key: &str) -> Result { } else { Ok(( splits.get(2).unwrap().to_string(), - splits.get(3).unwrap().parse::().unwrap(), + splits.get(3).unwrap().to_string(), + splits.get(4).unwrap().parse::().unwrap(), )) } } @@ -593,6 +620,8 @@ pub(super) struct SchedulerState, volatile_state: VolatileSchedulerState, listener: SchedulerStateWatcher, + /// DataFusion session contexts that are registered within the SchedulerServer + pub(crate) session_context_registry: Arc, } impl SchedulerState { @@ -607,6 +636,7 @@ impl SchedulerState SchedulerState Result<()> { - self.stable_state.init(ctx).await?; + pub async fn init(&self) -> Result<()> { + self.stable_state.init().await?; Ok(()) } diff --git a/ballista/rust/scheduler/src/test_utils.rs b/ballista/rust/scheduler/src/test_utils.rs index b9d7ee42f48b..3f8b57670834 100644 --- a/ballista/rust/scheduler/src/test_utils.rs +++ b/ballista/rust/scheduler/src/test_utils.rs @@ -18,18 +18,17 @@ use ballista_core::error::Result; use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::{SessionConfig, SessionContext}; use datafusion::prelude::CsvReadOptions; pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; -pub async fn datafusion_test_context(path: &str) -> Result { +pub async fn datafusion_test_context(path: &str) -> Result { let default_shuffle_partitions = 2; - let config = - ExecutionConfig::new().with_target_partitions(default_shuffle_partitions); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions(default_shuffle_partitions); + let ctx = SessionContext::with_config(config); for table in TPCH_TABLES { let schema = get_tpch_schema(table); let options = CsvReadOptions::new() diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs index 49679f46d7eb..19752907599c 100644 --- a/benchmarks/src/bin/nyctaxi.rs +++ b/benchmarks/src/bin/nyctaxi.rs @@ -18,16 +18,17 @@ //! Apache Arrow Rust Benchmarks use std::collections::HashMap; +use std::ops::Deref; use std::path::PathBuf; use std::process; +use std::sync::Arc; use std::time::Instant; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::util::pretty; use datafusion::error::Result; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; - +use datafusion::execution::context::{SessionConfig, SessionContext, TaskContext}; use datafusion::physical_plan::collect; use datafusion::prelude::CsvReadOptions; use structopt::StructOpt; @@ -69,10 +70,10 @@ async fn main() -> Result<()> { let opt = Opt::from_args(); println!("Running benchmarks with the following options: {:?}", opt); - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(opt.partitions) .with_batch_size(opt.batch_size); - let mut ctx = ExecutionContext::with_config(config); + let ctx = Arc::new(SessionContext::with_config(config)); let path = opt.path.to_str().unwrap(); @@ -89,11 +90,11 @@ async fn main() -> Result<()> { } } - datafusion_sql_benchmarks(&mut ctx, opt.iterations, opt.debug).await + datafusion_sql_benchmarks(ctx, opt.iterations, opt.debug).await } async fn datafusion_sql_benchmarks( - ctx: &mut ExecutionContext, + ctx: Arc, iterations: usize, debug: bool, ) -> Result<()> { @@ -103,7 +104,7 @@ async fn datafusion_sql_benchmarks( println!("Executing '{}'", name); for i in 0..iterations { let start = Instant::now(); - execute_sql(ctx, sql, debug).await?; + execute_sql(ctx.clone(), sql, debug).await?; println!( "Query '{}' iteration {} took {} ms", name, @@ -115,15 +116,15 @@ async fn datafusion_sql_benchmarks( Ok(()) } -async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Result<()> { - let runtime = ctx.state.lock().runtime_env.clone(); +async fn execute_sql(ctx: Arc, sql: &str, debug: bool) -> Result<()> { let plan = ctx.create_logical_plan(sql)?; let plan = ctx.optimize(&plan)?; if debug { println!("Optimized logical plan:\n{:?}", plan); } let physical_plan = ctx.create_physical_plan(&plan).await?; - let result = collect(physical_plan, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(ctx.deref())); + let result = collect(physical_plan, task_ctx).await?; if debug { pretty::print_batches(&result)?; } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 1cc668789110..6bab4f795f0e 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -30,7 +30,9 @@ use std::{ }; use ballista::context::BallistaContext; -use ballista::prelude::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS}; +use ballista::prelude::{ + BallistaConfig, BALLISTA_DEFAULT_BATCH_SIZE, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, +}; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; @@ -58,6 +60,7 @@ use datafusion::{ use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; +use datafusion::execution::context::TaskContext; use serde::Serialize; use structopt::StructOpt; @@ -83,9 +86,10 @@ struct BallistaBenchmarkOpt { #[structopt(short = "i", long = "iterations", default_value = "3")] iterations: usize, - // /// Batch size when reading CSV or Parquet files - // #[structopt(short = "s", long = "batch-size", default_value = "8192")] - // batch_size: usize, + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, + /// Path to data files #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] path: PathBuf, @@ -174,6 +178,10 @@ struct BallistaLoadtestOpt { #[structopt(short = "n", long = "partitions", default_value = "2")] partitions: usize, + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, + /// Path to data files #[structopt(parse(from_os_str), required = true, short = "p", long = "data-path")] path: PathBuf, @@ -256,6 +264,7 @@ async fn main() -> Result<()> { use LoadtestOpt::*; env_logger::init(); + match TpchOpt::from_args() { TpchOpt::Benchmark(BallistaBenchmark(opt)) => { benchmark_ballista(opt).await.map(|_| ()) @@ -273,11 +282,10 @@ async fn main() -> Result<()> { async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result> { println!("Running benchmarks with the following options: {:?}", opt); let mut benchmark_run = BenchmarkRun::new(opt.query); - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(opt.partitions) .with_batch_size(opt.batch_size); - let mut ctx = ExecutionContext::with_config(config); - let runtime = ctx.state.lock().runtime_env.clone(); + let ctx = SessionContext::with_config(config); // register tables for table in TABLES { @@ -290,10 +298,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Result = Vec::with_capacity(1); for i in 0..opt.iterations { let start = Instant::now(); - let plan = create_logical_plan(&mut ctx, opt.query)?; - result = execute_query(&mut ctx, &plan, opt.debug).await?; + let plan = create_logical_plan(&ctx, opt.query)?; + result = execute_query(&ctx, &plan, opt.debug).await?; let elapsed = start.elapsed().as_secs_f64() * 1000.0; millis.push(elapsed as f64); let row_count = result.iter().map(|b| b.num_rows()).sum(); @@ -341,11 +348,13 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> { BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, &format!("{}", opt.partitions), ) + .set(BALLISTA_DEFAULT_BATCH_SIZE, &format!("{}", opt.batch_size)) .build() .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let ctx = - BallistaContext::remote(opt.host.unwrap().as_str(), opt.port.unwrap(), &config); + BallistaContext::remote(opt.host.unwrap().as_str(), opt.port.unwrap(), &config) + .await?; // register tables with Ballista context let path = opt.path.to_str().unwrap(); @@ -421,6 +430,7 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, &format!("{}", opt.partitions), ) + .set(BALLISTA_DEFAULT_BATCH_SIZE, &format!("{}", opt.batch_size)) .build() .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -429,11 +439,14 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { let mut clients = vec![]; for _num in 0..concurrency { - clients.push(BallistaContext::remote( - opt.host.clone().unwrap().as_str(), - opt.port.unwrap(), - &config, - )); + clients.push( + BallistaContext::remote( + opt.host.clone().unwrap().as_str(), + opt.port.unwrap(), + &config, + ) + .await?, + ); } // register tables with Ballista context @@ -584,13 +597,13 @@ fn get_query_sql(query: usize) -> Result { } } -fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result { +fn create_logical_plan(ctx: &SessionContext, query: usize) -> Result { let sql = get_query_sql(query)?; ctx.create_logical_plan(&sql) } async fn execute_query( - ctx: &mut ExecutionContext, + ctx: &SessionContext, plan: &LogicalPlan, debug: bool, ) -> Result> { @@ -608,8 +621,8 @@ async fn execute_query( displayable(physical_plan.as_ref()).indent() ); } - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(physical_plan.clone(), runtime).await?; + let task_ctx = Arc::new(TaskContext::from(ctx)); + let result = collect(physical_plan.clone(), task_ctx).await?; if debug { println!( "=== Physical plan with metrics ===\n{}\n", @@ -632,8 +645,8 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> { .delimiter(b'|') .file_extension(".tbl"); - let config = ExecutionConfig::new().with_batch_size(opt.batch_size); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_batch_size(opt.batch_size); + let ctx = SessionContext::with_config(config); // build plan to read the TBL file let mut csv = ctx.read_csv(&input_path, options).await?; @@ -1281,11 +1294,10 @@ mod tests { async fn run_query(n: usize) -> Result<()> { // Tests running query with empty tables, to see whether they run succesfully. - - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let mut ctx = ExecutionContext::with_config(config); + let ctx = SessionContext::with_config(config); for &table in TABLES { let schema = get_schema(table); @@ -1296,8 +1308,8 @@ mod tests { ctx.register_table(table, Arc::new(provider))?; } - let plan = create_logical_plan(&mut ctx, n)?; - execute_query(&mut ctx, &plan, false).await?; + let plan = create_logical_plan(&ctx, n)?; + execute_query(&ctx, &plan, false).await?; Ok(()) } @@ -1307,7 +1319,7 @@ mod tests { // load expected answers from tpch-dbgen // read csv as all strings, trim and cast to expected type as the csv string // to value parser does not handle data with leading/trailing spaces - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = string_schema(get_answer_schema(n)); let options = CsvReadOptions::new() .schema(&schema) @@ -1376,13 +1388,15 @@ mod tests { use ballista_core::serde::{ protobuf, AsExecutionPlan, AsLogicalPlan, BallistaCodec, }; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::ExecutionPlan; async fn round_trip_query(n: usize) -> Result<()> { - let config = ExecutionConfig::new() + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let mut ctx = ExecutionContext::with_config(config); + let ctx = SessionContext::with_config_rt(config, runtime.clone()); let codec: BallistaCodec< protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode, @@ -1412,15 +1426,16 @@ mod tests { } // test logical plan round trip - let plan = create_logical_plan(&mut ctx, n)?; + let plan = create_logical_plan(&ctx, n)?; let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &plan, codec.logical_extension_codec(), ) .unwrap(); + let ref_ctx = Arc::new(ctx); let round_trip: LogicalPlan = (&proto) - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .try_into_logical_plan(ref_ctx.clone(), codec.logical_extension_codec()) .unwrap(); assert_eq!( format!("{:?}", plan), @@ -1429,7 +1444,7 @@ mod tests { ); // test optimized logical plan round trip - let plan = ctx.optimize(&plan)?; + let plan = ref_ctx.optimize(&plan)?; let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &plan, @@ -1437,7 +1452,7 @@ mod tests { ) .unwrap(); let round_trip: LogicalPlan = (&proto) - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .try_into_logical_plan(ref_ctx.clone(), codec.logical_extension_codec()) .unwrap(); assert_eq!( format!("{:?}", plan), @@ -1447,7 +1462,7 @@ mod tests { // test physical plan roundtrip if env::var("TPCH_DATA").is_ok() { - let physical_plan = ctx.create_physical_plan(&plan).await?; + let physical_plan = ref_ctx.create_physical_plan(&plan).await?; let proto: protobuf::PhysicalPlanNode = protobuf::PhysicalPlanNode::try_from_physical_plan( physical_plan.clone(), @@ -1455,7 +1470,11 @@ mod tests { ) .unwrap(); let round_trip: Arc = (&proto) - .try_into_physical_plan(&ctx, codec.physical_extension_codec()) + .try_into_physical_plan( + ref_ctx.session_id.clone(), + codec.physical_extension_codec(), + runtime, + ) .unwrap(); assert_eq!( format!("{:?}", physical_plan), diff --git a/datafusion-cli/src/context.rs b/datafusion-cli/src/context.rs index 0b746afa3662..a6eeab30ac08 100644 --- a/datafusion-cli/src/context.rs +++ b/datafusion-cli/src/context.rs @@ -19,26 +19,26 @@ use datafusion::dataframe::DataFrame; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::{SessionConfig, SessionContext}; use std::sync::Arc; /// The CLI supports using a local DataFusion context or a distributed BallistaContext pub enum Context { /// In-process execution with DataFusion - Local(ExecutionContext), + Local(SessionContext), /// Distributed execution with Ballista (if available) Remote(BallistaContext), } impl Context { /// create a new remote context with given host and port - pub fn new_remote(host: &str, port: u16) -> Result { - Ok(Context::Remote(BallistaContext::try_new(host, port)?)) + pub async fn new_remote(host: &str, port: u16) -> Result { + Ok(Context::Remote(BallistaContext::try_new(host, port).await?)) } /// create a local context using the given config - pub fn new_local(config: &ExecutionConfig) -> Context { - Context::Local(ExecutionContext::with_config(config.clone())) + pub fn new_local(config: &SessionConfig) -> Context { + Context::Local(SessionContext::with_config(config.clone())) } /// execute an SQL statement against the context @@ -56,13 +56,14 @@ impl Context { pub struct BallistaContext(ballista::context::BallistaContext); #[cfg(feature = "ballista")] impl BallistaContext { - pub fn try_new(host: &str, port: u16) -> Result { + pub async fn try_new(host: &str, port: u16) -> Result { use ballista::context::BallistaContext; use ballista::prelude::BallistaConfig; let config: BallistaConfig = BallistaConfig::new() .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - Ok(Self(BallistaContext::remote(host, port, &config))) + Ok(Self(BallistaContext::remote(host, port, &config).await?)) } + pub async fn sql(&mut self, sql: &str) -> Result> { self.0.sql(sql).await } @@ -72,7 +73,7 @@ impl BallistaContext { pub struct BallistaContext(); #[cfg(not(feature = "ballista"))] impl BallistaContext { - pub fn try_new(_host: &str, _port: u16) -> Result { + pub async fn try_new(_host: &str, _port: u16) -> Result { Err(DataFusionError::NotImplemented( "Remote execution not supported. Compile with feature 'ballista' to enable" .to_string(), diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 08878f9c70eb..d76fe38e5fb6 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -17,7 +17,7 @@ use clap::Parser; use datafusion::error::Result; -use datafusion::execution::context::ExecutionConfig; +use datafusion::execution::context::SessionConfig; use datafusion_cli::{ context::Context, exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION, @@ -98,15 +98,15 @@ pub async fn main() -> Result<()> { env::set_current_dir(&p).unwrap(); }; - let mut execution_config = ExecutionConfig::new().with_information_schema(true); + let mut session_config = SessionConfig::new().with_information_schema(true); if let Some(batch_size) = args.batch_size { - execution_config = execution_config.with_batch_size(batch_size); + session_config = session_config.with_batch_size(batch_size); }; let mut ctx: Context = match (args.host, args.port) { - (Some(ref h), Some(p)) => Context::new_remote(h, p)?, - _ => Context::new_local(&execution_config), + (Some(ref h), Some(p)) => Context::new_remote(h, p).await?, + _ => Context::new_local(&session_config), }; let mut print_options = PrintOptions { diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs index f08c12bbb73a..400a3327f786 100644 --- a/datafusion-examples/examples/avro_sql.rs +++ b/datafusion-examples/examples/avro_sql.rs @@ -25,7 +25,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::arrow_test_data(); diff --git a/datafusion-examples/examples/csv_sql.rs b/datafusion-examples/examples/csv_sql.rs index 5ad9bd7d4385..9299bfc39e52 100644 --- a/datafusion-examples/examples/csv_sql.rs +++ b/datafusion-examples/examples/csv_sql.rs @@ -22,8 +22,8 @@ use datafusion::prelude::*; /// fetching results #[tokio::main] async fn main() -> Result<()> { - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index aad153a99c90..389499b5eb55 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -21,8 +21,8 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::TableProvider; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::context::TaskContext; use datafusion::execution::dataframe_impl::DataFrameImpl; -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::{Expr, LogicalPlanBuilder}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryStream; @@ -57,7 +57,7 @@ async fn search_accounts( expected_result_length: usize, ) -> Result<()> { // create local execution context - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // create logical plan composed of a single TableScan let logical_plan = @@ -115,8 +115,14 @@ impl CustomDataSource { &self, projections: &Option>, schema: SchemaRef, + session_id: String, ) -> Result> { - Ok(Arc::new(CustomExec::new(projections, schema, self.clone()))) + Ok(Arc::new(CustomExec::new( + projections, + schema, + self.clone(), + session_id, + ))) } pub(crate) fn populate_users(&self) { @@ -171,8 +177,11 @@ impl TableProvider for CustomDataSource { // filters and limit can be used here to inject some push-down operations if needed _filters: &[Expr], _limit: Option, + session_id: String, ) -> Result> { - return self.create_physical_plan(projection, self.schema()).await; + return self + .create_physical_plan(projection, self.schema(), session_id) + .await; } } @@ -180,6 +189,7 @@ impl TableProvider for CustomDataSource { struct CustomExec { db: CustomDataSource, projected_schema: SchemaRef, + session_id: String, } impl CustomExec { @@ -187,11 +197,13 @@ impl CustomExec { projections: &Option>, schema: SchemaRef, db: CustomDataSource, + session_id: String, ) -> Self { let projected_schema = project_schema(&schema, projections.as_ref()).unwrap(); Self { db, projected_schema, + session_id, } } } @@ -235,7 +247,7 @@ impl ExecutionPlan for CustomExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { let users: Vec = { let db = self.db.inner.lock().unwrap(); @@ -263,6 +275,10 @@ impl ExecutionPlan for CustomExec { )?)); } + fn session_id(&self) -> String { + self.session_id.clone() + } + fn statistics(&self) -> Statistics { todo!() } diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 6fd34610ba5c..fbee553ec034 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -22,8 +22,8 @@ use datafusion::prelude::*; /// fetching results, using the DataFrame trait #[tokio::main] async fn main() -> Result<()> { - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs index e17c69ed1ded..0cbca7ba5416 100644 --- a/datafusion-examples/examples/dataframe_in_memory.rs +++ b/datafusion-examples/examples/dataframe_in_memory.rs @@ -44,7 +44,7 @@ async fn main() -> Result<()> { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs index c26dcce59f69..0c51a28adb7f 100644 --- a/datafusion-examples/examples/flight_server.rs +++ b/datafusion-examples/examples/flight_server.rs @@ -89,8 +89,8 @@ impl FlightService for FlightServiceImpl { Ok(sql) => { println!("do_get: {}", sql); - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index e113d98db677..826e8d6aad33 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -20,7 +20,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use std::sync::Arc; use std::time::Duration; use tokio::time::timeout; @@ -30,8 +30,8 @@ use tokio::time::timeout; async fn main() -> Result<()> { let mem_table = create_memtable()?; - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let ctx = SessionContext::new(); // Register the in-memory table containing the data ctx.register_table("users", Arc::new(mem_table))?; diff --git a/datafusion-examples/examples/parquet_sql.rs b/datafusion-examples/examples/parquet_sql.rs index e74ed39c68ce..a34dd7188134 100644 --- a/datafusion-examples/examples/parquet_sql.rs +++ b/datafusion-examples/examples/parquet_sql.rs @@ -22,8 +22,8 @@ use datafusion::prelude::*; /// fetching results #[tokio::main] async fn main() -> Result<()> { - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index 7485bc72f193..c12c69fc63ef 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -27,8 +27,8 @@ use std::sync::Arc; /// with multiple Parquet files) and fetching results #[tokio::main] async fn main() -> Result<()> { - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 3acace27e4de..f3829c9ffa11 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -28,8 +28,8 @@ use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumu use datafusion::{prelude::*, scalar::ScalarValue}; use std::sync::Arc; -// create local execution context with an in-memory table -fn create_context() -> Result { +// create local session context with an in-memory table +fn create_context() -> Result { use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. @@ -46,7 +46,7 @@ fn create_context() -> Result { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 33242c7b9870..e6c8b6714158 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -29,8 +29,8 @@ use datafusion::prelude::*; use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use std::sync::Arc; -// create local execution context with an in-memory table -fn create_context() -> Result { +// create local session context with an in-memory table +fn create_context() -> Result { use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. @@ -49,7 +49,7 @@ fn create_context() -> Result { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; @@ -60,7 +60,7 @@ fn create_context() -> Result { /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b #[tokio::main] async fn main() -> Result<()> { - let mut ctx = create_context()?; + let ctx = create_context()?; // First, declare the actual implementation of the calculation let pow = |args: &[ArrayRef]| { diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 80842272d613..e2523b2b9334 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -81,6 +81,7 @@ num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.16", optional = true } tempfile = "3" parking_lot = "0.12" +uuid = { version = "0.8", features = ["v4"] } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs index e587fe58cd44..807e64ff5e27 100644 --- a/datafusion/benches/aggregate_query_sql.rs +++ b/datafusion/benches/aggregate_query_sql.rs @@ -24,12 +24,12 @@ mod data_utils; use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); @@ -39,8 +39,8 @@ fn create_context( partitions_len: usize, array_len: usize, batch_size: usize, -) -> Result>> { - let mut ctx = ExecutionContext::new(); +) -> Result>> { + let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; ctx.register_table("t", provider)?; Ok(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/benches/filter_query_sql.rs b/datafusion/benches/filter_query_sql.rs index 9885918de229..f077dbcda7ab 100644 --- a/datafusion/benches/filter_query_sql.rs +++ b/datafusion/benches/filter_query_sql.rs @@ -22,13 +22,13 @@ use arrow::{ }; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::from_slice::FromSlice; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; use std::sync::Arc; use tokio::runtime::Runtime; -async fn query(ctx: &mut ExecutionContext, sql: &str) { +async fn query(ctx: &SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -36,7 +36,7 @@ async fn query(ctx: &mut ExecutionContext, sql: &str) { criterion::black_box(rt.block_on(df.collect()).unwrap()); } -fn create_context(array_len: usize, batch_size: usize) -> Result { +fn create_context(array_len: usize, batch_size: usize) -> Result { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), @@ -57,7 +57,7 @@ fn create_context(array_len: usize, batch_size: usize) -> Result>(); - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; @@ -71,25 +71,25 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_size = 4096; // 2^12 c.bench_function("filter_array", |b| { - let mut ctx = create_context(array_len, batch_size).unwrap(); - b.iter(|| block_on(query(&mut ctx, "select f32, f64 from t where f32 >= f64"))) + let ctx = create_context(array_len, batch_size).unwrap(); + b.iter(|| block_on(query(&ctx, "select f32, f64 from t where f32 >= f64"))) }); c.bench_function("filter_scalar", |b| { - let mut ctx = create_context(array_len, batch_size).unwrap(); + let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| { block_on(query( - &mut ctx, + &ctx, "select f32, f64 from t where f32 >= 250 and f64 > 250", )) }) }); c.bench_function("filter_scalar in list", |b| { - let mut ctx = create_context(array_len, batch_size).unwrap(); + let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| { block_on(query( - &mut ctx, + &ctx, "select f32, f64 from t where f32 in (10, 20, 30, 40)", )) }) diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index 6195937dc4e5..11b107f02f4b 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -34,10 +34,10 @@ use arrow::{ }; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use datafusion::from_slice::FromSlice; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -48,7 +48,7 @@ fn query(ctx: Arc>, sql: &str) { fn create_context( array_len: usize, batch_size: usize, -) -> Result>> { +) -> Result>> { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), @@ -69,7 +69,7 @@ fn create_context( }) .collect::>(); - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; diff --git a/datafusion/benches/parquet_query_sql.rs b/datafusion/benches/parquet_query_sql.rs index 17bc78bd038a..22f94b4e3621 100644 --- a/datafusion/benches/parquet_query_sql.rs +++ b/datafusion/benches/parquet_query_sql.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ }; use arrow::record_batch::RecordBatch; use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; use rand::distributions::uniform::SampleUniform; @@ -193,7 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { assert!(Path::new(&file_path).exists(), "path not found"); println!("Using parquet file {}", file_path); - let mut context = ExecutionContext::new(); + let context = SessionContext::new(); let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); rt.block_on(context.register_parquet("t", file_path.as_str())) diff --git a/datafusion/benches/physical_plan.rs b/datafusion/benches/physical_plan.rs index 8dd1f49d183e..3aece08287b0 100644 --- a/datafusion/benches/physical_plan.rs +++ b/datafusion/benches/physical_plan.rs @@ -59,8 +59,7 @@ fn sort_preserving_merge_operator(batches: Vec, sort: &[&str]) { let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); let rt = Runtime::new().unwrap(); - let rt_env = Arc::new(RuntimeEnv::default()); - rt.block_on(collect(merge, rt_env)).unwrap(); + rt.block_on(collect(merge)).unwrap(); } // Produces `n` record batches of row size `m`. Each record batch will have diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index 2434341ae51c..cd305b2b5d2d 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -31,11 +31,11 @@ extern crate datafusion; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use tokio::runtime::Runtime; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -43,7 +43,7 @@ fn query(ctx: Arc>, sql: &str) { rt.block_on(df.collect()).unwrap(); } -fn create_context() -> Arc> { +fn create_context() -> Arc> { // define schema for data source (csv file) let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Utf8, false), @@ -76,18 +76,17 @@ fn create_context() -> Arc> { let rt = Runtime::new().unwrap(); - let ctx_holder: Arc>>>> = + let ctx_holder: Arc>>>> = Arc::new(Mutex::new(vec![])); let partitions = 16; rt.block_on(async { // create local execution context - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.state.lock().config.target_partitions = 1; - let runtime = ctx.state.lock().runtime_env.clone(); - let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), runtime) + let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions)) .await .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) diff --git a/datafusion/benches/window_query_sql.rs b/datafusion/benches/window_query_sql.rs index dad838eb7f62..42a1e51be361 100644 --- a/datafusion/benches/window_query_sql.rs +++ b/datafusion/benches/window_query_sql.rs @@ -24,12 +24,12 @@ mod data_utils; use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); @@ -39,8 +39,8 @@ fn create_context( partitions_len: usize, array_len: usize, batch_size: usize, -) -> Result>> { - let mut ctx = ExecutionContext::new(); +) -> Result>> { + let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; ctx.register_table("t", provider)?; Ok(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/src/catalog/schema.rs b/datafusion/src/catalog/schema.rs index a97590af216e..90fece9dd786 100644 --- a/datafusion/src/catalog/schema.rs +++ b/datafusion/src/catalog/schema.rs @@ -251,7 +251,7 @@ mod tests { }; use crate::datasource::empty::EmptyTable; use crate::datasource::object_store::local::LocalFileSystem; - use crate::execution::context::ExecutionContext; + use crate::execution::context::SessionContext; use futures::StreamExt; @@ -290,7 +290,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); catalog.register_schema("active", Arc::new(schema)); - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_catalog("cat", Arc::new(catalog)); diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 7748a832a21c..00c107345322 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -33,7 +33,7 @@ use async_trait::async_trait; /// [Spark DataFrame](https://spark.apache.org/docs/latest/sql-programming-guide.html) /// /// DataFrames are typically created by the `read_csv` and `read_parquet` methods on the -/// [ExecutionContext](../execution/context/struct.ExecutionContext.html) and can then be modified +/// [SessionContext](../execution/context/struct.SessionContext.html) and can then be modified /// by calling the transformation methods, such as `filter`, `select`, `aggregate`, and `limit` /// to build up a query definition. /// @@ -44,7 +44,7 @@ use async_trait::async_trait; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = ExecutionContext::new(); +/// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? @@ -63,7 +63,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.select_columns(&["a", "b"])?; /// # Ok(()) @@ -78,7 +78,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.select(vec![col("a") * col("b"), col("c")])?; /// # Ok(()) @@ -93,7 +93,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))?; /// # Ok(()) @@ -108,7 +108,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" @@ -132,7 +132,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.limit(100)?; /// # Ok(()) @@ -147,7 +147,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.union(df.clone())?; /// # Ok(()) @@ -162,7 +162,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.union(df.clone())?; /// let df = df.distinct()?; @@ -179,7 +179,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; /// # Ok(()) @@ -194,7 +194,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await? /// .select(vec![ @@ -223,7 +223,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df1 = df.repartition(Partitioning::RoundRobinBatch(4))?; /// # Ok(()) @@ -241,7 +241,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.collect().await?; /// # Ok(()) @@ -256,7 +256,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// df.show().await?; /// # Ok(()) @@ -271,7 +271,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// df.show_limit(10).await?; /// # Ok(()) @@ -286,7 +286,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let stream = df.execute_stream().await?; /// # Ok(()) @@ -302,7 +302,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.collect_partitioned().await?; /// # Ok(()) @@ -317,7 +317,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.execute_stream_partitioned().await?; /// # Ok(()) @@ -333,7 +333,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let schema = df.schema(); /// # Ok(()) @@ -353,7 +353,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.limit(100)?.explain(false, false)?.collect().await?; /// # Ok(()) @@ -368,7 +368,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let f = df.registry(); /// // use f.udf("name", vec![...]) to use the udf @@ -384,7 +384,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.intersect(df.clone())?; /// # Ok(()) @@ -399,7 +399,7 @@ pub trait DataFrame: Send + Sync { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.except(df.clone())?; /// # Ok(()) diff --git a/datafusion/src/datasource/datasource.rs b/datafusion/src/datasource/datasource.rs index 1b59c857fb07..013269881229 100644 --- a/datafusion/src/datasource/datasource.rs +++ b/datafusion/src/datasource/datasource.rs @@ -83,6 +83,7 @@ pub trait TableProvider: Sync + Send { // If set, it contains the amount of rows needed by the `LogicalPlan`, // The datasource should return *at least* this number of rows if available. limit: Option, + session_id: String, ) -> Result>; /// Tests whether the table provider can make use of a filter expression diff --git a/datafusion/src/datasource/empty.rs b/datafusion/src/datasource/empty.rs index 5622d15a0d67..31b34624a45f 100644 --- a/datafusion/src/datasource/empty.rs +++ b/datafusion/src/datasource/empty.rs @@ -56,9 +56,14 @@ impl TableProvider for EmptyTable { projection: &Option>, _filters: &[Expr], _limit: Option, + session_id: String, ) -> Result> { // even though there is no data, projections apply let projected_schema = project_schema(&self.schema, projection.as_ref())?; - Ok(Arc::new(EmptyExec::new(false, projected_schema))) + Ok(Arc::new(EmptyExec::new( + false, + projected_schema, + session_id, + ))) } } diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index fa02d1ae2833..b9eb681418f7 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -65,8 +65,9 @@ impl FileFormat for AvroFormat { &self, conf: FileScanConfig, _filters: &[Expr], + session_id: String, ) -> Result> { - let exec = AvroExec::new(conf); + let exec = AvroExec::new(conf, session_id); Ok(Arc::new(exec)) } } @@ -83,7 +84,9 @@ mod tests { }; use super::*; + use crate::execution::context::TaskContext; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, TimestampMicrosecondArray, @@ -92,10 +95,12 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); let projection = None; - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let stream = exec.execute(0, task_ctx).await?; let tt_batches = stream .map(|batch| { @@ -113,10 +118,11 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.avro", &projection, Some(1)).await?; - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, Some(1), &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -126,9 +132,9 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; let x: Vec = exec .schema() @@ -153,7 +159,8 @@ mod tests { x ); - let batches = collect(exec, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -177,11 +184,11 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![1]); - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -206,11 +213,11 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -232,11 +239,11 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![10]); - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -258,11 +265,11 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![6]); - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctxs).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -287,11 +294,11 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![7]); - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -316,11 +323,11 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![9]); - let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.avro", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -347,6 +354,7 @@ mod tests { file_name: &str, projection: &Option>, limit: Option, + session_ctx: &SessionContext, ) -> Result> { let testdata = crate::test_util::arrow_test_data(); let filename = format!("{}/avro/{}", testdata, file_name); @@ -372,6 +380,7 @@ mod tests { table_partition_cols: vec![], }, &[], + session_ctx.session_id.clone(), ) .await?; Ok(exec) diff --git a/datafusion/src/datasource/file_format/csv.rs b/datafusion/src/datasource/file_format/csv.rs index 6aa0d21235a4..44f4586facb5 100644 --- a/datafusion/src/datasource/file_format/csv.rs +++ b/datafusion/src/datasource/file_format/csv.rs @@ -127,8 +127,9 @@ impl FileFormat for CsvFormat { &self, conf: FileScanConfig, _filters: &[Expr], + session_id: String, ) -> Result> { - let exec = CsvExec::new(conf, self.has_header, self.delimiter); + let exec = CsvExec::new(conf, self.has_header, self.delimiter, session_id); Ok(Arc::new(exec)) } } @@ -138,7 +139,8 @@ mod tests { use arrow::array::StringArray; use super::*; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::execution::context::TaskContext; + use crate::prelude::{SessionConfig, SessionContext}; use crate::{ datasource::{ file_format::FileScanConfig, @@ -152,11 +154,13 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]); - let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let exec = get_exec("aggregate_test_100.csv", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let stream = exec.execute(0, task_ctx).await?; let tt_batches: i32 = stream .map(|batch| { @@ -178,10 +182,11 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0, 1, 2, 3]); - let exec = get_exec("aggregate_test_100.csv", &projection, Some(1)).await?; - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("aggregate_test_100.csv", &projection, Some(1), &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -192,7 +197,8 @@ mod tests { #[tokio::test] async fn infer_schema() -> Result<()> { let projection = None; - let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; + let ctx = SessionContext::new(); + let exec = get_exec("aggregate_test_100.csv", &projection, None, &ctx).await?; let x: Vec = exec .schema() @@ -224,11 +230,11 @@ mod tests { #[tokio::test] async fn read_char_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; - - let batches = collect(exec, runtime).await.expect("Collect batches"); + let ctx = SessionContext::new(); + let exec = get_exec("aggregate_test_100.csv", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -253,6 +259,7 @@ mod tests { file_name: &str, projection: &Option>, limit: Option, + session_ctx: &SessionContext, ) -> Result> { let testdata = crate::test_util::arrow_test_data(); let filename = format!("{}/csv/{}", testdata, file_name); @@ -278,6 +285,7 @@ mod tests { table_partition_cols: vec![], }, &[], + session_ctx.session_id.clone(), ) .await?; Ok(exec) diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index bdd5ef81d559..7fe18173d9e4 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -89,8 +89,9 @@ impl FileFormat for JsonFormat { &self, conf: FileScanConfig, _filters: &[Expr], + session_id: String, ) -> Result> { - let exec = NdJsonExec::new(conf); + let exec = NdJsonExec::new(conf, session_id); Ok(Arc::new(exec)) } } @@ -100,7 +101,8 @@ mod tests { use arrow::array::Int64Array; use super::*; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::execution::context::TaskContext; + use crate::prelude::{SessionConfig, SessionContext}; use crate::{ datasource::{ file_format::FileScanConfig, @@ -114,10 +116,12 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); let projection = None; - let exec = get_exec(&projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let exec = get_exec(&projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let stream = exec.execute(0, task_ctx).await?; let tt_batches: i32 = stream .map(|batch| { @@ -139,10 +143,11 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec(&projection, Some(1)).await?; - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec(&projection, Some(1), &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -153,7 +158,8 @@ mod tests { #[tokio::test] async fn infer_schema() -> Result<()> { let projection = None; - let exec = get_exec(&projection, None).await?; + let ctx = SessionContext::new(); + let exec = get_exec(&projection, None, &ctx).await?; let x: Vec = exec .schema() @@ -168,11 +174,11 @@ mod tests { #[tokio::test] async fn read_int_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec(&projection, None).await?; - - let batches = collect(exec, runtime).await.expect("Collect batches"); + let ctx = SessionContext::new(); + let exec = get_exec(&projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -199,6 +205,7 @@ mod tests { async fn get_exec( projection: &Option>, limit: Option, + session_ctx: &SessionContext, ) -> Result> { let filename = "tests/jsons/2.json"; let format = JsonFormat::default(); @@ -223,6 +230,7 @@ mod tests { table_partition_cols: vec![], }, &[], + session_ctx.session_id.clone(), ) .await?; Ok(exec) diff --git a/datafusion/src/datasource/file_format/mod.rs b/datafusion/src/datasource/file_format/mod.rs index 21da2e1e6a27..1539d88dccb5 100644 --- a/datafusion/src/datasource/file_format/mod.rs +++ b/datafusion/src/datasource/file_format/mod.rs @@ -61,5 +61,6 @@ pub trait FileFormat: Send + Sync + fmt::Debug { &self, conf: FileScanConfig, filters: &[Expr], + session_id: String, ) -> Result>; } diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index d1d26e2c6d42..bf10fe09c6e6 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -107,6 +107,7 @@ impl FileFormat for ParquetFormat { &self, conf: FileScanConfig, filters: &[Expr], + session_id: String, ) -> Result> { // If enable pruning then combine the filters to build the predicate. // If disable pruning then set the predicate to None, thus readers @@ -117,7 +118,7 @@ impl FileFormat for ParquetFormat { None }; - Ok(Arc::new(ParquetExec::new(conf, predicate))) + Ok(Arc::new(ParquetExec::new(conf, predicate, session_id))) } } @@ -367,7 +368,8 @@ mod tests { use super::*; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::execution::context::TaskContext; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, TimestampNanosecondArray, @@ -376,10 +378,12 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); let projection = None; - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let stream = exec.execute(0, task_ctx).await?; let tt_batches = stream .map(|batch| { @@ -401,15 +405,16 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.parquet", &projection, Some(1)).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, Some(1), &ctx).await?; // note: even if the limit is set, the executor rounds up to the batch size assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); assert!(exec.statistics().is_exact); - let batches = collect(exec, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -419,9 +424,9 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; let x: Vec = exec .schema() @@ -444,8 +449,8 @@ mod tests { timestamp_col: Timestamp(Nanosecond, None)", y ); - - let batches = collect(exec, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -456,11 +461,11 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![1]); - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -485,11 +490,11 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -511,11 +516,11 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![10]); - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -537,11 +542,11 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![6]); - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -566,11 +571,11 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![7]); - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -595,11 +600,11 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![9]); - let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - - let batches = collect(exec, runtime).await?; + let ctx = SessionContext::new(); + let exec = get_exec("alltypes_plain.parquet", &projection, None, &ctx).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -626,6 +631,7 @@ mod tests { file_name: &str, projection: &Option>, limit: Option, + session_ctx: &SessionContext, ) -> Result> { let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, file_name); @@ -651,6 +657,7 @@ mod tests { table_partition_cols: vec![], }, &[], + session_ctx.session_id.clone(), ) .await?; Ok(exec) diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index 335d8275fcce..b1db1b6ca249 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use log::debug; use crate::{ error::Result, - execution::context::ExecutionContext, + execution::context::SessionContext, logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion}, physical_plan::functions::Volatility, scalar::ScalarValue, @@ -242,7 +242,7 @@ pub async fn pruned_partition_list( // Filter the partitions using a local datafusion context // TODO having the external context would allow us to resolve `Volatility::Stable` // scalar functions (`ScalarFunction` & `ScalarUDF`) and `ScalarVariable`s - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let mut df = ctx.read_table(Arc::new(mem_table))?; for filter in applicable_filters { df = df.filter(filter.clone())?; diff --git a/datafusion/src/datasource/listing/table.rs b/datafusion/src/datasource/listing/table.rs index 3fbd6c12397d..6262ddde4e75 100644 --- a/datafusion/src/datasource/listing/table.rs +++ b/datafusion/src/datasource/listing/table.rs @@ -301,6 +301,7 @@ impl TableProvider for ListingTable { projection: &Option>, filters: &[Expr], limit: Option, + session_id: String, ) -> Result> { let (partitioned_file_lists, statistics) = self.list_files_for_scan(filters, limit).await?; @@ -309,7 +310,11 @@ impl TableProvider for ListingTable { if partitioned_file_lists.is_empty() { let schema = self.schema(); let projected_schema = project_schema(&schema, projection.as_ref())?; - return Ok(Arc::new(EmptyExec::new(false, projected_schema))); + return Ok(Arc::new(EmptyExec::new( + false, + projected_schema, + session_id.clone(), + ))); } // create the execution plan @@ -326,6 +331,7 @@ impl TableProvider for ListingTable { table_partition_cols: self.options.table_partition_cols.clone(), }, filters, + session_id, ) .await } @@ -411,7 +417,7 @@ mod tests { let table = load_table("alltypes_plain.parquet").await?; let projection = None; let exec = table - .scan(&projection, &[], None) + .scan(&projection, &[], None, "sess_123".to_owned()) .await .expect("Scan table"); @@ -437,7 +443,7 @@ mod tests { .with_listing_options(opt) .with_schema(schema); let table = ListingTable::try_new(config)?; - let exec = table.scan(&None, &[], None).await?; + let exec = table.scan(&None, &[], None, "sess_123".to_owned()).await?; assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); @@ -471,9 +477,8 @@ mod tests { // this will filter out the only file in the store let filter = Expr::not_eq(col("p1"), lit("v1")); - let scan = table - .scan(&None, &[filter], None) + .scan(&None, &[filter], None, "sess_123".to_owned()) .await .expect("Empty execution plan"); diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index 5fad702672ef..b4959b6726d1 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -29,7 +29,7 @@ use async_trait::async_trait; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; @@ -65,18 +65,18 @@ impl MemTable { pub async fn load( t: Arc, output_partitions: Option, - runtime: Arc, + context: Arc, ) -> Result { let schema = t.schema(); - let exec = t.scan(&None, &[], None).await?; + let exec = t.scan(&None, &[], None, context.session_id.clone()).await?; let partition_count = exec.output_partitioning().partition_count(); let tasks = (0..partition_count) .map(|part_i| { - let runtime1 = runtime.clone(); + let ctx_clone = context.clone(); let exec = exec.clone(); tokio::spawn(async move { - let stream = exec.execute(part_i, runtime1.clone()).await?; + let stream = exec.execute(part_i, ctx_clone).await?; common::collect(stream).await }) }) @@ -91,7 +91,8 @@ impl MemTable { data.push(result); } - let exec = MemoryExec::try_new(&data, schema.clone(), None)?; + let exec = + MemoryExec::try_new(&data, schema.clone(), None, context.session_id.clone())?; if let Some(num_partitions) = output_partitions { let exec = RepartitionExec::try_new( @@ -103,7 +104,7 @@ impl MemTable { let mut output_partitions = vec![]; for i in 0..exec.output_partitioning().partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, runtime.clone()).await?; + let mut stream = exec.execute(i, context.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -132,11 +133,13 @@ impl TableProvider for MemTable { projection: &Option>, _filters: &[Expr], _limit: Option, + session_id: String, ) -> Result> { Ok(Arc::new(MemoryExec::try_new( &self.batches.clone(), self.schema(), projection.clone(), + session_id, )?)) } } @@ -145,6 +148,7 @@ impl TableProvider for MemTable { mod tests { use super::*; use crate::from_slice::FromSlice; + use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; @@ -153,7 +157,6 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -171,11 +174,14 @@ mod tests { ], )?; + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); let provider = MemTable::try_new(schema, vec![vec![batch]])?; - // scan with projection - let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?; - let mut it = exec.execute(0, runtime).await?; + let exec = provider + .scan(&Some(vec![2, 1]), &[], None, session_ctx.session_id.clone()) + .await?; + let mut it = exec.execute(0, task_ctx).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -187,7 +193,6 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -204,9 +209,12 @@ mod tests { )?; let provider = MemTable::try_new(schema, vec![vec![batch]])?; - - let exec = provider.scan(&None, &[], None).await?; - let mut it = exec.execute(0, runtime).await?; + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let exec = provider + .scan(&None, &[], None, session_ctx.session_id.clone()) + .await?; + let mut it = exec.execute(0, task_ctx).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); @@ -234,8 +242,10 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; let projection: Vec = vec![0, 4]; - - match provider.scan(&Some(projection), &[], None).await { + match provider + .scan(&Some(projection), &[], None, "sess_123".to_owned()) + .await + { Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { assert_eq!( "\"project index 4 out of bounds, max field 3\"", @@ -316,7 +326,6 @@ mod tests { #[tokio::test] async fn test_merged_schema() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let mut metadata = HashMap::new(); metadata.insert("foo".to_string(), "bar".to_string()); @@ -359,9 +368,12 @@ mod tests { let provider = MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?; - - let exec = provider.scan(&None, &[], None).await?; - let mut it = exec.execute(0, runtime).await?; + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let exec = provider + .scan(&None, &[], None, session_ctx.session_id.clone()) + .await?; + let mut it = exec.execute(0, task_ctx).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); diff --git a/datafusion/src/datasource/object_store/mod.rs b/datafusion/src/datasource/object_store/mod.rs index aad70e70a308..3a9da6701700 100644 --- a/datafusion/src/datasource/object_store/mod.rs +++ b/datafusion/src/datasource/object_store/mod.rs @@ -163,7 +163,7 @@ pub trait ObjectStore: Sync + Send + Debug { static LOCAL_SCHEME: &str = "file"; -/// A Registry holds all the object stores at runtime with a scheme for each store. +/// A Registry holds all the object stores at Runtime with a scheme for each store. /// This allows the user to extend DataFusion with different storage systems such as S3 or HDFS /// and query data inside these systems. pub struct ObjectStoreRegistry { diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 8c43ecd715d9..dfb60382bf38 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! ExecutionContext contains methods for registering data sources and executing queries +//! SessionContext contains methods for registering data sources and executing queries use crate::{ catalog::{ catalog::{CatalogList, MemoryCatalogList}, @@ -41,7 +41,6 @@ use crate::{ use log::{debug, trace}; use parking_lot::Mutex; use std::collections::{HashMap, HashSet}; -use std::path::PathBuf; use std::string::String; use std::sync::Arc; @@ -53,7 +52,6 @@ use crate::catalog::{ ResolvedTableReference, TableReference, }; use crate::datasource::listing::ListingTableConfig; -use crate::datasource::object_store::{ObjectStore, ObjectStoreRegistry}; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; use crate::execution::dataframe_impl::DataFrameImpl; @@ -90,16 +88,13 @@ use crate::{dataframe::DataFrame, physical_plan::udaf::AggregateUDF}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use parquet::file::properties::WriterProperties; +use uuid::Uuid; -use super::{ - disk_manager::DiskManagerConfig, - memory_manager::MemoryManagerConfig, - options::{AvroReadOptions, CsvReadOptions}, - DiskManager, MemoryManager, -}; +use super::options::{AvroReadOptions, CsvReadOptions}; -/// ExecutionContext is the main interface for executing queries with DataFusion. The context -/// provides the following functionality: +/// SessionContext is the main interface for executing queries with DataFusion. It stands for +/// the connection between user and DataFusion/Ballista cluster. +/// The context provides the following functionality: /// /// * Create DataFrame from a CSV or Parquet data source. /// * Register a CSV or Parquet data source as a table that can be referenced from a SQL query. @@ -114,7 +109,7 @@ use super::{ /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = ExecutionContext::new(); +/// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? @@ -132,80 +127,72 @@ use super::{ /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = ExecutionContext::new(); +/// let ctx = SessionContext::new(); /// ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; /// let results = ctx.sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100").await?; /// # Ok(()) /// # } /// ``` -#[derive(Clone)] -pub struct ExecutionContext { - /// Internal state for the context - pub state: Arc>, +pub struct SessionContext { + /// Uuid for the session + pub session_id: String, + /// Session start time + pub session_start_time: DateTime, + /// Shared session state for the context + pub state: Arc>, } -impl Default for ExecutionContext { +impl Default for SessionContext { fn default() -> Self { Self::new() } } -impl ExecutionContext { - /// Creates a new execution context using a default configuration. +impl SessionContext { + /// Creates a new session context using a default session configuration. pub fn new() -> Self { - Self::with_config(ExecutionConfig::new()) + Self::with_config(SessionConfig::new()) } - /// Creates a new execution context using the provided configuration. - pub fn with_config(config: ExecutionConfig) -> Self { - let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; - - if config.create_default_catalog_and_schema { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog.register_schema( - config.default_schema.clone(), - Arc::new(MemorySchemaProvider::new()), - ); - - let default_catalog: Arc = if config.information_schema { - Arc::new(CatalogWithInformationSchema::new( - Arc::downgrade(&catalog_list), - Arc::new(default_catalog), - )) - } else { - Arc::new(default_catalog) - }; + /// Creates a new session context using the provided configuration. + pub fn with_config(config: SessionConfig) -> Self { + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + Self::with_config_rt(config, runtime) + } - catalog_list - .register_catalog(config.default_catalog.clone(), default_catalog); + /// Creates a new session context using the provided configuration and RuntimeEnv. + pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + let state = SessionState::with_config(config, runtime); + Self { + session_id: state.session_id.clone(), + session_start_time: chrono::Utc::now(), + state: Arc::new(Mutex::new(state)), } + } - let runtime_env = Arc::new(RuntimeEnv::new(config.runtime.clone()).unwrap()); - + /// Creates a new session context using the provided session state. + pub fn with_state(state: SessionState) -> Self { Self { - state: Arc::new(Mutex::new(ExecutionContextState { - catalog_list, - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - config, - execution_props: ExecutionProps::new(), - object_store_registry: Arc::new(ObjectStoreRegistry::new()), - runtime_env, - })), + session_id: state.session_id.clone(), + session_start_time: chrono::Utc::now(), + state: Arc::new(Mutex::new(state)), } } - /// Return the [RuntimeEnv] used to run queries with this [ExecutionContext] - pub fn runtime_env(&self) -> Arc { - self.state.lock().runtime_env.clone() + /// Return a copied version of config for this Session + pub fn copied_config(&self) -> SessionConfig { + let config = { + // Clone the session config + self.state.lock().config.lock().clone() + }; + config } /// Creates a dataframe that will execute a SQL query. /// /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` /// might require the schema to be inferred. - pub async fn sql(&mut self, sql: &str) -> Result> { + pub async fn sql(&self, sql: &str) -> Result> { let plan = self.create_logical_plan(sql)?; match plan { LogicalPlan::CreateExternalTable(CreateExternalTable { @@ -239,7 +226,7 @@ impl ExecutionContext { format: file_format, collect_stat: false, file_extension: file_extension.to_owned(), - target_partitions: self.state.lock().config.target_partitions, + target_partitions: self.copied_config().target_partitions, table_partition_cols: vec![], }; @@ -313,7 +300,7 @@ impl ExecutionContext { /// Registers a variable provider within this context. pub fn register_variable( - &mut self, + &self, variable_type: VarType, provider: Arc, ) { @@ -330,7 +317,7 @@ impl ExecutionContext { /// /// `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` /// `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` - pub fn register_udf(&mut self, f: ScalarUDF) { + pub fn register_udf(&self, f: ScalarUDF) { self.state .lock() .scalar_functions @@ -344,7 +331,7 @@ impl ExecutionContext { /// /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"` /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"` - pub fn register_udaf(&mut self, f: AggregateUDF) { + pub fn register_udaf(&self, f: AggregateUDF) { self.state .lock() .aggregate_functions @@ -354,13 +341,13 @@ impl ExecutionContext { /// Creates a DataFrame for reading an Avro data source. pub async fn read_avro( - &mut self, + &self, uri: impl Into, options: AvroReadOptions<'_>, ) -> Result> { let uri: String = uri.into(); - let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().config.target_partitions; + let (object_store, path) = self.state.lock().runtime.object_store(&uri)?; + let target_partitions = self.copied_config().target_partitions; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_avro( @@ -385,13 +372,13 @@ impl ExecutionContext { /// Creates a DataFrame for reading a CSV data source. pub async fn read_csv( - &mut self, + &self, uri: impl Into, options: CsvReadOptions<'_>, ) -> Result> { let uri: String = uri.into(); - let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().config.target_partitions; + let (object_store, path) = self.state.lock().runtime.object_store(&uri)?; + let target_partitions = self.copied_config().target_partitions; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_csv( @@ -408,12 +395,12 @@ impl ExecutionContext { /// Creates a DataFrame for reading a Parquet data source. pub async fn read_parquet( - &mut self, + &self, uri: impl Into, ) -> Result> { let uri: String = uri.into(); - let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().config.target_partitions; + let (object_store, path) = self.state.lock().runtime.object_store(&uri)?; + let target_partitions = self.copied_config().target_partitions; let logical_plan = LogicalPlanBuilder::scan_parquet(object_store, path, None, target_partitions) .await? @@ -426,7 +413,7 @@ impl ExecutionContext { /// Creates a DataFrame for reading a custom TableProvider. pub fn read_table( - &mut self, + &self, provider: Arc, ) -> Result> { Ok(Arc::new(DataFrameImpl::new( @@ -439,13 +426,13 @@ impl ExecutionContext { /// find the files to be processed /// This is async because it might need to resolve the schema. pub async fn register_listing_table<'a>( - &'a mut self, + &'a self, name: &'a str, uri: &'a str, options: ListingOptions, provided_schema: Option, ) -> Result<()> { - let (object_store, path) = self.object_store(uri)?; + let (object_store, path) = self.state.lock().runtime.object_store(uri)?; let resolved_schema = match provided_schema { None => { options @@ -465,13 +452,13 @@ impl ExecutionContext { /// Registers a CSV data source so that it can be referenced from SQL statements /// executed against this context. pub async fn register_csv( - &mut self, + &self, name: &str, uri: &str, options: CsvReadOptions<'_>, ) -> Result<()> { let listing_options = - options.to_listing_options(self.state.lock().config.target_partitions); + options.to_listing_options(self.copied_config().target_partitions); self.register_listing_table( name, @@ -486,10 +473,10 @@ impl ExecutionContext { /// Registers a Parquet data source so that it can be referenced from SQL statements /// executed against this context. - pub async fn register_parquet(&mut self, name: &str, uri: &str) -> Result<()> { + pub async fn register_parquet(&self, name: &str, uri: &str) -> Result<()> { let (target_partitions, enable_pruning) = { - let m = self.state.lock(); - (m.config.target_partitions, m.config.parquet_pruning) + let c = self.copied_config(); + (c.target_partitions, c.parquet_pruning) }; let file_format = ParquetFormat::default().with_enable_pruning(enable_pruning); @@ -509,13 +496,13 @@ impl ExecutionContext { /// Registers an Avro data source so that it can be referenced from SQL statements /// executed against this context. pub async fn register_avro( - &mut self, + &self, name: &str, uri: &str, options: AvroReadOptions<'_>, ) -> Result<()> { let listing_options = - options.to_listing_options(self.state.lock().config.target_partitions); + options.to_listing_options(self.copied_config().target_partitions); self.register_listing_table(name, uri, listing_options, options.schema) .await?; @@ -534,9 +521,9 @@ impl ExecutionContext { catalog: Arc, ) -> Option> { let name = name.into(); - + let information_schema = self.copied_config().information_schema; let state = self.state.lock(); - let catalog = if state.config.information_schema { + let catalog = if information_schema { Arc::new(CatalogWithInformationSchema::new( Arc::downgrade(&state.catalog_list), catalog, @@ -553,35 +540,6 @@ impl ExecutionContext { self.state.lock().catalog_list.catalog(name) } - /// Registers a object store with scheme using a custom `ObjectStore` so that - /// an external file system or object storage system could be used against this context. - /// - /// Returns the `ObjectStore` previously registered for this scheme, if any - pub fn register_object_store( - &self, - scheme: impl Into, - object_store: Arc, - ) -> Option> { - let scheme = scheme.into(); - - self.state - .lock() - .object_store_registry - .register_store(scheme, object_store) - } - - /// Retrieves a `ObjectStore` instance by scheme - pub fn object_store<'a>( - &self, - uri: &'a str, - ) -> Result<(Arc, &'a str)> { - self.state - .lock() - .object_store_registry - .get_by_uri(uri) - .map_err(DataFusionError::from) - } - /// Registers a table using a custom `TableProvider` so that /// it can be referenced from SQL statements executed against this /// context. @@ -589,7 +547,7 @@ impl ExecutionContext { /// Returns the `TableProvider` previously registered for this /// reference, if any pub fn register_table<'a>( - &'a mut self, + &'a self, table_ref: impl Into>, provider: Arc, ) -> Result>> { @@ -604,7 +562,7 @@ impl ExecutionContext { /// /// Returns the registered provider, if any pub fn deregister_table<'a>( - &'a mut self, + &'a self, table_ref: impl Into>, ) -> Result>> { let table_ref = table_ref.into(); @@ -614,6 +572,20 @@ impl ExecutionContext { .deregister_table(table_ref.table()) } + /// Check whether the given table exists in the schema provider or not + /// Returns true if the table exists. + pub fn table_exist<'a>( + &'a self, + table_ref: impl Into>, + ) -> Result { + let table_ref = table_ref.into(); + Ok(self + .state + .lock() + .schema_for_ref(table_ref)? + .table_exist(table_ref.table())) + } + /// Retrieves a DataFrame representing a table previously registered by calling the /// register_table function. /// @@ -645,9 +617,9 @@ impl ExecutionContext { /// /// Use [`table`] to get a specific table. /// - /// [`table`]: ExecutionContext::table + /// [`table`]:SessionContext::table #[deprecated( - note = "Please use the catalog provider interface (`ExecutionContext::catalog`) to examine available catalogs, schemas, and tables" + note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables" )] pub fn tables(&self) -> Result> { Ok(self @@ -663,26 +635,7 @@ impl ExecutionContext { /// Optimizes the logical plan by applying optimizer rules. pub fn optimize(&self, plan: &LogicalPlan) -> Result { - if let LogicalPlan::Explain(e) = plan { - let mut stringified_plans = e.stringified_plans.clone(); - - // optimize the child plan, capturing the output of each optimizer - let plan = - self.optimize_internal(e.plan.as_ref(), |optimized_plan, optimizer| { - let optimizer_name = optimizer.name().to_string(); - let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; - stringified_plans.push(optimized_plan.to_stringified(plan_type)); - })?; - - Ok(LogicalPlan::Explain(Explain { - verbose: e.verbose, - plan: Arc::new(plan), - stringified_plans, - schema: e.schema.clone(), - })) - } else { - self.optimize_internal(plan, |_, _| {}) - } + self.state.lock().optimize(plan) } /// Creates a physical plan from a logical plan. @@ -690,7 +643,7 @@ impl ExecutionContext { &self, logical_plan: &LogicalPlan, ) -> Result> { - let (state, planner) = { + let state_cloned = { let mut state = self.state.lock(); state.execution_props.start_execution(); @@ -703,10 +656,9 @@ impl ExecutionContext { // original state after it has been cloned, they will not be picked up by the // clone but that is okay, as it is equivalent to postponing the state update // by keeping the lock until the end of the function scope. - (state.clone(), Arc::clone(&state.config.query_planner)) + state.clone() }; - - planner.create_physical_plan(logical_plan, &state).await + state_cloned.create_physical_plan(logical_plan).await } /// Executes a query and writes the results to a partitioned CSV file. @@ -715,7 +667,8 @@ impl ExecutionContext { plan: Arc, path: impl AsRef, ) -> Result<()> { - plan_to_csv(self, plan, path).await + let state = self.state.lock().clone(); + plan_to_csv(&state, plan, path).await } /// Executes a query and writes the results to a partitioned Parquet file. @@ -725,55 +678,129 @@ impl ExecutionContext { path: impl AsRef, writer_properties: Option, ) -> Result<()> { - plan_to_parquet(self, plan, path, writer_properties).await + let state = self.state.lock().clone(); + plan_to_parquet(&state, plan, path, writer_properties).await } +} - /// Optimizes the logical plan by applying optimizer rules, and - /// invoking observer function after each call - fn optimize_internal( - &self, - plan: &LogicalPlan, - mut observer: F, - ) -> Result - where - F: FnMut(&LogicalPlan, &dyn OptimizerRule), - { - let state = &mut self.state.lock(); - let execution_props = &mut state.execution_props.clone(); - let optimizers = &state.config.optimizers; +impl FunctionRegistry for SessionContext { + fn udfs(&self) -> HashSet { + self.state.lock().udfs() + } - let execution_props = execution_props.start_execution(); + fn udf(&self, name: &str) -> Result> { + self.state.lock().udf(name) + } - let mut new_plan = plan.clone(); - debug!("Input logical plan:\n{}\n", plan.display_indent()); - trace!("Full input logical plan:\n{:?}", plan); - for optimizer in optimizers { - new_plan = optimizer.optimize(&new_plan, execution_props)?; - observer(&new_plan, optimizer.as_ref()); - } - debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); - trace!("Full Optimized logical plan:\n {:?}", plan); - Ok(new_plan) + fn udaf(&self, name: &str) -> Result> { + self.state.lock().udaf(name) } } -impl From>> for ExecutionContext { - fn from(state: Arc>) -> Self { - ExecutionContext { state } +/// Task Context Properties +pub enum TaskProperties { + ///SessionConfig + SessionConfig(SessionConfig), + /// KV pairs + KVPairs(HashMap), +} + +/// Ballista Task Context +pub struct TaskContext { + /// Optional Task Identify + pub task_id: Option, + /// Session Id + pub session_id: String, + /// Task settings + pub task_settings: TaskProperties, + /// Runtime environment associated with this task context + pub runtime: Arc, +} + +impl TaskContext { + /// Create a new task context instance + pub fn new( + task_id: String, + session_id: String, + task_settings: HashMap, + runtime: Arc, + ) -> Self { + Self { + task_id: Some(task_id), + session_id, + task_settings: TaskProperties::KVPairs(task_settings), + runtime, + } } } -impl FunctionRegistry for ExecutionContext { - fn udfs(&self) -> HashSet { - self.state.lock().udfs() +impl TaskContext { + /// Return session id + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Return the SessionConfig associated with the Task + pub fn session_config(&self) -> SessionConfig { + let task_settings = &self.task_settings; + match task_settings { + TaskProperties::KVPairs(props) => { + let session_config = SessionConfig::new(); + session_config + .with_batch_size(props.get(BATCH_SIZE).unwrap().parse().unwrap()) + .with_target_partitions( + props.get(TARGET_PARTITIONS).unwrap().parse().unwrap(), + ) + .with_repartition_joins( + props.get(REPARTITION_JOINS).unwrap().parse().unwrap(), + ) + .with_repartition_aggregations( + props + .get(REPARTITION_AGGREGATIONS) + .unwrap() + .parse() + .unwrap(), + ) + .with_repartition_windows( + props.get(REPARTITION_WINDOWS).unwrap().parse().unwrap(), + ) + .with_parquet_pruning( + props.get(PARQUET_PRUNING).unwrap().parse().unwrap(), + ) + } + TaskProperties::SessionConfig(session_config) => session_config.clone(), + } } +} - fn udf(&self, name: &str) -> Result> { - self.state.lock().udf(name) +/// Create a new task context instance from SessionContext +impl From<&SessionContext> for TaskContext { + fn from(session: &SessionContext) -> Self { + let state_clone = session.state.lock(); + let session_id = session.session_id.clone(); + let config = state_clone.config.lock().clone(); + let runtime = state_clone.runtime.clone(); + Self { + task_id: None, + session_id, + task_settings: TaskProperties::SessionConfig(config), + runtime, + } } +} - fn udaf(&self, name: &str) -> Result> { - self.state.lock().udaf(name) +/// Create a new task context instance from SessionState +impl From<&SessionState> for TaskContext { + fn from(state: &SessionState) -> Self { + let session_id = state.session_id.clone(); + let config = state.config.lock().clone(); + let runtime = state.runtime.clone(); + Self { + task_id: None, + session_id, + task_settings: TaskProperties::SessionConfig(config), + runtime, + } } } @@ -784,7 +811,7 @@ pub trait QueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>; } @@ -797,33 +824,37 @@ impl QueryPlanner for DefaultQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { let planner = DefaultPhysicalPlanner::default(); - planner.create_physical_plan(logical_plan, ctx_state).await + planner + .create_physical_plan(logical_plan, session_state) + .await } } -/// Configuration options for execution context +/// Session Configuration entry name +pub const BATCH_SIZE: &str = "batch_size"; +/// Session Configuration entry name +pub const TARGET_PARTITIONS: &str = "target_partitions"; +/// Session Configuration entry name +pub const REPARTITION_JOINS: &str = "repartition_joins"; +/// Session Configuration entry name +pub const REPARTITION_AGGREGATIONS: &str = "repartition_aggregations"; +/// Session Configuration entry name +pub const REPARTITION_WINDOWS: &str = "repartition_windows"; +/// Session Configuration entry name +pub const PARQUET_PRUNING: &str = "parquet_pruning"; + +/// Configuration options for session context #[derive(Clone)] -pub struct ExecutionConfig { +pub struct SessionConfig { + /// Default batch size while creating new batches, it's especially useful + /// for buffer-in-memory batches since creating tiny batches would results + /// in too much metadata memory consumption. + pub batch_size: usize, /// Number of partitions for query execution. Increasing partitions can increase concurrency. pub target_partitions: usize, - /// Responsible for optimizing a logical plan - optimizers: Vec>, - /// Responsible for optimizing a physical execution plan - pub physical_optimizers: Vec>, - /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` - query_planner: Arc, - /// Default catalog name for table resolution - default_catalog: String, - /// Default schema name for table resolution - default_schema: String, - /// Whether the default catalog and schema should be created automatically - create_default_catalog_and_schema: bool, - /// Should DataFusion provide access to `information_schema` - /// virtual tables for displaying schema information - information_schema: bool, /// Should DataFusion repartition data using the join keys to execute joins in parallel /// using the provided `target_partitions` level pub repartition_joins: bool, @@ -834,115 +865,56 @@ pub struct ExecutionConfig { /// parallel using the provided `target_partitions` level pub repartition_windows: bool, /// Should Datafusion parquet reader using the predicate to prune data - parquet_pruning: bool, - /// Runtime configurations such as memory threshold and local disk for spill - pub runtime: RuntimeConfig, + pub parquet_pruning: bool, + + /// Default catalog name for table resolution + default_catalog: String, + /// Default schema name for table resolution + default_schema: String, + /// Whether the default catalog and schema should be created automatically + /// Whether the default catalog and schema should be created automatically + create_default_catalog_and_schema: bool, + /// Should DataFusion provide access to `information_schema` + /// virtual tables for displaying schema information + information_schema: bool, } -impl Default for ExecutionConfig { +impl Default for SessionConfig { fn default() -> Self { Self { + batch_size: 8192, target_partitions: num_cpus::get(), - optimizers: vec![ - // Simplify expressions first to maximize the chance - // of applying other optimizations - Arc::new(SimplifyExpressions::new()), - Arc::new(CommonSubexprEliminate::new()), - Arc::new(EliminateLimit::new()), - Arc::new(ProjectionPushDown::new()), - Arc::new(FilterPushDown::new()), - Arc::new(LimitPushDown::new()), - Arc::new(SingleDistinctToGroupBy::new()), - // ToApproxPerc must be applied last because - // it rewrites only the function and may interfere with - // other rules - Arc::new(ToApproxPerc::new()), - ], - physical_optimizers: vec![ - Arc::new(AggregateStatistics::new()), - Arc::new(HashBuildProbeOrder::new()), - Arc::new(CoalesceBatches::new()), - Arc::new(Repartition::new()), - Arc::new(AddCoalescePartitionsExec::new()), - ], - query_planner: Arc::new(DefaultQueryPlanner {}), - default_catalog: "datafusion".to_owned(), - default_schema: "public".to_owned(), - create_default_catalog_and_schema: true, - information_schema: false, repartition_joins: true, repartition_aggregations: true, repartition_windows: true, parquet_pruning: true, - runtime: RuntimeConfig::default(), + default_catalog: "datafusion".to_owned(), + default_schema: "public".to_owned(), + create_default_catalog_and_schema: true, + information_schema: false, } } } -impl ExecutionConfig { - /// Create an execution config with default setting +impl SessionConfig { + /// Create an session config with default setting pub fn new() -> Self { Default::default() } - /// Customize target_partitions - pub fn with_target_partitions(mut self, n: usize) -> Self { - // partition count must be greater than zero - assert!(n > 0); - self.target_partitions = n; - self - } - /// Customize batch size pub fn with_batch_size(mut self, n: usize) -> Self { // batch size must be greater than zero assert!(n > 0); - self.runtime.batch_size = n; - self - } - - /// Replace the default query planner - pub fn with_query_planner( - mut self, - query_planner: Arc, - ) -> Self { - self.query_planner = query_planner; - self - } - - /// Replace the optimizer rules - pub fn with_optimizer_rules( - mut self, - optimizers: Vec>, - ) -> Self { - self.optimizers = optimizers; - self - } - - /// Replace the physical optimizer rules - pub fn with_physical_optimizer_rules( - mut self, - physical_optimizers: Vec>, - ) -> Self { - self.physical_optimizers = physical_optimizers; + self.batch_size = n; self } - /// Adds a new [`OptimizerRule`] - pub fn add_optimizer_rule( - mut self, - optimizer_rule: Arc, - ) -> Self { - self.optimizers.push(optimizer_rule); - self - } - - /// Adds a new [`PhysicalOptimizerRule`] - pub fn add_physical_optimizer_rule( - mut self, - optimizer_rule: Arc, - ) -> Self { - self.physical_optimizers.push(optimizer_rule); + /// Customize target_partitions + pub fn with_target_partitions(mut self, n: usize) -> Self { + // partition count must be greater than zero + assert!(n > 0); + self.target_partitions = n; self } @@ -993,55 +965,35 @@ impl ExecutionConfig { self } - /// Customize runtime config - pub fn with_runtime_config(mut self, config: RuntimeConfig) -> Self { - self.runtime = config; - self - } - - /// Use an an existing [MemoryManager] - pub fn with_existing_memory_manager(mut self, existing: Arc) -> Self { - self.runtime = self - .runtime - .with_memory_manager(MemoryManagerConfig::new_existing(existing)); - self - } - - /// Specify the total memory to use while running the DataFusion - /// plan to `max_memory * memory_fraction` in bytes. - /// - /// Note DataFusion does not yet respect this limit in all cases. - pub fn with_memory_limit( - mut self, - max_memory: usize, - memory_fraction: f64, - ) -> Result { - self.runtime = - self.runtime - .with_memory_manager(MemoryManagerConfig::try_new_limit( - max_memory, - memory_fraction, - )?); - Ok(self) - } - - /// Use an an existing [DiskManager] - pub fn with_existing_disk_manager(mut self, existing: Arc) -> Self { - self.runtime = self - .runtime - .with_disk_manager(DiskManagerConfig::new_existing(existing)); - self - } - - /// Use the specified path to create any needed temporary files - pub fn with_temp_file_path(mut self, path: impl Into) -> Self { - self.runtime = self - .runtime - .with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])); - self + /// Convert configuration to name-value pairs + pub fn to_props(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert(BATCH_SIZE.to_owned(), format!("{}", self.batch_size)); + map.insert( + TARGET_PARTITIONS.to_owned(), + format!("{}", self.target_partitions), + ); + map.insert( + REPARTITION_JOINS.to_owned(), + format!("{}", self.repartition_joins), + ); + map.insert( + REPARTITION_AGGREGATIONS.to_owned(), + format!("{}", self.repartition_aggregations), + ); + map.insert( + REPARTITION_WINDOWS.to_owned(), + format!("{}", self.repartition_windows), + ); + map.insert( + PARQUET_PRUNING.to_owned(), + format!("{}", self.parquet_pruning), + ); + map } } +// TODO refact this later, need to add a new QueryExecutionContext to track per query execution properties /// Holds per-execution properties and data (such as starting timestamps, etc). /// An instance of this struct is created each time a [`LogicalPlan`] is prepared for /// execution (optimized). If the same plan is optimized multiple times, a new @@ -1104,42 +1056,105 @@ impl ExecutionProps { } } -/// Execution context for registering data sources and executing queries +/// Session state for registering data sources and executing queries #[derive(Clone)] -pub struct ExecutionContextState { +pub struct SessionState { + /// Uuid for the session + pub session_id: String, + /// Responsible for optimizing a logical plan + pub optimizers: Vec>, + /// Responsible for optimizing a physical execution plan + pub physical_optimizers: Vec>, + /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` + pub query_planner: Arc, /// Collection of catalogs containing schemas and ultimately TableProviders pub catalog_list: Arc, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, /// Aggregate functions registered in the context pub aggregate_functions: HashMap>, - /// Context configuration - pub config: ExecutionConfig, /// Execution properties pub execution_props: ExecutionProps, - /// Object Store that are registered with the context - pub object_store_registry: Arc, - /// Runtime environment - pub runtime_env: Arc, + /// Default catalog name for table resolution + pub default_catalog: String, + /// Default schema name for table resolution + pub default_schema: String, + /// Session configuration + pub config: Arc>, + /// Runtime environment associated with this session + pub runtime: Arc, } -impl Default for ExecutionContextState { - fn default() -> Self { - Self::new() - } +/// Default session builder using the provided configuration +pub fn default_session_builder(config: SessionConfig) -> SessionState { + SessionState::with_config( + config, + Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()), + ) } -impl ExecutionContextState { - /// Returns new ExecutionContextState - pub fn new() -> Self { - ExecutionContextState { - catalog_list: Arc::new(MemoryCatalogList::new()), +impl SessionState { + /// Returns new SessionState using the provided configuration and runtime + pub fn with_config(config: SessionConfig, runtime: Arc) -> Self { + let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; + + if config.create_default_catalog_and_schema { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog.register_schema( + config.default_schema.clone(), + Arc::new(MemorySchemaProvider::new()), + ); + + let default_catalog: Arc = if config.information_schema { + Arc::new(CatalogWithInformationSchema::new( + Arc::downgrade(&catalog_list), + Arc::new(default_catalog), + )) + } else { + Arc::new(default_catalog) + }; + catalog_list + .register_catalog(config.default_catalog.clone(), default_catalog); + } + let session_id = Uuid::new_v4().to_string(); + let default_catalog = config.default_catalog.clone(); + let default_schema = config.default_schema.clone(); + let shared_conf = Arc::new(Mutex::new(config)); + + SessionState { + session_id, + optimizers: vec![ + // Simplify expressions first to maximize the chance + // of applying other optimizations + Arc::new(SimplifyExpressions::new()), + Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateLimit::new()), + Arc::new(ProjectionPushDown::new()), + Arc::new(FilterPushDown::new()), + Arc::new(LimitPushDown::new()), + Arc::new(SingleDistinctToGroupBy::new()), + // ToApproxPerc must be applied last because + // it rewrites only the function and may interfere with + // other rules + Arc::new(ToApproxPerc::new()), + ], + physical_optimizers: vec![ + Arc::new(AggregateStatistics::new()), + Arc::new(HashBuildProbeOrder::new()), + Arc::new(CoalesceBatches::new()), + Arc::new(Repartition::new()), + Arc::new(AddCoalescePartitionsExec::new()), + ], + query_planner: Arc::new(DefaultQueryPlanner {}), + catalog_list, scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), - config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), - object_store_registry: Arc::new(ObjectStoreRegistry::new()), - runtime_env: Arc::new(RuntimeEnv::default()), + default_catalog, + default_schema, + config: shared_conf, + runtime, } } @@ -1149,7 +1164,7 @@ impl ExecutionContextState { ) -> ResolvedTableReference<'a> { table_ref .into() - .resolve(&self.config.default_catalog, &self.config.default_schema) + .resolve(&self.default_catalog, &self.default_schema) } fn schema_for_ref<'a>( @@ -1174,9 +1189,125 @@ impl ExecutionContextState { )) }) } + + /// Replace the default query planner + pub fn with_query_planner( + mut self, + query_planner: Arc, + ) -> Self { + self.query_planner = query_planner; + self + } + + /// Replace the optimizer rules + pub fn with_optimizer_rules( + mut self, + optimizers: Vec>, + ) -> Self { + self.optimizers = optimizers; + self + } + + /// Replace the physical optimizer rules + pub fn with_physical_optimizer_rules( + mut self, + physical_optimizers: Vec>, + ) -> Self { + self.physical_optimizers = physical_optimizers; + self + } + + /// Adds a new [`OptimizerRule`] + pub fn add_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + self.optimizers.push(optimizer_rule); + self + } + + /// Adds a new [`PhysicalOptimizerRule`] + pub fn add_physical_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + self.physical_optimizers.push(optimizer_rule); + self + } + + /// Selects a name for the default catalog and schema + pub fn with_default_catalog_and_schema( + mut self, + catalog: impl Into, + schema: impl Into, + ) -> Self { + self.default_catalog = catalog.into(); + self.default_schema = schema.into(); + self + } + + /// Optimizes the logical plan by applying optimizer rules. + pub fn optimize(&self, plan: &LogicalPlan) -> Result { + if let LogicalPlan::Explain(e) = plan { + let mut stringified_plans = e.stringified_plans.clone(); + + // optimize the child plan, capturing the output of each optimizer + let plan = + self.optimize_internal(e.plan.as_ref(), |optimized_plan, optimizer| { + let optimizer_name = optimizer.name().to_string(); + let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; + stringified_plans.push(optimized_plan.to_stringified(plan_type)); + })?; + + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: Arc::new(plan), + stringified_plans, + schema: e.schema.clone(), + })) + } else { + self.optimize_internal(plan, |_, _| {}) + } + } + + /// Optimizes the logical plan by applying optimizer rules, and + /// invoking observer function after each call + fn optimize_internal( + &self, + plan: &LogicalPlan, + mut observer: F, + ) -> Result + where + F: FnMut(&LogicalPlan, &dyn OptimizerRule), + { + let execution_props = &mut self.execution_props.clone(); + let optimizers = &self.optimizers; + + let execution_props = execution_props.start_execution(); + + let mut new_plan = plan.clone(); + debug!("Input logical plan:\n{}\n", plan.display_indent()); + trace!("Full input logical plan:\n{:?}", plan); + for optimizer in optimizers { + new_plan = optimizer.optimize(&new_plan, execution_props)?; + observer(&new_plan, optimizer.as_ref()); + } + debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); + trace!("Full Optimized logical plan:\n {:?}", plan); + Ok(new_plan) + } + + /// Creates a physical plan from a logical plan. + pub async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + ) -> Result> { + let planner = self.query_planner.clone(); + planner.create_physical_plan(logical_plan, self).await + } } -impl ContextProvider for ExecutionContextState { +impl ContextProvider for SessionState { fn get_table_provider(&self, name: TableReference) -> Option> { let resolved_ref = self.resolve_table_ref(name); let schema = self.schema_for_ref(resolved_ref).ok()?; @@ -1192,7 +1323,7 @@ impl ContextProvider for ExecutionContextState { } } -impl FunctionRegistry for ExecutionContextState { +impl FunctionRegistry for SessionState { fn udfs(&self) -> HashSet { self.scalar_functions.keys().cloned().collect() } @@ -1251,45 +1382,11 @@ mod tests { use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; - #[tokio::test] - async fn shared_memory_and_disk_manager() { - // Demonstrate the ability to share DiskManager and - // MemoryManager between two different executions. - let ctx1 = ExecutionContext::new(); - - // configure with same memory / disk manager - let memory_manager = ctx1.runtime_env().memory_manager.clone(); - let disk_manager = ctx1.runtime_env().disk_manager.clone(); - let config = ExecutionConfig::new() - .with_existing_memory_manager(memory_manager.clone()) - .with_existing_disk_manager(disk_manager.clone()); - - let ctx2 = ExecutionContext::with_config(config); - - assert!(std::ptr::eq( - Arc::as_ptr(&memory_manager), - Arc::as_ptr(&ctx1.runtime_env().memory_manager) - )); - assert!(std::ptr::eq( - Arc::as_ptr(&memory_manager), - Arc::as_ptr(&ctx2.runtime_env().memory_manager) - )); - - assert!(std::ptr::eq( - Arc::as_ptr(&disk_manager), - Arc::as_ptr(&ctx1.runtime_env().disk_manager) - )); - assert!(std::ptr::eq( - Arc::as_ptr(&disk_manager), - Arc::as_ptr(&ctx2.runtime_env().disk_manager) - )); - } - #[tokio::test] async fn create_variable_expr() -> Result<()> { let tmp_dir = TempDir::new()?; let partition_count = 4; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; + let ctx = create_ctx(&tmp_dir, partition_count).await?; let variable_provider = test::variable::SystemVar::new(); ctx.register_variable(VarType::System, Arc::new(variable_provider)); @@ -1299,8 +1396,7 @@ mod tests { let provider = test::create_table_dual(); ctx.register_table("dual", provider)?; - let results = - plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?; + let results = plan_and_collect(&ctx, "SELECT @@version, @name FROM dual").await?; let expected = vec![ "+----------------------+------------------------+", @@ -1318,7 +1414,7 @@ mod tests { async fn register_deregister() -> Result<()> { let tmp_dir = TempDir::new()?; let partition_count = 4; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; + let ctx = create_ctx(&tmp_dir, partition_count).await?; let provider = test::create_table_dual(); ctx.register_table("dual", provider)?; @@ -1577,11 +1673,11 @@ mod tests { #[tokio::test] async fn aggregate_decimal_min() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") + let result = plan_and_collect(&ctx, "select min(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1601,12 +1697,12 @@ mod tests { #[tokio::test] async fn aggregate_decimal_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select max(c1) from d_table") + let result = plan_and_collect(&ctx, "select max(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1626,11 +1722,11 @@ mod tests { #[tokio::test] async fn aggregate_decimal_sum() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table") + let result = plan_and_collect(&ctx, "select sum(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1650,11 +1746,11 @@ mod tests { #[tokio::test] async fn aggregate_decimal_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table") + let result = plan_and_collect(&ctx, "select avg(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1965,7 +2061,7 @@ mod tests { #[tokio::test] async fn group_by_date_trunc() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c2", DataType::UInt64, false), Field::new( @@ -1996,7 +2092,7 @@ mod tests { .await?; let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT date_trunc('week', t1) as week, SUM(c2) FROM test GROUP BY date_trunc('week', t1)", ).await?; @@ -2016,7 +2112,7 @@ mod tests { #[tokio::test] async fn group_by_largeutf8() { { - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); // input data looks like: // A, 1 @@ -2047,7 +2143,7 @@ mod tests { ctx.register_table("t", Arc::new(provider)).unwrap(); let results = - plan_and_collect(&mut ctx, "SELECT str, count(val) FROM t GROUP BY str") + plan_and_collect(&ctx, "SELECT str, count(val) FROM t GROUP BY str") .await .expect("ran plan correctly"); @@ -2066,7 +2162,7 @@ mod tests { #[tokio::test] async fn unprojected_filter() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let df = ctx .read_table(test::table_with_sequence(1, 3).unwrap()) .unwrap(); @@ -2091,7 +2187,7 @@ mod tests { #[tokio::test] async fn group_by_dictionary() { async fn run_test_case() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // input data looks like: // A, 1 @@ -2119,12 +2215,10 @@ mod tests { let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); - let results = plan_and_collect( - &mut ctx, - "SELECT dict, count(val) FROM t GROUP BY dict", - ) - .await - .expect("ran plan correctly"); + let results = + plan_and_collect(&ctx, "SELECT dict, count(val) FROM t GROUP BY dict") + .await + .expect("ran plan correctly"); let expected = vec![ "+------+--------------+", @@ -2139,7 +2233,7 @@ mod tests { // Now, use dict as an aggregate let results = - plan_and_collect(&mut ctx, "SELECT val, count(dict) FROM t GROUP BY val") + plan_and_collect(&ctx, "SELECT val, count(dict) FROM t GROUP BY val") .await .expect("ran plan correctly"); @@ -2156,7 +2250,7 @@ mod tests { // Now, use dict as an aggregate let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT val, count(distinct dict) FROM t GROUP BY val", ) .await @@ -2188,7 +2282,7 @@ mod tests { partitions: Vec>, ) -> Result> { let tmp_dir = TempDir::new()?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c_group", DataType::Utf8, false), Field::new("c_int8", DataType::Int8, false), @@ -2227,7 +2321,7 @@ mod tests { .await?; let results = plan_and_collect( - &mut ctx, + &ctx, " SELECT c_group, @@ -2337,14 +2431,13 @@ mod tests { #[tokio::test] async fn limit() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; + let ctx = create_ctx(&tmp_dir, 1).await?; ctx.register_table("t", test::table_with_sequence(1, 1000).unwrap()) .unwrap(); - let results = - plan_and_collect(&mut ctx, "SELECT i FROM t ORDER BY i DESC limit 3") - .await - .unwrap(); + let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i DESC limit 3") + .await + .unwrap(); let expected = vec![ "+------+", "| i |", "+------+", "| 1000 |", "| 999 |", "| 998 |", @@ -2353,7 +2446,7 @@ mod tests { assert_batches_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT i FROM t ORDER BY i limit 3") + let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i limit 3") .await .unwrap(); @@ -2363,7 +2456,7 @@ mod tests { assert_batches_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT i FROM t limit 3") + let results = plan_and_collect(&ctx, "SELECT i FROM t limit 3") .await .unwrap(); @@ -2377,7 +2470,7 @@ mod tests { #[tokio::test] async fn limit_multi_partitions() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; + let ctx = create_ctx(&tmp_dir, 1).await?; let partitions = vec![ vec![test::make_partition(0)], @@ -2393,14 +2486,14 @@ mod tests { ctx.register_table("t", provider).unwrap(); // select all rows - let results = plan_and_collect(&mut ctx, "SELECT i FROM t").await.unwrap(); + let results = plan_and_collect(&ctx, "SELECT i FROM t").await.unwrap(); let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); assert_eq!(num_rows, 15); for limit in 1..10 { let query = format!("SELECT i FROM t limit {}", limit); - let results = plan_and_collect(&mut ctx, &query).await.unwrap(); + let results = plan_and_collect(&ctx, &query).await.unwrap(); let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); assert_eq!(num_rows, limit, "mismatch with query {}", query); @@ -2411,7 +2504,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_functions() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2423,19 +2516,19 @@ mod tests { "+-----------+", ]; - let results = plan_and_collect(&mut ctx, "SELECT sqrt(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT sqrt(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT SQRT(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT SQRT(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&mut ctx, "SELECT \"SQRT\"(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT \"SQRT\"(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2443,7 +2536,7 @@ mod tests { "Error during planning: Invalid function 'SQRT'" ); - let results = plan_and_collect(&mut ctx, "SELECT \"sqrt\"(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT \"sqrt\"(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); @@ -2451,7 +2544,7 @@ mod tests { #[tokio::test] async fn case_builtin_math_expression() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let type_values = vec![ ( @@ -2511,7 +2604,7 @@ mod tests { "| 1 |", "+-----------+", ]; - let results = plan_and_collect(&mut ctx, "SELECT sqrt(v) FROM t") + let results = plan_and_collect(&ctx, "SELECT sqrt(v) FROM t") .await .unwrap(); @@ -2521,7 +2614,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2537,7 +2630,7 @@ mod tests { )); // doesn't work as it was registered with non lowercase - let err = plan_and_collect(&mut ctx, "SELECT MY_FUNC(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2546,7 +2639,7 @@ mod tests { ); // Can call it if you put quotes - let result = plan_and_collect(&mut ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; let expected = vec![ "+--------------+", @@ -2562,7 +2655,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_aggregates() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2574,19 +2667,19 @@ mod tests { "+----------+", ]; - let results = plan_and_collect(&mut ctx, "SELECT max(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT max(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT MAX(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT MAX(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&mut ctx, "SELECT \"MAX\"(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2594,7 +2687,7 @@ mod tests { "Error during planning: Invalid function 'MAX'" ); - let results = plan_and_collect(&mut ctx, "SELECT \"max\"(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); @@ -2602,7 +2695,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2619,7 +2712,7 @@ mod tests { ctx.register_udaf(my_avg); // doesn't work as it was registered as non lowercase - let err = plan_and_collect(&mut ctx, "SELECT MY_AVG(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2628,7 +2721,7 @@ mod tests { ); // Can call it if you put quotes - let result = plan_and_collect(&mut ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; let expected = vec![ "+-------------+", @@ -2649,7 +2742,7 @@ mod tests { // The main stipulation of this test: use a file extension that isn't .csv. let file_extension = ".tst"; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; ctx.register_csv( "test", @@ -2660,8 +2753,7 @@ mod tests { ) .await?; let results = - plan_and_collect(&mut ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test") - .await?; + plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; assert_eq!(results.len(), 1); let expected = vec![ @@ -2678,7 +2770,7 @@ mod tests { #[tokio::test] async fn send_context_to_threads() -> Result<()> { - // ensure ExecutionContexts can be used in a multi-threaded + // ensure SessionContexts can be used in a multi-threaded // environment. Usecase is for concurrent planing. let tmp_dir = TempDir::new()?; let partition_count = 4; @@ -2705,7 +2797,7 @@ mod tests { #[tokio::test] async fn ctx_sql_should_optimize_plan() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let plan1 = ctx .create_logical_plan("SELECT * FROM (SELECT 1) AS one WHERE TRUE AND TRUE")?; @@ -2736,13 +2828,13 @@ mod tests { vec![Arc::new(Int32Array::from_slice(&[4, 5]))], )?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; - let result = plan_and_collect(&mut ctx, "SELECT AVG(a) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT AVG(a) FROM t").await?; let batch = &result[0]; assert_eq!(1, batch.num_columns()); @@ -2761,9 +2853,10 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_query_planner(Arc::new(MyQueryPlanner {})), - ); + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + let state = SessionState::with_config(SessionConfig::new(), runtime) + .with_query_planner(Arc::new(MyQueryPlanner {})); + let ctx = SessionContext::with_state(state); let df = ctx.sql("SELECT 1").await?; df.collect().await.expect_err("query not supported"); @@ -2772,8 +2865,8 @@ mod tests { #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().create_default_catalog_and_schema(false), + let ctx = SessionContext::with_config( + SessionConfig::new().create_default_catalog_and_schema(false), ); assert!(matches!( @@ -2791,8 +2884,8 @@ mod tests { #[tokio::test] async fn custom_catalog_and_schema() -> Result<()> { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new() + let ctx = SessionContext::with_config( + SessionConfig::new() .create_default_catalog_and_schema(false) .with_default_catalog_and_schema("my_catalog", "my_schema"), ); @@ -2805,7 +2898,7 @@ mod tests { for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] { let result = plan_and_collect( - &mut ctx, + &ctx, &format!("SELECT COUNT(*) AS count FROM {}", table_ref), ) .await?; @@ -2825,8 +2918,7 @@ mod tests { #[tokio::test] async fn cross_catalog_access() -> Result<()> { - let mut ctx = ExecutionContext::new(); - + let ctx = SessionContext::new(); let catalog_a = MemoryCatalogProvider::new(); let schema_a = MemorySchemaProvider::new(); schema_a @@ -2842,7 +2934,7 @@ mod tests { ctx.register_catalog("catalog_b", Arc::new(catalog_b)); let result = plan_and_collect( - &mut ctx, + &ctx, "SELECT cat, SUM(i) AS total FROM ( SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a UNION ALL @@ -2870,8 +2962,8 @@ mod tests { #[tokio::test] async fn catalogs_not_leaked() { // the information schema used to introduce cyclic Arcs - let ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), + let ctx = SessionContext::with_config( + SessionConfig::new().with_information_schema(true), ); // register a single catalog @@ -2893,7 +2985,7 @@ mod tests { #[tokio::test] async fn normalized_column_identifiers() { // create local execution context - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // register csv file with the execution context ctx.register_csv( @@ -2905,7 +2997,7 @@ mod tests { .unwrap(); let sql = "SELECT A, b FROM case_insensitive_test"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -2918,7 +3010,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = "SELECT t.A, b FROM case_insensitive_test AS t"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -2933,7 +3025,7 @@ mod tests { // Aliases let sql = "SELECT t.A as x, b FROM case_insensitive_test AS t"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -2946,7 +3038,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = "SELECT t.A AS X, b FROM case_insensitive_test AS t"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -2959,7 +3051,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t"#; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -2974,7 +3066,7 @@ mod tests { // Order by let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY x"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -2987,7 +3079,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY X"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3000,7 +3092,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t ORDER BY "X""#; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3015,7 +3107,7 @@ mod tests { // Where let sql = "SELECT a, b FROM case_insensitive_test where A IS NOT null"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3030,7 +3122,7 @@ mod tests { // Group by let sql = "SELECT a as x, count(*) as c FROM case_insensitive_test GROUP BY X"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3044,7 +3136,7 @@ mod tests { let sql = r#"SELECT a as "X", count(*) as c FROM case_insensitive_test GROUP BY "X""#; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3064,7 +3156,7 @@ mod tests { async fn create_physical_plan( &self, _logical_plan: &LogicalPlan, - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result> { Err(DataFusionError::NotImplemented( "query not supported".to_string(), @@ -3076,7 +3168,7 @@ mod tests { _expr: &Expr, _input_dfschema: &crate::logical_plan::DFSchema, _input_schema: &Schema, - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result> { unimplemented!() } @@ -3089,18 +3181,18 @@ mod tests { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { let physical_planner = MyPhysicalPlanner {}; physical_planner - .create_physical_plan(logical_plan, ctx_state) + .create_physical_plan(logical_plan, session_state) .await } } /// Execute SQL and return results async fn plan_and_collect( - ctx: &mut ExecutionContext, + ctx: &SessionContext, sql: &str, ) -> Result> { ctx.sql(sql).await?.collect().await @@ -3109,8 +3201,8 @@ mod tests { /// Execute SQL and return results async fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await + let ctx = create_ctx(&tmp_dir, partition_count).await?; + plan_and_collect(&ctx, sql).await } /// Generate CSV partitions within the supplied directory @@ -3146,10 +3238,9 @@ mod tests { async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, - ) -> Result { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_target_partitions(8), - ); + ) -> Result { + let ctx = + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -3178,19 +3269,19 @@ mod tests { #[async_trait] impl CallReadTrait for CallRead { async fn call_read_csv(&self) -> Arc { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() } async fn call_read_avro(&self) -> Arc { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.read_avro("dummy", AvroReadOptions::default()) .await .unwrap() } async fn call_read_parquet(&self) -> Arc { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.read_parquet("dummy").await.unwrap() } } diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 2af1cd41c35d..509d22900080 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -17,15 +17,15 @@ //! Implementation of DataFrame API. -use parking_lot::Mutex; use std::any::Any; +use std::ops::Deref; use std::sync::Arc; use crate::arrow::datatypes::Schema; use crate::arrow::datatypes::SchemaRef; use crate::arrow::record_batch::RecordBatch; use crate::error::Result; -use crate::execution::context::{ExecutionContext, ExecutionContextState}; +use crate::execution::context::{SessionState, TaskContext}; use crate::logical_plan::{ col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, @@ -46,28 +46,35 @@ use crate::physical_plan::{ }; use crate::sql::utils::find_window_exprs; use async_trait::async_trait; +use parking_lot::Mutex; /// Implementation of DataFrame API pub struct DataFrameImpl { - ctx_state: Arc>, + session_state: Arc>, plan: LogicalPlan, } impl DataFrameImpl { /// Create a new Table based on an existing logical plan - pub fn new(ctx_state: Arc>, plan: &LogicalPlan) -> Self { + pub fn new(session_state: Arc>, plan: &LogicalPlan) -> Self { Self { - ctx_state, + session_state, plan: plan.clone(), } } /// Create a physical plan async fn create_physical_plan(&self) -> Result> { - let state = self.ctx_state.lock().clone(); - let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); - let plan = ctx.optimize(&self.plan)?; - ctx.create_physical_plan(&plan).await + // We need to clone `state` to release the lock that is not `Send`. + // Cloning `state` here is fine as we then pass it as immutable `&state`, which + // means that we avoid write consistency issues as the cloned version will not + // be written to. As for eventual modifications that would be applied to the + // original state after it has been cloned, they will not be picked up by the + // clone but that is okay, as it is equivalent to postponing the state update + // by keeping the lock until the end of the function scope. + let state = self.session_state.lock().clone(); + let optimized_plan = state.optimize(&self.plan)?; + state.create_physical_plan(&optimized_plan).await } } @@ -91,12 +98,16 @@ impl TableProvider for DataFrameImpl { projection: &Option>, filters: &[Expr], limit: Option, + _session_id: String, ) -> Result> { let expr = projection .as_ref() // construct projections .map_or_else( - || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan)) as Arc<_>), + || { + Ok(Arc::new(Self::new(self.session_state.clone(), &self.plan)) + as Arc<_>) + }, |projection| { let schema = TableProvider::schema(self).project(projection)?; let names = schema @@ -114,7 +125,7 @@ impl TableProvider for DataFrameImpl { ))?; // add a limit if given Self::new( - self.ctx_state.clone(), + self.session_state.clone(), &limit .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))? .to_logical_plan(), @@ -146,7 +157,7 @@ impl DataFrame for DataFrameImpl { }; let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; Ok(Arc::new(DataFrameImpl::new( - self.ctx_state.clone(), + self.session_state.clone(), &project_plan, ))) } @@ -156,7 +167,10 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .filter(predicate)? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } /// Perform an aggregate query @@ -168,7 +182,10 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .aggregate(group_expr, aggr_expr)? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } /// Limit the number of rows @@ -176,7 +193,10 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .limit(n)? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } /// Sort by specified sorting expressions @@ -184,7 +204,10 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .sort(expr)? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } /// Join with another DataFrame @@ -202,7 +225,10 @@ impl DataFrame for DataFrameImpl { (left_cols.to_vec(), right_cols.to_vec()), )? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } fn repartition( @@ -212,7 +238,10 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .repartition(partitioning_scheme)? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } /// Convert to logical plan @@ -223,9 +252,9 @@ impl DataFrame for DataFrameImpl { /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, collecting all resulting batches into memory async fn collect(&self) -> Result> { + let task_ctx = Arc::new(TaskContext::from(self.session_state.lock().deref())); let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - Ok(collect(plan, runtime).await?) + Ok(collect(plan, task_ctx).await?) } /// Print results. @@ -243,26 +272,26 @@ impl DataFrame for DataFrameImpl { /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, returning a stream over a single partition async fn execute_stream(&self) -> Result { + let task_ctx = Arc::new(TaskContext::from(self.session_state.lock().deref())); let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - execute_stream(plan, runtime).await + execute_stream(plan, task_ctx).await } /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, collecting all resulting batches into memory while maintaining /// partitioning async fn collect_partitioned(&self) -> Result>> { + let task_ctx = Arc::new(TaskContext::from(self.session_state.lock().deref())); let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - Ok(collect_partitioned(plan, runtime).await?) + Ok(collect_partitioned(plan, task_ctx).await?) } /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, returning a stream for each partition async fn execute_stream_partitioned(&self) -> Result> { + let task_ctx = Arc::new(TaskContext::from(self.session_state.lock().deref())); let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - Ok(execute_stream_partitioned(plan, runtime).await?) + Ok(execute_stream_partitioned(plan, task_ctx).await?) } /// Returns the schema from the logical plan @@ -274,11 +303,14 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .explain(verbose, analyze)? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } fn registry(&self) -> Arc { - let registry = self.ctx_state.lock().clone(); + let registry = self.session_state.lock().clone(); Arc::new(registry) } @@ -286,12 +318,15 @@ impl DataFrame for DataFrameImpl { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .union(dataframe.to_logical_plan())? .build()?; - Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrameImpl::new( + self.session_state.clone(), + &plan, + ))) } fn distinct(&self) -> Result> { Ok(Arc::new(DataFrameImpl::new( - self.ctx_state.clone(), + self.session_state.clone(), &LogicalPlanBuilder::from(self.to_logical_plan()) .distinct()? .build()?, @@ -302,7 +337,7 @@ impl DataFrame for DataFrameImpl { let left_plan = self.to_logical_plan(); let right_plan = dataframe.to_logical_plan(); Ok(Arc::new(DataFrameImpl::new( - self.ctx_state.clone(), + self.session_state.clone(), &LogicalPlanBuilder::intersect(left_plan, right_plan, true)?, ))) } @@ -311,16 +346,15 @@ impl DataFrame for DataFrameImpl { let left_plan = self.to_logical_plan(); let right_plan = dataframe.to_logical_plan(); Ok(Arc::new(DataFrameImpl::new( - self.ctx_state.clone(), + self.session_state.clone(), &LogicalPlanBuilder::except(left_plan, right_plan, true)?, ))) } async fn write_csv(&self, path: &str) -> Result<()> { let plan = self.create_physical_plan().await?; - let state = self.ctx_state.lock().clone(); - let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); - plan_to_csv(&ctx, plan, path).await + let state = self.session_state.lock().clone(); + plan_to_csv(&state, plan, path).await } async fn write_parquet( @@ -329,9 +363,8 @@ impl DataFrame for DataFrameImpl { writer_properties: Option, ) -> Result<()> { let plan = self.create_physical_plan().await?; - let state = self.ctx_state.lock().clone(); - let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); - plan_to_parquet(&ctx, plan, path, writer_properties).await + let state = self.session_state.lock().clone(); + plan_to_parquet(&state, plan, path, writer_properties).await } } @@ -342,7 +375,7 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; use crate::physical_plan::{window_functions, ColumnarValue}; - use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; use crate::{logical_plan::*, test_util}; use arrow::datatypes::DataType; use datafusion_expr::ScalarFunctionImplementation; @@ -495,8 +528,8 @@ mod tests { #[tokio::test] async fn registry() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; // declare the udf let my_fn: ScalarFunctionImplementation = @@ -571,7 +604,7 @@ mod tests { #[tokio::test] async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let df_impl = Arc::new(DataFrameImpl::new(ctx.state.clone(), &df.to_logical_plan())); @@ -630,14 +663,14 @@ mod tests { /// Create a logical plan from a SQL query async fn create_plan(sql: &str) -> Result { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; ctx.create_logical_plan(sql) } async fn test_table_with_name(name: &str) -> Result> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx, name).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, name).await?; ctx.table(name) } @@ -646,7 +679,7 @@ mod tests { } async fn register_aggregate_csv( - ctx: &mut ExecutionContext, + ctx: &SessionContext, table_name: &str, ) -> Result<()> { let schema = test_util::aggr_test_schema(); diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs index e993b385ecd4..305bd7716dfc 100644 --- a/datafusion/src/execution/runtime_env.rs +++ b/datafusion/src/execution/runtime_env.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Execution runtime environment that tracks memory, disk and various configurations +//! Execution runtime environment that tracks object Store, memory, disk and various configurations //! that are used during physical plan execution. use crate::{ @@ -26,48 +26,36 @@ use crate::{ }, }; -use std::fmt::{Debug, Formatter}; +use crate::datasource::object_store::{ObjectStore, ObjectStoreRegistry}; +use datafusion_common::DataFusionError; +use std::path::PathBuf; use std::sync::Arc; #[derive(Clone)] -/// Execution runtime environment. This structure is passed to the -/// physical plans when they are run. +/// Execution runtime environment. pub struct RuntimeEnv { - /// Default batch size while creating new batches - pub batch_size: usize, /// Runtime memory management pub memory_manager: Arc, /// Manage temporary files during query execution pub disk_manager: Arc, -} - -impl Debug for RuntimeEnv { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "RuntimeEnv") - } + /// Object Store Registry + pub object_store_registry: Arc, } impl RuntimeEnv { /// Create env based on configuration pub fn new(config: RuntimeConfig) -> Result { let RuntimeConfig { - batch_size, memory_manager, disk_manager, } = config; - Ok(Self { - batch_size, memory_manager: MemoryManager::new(memory_manager), disk_manager: DiskManager::try_new(disk_manager)?, + object_store_registry: Arc::new(ObjectStoreRegistry::new()), }) } - /// Get execution batch size based on config - pub fn batch_size(&self) -> usize { - self.batch_size - } - /// Register the consumer to get it tracked pub fn register_requester(&self, id: &MemoryConsumerId) { self.memory_manager.register_requester(id); @@ -87,21 +75,35 @@ impl RuntimeEnv { pub fn shrink_tracker_usage(&self, delta: usize) { self.memory_manager.shrink_tracker_usage(delta) } -} -impl Default for RuntimeEnv { - fn default() -> Self { - RuntimeEnv::new(RuntimeConfig::new()).unwrap() + /// Registers a object store with scheme using a custom `ObjectStore` so that + /// an external file system or object storage system could be used against this context. + /// + /// Returns the `ObjectStore` previously registered for this scheme, if any + pub fn register_object_store( + &self, + scheme: impl Into, + object_store: Arc, + ) -> Option> { + let scheme = scheme.into(); + self.object_store_registry + .register_store(scheme, object_store) + } + + /// Retrieves a `ObjectStore` instance by scheme + pub fn object_store<'a>( + &self, + uri: &'a str, + ) -> Result<(Arc, &'a str)> { + self.object_store_registry + .get_by_uri(uri) + .map_err(DataFusionError::from) } } -#[derive(Clone)] +#[derive(Clone, Default)] /// Execution runtime configuration pub struct RuntimeConfig { - /// Default batch size while creating new batches, it's especially useful - /// for buffer-in-memory batches since creating tiny batches would results - /// in too much metadata memory consumption. - pub batch_size: usize, /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, /// MemoryManager to limit access to memory @@ -114,14 +116,6 @@ impl RuntimeConfig { Default::default() } - /// Customize batch size - pub fn with_batch_size(mut self, n: usize) -> Self { - // batch size must be greater than zero - assert!(n > 0); - self.batch_size = n; - self - } - /// Customize disk manager pub fn with_disk_manager(mut self, disk_manager: DiskManagerConfig) -> Self { self.disk_manager = disk_manager; @@ -133,14 +127,19 @@ impl RuntimeConfig { self.memory_manager = memory_manager; self } -} -impl Default for RuntimeConfig { - fn default() -> Self { - Self { - batch_size: 8192, - disk_manager: DiskManagerConfig::default(), - memory_manager: MemoryManagerConfig::default(), - } + /// Specify the total memory to use while running the DataFusion + /// plan to `max_memory * memory_fraction` in bytes. + /// + /// Note DataFusion does not yet respect this limit in all cases. + pub fn with_memory_limit(self, max_memory: usize, memory_fraction: f64) -> Self { + self.with_memory_manager( + MemoryManagerConfig::try_new_limit(max_memory, memory_fraction).unwrap(), + ) + } + + /// Use the specified path to create any needed temporary files + pub fn with_temp_file_path(self, path: impl Into) -> Self { + self.with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])) } } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 0f2fb1418e7b..4bb5bcd35dea 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -34,7 +34,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<()> { -//! let mut ctx = ExecutionContext::new(); +//! let ctx = SessionContext::new(); //! //! // create the dataframe //! let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; @@ -73,7 +73,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<()> { -//! let mut ctx = ExecutionContext::new(); +//! let ctx = SessionContext::new(); //! //! ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; //! diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index d8e43ed2175b..8875c4c6be01 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -1384,6 +1384,7 @@ mod tests { _: &Option>, _: &[Expr], _: Option, + _: String, ) -> Result> { unimplemented!() } diff --git a/datafusion/src/physical_optimizer/aggregate_statistics.rs b/datafusion/src/physical_optimizer/aggregate_statistics.rs index 4ae6ce3638cc..0b2b3090d987 100644 --- a/datafusion/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/src/physical_optimizer/aggregate_statistics.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use crate::execution::context::ExecutionConfig; +use crate::execution::context::SessionConfig; use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use crate::physical_plan::projection::ProjectionExec; @@ -48,7 +48,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { fn optimize( &self, plan: Arc, - execution_config: &ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { if let Some(partial_agg_exec) = take_optimizable(&*plan) { let partial_agg_exec = partial_agg_exec @@ -81,13 +81,17 @@ impl PhysicalOptimizerRule for AggregateStatistics { // input can be entirely removed Ok(Arc::new(ProjectionExec::try_new( projections, - Arc::new(EmptyExec::new(true, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new( + true, + Arc::new(Schema::empty()), + plan.session_id(), + )), )?)) } else { - optimize_children(self, plan, execution_config) + optimize_children(self, plan, session_config) } } else { - optimize_children(self, plan, execution_config) + optimize_children(self, plan, session_config) } } @@ -259,7 +263,7 @@ mod tests { use arrow::record_batch::RecordBatch; use crate::error::Result; - use crate::execution::runtime_env::RuntimeEnv; + use crate::execution::context::TaskContext; use crate::logical_plan::Operator; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; @@ -267,9 +271,10 @@ mod tests { use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::HashAggregateExec; use crate::physical_plan::memory::MemoryExec; + use crate::prelude::SessionContext; /// Mock data using a MemoryExec which has an exact count statistic - fn mock_data() -> Result> { + fn mock_data(session_ctx: &SessionContext) -> Result> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -287,6 +292,7 @@ mod tests { &[vec![batch]], Arc::clone(&schema), None, + session_ctx.session_id.clone(), )?)) } @@ -294,9 +300,9 @@ mod tests { async fn assert_count_optim_success( plan: HashAggregateExec, nulls: bool, + session_ctx: &SessionContext, ) -> Result<()> { - let conf = ExecutionConfig::new(); - let runtime = Arc::new(RuntimeEnv::default()); + let conf = session_ctx.copied_config(); let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; let (col, count) = match nulls { @@ -306,7 +312,8 @@ mod tests { // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); - let result = common::collect(optimized.execute(0, runtime).await?).await?; + let task_ctx = Arc::new(TaskContext::from(session_ctx)); + let result = common::collect(optimized.execute(0, task_ctx).await?).await?; assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); assert_eq!( result[0] @@ -332,7 +339,8 @@ mod tests { #[tokio::test] async fn test_count_partial_direct_child() -> Result<()> { // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; + let session_ctx = SessionContext::new(); + let source = mock_data(&session_ctx)?; let schema = source.schema(); let partial_agg = HashAggregateExec::try_new( @@ -351,7 +359,7 @@ mod tests { Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, false).await?; + assert_count_optim_success(final_agg, false, &session_ctx).await?; Ok(()) } @@ -359,7 +367,8 @@ mod tests { #[tokio::test] async fn test_count_partial_with_nulls_direct_child() -> Result<()> { // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; + let session_ctx = SessionContext::new(); + let source = mock_data(&session_ctx)?; let schema = source.schema(); let partial_agg = HashAggregateExec::try_new( @@ -378,14 +387,15 @@ mod tests { Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, true).await?; + assert_count_optim_success(final_agg, true, &session_ctx).await?; Ok(()) } #[tokio::test] async fn test_count_partial_indirect_child() -> Result<()> { - let source = mock_data()?; + let session_ctx = SessionContext::new(); + let source = mock_data(&session_ctx)?; let schema = source.schema(); let partial_agg = HashAggregateExec::try_new( @@ -407,14 +417,15 @@ mod tests { Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, false).await?; + assert_count_optim_success(final_agg, false, &session_ctx).await?; Ok(()) } #[tokio::test] async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { - let source = mock_data()?; + let session_ctx = SessionContext::new(); + let source = mock_data(&session_ctx)?; let schema = source.schema(); let partial_agg = HashAggregateExec::try_new( @@ -436,14 +447,15 @@ mod tests { Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, true).await?; + assert_count_optim_success(final_agg, true, &session_ctx).await?; Ok(()) } #[tokio::test] async fn test_count_inexact_stat() -> Result<()> { - let source = mock_data()?; + let session_ctx = SessionContext::new(); + let source = mock_data(&session_ctx)?; let schema = source.schema(); // adding a filter makes the statistics inexact @@ -473,7 +485,7 @@ mod tests { Arc::clone(&schema), )?; - let conf = ExecutionConfig::new(); + let conf = session_ctx.copied_config(); let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; @@ -485,7 +497,8 @@ mod tests { #[tokio::test] async fn test_count_with_nulls_inexact_stat() -> Result<()> { - let source = mock_data()?; + let session_ctx = SessionContext::new(); + let source = mock_data(&session_ctx)?; let schema = source.schema(); // adding a filter makes the statistics inexact @@ -515,7 +528,7 @@ mod tests { Arc::clone(&schema), )?; - let conf = ExecutionConfig::new(); + let conf = session_ctx.copied_config(); let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 98e65a2b1281..6e6fc900e403 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -19,6 +19,7 @@ //! in bigger batches to avoid overhead with small batches use super::optimizer::PhysicalOptimizerRule; +use crate::prelude::SessionConfig; use crate::{ error::Result, physical_plan::{ @@ -42,14 +43,14 @@ impl PhysicalOptimizerRule for CoalesceBatches { fn optimize( &self, plan: Arc, - config: &crate::execution::context::ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { // wrap operators in CoalesceBatches to avoid lots of tiny batches when we have // highly selective filters let children = plan .children() .iter() - .map(|child| self.optimize(child.clone(), config)) + .map(|child| self.optimize(child.clone(), session_config)) .collect::>>()?; let plan_any = plan.as_any(); @@ -75,7 +76,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { // we should do that once https://issues.apache.org/jira/browse/ARROW-11059 is // implemented. For now, we choose half the configured batch size to avoid copies // when a small number of rows are removed from a batch - let target_batch_size = config.runtime.batch_size / 2; + let target_batch_size = session_config.batch_size / 2; Arc::new(CoalesceBatchesExec::new(plan.clone(), target_batch_size)) } else { plan.clone() diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs index 244eb6a560b6..b5f680d56673 100644 --- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use crate::execution::context::ExecutionConfig; +use crate::execution::context::SessionConfig; use crate::logical_plan::JoinType; use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::expressions::Column; @@ -113,9 +113,9 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder { fn optimize( &self, plan: Arc, - execution_config: &ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { - let plan = optimize_children(self, plan, execution_config)?; + let plan = optimize_children(self, plan, session_config)?; if let Some(hash_join) = plan.as_any().downcast_ref::() { let left = hash_join.left(); let right = hash_join.right(); @@ -174,6 +174,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; fn create_big_and_small() -> (Arc, Arc) { + let session_id = "sess_123"; let big = Arc::new(StatisticsExec::new( Statistics { num_rows: Some(10), @@ -181,6 +182,7 @@ mod tests { ..Default::default() }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), + session_id.to_owned(), )); let small = Arc::new(StatisticsExec::new( @@ -190,6 +192,7 @@ mod tests { ..Default::default() }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), + session_id.to_owned(), )); (big, small) } @@ -212,7 +215,7 @@ mod tests { .unwrap(); let optimized_join = HashBuildProbeOrder::new() - .optimize(Arc::new(join), &ExecutionConfig::new()) + .optimize(Arc::new(join), &SessionConfig::new()) .unwrap(); let swapping_projection = optimized_join @@ -259,7 +262,7 @@ mod tests { .unwrap(); let optimized_join = HashBuildProbeOrder::new() - .optimize(Arc::new(join), &ExecutionConfig::new()) + .optimize(Arc::new(join), &SessionConfig::new()) .unwrap(); let swapped_join = optimized_join diff --git a/datafusion/src/physical_optimizer/merge_exec.rs b/datafusion/src/physical_optimizer/merge_exec.rs index 58823a665b16..5ee6c363842a 100644 --- a/datafusion/src/physical_optimizer/merge_exec.rs +++ b/datafusion/src/physical_optimizer/merge_exec.rs @@ -19,6 +19,7 @@ //! with more than one partition, to coalesce them into one partition //! when the node needs a single partition use super::optimizer::PhysicalOptimizerRule; +use crate::prelude::SessionConfig; use crate::{ error::Result, physical_plan::{coalesce_partitions::CoalescePartitionsExec, Distribution}, @@ -40,7 +41,7 @@ impl PhysicalOptimizerRule for AddCoalescePartitionsExec { fn optimize( &self, plan: Arc, - config: &crate::execution::context::ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { if plan.children().is_empty() { // leaf node, children cannot be replaced @@ -49,7 +50,7 @@ impl PhysicalOptimizerRule for AddCoalescePartitionsExec { let children = plan .children() .iter() - .map(|child| self.optimize(child.clone(), config)) + .map(|child| self.optimize(child.clone(), session_config)) .collect::>>()?; match plan.required_child_distribution() { Distribution::UnspecifiedDistribution => plan.with_new_children(children), diff --git a/datafusion/src/physical_optimizer/optimizer.rs b/datafusion/src/physical_optimizer/optimizer.rs index e2f40ae95402..6d52d2036337 100644 --- a/datafusion/src/physical_optimizer/optimizer.rs +++ b/datafusion/src/physical_optimizer/optimizer.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::{ - error::Result, execution::context::ExecutionConfig, physical_plan::ExecutionPlan, + error::Result, execution::context::SessionConfig, physical_plan::ExecutionPlan, }; /// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which @@ -31,7 +31,7 @@ pub trait PhysicalOptimizerRule { fn optimize( &self, plan: Arc, - config: &ExecutionConfig, + session_config: &SessionConfig, ) -> Result>; /// A human readable name for this optimizer rule diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index ae074d2893da..e652fb2efebe 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::optimizer::PhysicalOptimizerRule; use crate::physical_plan::Partitioning::*; use crate::physical_plan::{repartition::RepartitionExec, ExecutionPlan}; -use crate::{error::Result, execution::context::ExecutionConfig}; +use crate::{error::Result, execution::context::SessionConfig}; /// Optimizer that introduces repartition to introduce more /// parallelism in the plan @@ -154,7 +154,6 @@ fn optimize_partitions( ) -> Result> { // Recurse into children bottom-up (attempt to repartition as // early as possible) - let new_plan = if plan.children().is_empty() { // leaf node - don't replace children plan @@ -218,13 +217,13 @@ impl PhysicalOptimizerRule for Repartition { fn optimize( &self, plan: Arc, - config: &ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { // Don't run optimizer if target_partitions == 1 - if config.target_partitions == 1 { + if session_config.target_partitions == 1 { Ok(plan) } else { - optimize_partitions(config.target_partitions, plan, false, false) + optimize_partitions(session_config.target_partitions, plan, false, false) } } @@ -267,6 +266,7 @@ mod tests { table_partition_cols: vec![], }, None, + "sess_123".to_owned(), )) } @@ -343,7 +343,7 @@ mod tests { // run optimizer let optimizer = Repartition {}; let optimized = optimizer - .optimize($PLAN, &ExecutionConfig::new().with_target_partitions(10))?; + .optimize($PLAN, &SessionConfig::new().with_target_partitions(10))?; // Now format correctly let plan = displayable(optimized.as_ref()).indent().to_string(); diff --git a/datafusion/src/physical_optimizer/utils.rs b/datafusion/src/physical_optimizer/utils.rs index 962b8ce14557..fbafa8948e10 100644 --- a/datafusion/src/physical_optimizer/utils.rs +++ b/datafusion/src/physical_optimizer/utils.rs @@ -18,7 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use super::optimizer::PhysicalOptimizerRule; -use crate::execution::context::ExecutionConfig; +use crate::execution::context::SessionConfig; use crate::error::Result; use crate::physical_plan::ExecutionPlan; @@ -31,12 +31,12 @@ use std::sync::Arc; pub fn optimize_children( optimizer: &impl PhysicalOptimizerRule, plan: Arc, - execution_config: &ExecutionConfig, + config: &SessionConfig, ) -> Result> { let children = plan .children() .iter() - .map(|child| optimizer.optimize(Arc::clone(child), execution_config)) + .map(|child| optimizer.optimize(Arc::clone(child), config)) .collect::>>()?; if children.is_empty() { diff --git a/datafusion/src/physical_plan/analyze.rs b/datafusion/src/physical_plan/analyze.rs index 6857ad532273..505c9831f8b4 100644 --- a/datafusion/src/physical_plan/analyze.rs +++ b/datafusion/src/physical_plan/analyze.rs @@ -32,7 +32,7 @@ use futures::StreamExt; use super::expressions::PhysicalSortExpr; use super::{stream::RecordBatchReceiverStream, Distribution, SendableRecordBatchStream}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; /// `EXPLAIN ANALYZE` execution plan operator. This operator runs its input, @@ -112,7 +112,7 @@ impl ExecutionPlan for AnalyzeExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -133,7 +133,7 @@ impl ExecutionPlan for AnalyzeExec { let (tx, rx) = tokio::sync::mpsc::channel(input_partitions); let captured_input = self.input.clone(); - let mut input_stream = captured_input.execute(0, runtime).await?; + let mut input_stream = captured_input.execute(0, context).await?; let captured_schema = self.schema.clone(); let verbose = self.verbose; @@ -231,6 +231,10 @@ impl ExecutionPlan for AnalyzeExec { // Statistics an an ANALYZE plan are not relevant Statistics::default() } + + fn session_id(&self) -> String { + self.input.session_id() + } } #[cfg(test)] @@ -238,6 +242,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; + use crate::prelude::SessionContext; use crate::{ physical_plan::collect, test::{ @@ -250,15 +255,20 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let session_ctx = SessionContext::new(); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 1, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec, schema)); - let fut = collect(analyze_exec, runtime); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let fut = collect(analyze_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index 0d6fe38636f6..3c81f0809297 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -29,7 +29,7 @@ use crate::physical_plan::{ SendableRecordBatchStream, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use arrow::compute::kernels::concat::concat; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; @@ -124,10 +124,10 @@ impl ExecutionPlan for CoalesceBatchesExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { Ok(Box::pin(CoalesceBatchesStream { - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, schema: self.input.schema(), target_batch_size: self.target_batch_size, buffer: Vec::new(), @@ -160,6 +160,10 @@ impl ExecutionPlan for CoalesceBatchesExec { fn statistics(&self) -> Statistics { self.input.statistics() } + + fn session_id(&self) -> String { + self.input.session_id() + } } struct CoalesceBatchesStream { @@ -305,6 +309,7 @@ pub fn concat_batches( mod tests { use super::*; use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; + use crate::prelude::SessionContext; use crate::test::create_vec_batches; use arrow::datatypes::{DataType, Field, Schema}; @@ -338,8 +343,14 @@ mod tests { input_partitions: Vec>, target_batch_size: usize, ) -> Result>> { + let session_ctx = SessionContext::new(); // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new( + &input_partitions, + schema.clone(), + None, + session_ctx.session_id.clone(), + )?; let exec = RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; let exec: Arc = @@ -348,10 +359,10 @@ mod tests { // execute and collect results let output_partition_count = exec.output_partitioning().partition_count(); let mut output_partitions = Vec::with_capacity(output_partition_count); - let runtime = Arc::new(RuntimeEnv::default()); for i in 0..output_partition_count { + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, runtime.clone()).await?; + let mut stream = exec.execute(i, task_ctx).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs b/datafusion/src/physical_plan/coalesce_partitions.rs index 20b548733715..0011de8da1c3 100644 --- a/datafusion/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/src/physical_plan/coalesce_partitions.rs @@ -38,7 +38,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use super::SendableRecordBatchStream; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::common::spawn_execution; use pin_project_lite::pin_project; @@ -110,7 +110,7 @@ impl ExecutionPlan for CoalescePartitionsExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // CoalescePartitionsExec produces a single partition if 0 != partition { @@ -127,7 +127,7 @@ impl ExecutionPlan for CoalescePartitionsExec { )), 1 => { // bypass any threading / metrics if there is a single partition - self.input.execute(0, runtime).await + self.input.execute(0, context.clone()).await } _ => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -150,7 +150,7 @@ impl ExecutionPlan for CoalescePartitionsExec { self.input.clone(), sender.clone(), part_i, - runtime.clone(), + context.clone(), )); } @@ -183,6 +183,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn statistics(&self) -> Statistics { self.input.statistics() } + + fn session_id(&self) -> String { + self.input.session_id() + } } pin_project! { @@ -224,18 +228,19 @@ mod tests { use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, common}; + use crate::prelude::SessionContext; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending}; use crate::test_util; #[tokio::test] async fn merge() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let num_partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?; + let session_ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -248,6 +253,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), ); // input should have 4 partitions @@ -259,7 +265,8 @@ mod tests { assert_eq!(merge.output_partitioning().partition_count(), 1); // the result should contain 4 batches (one per input partition) - let iter = merge.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let iter = merge.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; assert_eq!(batches.len(), num_partitions); @@ -272,16 +279,21 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); + let session_ctx = SessionContext::new(); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 2, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let coaelesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec, runtime); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let fut = collect(coaelesce_partitions_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index bc4400d98186..b7313b9f25fd 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -19,7 +19,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::compute::concat; @@ -176,10 +176,10 @@ pub(crate) fn spawn_execution( input: Arc, mut output: mpsc::Sender>, partition: usize, - runtime: Arc, + context: Arc, ) -> JoinHandle<()> { tokio::spawn(async move { - let mut stream = match input.execute(partition, runtime).await { + let mut stream = match input.execute(partition, context).await { Err(e) => { // If send fails, plan being torn // down, no place to send the error diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 82ee5618f5f0..778911b787db 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -43,7 +43,7 @@ use super::{ coalesce_batches::concat_batches, memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use log::debug; /// Data of the left side @@ -149,7 +149,7 @@ impl ExecutionPlan for CrossJoinExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // we only want to compute the build side once let left_data = { @@ -162,7 +162,7 @@ impl ExecutionPlan for CrossJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0, runtime.clone()).await?; + let stream = merge.execute(0, context.clone()).await?; // Load all batches and count the rows let (batches, num_rows) = stream @@ -187,7 +187,7 @@ impl ExecutionPlan for CrossJoinExec { } }; - let stream = self.right.execute(partition, runtime.clone()).await?; + let stream = self.right.execute(partition, context).await?; if left_data.num_rows() == 0 { return Ok(Box::pin(MemoryStream::try_new( @@ -231,6 +231,10 @@ impl ExecutionPlan for CrossJoinExec { self.right.schema().fields().len(), ) } + + fn session_id(&self) -> String { + self.left.session_id() + } } /// [left/right]_col_count are required in case the column statistics are None diff --git a/datafusion/src/physical_plan/empty.rs b/datafusion/src/physical_plan/empty.rs index 045026b70ed5..4dcd2c307926 100644 --- a/datafusion/src/physical_plan/empty.rs +++ b/datafusion/src/physical_plan/empty.rs @@ -31,7 +31,7 @@ use arrow::record_batch::RecordBatch; use super::expressions::PhysicalSortExpr; use super::{common, SendableRecordBatchStream, Statistics}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; /// Execution plan for empty relation (produces no rows) @@ -41,14 +41,17 @@ pub struct EmptyExec { produce_one_row: bool, /// The schema for the produced row schema: SchemaRef, + /// Session id + session_id: String, } impl EmptyExec { /// Create a new EmptyExec - pub fn new(produce_one_row: bool, schema: SchemaRef) -> Self { + pub fn new(produce_one_row: bool, schema: SchemaRef, session_id: String) -> Self { EmptyExec { produce_one_row, schema, + session_id, } } @@ -111,6 +114,7 @@ impl ExecutionPlan for EmptyExec { 0 => Ok(Arc::new(EmptyExec::new( self.produce_one_row, self.schema.clone(), + self.session_id(), ))), _ => Err(DataFusionError::Internal( "EmptyExec wrong number of children".to_string(), @@ -121,7 +125,7 @@ impl ExecutionPlan for EmptyExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { @@ -156,23 +160,29 @@ impl ExecutionPlan for EmptyExec { .expect("Create empty RecordBatch should not fail"); common::compute_record_batch_statistics(&[batch], &self.schema, None) } + + fn session_id(&self) -> String { + self.session_id.clone() + } } #[cfg(test)] mod tests { use super::*; + use crate::prelude::SessionContext; use crate::{physical_plan::common, test_util}; #[tokio::test] async fn empty() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); - let empty = EmptyExec::new(false, schema.clone()); + let ctx = SessionContext::new(); + let empty = EmptyExec::new(false, schema.clone(), ctx.session_id.clone()); assert_eq!(empty.schema(), schema); // we should have no results - let iter = empty.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let iter = empty.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); @@ -182,8 +192,9 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test_util::aggr_test_schema(); - let empty = EmptyExec::new(false, schema.clone()); - let empty_with_row = EmptyExec::new(true, schema); + let ctx = SessionContext::new(); + let empty = EmptyExec::new(false, schema.clone(), ctx.session_id.clone()); + let empty_with_row = EmptyExec::new(true, schema, ctx.session_id.clone()); let empty2 = empty.with_new_children(vec![])?; assert_eq!(empty.schema(), empty2.schema()); @@ -201,23 +212,25 @@ mod tests { #[tokio::test] async fn invalid_execute() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); - let empty = EmptyExec::new(false, schema); + let ctx = SessionContext::new(); + let empty = EmptyExec::new(false, schema, ctx.session_id.clone()); // ask for the wrong partition - assert!(empty.execute(1, runtime.clone()).await.is_err()); - assert!(empty.execute(20, runtime.clone()).await.is_err()); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + assert!(empty.execute(1, task_ctx).await.is_err()); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + assert!(empty.execute(20, task_ctx).await.is_err()); Ok(()) } #[tokio::test] async fn produce_one_row() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); - let empty = EmptyExec::new(true, schema); - - let iter = empty.execute(0, runtime).await?; + let ctx = SessionContext::new(); + let empty = EmptyExec::new(true, schema, ctx.session_id.clone()); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let iter = empty.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; // should have one item diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index 0955655a1929..a6b13bad7ad0 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use super::{expressions::PhysicalSortExpr, SendableRecordBatchStream}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use async_trait::async_trait; @@ -46,6 +46,8 @@ pub struct ExplainExec { stringified_plans: Vec, /// control which plans to print verbose: bool, + /// Session id + session_id: String, } impl ExplainExec { @@ -54,11 +56,13 @@ impl ExplainExec { schema: SchemaRef, stringified_plans: Vec, verbose: bool, + session_id: String, ) -> Self { ExplainExec { schema, stringified_plans, verbose, + session_id, } } @@ -114,7 +118,7 @@ impl ExecutionPlan for ExplainExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -181,6 +185,10 @@ impl ExecutionPlan for ExplainExec { // Statistics an EXPLAIN plan are not relevant Statistics::default() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } /// If this plan should be shown, given the previous plan that was diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index ba0873d78b2b..384d20132f53 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -27,7 +27,7 @@ use arrow::datatypes::SchemaRef; #[cfg(feature = "avro")] use arrow::error::ArrowError; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use std::any::Any; use std::sync::Arc; @@ -42,17 +42,20 @@ pub struct AvroExec { base_config: FileScanConfig, projected_statistics: Statistics, projected_schema: SchemaRef, + /// Session id + session_id: String, } impl AvroExec { /// Create a new Avro reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { + pub fn new(base_config: FileScanConfig, session_id: String) -> Self { let (projected_schema, projected_statistics) = base_config.project(); Self { base_config, projected_schema, projected_statistics, + session_id, } } /// Ref to the base configs @@ -105,7 +108,7 @@ impl ExecutionPlan for AvroExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Err(DataFusionError::NotImplemented( "Cannot execute avro plan without avro feature enabled".to_string(), @@ -116,11 +119,11 @@ impl ExecutionPlan for AvroExec { async fn execute( &self, partition: usize, - runtime: Arc, + _context: Arc, ) -> Result { let proj = self.base_config.projected_file_column_names(); - let batch_size = runtime.batch_size(); + let batch_size = self.session_config().batch_size; let file_schema = Arc::clone(&self.base_config.file_schema); // The avro reader cannot limit the number of records, so `remaining` is ignored. @@ -169,6 +172,10 @@ impl ExecutionPlan for AvroExec { fn statistics(&self) -> Statistics { self.projected_statistics.clone() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } #[cfg(test)] diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index d9f4706fdf0b..1aadb61a906d 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -18,13 +18,14 @@ //! Execution plan for reading CSV files use crate::error::{DataFusionError, Result}; -use crate::execution::context::ExecutionContext; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use crate::execution::runtime_env::RuntimeEnv; +use super::file_stream::{BatchIter, FileStream}; +use super::FileScanConfig; +use crate::execution::context::{SessionState, TaskContext}; use arrow::csv; use arrow::datatypes::SchemaRef; use async_trait::async_trait; @@ -35,9 +36,6 @@ use std::path::Path; use std::sync::Arc; use tokio::task::{self, JoinHandle}; -use super::file_stream::{BatchIter, FileStream}; -use super::FileScanConfig; - /// Execution plan for scanning a CSV file #[derive(Debug, Clone)] pub struct CsvExec { @@ -46,11 +44,18 @@ pub struct CsvExec { projected_schema: SchemaRef, has_header: bool, delimiter: u8, + /// Session id + session_id: String, } impl CsvExec { /// Create a new CSV reader execution plan provided base and specific configurations - pub fn new(base_config: FileScanConfig, has_header: bool, delimiter: u8) -> Self { + pub fn new( + base_config: FileScanConfig, + has_header: bool, + delimiter: u8, + session_id: String, + ) -> Self { let (projected_schema, projected_statistics) = base_config.project(); Self { @@ -59,6 +64,7 @@ impl CsvExec { projected_statistics, has_header, delimiter, + session_id, } } @@ -123,9 +129,9 @@ impl ExecutionPlan for CsvExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let batch_size = runtime.batch_size(); + let batch_size = context.session_config().batch_size; let file_schema = Arc::clone(&self.base_config.file_schema); let file_projection = self.base_config.file_column_projection_indices(); let has_header = self.has_header; @@ -178,17 +184,20 @@ impl ExecutionPlan for CsvExec { fn statistics(&self) -> Statistics { self.projected_statistics.clone() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } pub async fn plan_to_csv( - context: &ExecutionContext, + state: &SessionState, plan: Arc, path: impl AsRef, ) -> Result<()> { let path = path.as_ref(); // create directory to contain the CSV files (one per partition) let fs_path = Path::new(path); - let runtime = context.runtime_env(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -198,7 +207,8 @@ pub async fn plan_to_csv( let path = fs_path.join(&filename); let file = fs::File::create(path)?; let mut writer = csv::Writer::new(file); - let stream = plan.execute(i, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(state)); + let stream = plan.execute(i, task_ctx).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| writer.write(&batch?)) @@ -236,11 +246,11 @@ mod tests { #[tokio::test] async fn csv_exec_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv/{}", testdata, filename); + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -253,12 +263,13 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); - - let mut stream = csv.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut stream = csv.execute(0, task_ctx).await?; let batch = stream.next().await.unwrap()?; assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); @@ -282,11 +293,11 @@ mod tests { #[tokio::test] async fn csv_exec_with_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv/{}", testdata, filename); + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -299,12 +310,13 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(13, csv.projected_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); - - let mut it = csv.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = csv.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); @@ -328,11 +340,11 @@ mod tests { #[tokio::test] async fn csv_exec_with_missing_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema_with_missing_col(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv/{}", testdata, filename); + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -345,12 +357,13 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); assert_eq!(14, csv.base_config.file_schema.fields().len()); assert_eq!(14, csv.projected_schema.fields().len()); assert_eq!(14, csv.schema().fields().len()); - - let mut it = csv.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = csv.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(14, batch.num_columns()); assert_eq!(5, batch.num_rows()); @@ -374,7 +387,6 @@ mod tests { #[tokio::test] async fn csv_exec_with_partition() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -384,6 +396,7 @@ mod tests { let mut partitioned_file = local_unpartitioned_file(path); partitioned_file.partition_values = vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { // we should be able to project on the partition column @@ -398,12 +411,14 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(2, csv.projected_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); - let mut it = csv.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = csv.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); @@ -457,9 +472,8 @@ mod tests { async fn write_csv_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_target_partitions(8), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; @@ -477,7 +491,7 @@ mod tests { df.write_csv(&out_dir).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::UInt32, false), diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index 6c5ffcd99eac..ea2fe805b568 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -19,7 +19,7 @@ use async_trait::async_trait; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -37,17 +37,20 @@ pub struct NdJsonExec { base_config: FileScanConfig, projected_statistics: Statistics, projected_schema: SchemaRef, + /// Session id + session_id: String, } impl NdJsonExec { /// Create a new JSON reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { + pub fn new(base_config: FileScanConfig, session_id: String) -> Self { let (projected_schema, projected_statistics) = base_config.project(); Self { base_config, projected_schema, projected_statistics, + session_id, } } } @@ -95,11 +98,11 @@ impl ExecutionPlan for NdJsonExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let proj = self.base_config.projected_file_column_names(); - let batch_size = runtime.batch_size(); + let batch_size = context.session_config().batch_size; let file_schema = Arc::clone(&self.base_config.file_schema); // The json reader cannot limit the number of records, so `remaining` is ignored. @@ -142,6 +145,10 @@ impl ExecutionPlan for NdJsonExec { fn statistics(&self) -> Statistics { self.projected_statistics.clone() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } #[cfg(test)] @@ -156,6 +163,7 @@ mod tests { local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }, }; + use crate::prelude::SessionContext; use super::*; @@ -169,18 +177,21 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_without_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); use arrow::datatypes::DataType; let path = format!("{}/1.json", TEST_DATA_BASE); - let exec = NdJsonExec::new(FileScanConfig { - object_store: Arc::new(LocalFileSystem {}), - file_groups: vec![vec![local_unpartitioned_file(path.clone())]], - file_schema: infer_schema(path).await?, - statistics: Statistics::default(), - projection: None, - limit: Some(3), - table_partition_cols: vec![], - }); + let ctx = SessionContext::new(); + let exec = NdJsonExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(path.clone())]], + file_schema: infer_schema(path).await?, + statistics: Statistics::default(), + projection: None, + limit: Some(3), + table_partition_cols: vec![], + }, + ctx.session_id.clone(), + ); // TODO: this is not where schema inference should be tested @@ -206,7 +217,8 @@ mod tests { &DataType::Utf8 ); - let mut it = exec.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = exec.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 3); @@ -224,7 +236,6 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_with_missing_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); use arrow::datatypes::DataType; let path = format!("{}/1.json", TEST_DATA_BASE); @@ -236,17 +247,22 @@ mod tests { let file_schema = Arc::new(Schema::new(fields)); - let exec = NdJsonExec::new(FileScanConfig { - object_store: Arc::new(LocalFileSystem {}), - file_groups: vec![vec![local_unpartitioned_file(path.clone())]], - file_schema, - statistics: Statistics::default(), - projection: None, - limit: Some(3), - table_partition_cols: vec![], - }); - - let mut it = exec.execute(0, runtime).await?; + let ctx = SessionContext::new(); + let exec = NdJsonExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(path.clone())]], + file_schema, + statistics: Statistics::default(), + projection: None, + limit: Some(3), + table_partition_cols: vec![], + }, + ctx.session_id.clone(), + ); + + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = exec.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 3); @@ -265,17 +281,20 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let path = format!("{}/1.json", TEST_DATA_BASE); - let exec = NdJsonExec::new(FileScanConfig { - object_store: Arc::new(LocalFileSystem {}), - file_groups: vec![vec![local_unpartitioned_file(path.clone())]], - file_schema: infer_schema(path).await?, - statistics: Statistics::default(), - projection: Some(vec![0, 2]), - limit: None, - table_partition_cols: vec![], - }); + let ctx = SessionContext::new(); + let exec = NdJsonExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(path.clone())]], + file_schema: infer_schema(path).await?, + statistics: Statistics::default(), + projection: Some(vec![0, 2]), + limit: None, + table_partition_cols: vec![], + }, + ctx.session_id.clone(), + ); let inferred_schema = exec.schema(); assert_eq!(inferred_schema.fields().len(), 2); @@ -284,7 +303,8 @@ mod tests { inferred_schema.field_with_name("c").unwrap(); inferred_schema.field_with_name("d").unwrap_err(); - let mut it = exec.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = exec.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 4); diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 2d23ca1c3ada..ad55abba7c52 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -27,7 +27,6 @@ use std::{any::Any, convert::TryInto}; use crate::datasource::file_format::parquet::ChunkObjectReader; use crate::datasource::object_store::ObjectStore; use crate::datasource::PartitionedFile; -use crate::execution::context::ExecutionContext; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::{ error::{DataFusionError, Result}, @@ -68,7 +67,7 @@ use tokio::{ task, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::{SessionState, TaskContext}; use crate::physical_plan::file_format::SchemaAdapter; use async_trait::async_trait; @@ -84,6 +83,8 @@ pub struct ParquetExec { metrics: ExecutionPlanMetricsSet, /// Optional predicate for pruning row groups pruning_predicate: Option, + /// Session id + session_id: String, } /// Stores metrics about the parquet execution for a particular parquet file @@ -98,7 +99,11 @@ struct ParquetFileMetrics { impl ParquetExec { /// Create a new Parquet reader execution plan provided file list and schema. /// Even if `limit` is set, ParquetExec rounds up the number of records to the next `batch_size`. - pub fn new(base_config: FileScanConfig, predicate: Option) -> Self { + pub fn new( + base_config: FileScanConfig, + predicate: Option, + session_id: String, + ) -> Self { debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", base_config.file_groups, base_config.projection, predicate, base_config.limit); @@ -128,6 +133,7 @@ impl ParquetExec { projected_statistics, metrics, pruning_predicate, + session_id, } } @@ -210,7 +216,7 @@ impl ExecutionPlan for ParquetExec { async fn execute( &self, partition_index: usize, - runtime: Arc, + context: Arc, ) -> Result { // because the parquet implementation is not thread-safe, it is necessary to execute // on a thread and communicate with channels @@ -226,7 +232,7 @@ impl ExecutionPlan for ParquetExec { None => (0..self.base_config.file_schema.fields().len()).collect(), }; let pruning_predicate = self.pruning_predicate.clone(); - let batch_size = runtime.batch_size(); + let batch_size = context.session_config().batch_size; let limit = self.base_config.limit; let object_store = Arc::clone(&self.base_config.object_store); let partition_col_proj = PartitionColumnProjector::new( @@ -294,6 +300,10 @@ impl ExecutionPlan for ParquetExec { fn statistics(&self) -> Statistics { self.projected_statistics.clone() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } fn send_result( @@ -528,7 +538,7 @@ fn read_partition( /// Executes a query and writes the results to a partitioned Parquet file. pub async fn plan_to_parquet( - context: &ExecutionContext, + state: &SessionState, plan: Arc, path: impl AsRef, writer_properties: Option, @@ -536,7 +546,6 @@ pub async fn plan_to_parquet( let path = path.as_ref(); // create directory to contain the Parquet files (one per partition) let fs_path = Path::new(path); - let runtime = context.runtime_env(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -550,7 +559,8 @@ pub async fn plan_to_parquet( plan.schema(), writer_properties.clone(), )?; - let stream = plan.execute(i, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(state)); + let stream = plan.execute(i, task_ctx).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| writer.write(&batch?)) @@ -589,7 +599,7 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; - use crate::prelude::ExecutionConfig; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::Float32Array; use arrow::{ array::{Int64Array, Int8Array, StringArray}, @@ -655,6 +665,7 @@ mod tests { .expect("inferring schema"), }; + let session_ctx = SessionContext::new(); // prepare the scan let parquet_exec = ParquetExec::new( FileScanConfig { @@ -667,10 +678,10 @@ mod tests { table_partition_cols: vec![], }, None, + session_ctx.session_id.clone(), ); - - let runtime = Arc::new(RuntimeEnv::default()); - collect(Arc::new(parquet_exec), runtime).await + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + collect(Arc::new(parquet_exec), task_ctx).await } // Add a new column with the specified field name to the RecordBatch @@ -887,9 +898,9 @@ mod tests { #[tokio::test] async fn parquet_exec_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); + let ctx = SessionContext::new(); let parquet_exec = ParquetExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -903,10 +914,12 @@ mod tests { table_partition_cols: vec![], }, None, + ctx.session_id.clone(), ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut results = parquet_exec.execute(0, task_ctx).await?; let batch = results.next().await.unwrap()?; assert_eq!(8, batch.num_rows()); @@ -931,7 +944,6 @@ mod tests { #[tokio::test] async fn parquet_exec_with_partition() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let mut partitioned_file = local_unpartitioned_file(filename.clone()); @@ -940,6 +952,7 @@ mod tests { ScalarValue::Utf8(Some("10".to_owned())), ScalarValue::Utf8(Some("26".to_owned())), ]; + let ctx = SessionContext::new(); let parquet_exec = ParquetExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -958,10 +971,12 @@ mod tests { ], }, None, + ctx.session_id.clone(), ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut results = parquet_exec.execute(0, task_ctx).await?; let batch = results.next().await.unwrap()?; let expected = vec![ "+----+----------+-------------+-------+", @@ -987,7 +1002,6 @@ mod tests { #[tokio::test] async fn parquet_exec_with_error() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let partitioned_file = PartitionedFile { @@ -1001,6 +1015,7 @@ mod tests { partition_values: vec![], }; + let ctx = SessionContext::new(); let parquet_exec = ParquetExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -1014,9 +1029,11 @@ mod tests { table_partition_cols: vec![], }, None, + ctx.session_id.clone(), ); - let mut results = parquet_exec.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut results = parquet_exec.execute(0, task_ctx.clone()).await?; let batch = results.next().await.unwrap(); // invalid file should produce an error to that effect assert_contains!( @@ -1318,9 +1335,8 @@ mod tests { // create partitioned input file and context let tmp_dir = TempDir::new()?; // let mut ctx = create_ctx(&tmp_dir, 4).await?; - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_target_partitions(8), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; // register csv file with the execution context ctx.register_csv( @@ -1337,7 +1353,7 @@ mod tests { // write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // register each partition as well as the top level dir ctx.register_parquet("part0", &format!("{}/part-0.parquet", out_dir)) diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index 69ff6bfc995b..9a3c5743568f 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -38,7 +38,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use futures::stream::{Stream, StreamExt}; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to @@ -136,14 +136,14 @@ impl ExecutionPlan for FilterExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { schema: self.input.schema().clone(), predicate: self.predicate.clone(), - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, baseline_metrics, })) } @@ -168,6 +168,10 @@ impl ExecutionPlan for FilterExec { fn statistics(&self) -> Statistics { Statistics::default() } + + fn session_id(&self) -> String { + self.input.session_id() + } } /// The FilterExec streams wraps the input iterator and applies the predicate expression to @@ -246,6 +250,7 @@ mod tests { use crate::physical_plan::expressions::*; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; + use crate::prelude::SessionContext; use crate::scalar::ScalarValue; use crate::test; use crate::test_util; @@ -254,12 +259,12 @@ mod tests { #[tokio::test] async fn simple_predicate() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { @@ -273,6 +278,7 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); let predicate: Arc = binary( @@ -295,7 +301,8 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?); - let results = collect(filter, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(filter, task_ctx).await?; results .iter() diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 33d3bccbba53..f3b2d0df3e3f 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -48,7 +48,7 @@ use arrow::{ use hashbrown::raw::RawTable; use pin_project_lite::pin_project; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use super::common::AbortOnDropSingle; @@ -233,9 +233,9 @@ impl ExecutionPlan for HashAggregateExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let input = self.input.execute(partition, runtime).await?; + let input = self.input.execute(partition, context).await?; let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -334,6 +334,10 @@ impl ExecutionPlan for HashAggregateExec { _ => Statistics::default(), } } + + fn session_id(&self) -> String { + self.input.session_id() + } } /* @@ -1030,6 +1034,7 @@ mod tests { use futures::FutureExt; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::prelude::SessionContext; /// some mock data to aggregates fn some_data() -> (Arc, Vec) { @@ -1064,7 +1069,10 @@ mod tests { } /// build the aggregates on the data from some_data() and check the results - async fn check_aggregates(input: Arc) -> Result<()> { + async fn check_aggregates( + input: Arc, + session_ctx: &SessionContext, + ) -> Result<()> { let input_schema = input.schema(); let groups: Vec<(Arc, String)> = @@ -1076,8 +1084,6 @@ mod tests { DataType::Float64, ))]; - let runtime = Arc::new(RuntimeEnv::default()); - let partial_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -1085,9 +1091,9 @@ mod tests { input, input_schema.clone(), )?); - + let task_ctx = Arc::new(TaskContext::from(session_ctx)); let result = - common::collect(partial_aggregate.execute(0, runtime.clone()).await?).await?; + common::collect(partial_aggregate.execute(0, task_ctx).await?).await?; let expected = vec![ "+---+---------------+-------------+", @@ -1118,8 +1124,9 @@ mod tests { input_schema, )?); + let task_ctx = Arc::new(TaskContext::from(session_ctx)); let result = - common::collect(merged_aggregate.execute(0, runtime.clone()).await?).await?; + common::collect(merged_aggregate.execute(0, task_ctx).await?).await?; assert_eq!(result.len(), 1); let batch = &result[0]; @@ -1151,6 +1158,8 @@ mod tests { struct TestYieldingExec { /// True if this exec should yield back to runtime the first time it is polled pub yield_first: bool, + /// Session id + session_id: String, } #[async_trait] @@ -1187,7 +1196,7 @@ mod tests { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { let stream = if self.yield_first { TestYieldingStream::New @@ -1202,6 +1211,10 @@ mod tests { let (_, batches) = some_data(); common::compute_record_batch_statistics(&[batches], &self.schema(), None) } + + fn session_id(&self) -> String { + self.session_id.clone() + } } /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records @@ -1248,23 +1261,28 @@ mod tests { #[tokio::test] async fn aggregate_source_not_yielding() -> Result<()> { - let input: Arc = - Arc::new(TestYieldingExec { yield_first: false }); + let session_ctx = SessionContext::new(); + let input: Arc = Arc::new(TestYieldingExec { + yield_first: false, + session_id: session_ctx.session_id.clone(), + }); - check_aggregates(input).await + check_aggregates(input, &session_ctx).await } #[tokio::test] async fn aggregate_source_with_yielding() -> Result<()> { - let input: Arc = - Arc::new(TestYieldingExec { yield_first: true }); + let session_ctx = SessionContext::new(); + let input: Arc = Arc::new(TestYieldingExec { + yield_first: true, + session_id: session_ctx.session_id.clone(), + }); - check_aggregates(input).await + check_aggregates(input, &session_ctx).await } #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1276,7 +1294,12 @@ mod tests { DataType::Float64, ))]; - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let session_ctx = SessionContext::new(); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 1, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, @@ -1286,7 +1309,8 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let fut = crate::physical_plan::collect(hash_aggregate_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1298,7 +1322,6 @@ mod tests { #[tokio::test] async fn test_drop_cancel_with_groups() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), @@ -1313,7 +1336,12 @@ mod tests { DataType::Float64, ))]; - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let session_ctx = SessionContext::new(); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 1, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, @@ -1323,7 +1351,8 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let fut = crate::physical_plan::collect(hash_aggregate_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index d276ac2e72de..b3531fb19a26 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -71,7 +71,7 @@ use super::{ }; use crate::arrow::array::BooleanBufferBuilder; use crate::arrow::datatypes::TimeUnit; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; use log::debug; @@ -290,7 +290,7 @@ impl ExecutionPlan for HashJoinExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); // we only want to compute the build side once for PartitionMode::CollectLeft @@ -306,7 +306,7 @@ impl ExecutionPlan for HashJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0, runtime.clone()).await?; + let stream = merge.execute(0, context.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -359,7 +359,7 @@ impl ExecutionPlan for HashJoinExec { let start = Instant::now(); // Load 1 partition of left side in memory - let stream = self.left.execute(partition, runtime.clone()).await?; + let stream = self.left.execute(partition, context.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -410,7 +410,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let right_stream = self.right.execute(partition, runtime.clone()).await?; + let right_stream = self.right.execute(partition, context).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let num_rows = left_data.1.num_rows(); @@ -465,6 +465,10 @@ impl ExecutionPlan for HashJoinExec { // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` Statistics::default() } + + fn session_id(&self) -> String { + self.left.session_id() + } } /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, @@ -1063,16 +1067,26 @@ mod tests { }; use super::*; + use crate::prelude::SessionContext; use std::sync::Arc; fn build_table( a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), + session_ctx: &SessionContext, ) -> Arc { let batch = build_table_i32(a, b, c); let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + Arc::new( + MemoryExec::try_new( + &[vec![batch]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(), + ) } fn join( @@ -1098,12 +1112,13 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, - runtime: Arc, + session_ctx: &SessionContext, ) -> Result<(Vec, Vec)> { let join = join(left, right, on, join_type, null_equals_null)?; let columns = columns(&join.schema()); + let task_ctx = Arc::new(TaskContext::from(session_ctx)); - let stream = join.execute(0, runtime).await?; + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; Ok((columns, batches)) @@ -1115,7 +1130,7 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, - runtime: Arc, + session_ctx: &SessionContext, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1148,7 +1163,8 @@ mod tests { let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(session_ctx)); + let stream = join.execute(i, task_ctx).await?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1163,16 +1179,18 @@ mod tests { #[tokio::test] async fn join_inner_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( @@ -1186,7 +1204,7 @@ mod tests { on.clone(), &JoinType::Inner, false, - runtime, + &session_ctx, ) .await?; @@ -1208,16 +1226,18 @@ mod tests { #[tokio::test] async fn partitioned_join_inner_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1230,7 +1250,7 @@ mod tests { on.clone(), &JoinType::Inner, false, - runtime, + &session_ctx, ) .await?; @@ -1252,16 +1272,18 @@ mod tests { #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1269,7 +1291,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; + join_collect(left, right, on, &JoinType::Inner, false, &session_ctx).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1290,16 +1312,18 @@ mod tests { #[tokio::test] async fn join_inner_two() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a1", &vec![1, 2, 3]), ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![ ( @@ -1313,7 +1337,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; + join_collect(left, right, on, &JoinType::Inner, false, &session_ctx).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1337,7 +1361,6 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] async fn join_inner_one_two_parts_left() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1346,14 +1369,22 @@ mod tests { let batch2 = build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9])); let schema = batch1.schema(); + let session_ctx = SessionContext::new(); let left = Arc::new( - MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + MemoryExec::try_new( + &[vec![batch1], vec![batch2]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(), ); let right = build_table( ("a1", &vec![1, 2, 3]), ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![ ( @@ -1367,7 +1398,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; + join_collect(left, right, on, &JoinType::Inner, false, &session_ctx).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1391,11 +1422,12 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] async fn join_inner_one_two_parts_right() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition ("c1", &vec![7, 8, 9]), + &session_ctx, ); let batch1 = build_table_i32( @@ -1407,7 +1439,13 @@ mod tests { build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); let schema = batch1.schema(); let right = Arc::new( - MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + MemoryExec::try_new( + &[vec![batch1], vec![batch2]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(), ); let on = vec![( @@ -1421,7 +1459,8 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part - let stream = join.execute(0, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); @@ -1435,7 +1474,8 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); // second part - let stream = join.execute(1, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(1, task_ctx).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1456,26 +1496,35 @@ mod tests { a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), + session_ctx: &SessionContext, ) -> Arc { let batch = build_table_i32(a, b, c); let schema = batch.schema(); Arc::new( - MemoryExec::try_new(&[vec![batch.clone(), batch]], schema, None).unwrap(), + MemoryExec::try_new( + &[vec![batch.clone(), batch]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(), ) } #[tokio::test] async fn join_left_multi_batch() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table_two_batches( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema()).unwrap(), @@ -1487,7 +1536,8 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1507,17 +1557,19 @@ mod tests { #[tokio::test] async fn join_full_multi_batch() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 9]), + &session_ctx, ); // create two identical batches for the right side let right = build_table_two_batches( ("a2", &vec![10, 20, 30]), ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema()).unwrap(), @@ -1529,7 +1581,8 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1551,11 +1604,12 @@ mod tests { #[tokio::test] async fn join_left_empty_right() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); let on = vec![( @@ -1563,13 +1617,22 @@ mod tests { Column::new_with_schema("b1", &right.schema()).unwrap(), )]; let schema = right.schema(); - let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let right = Arc::new( + MemoryExec::try_new( + &[vec![right]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(), + ); let join = join(left, right, on, &JoinType::Left, false).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1587,11 +1650,12 @@ mod tests { #[tokio::test] async fn join_full_empty_right() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); let on = vec![( @@ -1599,13 +1663,22 @@ mod tests { Column::new_with_schema("b2", &right.schema()).unwrap(), )]; let schema = right.schema(); - let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let right = Arc::new( + MemoryExec::try_new( + &[vec![right]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(), + ); let join = join(left, right, on, &JoinType::Full, false).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1623,16 +1696,18 @@ mod tests { #[tokio::test] async fn join_left_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1645,7 +1720,7 @@ mod tests { on.clone(), &JoinType::Left, false, - runtime, + &session_ctx, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1666,16 +1741,18 @@ mod tests { #[tokio::test] async fn partitioned_join_left_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1688,7 +1765,7 @@ mod tests { on.clone(), &JoinType::Left, false, - runtime, + &session_ctx, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1709,16 +1786,18 @@ mod tests { #[tokio::test] async fn join_semi() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30, 40]), ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right ("c2", &vec![70, 80, 90, 100]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1730,7 +1809,8 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1749,16 +1829,18 @@ mod tests { #[tokio::test] async fn join_anti() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 2, 3, 5]), ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 8, 9, 11]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30, 40]), ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right ("c2", &vec![70, 80, 90, 100]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1770,7 +1852,8 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1787,16 +1870,18 @@ mod tests { #[tokio::test] async fn join_right_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), // 6 does not exist on the left ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, @@ -1804,7 +1889,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false, runtime).await?; + join_collect(left, right, on, &JoinType::Right, false, &session_ctx).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1825,25 +1910,33 @@ mod tests { #[tokio::test] async fn partitioned_join_right_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), // 6 does not exist on the left ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema())?, Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false, runtime) - .await?; + let (columns, batches) = partitioned_join_collect( + left, + right, + on, + &JoinType::Right, + false, + &session_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1864,16 +1957,18 @@ mod tests { #[tokio::test] async fn join_full_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right ("c1", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a2", &vec![10, 20, 30]), ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( Column::new_with_schema("b1", &left.schema()).unwrap(), @@ -1885,7 +1980,8 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1955,16 +2051,18 @@ mod tests { #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let left = build_table( ("a", &vec![1, 2, 3]), ("b", &vec![4, 5, 7]), ("c", &vec![7, 8, 9]), + &session_ctx, ); let right = build_table( ("a", &vec![10, 20, 30]), ("b", &vec![1, 2, 7]), ("c", &vec![70, 80, 90]), + &session_ctx, ); let on = vec![( // join on a=b so there are duplicate column names on unjoined columns @@ -1977,7 +2075,8 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); - let stream = join.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index f150c5601294..884fe750c444 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -41,7 +41,7 @@ use super::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; /// Limit execution plan @@ -134,7 +134,7 @@ impl ExecutionPlan for GlobalLimitExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { @@ -152,7 +152,7 @@ impl ExecutionPlan for GlobalLimitExec { } let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(0, runtime).await?; + let stream = self.input.execute(0, context).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -196,6 +196,10 @@ impl ExecutionPlan for GlobalLimitExec { _ => Statistics::default(), } } + + fn session_id(&self) -> String { + self.input.session_id() + } } /// LocalLimitExec applies a limit to a single partition @@ -285,10 +289,10 @@ impl ExecutionPlan for LocalLimitExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(partition, runtime).await?; + let stream = self.input.execute(partition, context).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -335,6 +339,10 @@ impl ExecutionPlan for LocalLimitExec { _ => Statistics::default(), } } + + fn session_id(&self) -> String { + self.input.session_id() + } } /// Truncate a RecordBatch to maximum of n rows @@ -432,16 +440,17 @@ mod tests { use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; + use crate::prelude::SessionContext; use crate::{test, test_util}; #[tokio::test] async fn limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let num_partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?; + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { @@ -455,6 +464,7 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); // input should have 4 partitions @@ -464,7 +474,8 @@ mod tests { GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), 7); // the result should contain 4 batches (one per input partition) - let iter = limit.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let iter = limit.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; // there should be a total of 100 rows diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index cc8208346516..d625de9f2df0 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -32,7 +32,7 @@ use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use futures::Stream; @@ -46,6 +46,8 @@ pub struct MemoryExec { projected_schema: SchemaRef, /// Optional projection projection: Option>, + /// Session id + session_id: String, } impl fmt::Debug for MemoryExec { @@ -99,7 +101,7 @@ impl ExecutionPlan for MemoryExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), @@ -135,6 +137,10 @@ impl ExecutionPlan for MemoryExec { self.projection.clone(), ) } + + fn session_id(&self) -> String { + self.session_id.clone() + } } impl MemoryExec { @@ -144,6 +150,7 @@ impl MemoryExec { partitions: &[Vec], schema: SchemaRef, projection: Option>, + session_id: String, ) -> Result { let projected_schema = project_schema(&schema, projection.as_ref())?; Ok(Self { @@ -151,6 +158,7 @@ impl MemoryExec { schema, projected_schema, projection, + session_id, }) } } @@ -223,6 +231,7 @@ mod tests { use super::*; use crate::from_slice::FromSlice; use crate::physical_plan::ColumnStatistics; + use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; @@ -250,10 +259,15 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let (schema, batch) = mock_data()?; - let executor = MemoryExec::try_new(&[vec![batch]], schema, Some(vec![2, 1]))?; + let ctx = SessionContext::new(); + let executor = MemoryExec::try_new( + &[vec![batch]], + schema, + Some(vec![2, 1]), + ctx.session_id.clone(), + )?; let statistics = executor.statistics(); assert_eq!(statistics.num_rows, Some(3)); @@ -276,7 +290,8 @@ mod tests { ); // scan with projection - let mut it = executor.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = executor.execute(0, task_ctx).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -288,10 +303,11 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let (schema, batch) = mock_data()?; - let executor = MemoryExec::try_new(&[vec![batch]], schema, None)?; + let ctx = SessionContext::new(); + let executor = + MemoryExec::try_new(&[vec![batch]], schema, None, ctx.session_id.clone())?; let statistics = executor.statistics(); assert_eq!(statistics.num_rows, Some(3)); @@ -324,8 +340,8 @@ mod tests { }, ]) ); - - let mut it = executor.execute(0, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let mut it = executor.execute(0, task_ctx).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(4, batch1.schema().fields().len()); assert_eq!(4, batch1.num_columns()); diff --git a/datafusion/src/physical_plan/metrics/tracker.rs b/datafusion/src/physical_plan/metrics/tracker.rs index d8017b95ae8d..b9fde54f6574 100644 --- a/datafusion/src/physical_plan/metrics/tracker.rs +++ b/datafusion/src/physical_plan/metrics/tracker.rs @@ -22,6 +22,7 @@ use crate::execution::MemoryConsumerId; use crate::physical_plan::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, Time, }; +use std::fmt; use std::sync::Arc; use std::task::Poll; @@ -32,7 +33,6 @@ use arrow::{error::ArrowError, record_batch::RecordBatch}; /// /// You could use this to replace [BaselineMetrics], report the memory, /// and get the memory usage bookkeeping in the memory manager easily. -#[derive(Debug)] pub struct MemTrackingMetrics { id: MemoryConsumerId, runtime: Option>, @@ -119,6 +119,15 @@ impl MemTrackingMetrics { } } +impl fmt::Debug for MemTrackingMetrics { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MemTrackingMetrics") + .field("id", &self.id) + .field("metrics", &self.metrics) + .finish() + } +} + impl Drop for MemTrackingMetrics { fn drop(&mut self) { self.metrics.try_done(); diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index e2ce99f2bdf4..c9448f603ae1 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -23,7 +23,7 @@ use self::{ coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, }; use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::{error::Result, execution::runtime_env::RuntimeEnv, scalar::ScalarValue}; +use crate::{error::Result, scalar::ScalarValue}; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; @@ -226,7 +226,7 @@ pub trait ExecutionPlan: Debug + Send + Sync { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result; /// Return a snapshot of the set of [`Metric`]s for this @@ -254,6 +254,9 @@ pub trait ExecutionPlan: Debug + Send + Sync { write!(f, "ExecutionPlan(PlaceHolder)") } + /// Return the Session id associated with the execution plan. + fn session_id(&self) -> String; + /// Returns the global output statistics for this `ExecutionPlan` node. fn statistics(&self) -> Statistics; } @@ -269,9 +272,9 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// #[tokio::main] /// async fn main() { /// // Hard code target_partitions as it appears in the RepartitionExec output -/// let config = ExecutionConfig::new() -/// .with_target_partitions(3); -/// let mut ctx = ExecutionContext::with_config(config); +/// let config = SessionConfig::new() +/// .with_target_partitions(3);/// +/// let ctx = SessionContext::with_config(config); /// /// // register the a table /// ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await.unwrap(); @@ -385,26 +388,26 @@ pub fn visit_execution_plan( /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result> { - let stream = execute_stream(plan, runtime).await?; + let stream = execute_stream(plan, context).await?; common::collect(stream).await } /// Execute the [ExecutionPlan] and return a single stream of results pub async fn execute_stream( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result { match plan.output_partitioning().partition_count() { 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0, runtime).await, + 1 => plan.execute(0, context).await, _ => { // merge into a single partition let plan = CoalescePartitionsExec::new(plan.clone()); // CoalescePartitionsExec must produce a single partition assert_eq!(1, plan.output_partitioning().partition_count()); - plan.execute(0, runtime).await + plan.execute(0, context).await } } } @@ -412,9 +415,9 @@ pub async fn execute_stream( /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect_partitioned( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result>> { - let streams = execute_stream_partitioned(plan, runtime).await?; + let streams = execute_stream_partitioned(plan, context).await?; let mut batches = Vec::with_capacity(streams.len()); for stream in streams { batches.push(common::collect(stream).await?); @@ -425,12 +428,12 @@ pub async fn collect_partitioned( /// Execute the [ExecutionPlan] and return a vec with one stream per output partition pub async fn execute_stream_partitioned( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result> { let num_partitions = plan.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(num_partitions); for i in 0..num_partitions { - streams.push(plan.execute(i, runtime.clone()).await?); + streams.push(plan.execute(i, context.clone()).await?); } Ok(streams) } @@ -521,7 +524,9 @@ pub mod cross_join; pub mod display; pub mod empty; pub mod explain; +use crate::execution::context::TaskContext; pub use datafusion_physical_expr::expressions; + pub mod aggregate_rule; pub mod file_format; pub mod filter; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 88ff9fa4dc98..8b187e32686c 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -22,7 +22,7 @@ use super::{ aggregates, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, }; -use crate::execution::context::{ExecutionContextState, ExecutionProps}; +use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, TableScan, Window, }; @@ -217,7 +217,7 @@ pub trait PhysicalPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>; /// Create a physical expression from a logical expression @@ -233,7 +233,7 @@ pub trait PhysicalPlanner { expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>; } @@ -255,7 +255,7 @@ pub trait ExtensionPlanner { node: &dyn UserDefinedLogicalNode, logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>>; } @@ -272,13 +272,15 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { - match self.handle_explain(logical_plan, ctx_state).await? { + match self.handle_explain(logical_plan, session_state).await? { Some(plan) => Ok(plan), None => { - let plan = self.create_initial_plan(logical_plan, ctx_state).await?; - self.optimize_internal(plan, ctx_state, |_, _| {}) + let plan = self + .create_initial_plan(logical_plan, session_state) + .await?; + self.optimize_internal(plan, session_state, |_, _| {}) } } } @@ -296,13 +298,13 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { create_physical_expr( expr, input_dfschema, input_schema, - &ctx_state.execution_props, + &session_state.execution_props, ) } } @@ -322,9 +324,13 @@ impl DefaultPhysicalPlanner { fn create_initial_plan<'a>( &'a self, logical_plan: &'a LogicalPlan, - ctx_state: &'a ExecutionContextState, + session_state: &'a SessionState, ) -> BoxFuture<'a, Result>> { async move { + let config = { + // Clone the session config + session_state.config.lock().clone() + }; let exec_plan: Result> = match logical_plan { LogicalPlan::TableScan (TableScan { source, @@ -338,7 +344,7 @@ impl DefaultPhysicalPlanner { // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(projection, &unaliased, *limit).await + source.scan(projection, &unaliased, *limit, session_state.session_id.clone()).await } LogicalPlan::Values(Values { values, @@ -352,7 +358,7 @@ impl DefaultPhysicalPlanner { expr, schema, &exec_schema, - ctx_state, + session_state, ) }) .collect::>>>() @@ -360,7 +366,8 @@ impl DefaultPhysicalPlanner { .collect::>>()?; let value_exec = ValuesExec::try_new( SchemaRef::new(exec_schema), - exprs + exprs, + session_state.session_id.clone(), )?; Ok(Arc::new(value_exec)) } @@ -373,15 +380,15 @@ impl DefaultPhysicalPlanner { )); } - let input_exec = self.create_initial_plan(input, ctx_state).await?; + let input_exec = self.create_initial_plan(input, session_state).await?; // at this moment we are guaranteed by the logical planner // to have all the window_expr to have equal sort key let partition_keys = window_expr_common_partition_keys(window_expr)?; let can_repartition = !partition_keys.is_empty() - && ctx_state.config.target_partitions > 1 - && ctx_state.config.repartition_windows; + && config.target_partitions > 1 + && config.repartition_windows; let input_exec = if can_repartition { let partition_keys = partition_keys @@ -391,7 +398,7 @@ impl DefaultPhysicalPlanner { e, input.schema(), &input_exec.schema(), - ctx_state, + session_state, ) }) .collect::>>>()?; @@ -399,7 +406,7 @@ impl DefaultPhysicalPlanner { input_exec, Partitioning::Hash( partition_keys, - ctx_state.config.target_partitions, + config.target_partitions, ), )?) } else { @@ -446,7 +453,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - &ctx_state.execution_props, + &session_state.execution_props, ), _ => unreachable!(), }) @@ -466,7 +473,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - &ctx_state.execution_props, + &session_state.execution_props, ) }) .collect::>>()?; @@ -484,7 +491,7 @@ impl DefaultPhysicalPlanner { .. }) => { // Initially need to perform the aggregate and then merge the partitions - let input_exec = self.create_initial_plan(input, ctx_state).await?; + let input_exec = self.create_initial_plan(input, session_state).await?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); @@ -496,7 +503,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - ctx_state, + session_state, ), physical_name(e), )) @@ -509,7 +516,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - &ctx_state.execution_props, + &session_state.execution_props, ) }) .collect::>>()?; @@ -532,8 +539,8 @@ impl DefaultPhysicalPlanner { .any(|x| matches!(x, DataType::Dictionary(_, _))); let can_repartition = !groups.is_empty() - && ctx_state.config.target_partitions > 1 - && ctx_state.config.repartition_aggregations + && config.target_partitions > 1 + && config.repartition_aggregations && !contains_dict; let (initial_aggr, next_partition_mode): ( @@ -545,7 +552,7 @@ impl DefaultPhysicalPlanner { initial_aggr, Partitioning::Hash( final_group.clone(), - ctx_state.config.target_partitions, + config.target_partitions, ), )?); // Combine hash aggregates within the partition @@ -569,7 +576,7 @@ impl DefaultPhysicalPlanner { )?) ) } LogicalPlan::Projection(Projection { input, expr, .. }) => { - let input_exec = self.create_initial_plan(input, ctx_state).await?; + let input_exec = self.create_initial_plan(input, session_state).await?; let input_schema = input.as_ref().schema(); let physical_exprs = expr @@ -608,7 +615,7 @@ impl DefaultPhysicalPlanner { e, input_schema, &input_exec.schema(), - ctx_state, + session_state, ), physical_name, )) @@ -623,7 +630,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Filter(Filter { input, predicate, .. }) => { - let physical_input = self.create_initial_plan(input, ctx_state).await?; + let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); @@ -631,22 +638,22 @@ impl DefaultPhysicalPlanner { predicate, input_dfschema, &input_schema, - ctx_state, + session_state, )?; Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?) ) } LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = futures::stream::iter(inputs) - .then(|lp| self.create_initial_plan(lp, ctx_state)) + .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; - Ok(Arc::new(UnionExec::new(physical_plans)) ) + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, partitioning_scheme, }) => { - let physical_input = self.create_initial_plan(input, ctx_state).await?; + let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.schema(); let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { @@ -661,7 +668,7 @@ impl DefaultPhysicalPlanner { e, input_dfschema, &input_schema, - ctx_state, + session_state, ) }) .collect::>>()?; @@ -674,7 +681,7 @@ impl DefaultPhysicalPlanner { )?) ) } LogicalPlan::Sort(Sort { expr, input, .. }) => { - let physical_input = self.create_initial_plan(input, ctx_state).await?; + let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); let sort_expr = expr @@ -692,7 +699,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - &ctx_state.execution_props, + &session_state.execution_props, ), _ => Err(DataFusionError::Plan( "Sort only accepts sort expressions".to_string(), @@ -710,9 +717,9 @@ impl DefaultPhysicalPlanner { .. }) => { let left_df_schema = left.schema(); - let physical_left = self.create_initial_plan(left, ctx_state).await?; + let physical_left = self.create_initial_plan(left, session_state).await?; let right_df_schema = right.schema(); - let physical_right = self.create_initial_plan(right, ctx_state).await?; + let physical_right = self.create_initial_plan(right, session_state).await?; let join_on = keys .iter() .map(|(l, r)| { @@ -723,8 +730,8 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; - if ctx_state.config.target_partitions > 1 - && ctx_state.config.repartition_joins + if config.target_partitions > 1 + && config.repartition_joins { let (left_expr, right_expr) = join_on .iter() @@ -742,14 +749,14 @@ impl DefaultPhysicalPlanner { physical_left, Partitioning::Hash( left_expr, - ctx_state.config.target_partitions, + config.target_partitions, ), )?), Arc::new(RepartitionExec::try_new( physical_right, Partitioning::Hash( right_expr, - ctx_state.config.target_partitions, + config.target_partitions, ), )?), join_on, @@ -769,8 +776,8 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left = self.create_initial_plan(left, ctx_state).await?; - let right = self.create_initial_plan(right, ctx_state).await?; + let left = self.create_initial_plan(left, session_state).await?; + let right = self.create_initial_plan(right, session_state).await?; Ok(Arc::new(CrossJoinExec::try_new(left, right)?)) } LogicalPlan::EmptyRelation(EmptyRelation { @@ -778,11 +785,12 @@ impl DefaultPhysicalPlanner { schema, }) => Ok(Arc::new(EmptyExec::new( *produce_one_row, - SchemaRef::new(schema.as_ref().to_owned().into()), + SchemaRef::new(schema.as_ref().to_owned().into()),session_state.session_id.clone(), + ))), LogicalPlan::Limit(Limit { input, n, .. }) => { let limit = *n; - let input = self.create_initial_plan(input, ctx_state).await?; + let input = self.create_initial_plan(input, session_state).await?; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -809,19 +817,20 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(EmptyExec::new( false, SchemaRef::new(Schema::empty()), + session_state.session_id.clone(), ))) } LogicalPlan::Explain (_) => Err(DataFusionError::Internal( "Unsupported logical plan: Explain must be root of the plan".to_string(), )), LogicalPlan::Analyze(a) => { - let input = self.create_initial_plan(&a.input, ctx_state).await?; + let input = self.create_initial_plan(&a.input, session_state).await?; let schema = SchemaRef::new((*a.schema).clone().into()); Ok(Arc::new(AnalyzeExec::new(a.verbose, input, schema))) } LogicalPlan::Extension(e) => { let physical_inputs = futures::stream::iter(e.node.inputs()) - .then(|lp| self.create_initial_plan(lp, ctx_state)) + .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; @@ -836,7 +845,7 @@ impl DefaultPhysicalPlanner { e.node.as_ref(), &e.node.inputs(), &physical_inputs, - ctx_state, + session_state, ) } }, @@ -1351,7 +1360,7 @@ impl DefaultPhysicalPlanner { async fn handle_explain( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>> { if let LogicalPlan::Explain(e) = logical_plan { use PlanType::*; @@ -1359,16 +1368,19 @@ impl DefaultPhysicalPlanner { stringified_plans.push(e.plan.to_stringified(FinalLogicalPlan)); - let input = self.create_initial_plan(e.plan.as_ref(), ctx_state).await?; + let input = self + .create_initial_plan(e.plan.as_ref(), session_state) + .await?; stringified_plans .push(displayable(input.as_ref()).to_stringified(InitialPhysicalPlan)); - let input = self.optimize_internal(input, ctx_state, |plan, optimizer| { - let optimizer_name = optimizer.name().to_string(); - let plan_type = OptimizedPhysicalPlan { optimizer_name }; - stringified_plans.push(displayable(plan).to_stringified(plan_type)); - })?; + let input = + self.optimize_internal(input, session_state, |plan, optimizer| { + let optimizer_name = optimizer.name().to_string(); + let plan_type = OptimizedPhysicalPlan { optimizer_name }; + stringified_plans.push(displayable(plan).to_stringified(plan_type)); + })?; stringified_plans .push(displayable(input.as_ref()).to_stringified(FinalPhysicalPlan)); @@ -1377,6 +1389,7 @@ impl DefaultPhysicalPlanner { SchemaRef::new(e.schema.as_ref().to_owned().into()), stringified_plans, e.verbose, + session_state.session_id.clone(), )))) } else { Ok(None) @@ -1388,13 +1401,13 @@ impl DefaultPhysicalPlanner { fn optimize_internal( &self, plan: Arc, - ctx_state: &ExecutionContextState, + session_state: &SessionState, mut observer: F, ) -> Result> where F: FnMut(&dyn ExecutionPlan, &dyn PhysicalOptimizerRule), { - let optimizers = &ctx_state.config.physical_optimizers; + let optimizers = &session_state.physical_optimizers; debug!( "Input physical plan:\n{}\n", displayable(plan.as_ref()).indent() @@ -1403,7 +1416,7 @@ impl DefaultPhysicalPlanner { let mut new_plan = plan; for optimizer in optimizers { - new_plan = optimizer.optimize(new_plan, &ctx_state.config)?; + new_plan = optimizer.optimize(new_plan, &session_state.config.lock())?; observer(new_plan.as_ref(), optimizer.as_ref()) } debug!( @@ -1428,12 +1441,14 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; + use crate::execution::context::TaskContext; use crate::execution::options::CsvReadOptions; - use crate::execution::runtime_env::RuntimeEnv; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::logical_plan::plan::Extension; use crate::physical_plan::{ expressions, DisplayFormatType, Partitioning, Statistics, }; + use crate::prelude::SessionConfig; use crate::scalar::ScalarValue; use crate::{ logical_plan::LogicalPlanBuilder, physical_plan::SendableRecordBatchStream, @@ -1447,15 +1462,18 @@ mod tests { use std::convert::TryFrom; use std::{any::Any, fmt}; - fn make_ctx_state() -> ExecutionContextState { - ExecutionContextState::new() + fn make_session_state() -> SessionState { + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + SessionState::with_config(SessionConfig::new(), runtime) } async fn plan(logical_plan: &LogicalPlan) -> Result> { - let mut ctx_state = make_ctx_state(); - ctx_state.config.target_partitions = 4; + let session_state = make_session_state(); + session_state.config.lock().target_partitions = 4; let planner = DefaultPhysicalPlanner::default(); - planner.create_physical_plan(logical_plan, &ctx_state).await + planner + .create_physical_plan(logical_plan, &session_state) + .await } #[tokio::test] @@ -1501,7 +1519,7 @@ mod tests { &col("a").not(), &dfschema, &schema, - &make_ctx_state(), + &make_session_state(), )?; let expected = expressions::not(expressions::col("a", &schema)?, &schema)?; @@ -1580,13 +1598,13 @@ mod tests { #[tokio::test] async fn default_extension_planner() { - let ctx_state = make_ctx_state(); + let session_state = make_session_state(); let planner = DefaultPhysicalPlanner::default(); let logical_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner - .create_physical_plan(&logical_plan, &ctx_state) + .create_physical_plan(&logical_plan, &session_state) .await; let expected_error = @@ -1606,7 +1624,7 @@ mod tests { async fn bad_extension_planner() { // Test that creating an execution plan whose schema doesn't // match the logical plan's schema generates an error. - let ctx_state = make_ctx_state(); + let session_state = make_session_state(); let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( BadExtensionPlanner {}, )]); @@ -1615,7 +1633,7 @@ mod tests { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner - .create_physical_plan(&logical_plan, &ctx_state) + .create_physical_plan(&logical_plan, &session_state) .await; let expected_error: &str = "Error during planning: \ @@ -1901,7 +1919,7 @@ mod tests { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } @@ -1921,6 +1939,10 @@ mod tests { fn statistics(&self) -> Statistics { unimplemented!("NoOpExecutionPlan::statistics"); } + + fn session_id(&self) -> String { + unimplemented!("NoOpExecutionPlan::session_id"); + } } // Produces an execution plan where the schema is mismatched from @@ -1935,7 +1957,7 @@ mod tests { _node: &dyn UserDefinedLogicalNode, _logical_inputs: &[&LogicalPlan], _physical_inputs: &[Arc], - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result>> { Ok(Some(Arc::new(NoOpExecutionPlan { schema: SchemaRef::new(Schema::new(vec![Field::new( diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index 5940b64957c1..206f24d38307 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -37,7 +37,7 @@ use arrow::record_batch::RecordBatch; use super::expressions::{Column, PhysicalSortExpr}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use futures::stream::Stream; use futures::stream::StreamExt; @@ -153,12 +153,12 @@ impl ExecutionPlan for ProjectionExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { Ok(Box::pin(ProjectionStream { schema: self.schema.clone(), expr: self.expr.iter().map(|x| x.0.clone()).collect(), - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) } @@ -198,6 +198,10 @@ impl ExecutionPlan for ProjectionExec { self.expr.iter().map(|(e, _)| Arc::clone(e)), ) } + + fn session_id(&self) -> String { + self.input.session_id() + } } /// If e is a direct column reference, returns the field level @@ -303,6 +307,7 @@ mod tests { use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::expressions::{self, col}; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; + use crate::prelude::SessionContext; use crate::scalar::ScalarValue; use crate::test::{self}; use crate::test_util; @@ -310,13 +315,12 @@ mod tests { #[tokio::test] async fn project_first_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; - + let ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -329,6 +333,7 @@ mod tests { }, true, b',', + ctx.session_id.clone(), ); // pick column c1 and name it column c1 in the output schema @@ -346,7 +351,8 @@ mod tests { let mut row_count = 0; for partition in 0..projection.output_partitioning().partition_count() { partition_count += 1; - let stream = projection.execute(partition, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let stream = projection.execute(partition, task_ctx).await?; row_count += stream .map(|batch| { diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 55328c40c951..a395904029eb 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -37,7 +37,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use futures::stream::Stream; use futures::StreamExt; use hashbrown::HashMap; @@ -177,7 +177,7 @@ impl ExecutionPlan for RepartitionExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // lock mutexes let mut state = self.state.lock().await; @@ -221,7 +221,7 @@ impl ExecutionPlan for RepartitionExec { txs.clone(), self.partitioning.clone(), r_metrics, - runtime.clone(), + context.clone(), )); // In a separate task, wait for each input to be done @@ -268,6 +268,10 @@ impl ExecutionPlan for RepartitionExec { fn statistics(&self) -> Statistics { self.input.statistics() } + + fn session_id(&self) -> String { + self.input.session_id() + } } impl RepartitionExec { @@ -300,13 +304,13 @@ impl RepartitionExec { mut txs: HashMap>>>, partitioning: Partitioning, r_metrics: RepartitionMetrics, - runtime: Arc, + context: Arc, ) -> Result<()> { let num_output_partitions = txs.len(); // execute the child operator let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i, runtime).await?; + let mut stream = input.execute(i, context).await?; timer.done(); let mut counter = 0; @@ -503,6 +507,7 @@ impl RecordBatchStream for RepartitionStream { mod tests { use super::*; use crate::from_slice::FromSlice; + use crate::prelude::SessionContext; use crate::test::create_vec_batches; use crate::{ assert_batches_sorted_eq, @@ -616,16 +621,22 @@ mod tests { input_partitions: Vec>, partitioning: Partitioning, ) -> Result>> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new( + &input_partitions, + schema.clone(), + None, + session_ctx.session_id.clone(), + )?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; // execute and collect results let mut output_partitions = vec![]; for i in 0..exec.partitioning.partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, runtime.clone()).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let mut stream = exec.execute(i, task_ctx).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -665,7 +676,6 @@ mod tests { #[tokio::test] async fn unsupported_partitioning() { - let runtime = Arc::new(RuntimeEnv::default()); // have to send at least one batch through to provoke error let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", @@ -680,7 +690,9 @@ mod tests { // returned and no results produced let partitioning = Partitioning::UnknownPartitioning(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream = exec.execute(0, runtime).await.unwrap(); + let ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -700,14 +712,15 @@ mod tests { // This generates an error on a call to execute. The error // should be returned and no results produced. - let runtime = Arc::new(RuntimeEnv::default()); - let input = ErrorExec::new(); + let ctx = SessionContext::new(); + let input = ErrorExec::new(ctx.session_id.clone()); let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -723,7 +736,6 @@ mod tests { #[tokio::test] async fn repartition_with_error_in_stream() { - let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef, @@ -741,7 +753,9 @@ mod tests { // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0, runtime).await.unwrap(); + let ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -757,7 +771,6 @@ mod tests { #[tokio::test] async fn repartition_with_delayed_stream() { - let runtime = Arc::new(RuntimeEnv::default()); let batch1 = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef, @@ -792,7 +805,9 @@ mod tests { assert_batches_sorted_eq!(&expected, &expected_batches); - let output_stream = exec.execute(0, runtime).await.unwrap(); + let ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); let batches = crate::physical_plan::common::collect(output_stream) .await .unwrap(); @@ -802,17 +817,19 @@ mod tests { #[tokio::test] async fn robin_repartition_with_dropping_output_stream() { - let runtime = Arc::new(RuntimeEnv::default()); let partitioning = Partitioning::RoundRobinBatch(2); // The barrier exec waits to be pinged // requires the input to wait at least once) - let input = Arc::new(make_barrier_exec()); + let ctx = SessionContext::new(); + let input = Arc::new(make_barrier_exec(&ctx)); // partition into two output streams let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream0 = exec.execute(0, task_ctx).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream1 = exec.execute(1, task_ctx).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced @@ -845,7 +862,6 @@ mod tests { // wiht different compilers, we will compare the same execution with // and without droping the output stream. async fn hash_repartition_with_dropping_output_stream() { - let runtime = Arc::new(RuntimeEnv::default()); let partitioning = Partitioning::Hash( vec![Arc::new(crate::physical_plan::expressions::Column::new( "my_awesome_field", @@ -855,9 +871,11 @@ mod tests { ); // We first collect the results without droping the output stream. - let input = Arc::new(make_barrier_exec()); + let ctx = SessionContext::new(); + let input = Arc::new(make_barrier_exec(&ctx)); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream1 = exec.execute(1, task_ctx).await.unwrap(); input.wait().await; let batches_without_drop = crate::physical_plan::common::collect(output_stream1) .await @@ -875,10 +893,12 @@ mod tests { assert_eq!(items_set.difference(&source_str_set).count(), 0); // Now do the same but dropping the stream before waiting for the barrier - let input = Arc::new(make_barrier_exec()); + let input = Arc::new(make_barrier_exec(&ctx)); let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream0 = exec.execute(0, task_ctx).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream1 = exec.execute(1, task_ctx).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); @@ -910,7 +930,7 @@ mod tests { } /// Create a BarrierExec that returns two partitions of two batches each - fn make_barrier_exec() -> BarrierExec { + fn make_barrier_exec(ctx: &SessionContext) -> BarrierExec { let batch1 = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef, @@ -938,23 +958,32 @@ mod tests { // The barrier exec waits to be pinged // requires the input to wait at least once) let schema = batch1.schema(); - BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema) + BarrierExec::new( + vec![vec![batch1, batch2], vec![batch3, batch4]], + schema, + ctx.session_id.clone(), + ) } #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); + let session_ctx = SessionContext::new(); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 2, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let repartition_exec = Arc::new(RepartitionExec::try_new( blocking_exec, Partitioning::UnknownPartitioning(1), )?); - let fut = collect(repartition_exec, runtime); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let fut = collect(repartition_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -966,7 +995,6 @@ mod tests { #[tokio::test] async fn hash_repartition_avoid_empty_batch() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "a", Arc::new(StringArray::from_slice(&["foo"])) as ArrayRef, @@ -981,11 +1009,14 @@ mod tests { let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch)], schema); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); + let ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream0 = exec.execute(0, task_ctx).await.unwrap(); let batch0 = crate::physical_plan::common::collect(output_stream0) .await .unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let output_stream1 = exec.execute(1, task_ctx).await.unwrap(); let batch1 = crate::physical_plan::common::collect(output_stream1) .await .unwrap(); diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index 1428e1627d8f..06742c5eb9f0 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -20,6 +20,7 @@ //! but spills to disk if needed. use crate::error::{DataFusionError, Result}; +use crate::execution::context::TaskContext; use crate::execution::memory_manager::{ human_readable_size, ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, }; @@ -36,6 +37,7 @@ use crate::physical_plan::{ common, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; +use crate::prelude::SessionConfig; use arrow::array::ArrayRef; pub use arrow::compute::SortOptions; use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; @@ -75,6 +77,7 @@ struct ExternalSorter { /// Sort expressions expr: Vec, runtime: Arc, + session_config: Arc, metrics_set: CompositeMetricsSet, metrics: BaselineMetrics, } @@ -86,6 +89,7 @@ impl ExternalSorter { expr: Vec, metrics_set: CompositeMetricsSet, runtime: Arc, + session_config: Arc, ) -> Self { let metrics = metrics_set.new_intermediate_baseline(partition_id); Self { @@ -95,6 +99,7 @@ impl ExternalSorter { spills: Mutex::new(vec![]), expr, runtime, + session_config, metrics_set, metrics, } @@ -151,7 +156,7 @@ impl ExternalSorter { self.schema.clone(), &self.expr, tracking_metrics, - self.runtime.clone(), + self.session_config.batch_size, ))) } else if in_mem_batches.len() > 0 { let tracking_metrics = self @@ -476,7 +481,7 @@ impl ExecutionPlan for SortExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if !self.preserve_partitioning { if 0 != partition { @@ -494,14 +499,16 @@ impl ExecutionPlan for SortExec { } } - let input = self.input.execute(partition, runtime.clone()).await?; + let session_config = context.session_config(); + let input = self.input.execute(partition, context.clone()).await?; do_sort( input, partition, self.expr.clone(), self.metrics_set.clone(), - runtime, + context.runtime.clone(), + Arc::new(session_config), ) .await } @@ -526,6 +533,10 @@ impl ExecutionPlan for SortExec { fn statistics(&self) -> Statistics { self.input.statistics() } + + fn session_id(&self) -> String { + self.input.session_id() + } } fn sort_batch( @@ -569,6 +580,7 @@ async fn do_sort( expr: Vec, metrics_set: CompositeMetricsSet, runtime: Arc, + session_config: Arc, ) -> Result { let schema = input.schema(); let sorter = ExternalSorter::new( @@ -577,6 +589,7 @@ async fn do_sort( expr, metrics_set, runtime.clone(), + session_config.clone(), ); runtime.register_requester(sorter.id()); while let Some(batch) = input.next().await { @@ -590,7 +603,7 @@ async fn do_sort( mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; - use crate::execution::context::ExecutionConfig; + use crate::execution::runtime_env::RuntimeConfig; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::col; use crate::physical_plan::memory::MemoryExec; @@ -598,6 +611,7 @@ mod tests { collect, file_format::{CsvExec, FileScanConfig}, }; + use crate::prelude::SessionContext; use crate::test; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -610,11 +624,11 @@ mod tests { #[tokio::test] async fn test_in_mem_sort() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + let session_ctx = SessionContext::new(); let csv = CsvExec::new( FileScanConfig { @@ -628,6 +642,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), ); let sort_exec = Arc::new(SortExec::try_new( @@ -651,7 +666,8 @@ mod tests { Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), )?); - let result = collect(sort_exec, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let result = collect(sort_exec, task_ctx).await?; assert_eq!(result.len(), 1); @@ -675,13 +691,14 @@ mod tests { #[tokio::test] async fn test_sort_spill() -> Result<()> { // trigger spill there will be 4 batches with 5.5KB for each - let config = ExecutionConfig::new().with_memory_limit(12288, 1.0)?; - let runtime = Arc::new(RuntimeEnv::new(config.runtime)?); + let config = RuntimeConfig::new().with_memory_limit(12288, 1.0); + let runtime = Arc::new(RuntimeEnv::new(config)?); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); let csv = CsvExec::new( FileScanConfig { @@ -695,6 +712,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), ); let sort_exec = Arc::new(SortExec::try_new( @@ -718,7 +736,8 @@ mod tests { Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), )?); - let result = collect(sort_exec.clone(), runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let result = collect(sort_exec.clone(), task_ctx).await?; assert_eq!(result.len(), 1); @@ -749,7 +768,6 @@ mod tests { #[tokio::test] async fn test_sort_metadata() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let field_metadata: BTreeMap = vec![("foo".to_string(), "bar".to_string())] .into_iter() @@ -768,8 +786,16 @@ mod tests { Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); - let input = - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + let session_ctx = SessionContext::new(); + let input = Arc::new( + MemoryExec::try_new( + &[vec![batch]], + schema.clone(), + None, + session_ctx.session_id.clone(), + ) + .unwrap(), + ); let sort_exec = Arc::new(SortExec::try_new( vec![PhysicalSortExpr { @@ -779,7 +805,8 @@ mod tests { input, )?); - let result: Vec = collect(sort_exec, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let result: Vec = collect(sort_exec, task_ctx).await?; let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); @@ -801,7 +828,6 @@ mod tests { #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float64, true), @@ -834,6 +860,7 @@ mod tests { ], )?; + let session_ctx = SessionContext::new(); let sort_exec = Arc::new(SortExec::try_new( vec![ PhysicalSortExpr { @@ -851,13 +878,19 @@ mod tests { }, }, ], - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), + Arc::new(MemoryExec::try_new( + &[vec![batch]], + schema, + None, + session_ctx.session_id.clone(), + )?), )?); assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); - let result: Vec = collect(sort_exec.clone(), runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let result: Vec = collect(sort_exec.clone(), task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 8); @@ -906,11 +939,15 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let session_ctx = SessionContext::new(); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 1, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let sort_exec = Arc::new(SortExec::try_new( vec![PhysicalSortExpr { @@ -919,8 +956,8 @@ mod tests { }], blocking_exec, )?); - - let fut = collect(sort_exec, runtime); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let fut = collect(sort_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 780e2cc67659..4076c1f51b45 100644 --- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -42,7 +42,7 @@ use futures::stream::FusedStream; use futures::{Stream, StreamExt}; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, StreamWrapper}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, @@ -158,7 +158,7 @@ impl ExecutionPlan for SortPreservingMergeExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -177,7 +177,7 @@ impl ExecutionPlan for SortPreservingMergeExec { )), 1 => { // bypass if there is only one partition to merge (no metrics in this case either) - self.input.execute(0, runtime).await + self.input.execute(0, context.clone()).await } _ => { let (receivers, join_handles) = (0..input_partitions) @@ -188,7 +188,7 @@ impl ExecutionPlan for SortPreservingMergeExec { self.input.clone(), sender, part_i, - runtime.clone(), + context.clone(), ); (receiver, join_handle) }) @@ -200,7 +200,7 @@ impl ExecutionPlan for SortPreservingMergeExec { self.schema(), &self.expr, tracking_metrics, - runtime, + context.session_config().batch_size, ))) } } @@ -226,6 +226,10 @@ impl ExecutionPlan for SortPreservingMergeExec { fn statistics(&self) -> Statistics { self.input.statistics() } + + fn session_id(&self) -> String { + self.input.session_id() + } } #[derive(Debug)] @@ -302,7 +306,7 @@ impl SortPreservingMergeStream { schema: SchemaRef, expressions: &[PhysicalSortExpr], tracking_metrics: MemTrackingMetrics, - runtime: Arc, + batch_size: usize, ) -> Self { let stream_count = receivers.len(); let batches = (0..stream_count) @@ -324,7 +328,7 @@ impl SortPreservingMergeStream { in_progress: vec![], next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), - batch_size: runtime.batch_size(), + batch_size, } } @@ -333,7 +337,7 @@ impl SortPreservingMergeStream { schema: SchemaRef, expressions: &[PhysicalSortExpr], tracking_metrics: MemTrackingMetrics, - runtime: Arc, + batch_size: usize, ) -> Self { let stream_count = streams.len(); let batches = (0..stream_count) @@ -359,7 +363,7 @@ impl SortPreservingMergeStream { in_progress: vec![], next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), - batch_size: runtime.batch_size(), + batch_size, } } @@ -623,14 +627,13 @@ mod tests { use crate::{assert_batches_eq, test_util}; use super::*; - use crate::execution::runtime_env::RuntimeConfig; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::datatypes::{DataType, Field, Schema}; use futures::{FutureExt, SinkExt}; use tokio_stream::StreamExt; #[tokio::test] async fn test_merge_interleave() { - let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -671,14 +674,12 @@ mod tests { "| 3 | j | 1970-01-01 00:00:00.000000008 |", "+----+---+-------------------------------+", ], - runtime, ) .await; } #[tokio::test] async fn test_merge_some_overlap() { - let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -719,14 +720,12 @@ mod tests { "| 110 | g | 1970-01-01 00:00:00.000000006 |", "+-----+---+-------------------------------+", ], - runtime, ) .await; } #[tokio::test] async fn test_merge_no_overlap() { - let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -767,14 +766,12 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000006 |", "+----+---+-------------------------------+", ], - runtime, ) .await; } #[tokio::test] async fn test_merge_three_partitions() { - let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -832,16 +829,11 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000060 |", "+-----+---+-------------------------------+", ], - runtime, ) .await; } - async fn _test_merge( - partitions: &[Vec], - exp: &[&str], - runtime: Arc, - ) { + async fn _test_merge(partitions: &[Vec], exp: &[&str]) { let schema = partitions[0][0].schema(); let sort = vec![ PhysicalSortExpr { @@ -853,20 +845,25 @@ mod tests { options: Default::default(), }, ]; - let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let exec = + MemoryExec::try_new(partitions, schema, None, session_ctx.session_id.clone()) + .unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge, runtime).await.unwrap(); + let collected = collect(merge, task_ctx).await.unwrap(); assert_batches_eq!(exp, collected.as_slice()); } async fn sorted_merge( input: Arc, sort: Vec, - runtime: Arc, + session_ctx: &SessionContext, ) -> RecordBatch { let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let mut result = collect(merge, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(session_ctx)); + let mut result = collect(merge, task_ctx).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } @@ -874,33 +871,34 @@ mod tests { async fn partition_sort( input: Arc, sort: Vec, - runtime: Arc, + session_ctx: &SessionContext, ) -> RecordBatch { let sort_exec = Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true)); - sorted_merge(sort_exec, sort, runtime).await + sorted_merge(sort_exec, sort, session_ctx).await } async fn basic_sort( src: Arc, sort: Vec, - runtime: Arc, + session_ctx: &SessionContext, ) -> RecordBatch { let merge = Arc::new(CoalescePartitionsExec::new(src)); let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap()); - let mut result = collect(sort_exec, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(session_ctx)); + let mut result = collect(sort_exec, task_ctx).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } #[tokio::test] async fn test_partition_sort() { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions).unwrap(); + let session_ctx = SessionContext::new(); let csv = Arc::new(CsvExec::new( FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -913,6 +911,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), )); let sort = vec![ @@ -937,8 +936,8 @@ mod tests { }, ]; - let basic = basic_sort(csv.clone(), sort.clone(), runtime.clone()).await; - let partition = partition_sort(csv, sort, runtime.clone()).await; + let basic = basic_sort(csv.clone(), sort.clone(), &session_ctx).await; + let partition = partition_sort(csv, sort, &session_ctx).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -981,7 +980,7 @@ mod tests { async fn sorted_partitioned_input( sort: Vec, sizes: &[usize], - runtime: Arc, + session_ctx: &SessionContext, ) -> Arc { let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -1000,17 +999,25 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), )); - let sorted = basic_sort(csv, sort, runtime).await; + let sorted = basic_sort(csv, sort, session_ctx).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); - Arc::new(MemoryExec::try_new(&split, sorted.schema(), None).unwrap()) + Arc::new( + MemoryExec::try_new( + &split, + sorted.schema(), + None, + session_ctx.session_id.clone(), + ) + .unwrap(), + ) } #[tokio::test] async fn test_partition_sort_streaming_input() { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![ // uint8 @@ -1035,10 +1042,11 @@ mod tests { }, ]; + let session_ctx = SessionContext::new(); let input = - sorted_partitioned_input(sort.clone(), &[10, 3, 11], runtime.clone()).await; - let basic = basic_sort(input.clone(), sort.clone(), runtime.clone()).await; - let partition = sorted_merge(input, sort, runtime.clone()).await; + sorted_partitioned_input(sort.clone(), &[10, 3, 11], &session_ctx).await; + let basic = basic_sort(input.clone(), sort.clone(), &session_ctx).await; + let partition = sorted_merge(input, sort, &session_ctx).await; assert_eq!(basic.num_rows(), 300); assert_eq!(partition.num_rows(), 300); @@ -1056,7 +1064,6 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() { let schema = test_util::aggr_test_schema(); - let sort = vec![ // float64 PhysicalSortExpr { @@ -1070,15 +1077,21 @@ mod tests { }, ]; - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); let input = - sorted_partitioned_input(sort.clone(), &[10, 5, 13], runtime.clone()).await; - let basic = basic_sort(input.clone(), sort.clone(), runtime).await; + sorted_partitioned_input(sort.clone(), &[10, 5, 13], &session_ctx).await; + let basic = basic_sort(input.clone(), sort.clone(), &session_ctx).await; - let runtime_bs_23 = - Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(23)).unwrap()); - let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let merged = collect(merge, runtime_bs_23).await.unwrap(); + let session_ctx_bs_23 = + SessionContext::with_config(SessionConfig::new().with_batch_size(23)); + let input_bs_23 = + sorted_partitioned_input(sort.clone(), &[10, 5, 13], &session_ctx_bs_23) + .await; + let merge = Arc::new(SortPreservingMergeExec::new(sort, input_bs_23)); + + let task_ctx = Arc::new(TaskContext::from(&session_ctx_bs_23)); + assert_eq!(task_ctx.session_config().batch_size, 23); + let merged = collect(merge, task_ctx).await.unwrap(); assert_eq!(merged.len(), 14); @@ -1097,7 +1110,6 @@ mod tests { #[tokio::test] async fn test_nulls() { - let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ None, @@ -1149,10 +1161,20 @@ mod tests { }, }, ]; - let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); + + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + + let exec = MemoryExec::try_new( + &[vec![b1], vec![b2]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge, runtime).await.unwrap(); + let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); assert_batches_eq!( @@ -1178,15 +1200,16 @@ mod tests { #[tokio::test] async fn test_async() { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![PhysicalSortExpr { expr: col("c12", &schema).unwrap(), options: SortOptions::default(), }]; + let config = SessionConfig::new(); + let session_ctx = SessionContext::with_config(config.clone()); let batches = - sorted_partitioned_input(sort.clone(), &[5, 7, 3], runtime.clone()).await; + sorted_partitioned_input(sort.clone(), &[5, 7, 3], &session_ctx).await; let partition_count = batches.output_partitioning().partition_count(); let mut join_handles = Vec::with_capacity(partition_count); @@ -1194,7 +1217,8 @@ mod tests { for partition in 0..partition_count { let (mut sender, receiver) = mpsc::channel(1); - let mut stream = batches.execute(partition, runtime.clone()).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let mut stream = batches.execute(partition, task_ctx).await.unwrap(); let join_handle = tokio::spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); @@ -1216,7 +1240,7 @@ mod tests { batches.schema(), sort.as_slice(), tracking_metrics, - runtime.clone(), + config.batch_size, ); let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); @@ -1228,7 +1252,7 @@ mod tests { assert_eq!(merged.len(), 1); let merged = merged.remove(0); - let basic = basic_sort(batches, sort.clone(), runtime.clone()).await; + let basic = basic_sort(batches, sort.clone(), &session_ctx).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -1246,7 +1270,6 @@ mod tests { #[tokio::test] async fn test_merge_metrics() { - let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); @@ -1260,10 +1283,20 @@ mod tests { expr: col("b", &schema).unwrap(), options: Default::default(), }]; - let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); + + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + + let exec = MemoryExec::try_new( + &[vec![b1], vec![b2]], + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge.clone(), runtime).await.unwrap(); + let collected = collect(merge.clone(), task_ctx).await.unwrap(); let expected = vec![ "+----+---+", "| a | b |", @@ -1302,11 +1335,17 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 2, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( vec![PhysicalSortExpr { @@ -1316,7 +1355,7 @@ mod tests { blocking_exec, )); - let fut = collect(sort_preserving_merge_exec, runtime); + let fut = collect(sort_preserving_merge_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1328,8 +1367,6 @@ mod tests { #[tokio::test] async fn test_stable_sort() { - let runtime = Arc::new(RuntimeEnv::default()); - // Create record batches like: // batch_number |value // -------------+------ @@ -1365,10 +1402,18 @@ mod tests { }, }]; - let exec = MemoryExec::try_new(&partitions, schema, None).unwrap(); + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let exec = MemoryExec::try_new( + &partitions, + schema, + None, + session_ctx.session_id.clone(), + ) + .unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge, runtime).await.unwrap(); + let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); // Expect the data to be sorted first by "batch_number" (because diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 48f7b280b80e..25510badf9ac 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -32,7 +32,7 @@ use super::{ ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -104,7 +104,7 @@ impl ExecutionPlan for UnionExec { async fn execute( &self, mut partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); // record the tiny amount of work done in this function so @@ -116,7 +116,7 @@ impl ExecutionPlan for UnionExec { for input in self.inputs.iter() { // Calculate whether partition belongs to the current partition if partition < input.output_partitioning().partition_count() { - let stream = input.execute(partition, runtime.clone()).await?; + let stream = input.execute(partition, context).await?; return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))); } else { partition -= input.output_partitioning().partition_count(); @@ -156,6 +156,10 @@ impl ExecutionPlan for UnionExec { fn benefits_from_input_partitioning(&self) -> bool { false } + + fn session_id(&self) -> String { + self.inputs[0].session_id() + } } /// Stream wrapper that records `BaselineMetrics` for a particular @@ -237,6 +241,7 @@ mod tests { use crate::datasource::object_store::{local::LocalFileSystem, ObjectStore}; use crate::{test, test_util}; + use crate::prelude::SessionContext; use crate::{ physical_plan::{ collect, @@ -248,13 +253,14 @@ mod tests { #[tokio::test] async fn test_union_partitions() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let fs: Arc = Arc::new(LocalFileSystem {}); // Create csv's with different partitioning let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", 4)?; let (_, files2) = test::create_partitioned_csv("aggregate_test_100.csv", 5)?; + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); let csv = CsvExec::new( FileScanConfig { @@ -268,6 +274,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), ); let csv2 = CsvExec::new( @@ -282,6 +289,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), ); let union_exec = Arc::new(UnionExec::new(vec![Arc::new(csv), Arc::new(csv2)])); @@ -289,7 +297,7 @@ mod tests { // Should have 9 partitions and 9 output batches assert_eq!(union_exec.output_partitioning().partition_count(), 9); - let result: Vec = collect(union_exec, runtime).await?; + let result: Vec = collect(union_exec, task_ctx).await?; assert_eq!(result.len(), 9); Ok(()) diff --git a/datafusion/src/physical_plan/values.rs b/datafusion/src/physical_plan/values.rs index c65082ef0677..448ba74be1bb 100644 --- a/datafusion/src/physical_plan/values.rs +++ b/datafusion/src/physical_plan/values.rs @@ -20,7 +20,7 @@ use super::expressions::PhysicalSortExpr; use super::{common, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::{ memory::MemoryStream, ColumnarValue, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, @@ -40,6 +40,8 @@ pub struct ValuesExec { schema: SchemaRef, /// The data data: Vec, + /// Session id + session_id: String, } impl ValuesExec { @@ -47,6 +49,7 @@ impl ValuesExec { pub fn try_new( schema: SchemaRef, data: Vec>>, + session_id: String, ) -> Result { if data.is_empty() { return Err(DataFusionError::Plan("Values list cannot be empty".into())); @@ -87,7 +90,11 @@ impl ValuesExec { .collect::>>()?; let batch = RecordBatch::try_new(schema.clone(), arr)?; let data: Vec = vec![batch]; - Ok(Self { schema, data }) + Ok(Self { + schema, + data, + session_id, + }) } /// provides the data @@ -136,6 +143,7 @@ impl ExecutionPlan for ValuesExec { 0 => Ok(Arc::new(ValuesExec { schema: self.schema.clone(), data: self.data.clone(), + session_id: self.session_id(), })), _ => Err(DataFusionError::Internal( "ValuesExec wrong number of children".to_string(), @@ -146,7 +154,7 @@ impl ExecutionPlan for ValuesExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { @@ -179,6 +187,10 @@ impl ExecutionPlan for ValuesExec { let batch = self.data(); common::compute_record_batch_statistics(&[batch], &self.schema, None) } + + fn session_id(&self) -> String { + self.session_id.clone() + } } #[cfg(test)] @@ -189,7 +201,7 @@ mod tests { #[tokio::test] async fn values_empty_case() -> Result<()> { let schema = test_util::aggr_test_schema(); - let empty = ValuesExec::try_new(schema, vec![]); + let empty = ValuesExec::try_new(schema, vec![], "sess_123".to_owned()); assert!(empty.is_err()); Ok(()) } diff --git a/datafusion/src/physical_plan/windows/mod.rs b/datafusion/src/physical_plan/windows/mod.rs index e833c57c5b5e..9c77e18df577 100644 --- a/datafusion/src/physical_plan/windows/mod.rs +++ b/datafusion/src/physical_plan/windows/mod.rs @@ -153,11 +153,12 @@ fn create_built_in_window_expr( mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; - use crate::execution::runtime_env::RuntimeEnv; + use crate::execution::context::TaskContext; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::col; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, Statistics}; + use crate::prelude::SessionContext; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending}; use crate::test_util::{self, aggr_test_schema}; @@ -166,7 +167,10 @@ mod tests { use arrow::record_batch::RecordBatch; use futures::FutureExt; - fn create_test_schema(partitions: usize) -> Result<(Arc, SchemaRef)> { + fn create_test_schema( + partitions: usize, + session_ctx: &SessionContext, + ) -> Result<(Arc, SchemaRef)> { let schema = test_util::aggr_test_schema(); let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; @@ -182,6 +186,7 @@ mod tests { }, true, b',', + session_ctx.session_id.clone(), ); let input = Arc::new(csv); @@ -190,8 +195,9 @@ mod tests { #[tokio::test] async fn window_function() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); - let (input, schema) = create_test_schema(1)?; + let session_ctx = SessionContext::new(); + let (input, schema) = create_test_schema(1, &session_ctx)?; + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); let window_exec = Arc::new(WindowAggExec::try_new( vec![ @@ -227,7 +233,7 @@ mod tests { schema.clone(), )?); - let result: Vec = collect(window_exec, runtime).await?; + let result: Vec = collect(window_exec, task_ctx).await?; assert_eq!(result.len(), 1); let columns = result[0].columns(); @@ -251,11 +257,16 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + let session_ctx = SessionContext::new(); + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let blocking_exec = Arc::new(BlockingExec::new( + Arc::clone(&schema), + 1, + session_ctx.session_id.clone(), + )); let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( @@ -271,7 +282,7 @@ mod tests { schema, )?); - let fut = collect(window_agg_exec, runtime); + let fut = collect(window_agg_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs b/datafusion/src/physical_plan/windows/window_agg_exec.rs index 163868d07838..c65d68f1e9d7 100644 --- a/datafusion/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs @@ -18,7 +18,7 @@ //! Stream and channel implementations for window function expressions. use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ @@ -158,9 +158,9 @@ impl ExecutionPlan for WindowAggExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let input = self.input.execute(partition, runtime).await?; + let input = self.input.execute(partition, context).await?; let stream = Box::pin(WindowAggStream::new( self.schema.clone(), self.window_expr.clone(), @@ -212,6 +212,10 @@ impl ExecutionPlan for WindowAggExec { total_byte_size: None, } } + + fn session_id(&self) -> String { + self.input.session_id() + } } fn create_schema( diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index 0aff006c7896..e40693e71566 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -26,7 +26,7 @@ //! ``` pub use crate::dataframe::DataFrame; -pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; +pub use crate::execution::context::{SessionConfig, SessionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ diff --git a/datafusion/src/row/mod.rs b/datafusion/src/row/mod.rs index 531dbfe3e41e..06b1e79fefe3 100644 --- a/datafusion/src/row/mod.rs +++ b/datafusion/src/row/mod.rs @@ -556,12 +556,11 @@ mod tests { #[tokio::test] async fn test_with_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; let schema = exec.schema().clone(); - let batches = collect(exec, runtime).await?; + let batches = collect(exec).await?; assert_eq!(1, batches.len()); let batch = &batches[0]; diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index 5a6b27865d13..8bc235f49bc5 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -33,6 +33,8 @@ use arrow::{ }; use futures::Stream; +use crate::execution::context::TaskContext; +use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -41,9 +43,6 @@ use crate::{ error::{DataFusionError, Result}, physical_plan::stream::RecordBatchReceiverStream, }; -use crate::{ - execution::runtime_env::RuntimeEnv, physical_plan::expressions::PhysicalSortExpr, -}; /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] @@ -172,7 +171,7 @@ impl ExecutionPlan for MockExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { assert_eq!(partition, 0); @@ -235,6 +234,10 @@ impl ExecutionPlan for MockExec { common::compute_record_batch_statistics(&[data], &self.schema, None) } + + fn session_id(&self) -> String { + "mock_session".to_owned() + } } fn clone_error(e: &ArrowError) -> ArrowError { @@ -256,17 +259,24 @@ pub struct BarrierExec { /// all streams wait on this barrier to produce barrier: Arc, + /// Session id + session_id: String, } impl BarrierExec { /// Create a new exec with some number of partitions. - pub fn new(data: Vec>, schema: SchemaRef) -> Self { + pub fn new( + data: Vec>, + schema: SchemaRef, + session_id: String, + ) -> Self { // wait for all streams and the input let barrier = Arc::new(Barrier::new(data.len() + 1)); Self { data, schema, barrier, + session_id, } } @@ -311,7 +321,7 @@ impl ExecutionPlan for BarrierExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { assert!(partition < self.data.len()); @@ -354,28 +364,28 @@ impl ExecutionPlan for BarrierExec { fn statistics(&self) -> Statistics { common::compute_record_batch_statistics(&self.data, &self.schema, None) } + + fn session_id(&self) -> String { + self.session_id.clone() + } } /// A mock execution plan that errors on a call to execute #[derive(Debug)] pub struct ErrorExec { schema: SchemaRef, -} - -impl Default for ErrorExec { - fn default() -> Self { - Self::new() - } + /// Session id + session_id: String, } impl ErrorExec { - pub fn new() -> Self { + pub fn new(session_id: String) -> Self { let schema = Arc::new(Schema::new(vec![Field::new( "dummy", DataType::Int64, true, )])); - Self { schema } + Self { schema, session_id } } } @@ -412,7 +422,7 @@ impl ExecutionPlan for ErrorExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Err(DataFusionError::Internal(format!( "ErrorExec, unsurprisingly, errored in partition {}", @@ -435,6 +445,10 @@ impl ExecutionPlan for ErrorExec { fn statistics(&self) -> Statistics { Statistics::default() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } /// A mock execution plan that simply returns the provided statistics @@ -442,9 +456,11 @@ impl ExecutionPlan for ErrorExec { pub struct StatisticsExec { stats: Statistics, schema: Arc, + /// Session id + session_id: String, } impl StatisticsExec { - pub fn new(stats: Statistics, schema: Schema) -> Self { + pub fn new(stats: Statistics, schema: Schema, session_id: String) -> Self { assert!( stats .column_statistics @@ -456,6 +472,7 @@ impl StatisticsExec { Self { stats, schema: Arc::new(schema), + session_id, } } } @@ -497,7 +514,7 @@ impl ExecutionPlan for StatisticsExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -522,6 +539,10 @@ impl ExecutionPlan for StatisticsExec { } } } + + fn session_id(&self) -> String { + self.session_id.clone() + } } /// Execution plan that emits streams that block forever. @@ -537,15 +558,19 @@ pub struct BlockingExec { /// Ref-counting helper to check if the plan and the produced stream are still in memory. refs: Arc<()>, + + /// Session id + session_id: String, } impl BlockingExec { /// Create new [`BlockingExec`] with a give schema and number of partitions. - pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { + pub fn new(schema: SchemaRef, n_partitions: usize, session_id: String) -> Self { Self { schema, n_partitions, refs: Default::default(), + session_id, } } @@ -595,7 +620,7 @@ impl ExecutionPlan for BlockingExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Ok(Box::pin(BlockingStream { schema: Arc::clone(&self.schema), @@ -618,6 +643,10 @@ impl ExecutionPlan for BlockingExec { fn statistics(&self) -> Statistics { unimplemented!() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } /// A [`RecordBatchStream`] that is pending forever. diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index 926a017f14af..6975db76d1e0 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -30,7 +30,6 @@ use datafusion::{ physical_plan::DisplayFormatType, }; -use datafusion::execution::context::ExecutionContext; use datafusion::logical_plan::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; @@ -46,8 +45,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use datafusion::logical_plan::plan::Projection; +use datafusion::prelude::SessionContext; //// Custom source dataframe tests //// @@ -55,6 +55,8 @@ struct CustomTableProvider; #[derive(Debug, Clone)] struct CustomExecutionPlan { projection: Option>, + /// Session id + session_id: String, } struct TestCustomRecordBatchStream { /// the nb of batches of TEST_CUSTOM_RECORD_BATCH generated @@ -135,7 +137,7 @@ impl ExecutionPlan for CustomExecutionPlan { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } @@ -185,6 +187,10 @@ impl ExecutionPlan for CustomExecutionPlan { ), } } + + fn session_id(&self) -> String { + self.session_id.clone() + } } #[async_trait] @@ -202,16 +208,18 @@ impl TableProvider for CustomTableProvider { projection: &Option>, _filters: &[Expr], _limit: Option, + session_id: String, ) -> Result> { Ok(Arc::new(CustomExecutionPlan { projection: projection.clone(), + session_id, })) } } #[tokio::test] async fn custom_source_dataframe() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table = ctx.read_table(Arc::new(CustomTableProvider))?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -246,8 +254,8 @@ async fn custom_source_dataframe() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let runtime = ctx.state.lock().runtime_env.clone(); - let batches = collect(physical_plan, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(physical_plan, task_ctx).await?; let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -258,7 +266,7 @@ async fn custom_source_dataframe() -> Result<()> { #[tokio::test] async fn optimizers_catch_all_statistics() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(CustomTableProvider)) .unwrap(); @@ -293,8 +301,8 @@ async fn optimizers_catch_all_statistics() { ) .unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let actual = collect(physical_plan, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let actual = collect(physical_plan, task_ctx).await.unwrap(); assert_eq!(actual.len(), 1); assert_eq!(format!("{:?}", actual[0]), format!("{:?}", expected)); diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index 116315e9b9b2..1c0f300e7ef2 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use datafusion::assert_batches_eq; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use datafusion::logical_plan::{col, Expr}; use datafusion::{datasource::MemTable, prelude::JoinType}; use datafusion_expr::lit; @@ -58,7 +58,7 @@ async fn join() -> Result<()> { ], )?; - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?; let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?; @@ -96,7 +96,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { ) .unwrap(); - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); @@ -132,7 +132,7 @@ async fn filter_with_alias_overwrite() -> Result<()> { ) .unwrap(); - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index ae521a0050ff..7266b8d3acc8 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -31,8 +31,6 @@ use datafusion::error::Result; // use datafusion::logical_plan::Expr; use datafusion::prelude::*; -use datafusion::execution::context::ExecutionContext; - use datafusion::assert_batches_eq; fn create_test_table() -> Result> { @@ -55,7 +53,7 @@ fn create_test_table() -> Result> { ], )?; - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); let table = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion/tests/merge_fuzz.rs b/datafusion/tests/merge_fuzz.rs index d874ec507c49..49e6482cf8d9 100644 --- a/datafusion/tests/merge_fuzz.rs +++ b/datafusion/tests/merge_fuzz.rs @@ -23,15 +23,14 @@ use arrow::{ compute::SortOptions, record_batch::RecordBatch, }; -use datafusion::{ - execution::runtime_env::{RuntimeConfig, RuntimeEnv}, - physical_plan::{ - collect, - expressions::{col, PhysicalSortExpr}, - memory::MemoryExec, - sorts::sort_preserving_merge::SortPreservingMergeExec, - }, +use datafusion::execution::context::TaskContext; +use datafusion::physical_plan::{ + collect, + expressions::{col, PhysicalSortExpr}, + memory::MemoryExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, }; +use datafusion::prelude::{SessionConfig, SessionContext}; use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; use rand::{prelude::StdRng, Rng, SeedableRng}; @@ -116,14 +115,14 @@ async fn run_merge_test(input: Vec>) { nulls_first: true, }, }]; - - let exec = MemoryExec::try_new(&input, schema, None).unwrap(); + let session_config = SessionConfig::new().with_batch_size(batch_size); + let ctx = SessionContext::with_config(session_config); + let exec = + MemoryExec::try_new(&input, schema, None, ctx.session_id.clone()).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let runtime_config = RuntimeConfig::new().with_batch_size(batch_size); - - let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); - let collected = collect(merge, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let collected = collect(merge, task_ctx).await.unwrap(); // verify the output batch size: all batches except the last // should contain `batch_size` rows diff --git a/datafusion/tests/order_spill_fuzz.rs b/datafusion/tests/order_spill_fuzz.rs index b1586f06c02c..25eb9040ff99 100644 --- a/datafusion/tests/order_spill_fuzz.rs +++ b/datafusion/tests/order_spill_fuzz.rs @@ -22,12 +22,14 @@ use arrow::{ compute::SortOptions, record_batch::RecordBatch, }; +use datafusion::execution::context::TaskContext; use datafusion::execution::memory_manager::MemoryManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::{SessionConfig, SessionContext}; use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; @@ -71,14 +73,19 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { }, }]; - let exec = MemoryExec::try_new(&input, schema, None).unwrap(); - let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap()); - let runtime_config = RuntimeConfig::new().with_memory_manager( MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(), ); let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); - let collected = collect(sort.clone(), runtime).await.unwrap(); + let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); + + let exec = + MemoryExec::try_new(&input, schema, None, session_ctx.session_id.clone()) + .unwrap(); + let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap()); + + let task_ctx = Arc::new(TaskContext::from(&session_ctx)); + let collected = collect(sort.clone(), task_ctx).await.unwrap(); let expected = partitions_to_sorted_vec(&input); let actual = batches_to_vec(&collected); diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 9869a1f6b16a..cbbe4ca939ea 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -30,6 +30,7 @@ use arrow::{ util::pretty::pretty_format_batches, }; use chrono::{Datelike, Duration}; +use datafusion::execution::context::TaskContext; use datafusion::{ datasource::TableProvider, logical_plan::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}, @@ -37,7 +38,7 @@ use datafusion::{ accept, file_format::ParquetExec, metrics::MetricsSet, ExecutionPlan, ExecutionPlanVisitor, }, - prelude::{ExecutionConfig, ExecutionContext}, + prelude::{SessionConfig, SessionContext}, scalar::ScalarValue, }; use parquet::{arrow::ArrowWriter, file::properties::WriterProperties}; @@ -161,7 +162,7 @@ async fn prune_disabled() { ); // same query, without pruning - let config = ExecutionConfig::new().with_parquet_pruning(false); + let config = SessionConfig::new().with_parquet_pruning(false); let output = ContextWithParquet::with_config(Scenario::Timestamps, config) .await @@ -424,7 +425,7 @@ struct ContextWithParquet { /// when dropped file: NamedTempFile, provider: Arc, - ctx: ExecutionContext, + ctx: SessionContext, } /// The output of running one of the test cases @@ -472,15 +473,15 @@ impl TestOutput { /// and the appropriate scenario impl ContextWithParquet { async fn new(scenario: Scenario) -> Self { - Self::with_config(scenario, ExecutionConfig::new()).await + Self::with_config(scenario, SessionConfig::new()).await } - async fn with_config(scenario: Scenario, config: ExecutionConfig) -> Self { + async fn with_config(scenario: Scenario, config: SessionConfig) -> Self { let file = make_test_file(scenario).await; let parquet_path = file.path().to_string_lossy(); // now, setup a the file as a data source and run a query against it - let mut ctx = ExecutionContext::with_config(config); + let ctx = SessionContext::with_config(config); ctx.register_parquet("t", &parquet_path).await.unwrap(); let provider = ctx.deregister_table("t").unwrap().unwrap(); @@ -536,9 +537,8 @@ impl ContextWithParquet { .create_physical_plan(&logical_plan) .await .expect("creating physical plan"); - - let runtime = self.ctx.state.lock().runtime_env.clone(); - let results = datafusion::physical_plan::collect(physical_plan.clone(), runtime) + let task_ctx = Arc::new(TaskContext::from(&self.ctx)); + let results = datafusion::physical_plan::collect(physical_plan.clone(), task_ctx) .await .expect("Running"); diff --git a/datafusion/tests/path_partition.rs b/datafusion/tests/path_partition.rs index 178e318775c9..827f7983e72c 100644 --- a/datafusion/tests/path_partition.rs +++ b/datafusion/tests/path_partition.rs @@ -20,6 +20,7 @@ use std::{fs, io, sync::Arc}; use async_trait::async_trait; +use datafusion::prelude::SessionContext; use datafusion::{ assert_batches_sorted_eq, datasource::{ @@ -32,17 +33,15 @@ use datafusion::{ }, error::{DataFusionError, Result}, physical_plan::ColumnStatistics, - prelude::ExecutionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; use futures::{stream, StreamExt}; #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { - let mut ctx = ExecutionContext::new(); - + let ctx = Arc::new(SessionContext::new()); register_partitioned_aggregate_csv( - &mut ctx, + ctx.clone(), &[ "mytable/date=2021-10-27/file.csv", "mytable/date=2021-10-28/file.csv", @@ -75,10 +74,10 @@ async fn csv_filter_with_file_col() -> Result<()> { #[tokio::test] async fn csv_projection_on_partition() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); register_partitioned_aggregate_csv( - &mut ctx, + ctx.clone(), &[ "mytable/date=2021-10-27/file.csv", "mytable/date=2021-10-28/file.csv", @@ -111,10 +110,10 @@ async fn csv_projection_on_partition() -> Result<()> { #[tokio::test] async fn csv_grouping_by_partition() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); register_partitioned_aggregate_csv( - &mut ctx, + ctx.clone(), &[ "mytable/date=2021-10-26/file.csv", "mytable/date=2021-10-27/file.csv", @@ -145,10 +144,10 @@ async fn csv_grouping_by_partition() -> Result<()> { #[tokio::test] async fn parquet_multiple_partitions() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); register_partitioned_alltypes_parquet( - &mut ctx, + ctx.clone(), &[ "year=2021/month=09/day=09/file.parquet", "year=2021/month=10/day=09/file.parquet", @@ -187,10 +186,10 @@ async fn parquet_multiple_partitions() -> Result<()> { #[tokio::test] async fn parquet_statistics() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); register_partitioned_alltypes_parquet( - &mut ctx, + ctx.clone(), &[ "year=2021/month=09/day=09/file.parquet", "year=2021/month=10/day=09/file.parquet", @@ -246,11 +245,11 @@ async fn parquet_statistics() -> Result<()> { #[tokio::test] async fn parquet_overlapping_columns() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = Arc::new(SessionContext::new()); // `id` is both a column of the file and a partitioning col register_partitioned_alltypes_parquet( - &mut ctx, + ctx.clone(), &[ "id=1/file.parquet", "id=2/file.parquet", @@ -272,7 +271,7 @@ async fn parquet_overlapping_columns() -> Result<()> { } fn register_partitioned_aggregate_csv( - ctx: &mut ExecutionContext, + ctx: Arc, store_paths: &[&str], partition_cols: &[&str], table_path: &str, @@ -295,7 +294,7 @@ fn register_partitioned_aggregate_csv( } async fn register_partitioned_alltypes_parquet( - ctx: &mut ExecutionContext, + ctx: Arc, store_paths: &[&str], partition_cols: &[&str], table_path: &str, diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index 203fb7ce56ff..078a01841ed4 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -21,8 +21,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion::datasource::datasource::{TableProvider, TableProviderFilterPushDown}; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -54,6 +53,8 @@ fn create_batch(value: i32, num_rows: usize) -> Result { struct CustomPlan { schema: SchemaRef, batches: Vec>, + /// Session id + session_id: String, } #[async_trait] @@ -88,7 +89,7 @@ impl ExecutionPlan for CustomPlan { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { let metrics = ExecutionPlanMetricsSet::new(); let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); @@ -116,6 +117,10 @@ impl ExecutionPlan for CustomPlan { // but we want to test the filter pushdown not the CBOs Statistics::default() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } #[derive(Clone)] @@ -139,6 +144,7 @@ impl TableProvider for CustomProvider { _: &Option>, filters: &[Expr], _: Option, + session_id: String, ) -> Result> { match &filters[0] { Expr::BinaryExpr { right, .. } => { @@ -154,11 +160,13 @@ impl TableProvider for CustomProvider { 1 => vec![Arc::new(self.one_batch.clone())], _ => vec![], }, + session_id: session_id.clone(), })) } _ => Ok(Arc::new(CustomPlan { schema: self.zero_batch.schema(), batches: vec![], + session_id, })), } } @@ -174,7 +182,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() one_batch: create_batch(1, 5)?, }; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let df = ctx .read_table(Arc::new(provider.clone()))? .filter(col("flag").eq(lit(value)))? diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 187778c02fe9..447f761f40e9 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -20,14 +20,14 @@ use datafusion::scalar::ScalarValue; #[tokio::test] async fn csv_query_avg_multi_batch() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(plan, task_ctx).await.unwrap(); let batch = &results[0]; let column = batch.column(0); let array = column.as_any().downcast_ref::().unwrap(); @@ -41,10 +41,10 @@ async fn csv_query_avg_multi_batch() -> Result<()> { #[tokio::test] async fn csv_query_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.5089725099127211"]]; assert_float_eq(&expected, &actual); @@ -53,10 +53,10 @@ async fn csv_query_avg() -> Result<()> { #[tokio::test] async fn csv_query_covariance_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT covar_pop(c2, c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["-0.07916932235380847"]]; assert_float_eq(&expected, &actual); @@ -65,10 +65,10 @@ async fn csv_query_covariance_1() -> Result<()> { #[tokio::test] async fn csv_query_covariance_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT covar(c2, c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["-0.07996901247859442"]]; assert_float_eq(&expected, &actual); @@ -77,10 +77,10 @@ async fn csv_query_covariance_2() -> Result<()> { #[tokio::test] async fn csv_query_correlation() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT corr(c2, c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["-0.19064544190576607"]]; assert_float_eq(&expected, &actual); @@ -89,10 +89,10 @@ async fn csv_query_correlation() -> Result<()> { #[tokio::test] async fn csv_query_variance_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_pop(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8675"]]; assert_float_eq(&expected, &actual); @@ -101,10 +101,10 @@ async fn csv_query_variance_1() -> Result<()> { #[tokio::test] async fn csv_query_variance_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_pop(c6) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["26156334342021890000000000000000000000"]]; assert_float_eq(&expected, &actual); @@ -113,10 +113,10 @@ async fn csv_query_variance_2() -> Result<()> { #[tokio::test] async fn csv_query_variance_3() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_pop(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.09234223721582163"]]; assert_float_eq(&expected, &actual); @@ -125,10 +125,10 @@ async fn csv_query_variance_3() -> Result<()> { #[tokio::test] async fn csv_query_variance_4() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8863636363636365"]]; assert_float_eq(&expected, &actual); @@ -137,10 +137,10 @@ async fn csv_query_variance_4() -> Result<()> { #[tokio::test] async fn csv_query_variance_5() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_samp(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8863636363636365"]]; assert_float_eq(&expected, &actual); @@ -149,10 +149,10 @@ async fn csv_query_variance_5() -> Result<()> { #[tokio::test] async fn csv_query_stddev_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_pop(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.3665650368716449"]]; assert_float_eq(&expected, &actual); @@ -161,10 +161,10 @@ async fn csv_query_stddev_1() -> Result<()> { #[tokio::test] async fn csv_query_stddev_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_pop(c6) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["5114326382039172000"]]; assert_float_eq(&expected, &actual); @@ -173,10 +173,10 @@ async fn csv_query_stddev_2() -> Result<()> { #[tokio::test] async fn csv_query_stddev_3() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_pop(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.30387865541334363"]]; assert_float_eq(&expected, &actual); @@ -185,10 +185,10 @@ async fn csv_query_stddev_3() -> Result<()> { #[tokio::test] async fn csv_query_stddev_4() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.3054095399405338"]]; assert_float_eq(&expected, &actual); @@ -197,10 +197,10 @@ async fn csv_query_stddev_4() -> Result<()> { #[tokio::test] async fn csv_query_stddev_5() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_samp(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.3054095399405338"]]; assert_float_eq(&expected, &actual); @@ -209,10 +209,10 @@ async fn csv_query_stddev_5() -> Result<()> { #[tokio::test] async fn csv_query_stddev_6() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.9504384952922168"]]; assert_float_eq(&expected, &actual); @@ -221,10 +221,10 @@ async fn csv_query_stddev_6() -> Result<()> { #[tokio::test] async fn csv_query_median_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["3"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -232,10 +232,10 @@ async fn csv_query_median_1() -> Result<()> { #[tokio::test] async fn csv_query_median_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["1146409980542786560"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -243,10 +243,10 @@ async fn csv_query_median_2() -> Result<()> { #[tokio::test] async fn csv_query_median_3() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.5550065410522981"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -254,10 +254,10 @@ async fn csv_query_median_3() -> Result<()> { #[tokio::test] async fn csv_query_external_table_count() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(c12) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| COUNT(aggregate_test_100.c12) |", @@ -271,12 +271,12 @@ async fn csv_query_external_table_count() { #[tokio::test] async fn csv_query_external_table_sum() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // cast smallint and int to bigint to avoid overflow during calculation - register_aggregate_csv_by_sql(&mut ctx).await; + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------------------+-------------------------------------------+", "| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |", @@ -289,10 +289,10 @@ async fn csv_query_external_table_sum() { #[tokio::test] async fn csv_query_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT count(c12) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| COUNT(aggregate_test_100.c12) |", @@ -306,10 +306,10 @@ async fn csv_query_count() -> Result<()> { #[tokio::test] async fn csv_query_count_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT count(distinct c2) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------------+", "| COUNT(DISTINCT aggregate_test_100.c2) |", @@ -323,10 +323,10 @@ async fn csv_query_count_distinct() -> Result<()> { #[tokio::test] async fn csv_query_count_distinct_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT count(distinct c2 % 2) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------------------------+", "| COUNT(DISTINCT aggregate_test_100.c2 % Int64(2)) |", @@ -340,10 +340,10 @@ async fn csv_query_count_distinct_expr() -> Result<()> { #[tokio::test] async fn csv_query_count_star() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(*) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -356,10 +356,10 @@ async fn csv_query_count_star() { #[tokio::test] async fn csv_query_count_one() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(1) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -372,10 +372,10 @@ async fn csv_query_count_one() { #[tokio::test] async fn csv_query_approx_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+--------------+", "| count_c9 | count_c9_str |", @@ -412,15 +412,15 @@ async fn csv_query_approx_count() -> Result<()> { // float values. #[tokio::test] async fn csv_query_approx_percentile_cont() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; // Generate an assertion that the estimated $percentile value for $column is // within 5% of the $actual percentile value. macro_rules! percentile_test { ($ctx:ident, column=$column:literal, percentile=$percentile:literal, actual=$actual:literal) => { let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $percentile, $actual); - let actual = execute_to_batches(&mut ctx, &sql).await; + let actual = execute_to_batches(&ctx, &sql).await; // // "+------+", // "| q |", @@ -478,10 +478,10 @@ async fn csv_query_approx_percentile_cont() -> Result<()> { #[tokio::test] async fn csv_query_sum_crossjoin() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+-----------+", "| c1 | c1 | SUM(a.c2) |", @@ -518,9 +518,9 @@ async fn csv_query_sum_crossjoin() { #[tokio::test] async fn query_count_without_from() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT count(1 + 1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+", "| COUNT(Int64(1) + Int64(1)) |", @@ -534,11 +534,11 @@ async fn query_count_without_from() -> Result<()> { #[tokio::test] async fn csv_query_array_agg() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------------------------------------------+", "| ARRAYAGG(test.c13) |", @@ -552,11 +552,11 @@ async fn csv_query_array_agg() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_empty() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------+", "| ARRAYAGG(test.c13) |", @@ -570,11 +570,11 @@ async fn csv_query_array_agg_empty() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_one() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------------+", "| ARRAYAGG(test.c13) |", @@ -588,10 +588,10 @@ async fn csv_query_array_agg_one() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(distinct c2) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // The results for this query should be something like the following: // +------------------------------------------+ @@ -638,11 +638,11 @@ async fn csv_query_array_agg_distinct() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_sum() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", ) .await @@ -655,11 +655,11 @@ async fn aggregate_timestamps_sum() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", ) .await; @@ -678,11 +678,11 @@ async fn aggregate_timestamps_count() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_min() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", ) .await; @@ -701,11 +701,11 @@ async fn aggregate_timestamps_min() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", ) .await; @@ -724,11 +724,11 @@ async fn aggregate_timestamps_max() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", ) .await diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index 82d91a0bd481..565cc76978e9 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -17,7 +17,7 @@ use super::*; -async fn register_alltypes_avro(ctx: &mut ExecutionContext) { +async fn register_alltypes_avro(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); ctx.register_avro( "alltypes_plain", @@ -30,12 +30,12 @@ async fn register_alltypes_avro(ctx: &mut ExecutionContext) { #[tokio::test] async fn avro_query() { - let mut ctx = ExecutionContext::new(); - register_alltypes_avro(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_avro(&ctx).await; // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------------------+", "| id | CAST(alltypes_plain.string_col AS Utf8) |", @@ -71,7 +71,7 @@ async fn avro_query_multiple_files() { ) .unwrap(); - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_avro( "alltypes_plain", table_path.display().to_string().as_str(), @@ -82,7 +82,7 @@ async fn avro_query_multiple_files() { // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------------------+", "| id | CAST(alltypes_plain.string_col AS Utf8) |", @@ -111,7 +111,7 @@ async fn avro_query_multiple_files() { #[tokio::test] async fn avro_single_nan_schema() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); ctx.register_avro( "single_nan", @@ -124,8 +124,7 @@ async fn avro_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let results = collect(plan).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); assert_eq!(1, batch.num_columns()); @@ -134,11 +133,11 @@ async fn avro_single_nan_schema() { #[tokio::test] async fn avro_explain() { - let mut ctx = ExecutionContext::new(); - register_alltypes_avro(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_avro(&ctx).await; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); let expected = vec![ vec![ diff --git a/datafusion/tests/sql/create_drop.rs b/datafusion/tests/sql/create_drop.rs index 45f2a36047c5..748bd7ae960e 100644 --- a/datafusion/tests/sql/create_drop.rs +++ b/datafusion/tests/sql/create_drop.rs @@ -23,14 +23,14 @@ use super::*; #[tokio::test] async fn create_table_as() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; ctx.sql(sql).await.unwrap(); let sql_all = "SELECT * FROM my_table order by c1 LIMIT 1"; - let results_all = execute_to_batches(&mut ctx, sql_all).await; + let results_all = execute_to_batches(&ctx, sql_all).await; let expected = vec![ "+---------+----------------+------+", @@ -47,8 +47,8 @@ async fn create_table_as() -> Result<()> { #[tokio::test] async fn drop_table() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; ctx.sql(sql).await.unwrap(); @@ -67,10 +67,10 @@ async fn drop_table() -> Result<()> { #[tokio::test] async fn csv_query_create_external_table() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | Int64(10) | c11 | c12 | c13 |", @@ -83,7 +83,7 @@ async fn csv_query_create_external_table() { #[tokio::test] async fn create_external_table_with_timestamps() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let data = "Jorge,2018-12-13T12:12:10.011Z\n\ Andrew,2018-11-13T17:11:10.011Z"; @@ -110,12 +110,12 @@ async fn create_external_table_with_timestamps() { file_path.to_str().expect("path is utf8") ); - plan_and_collect(&mut ctx, &sql) + plan_and_collect(&ctx, &sql) .await .expect("Executing CREATE EXTERNAL TABLE"); let sql = "SELECT * from csv_with_timestamps"; - let result = plan_and_collect(&mut ctx, sql).await.unwrap(); + let result = plan_and_collect(&ctx, sql).await.unwrap(); let expected = vec![ "+--------+-------------------------+", "| name | ts |", diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index 92b634dd5e96..af509a22dd6e 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -20,8 +20,8 @@ use super::*; #[tokio::test] async fn csv_query_error() -> Result<()> { // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT sin(c1) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql); assert!(plan.is_err()); @@ -31,14 +31,14 @@ async fn csv_query_error() -> Result<()> { #[tokio::test] async fn test_cast_expressions_error() -> Result<()> { // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT CAST(c1 AS INT) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(plan, runtime).await; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let result = collect(plan, task_ctx).await; match result { Ok(_) => panic!("expected error"), @@ -54,8 +54,8 @@ async fn test_cast_expressions_error() -> Result<()> { #[tokio::test] async fn test_aggregation_with_bad_arguments() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; let logical_plan = ctx.create_logical_plan(sql); let err = logical_plan.unwrap_err(); @@ -71,7 +71,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { #[tokio::test] async fn query_cte_incorrect() -> Result<()> { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // self reference let sql = "WITH t AS (SELECT * FROM t) SELECT * from u"; @@ -105,7 +105,7 @@ async fn query_cte_incorrect() -> Result<()> { #[tokio::test] async fn test_select_wildcard_without_table() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * "; let actual = ctx.sql(sql).await; match actual { @@ -122,8 +122,8 @@ async fn test_select_wildcard_without_table() -> Result<()> { #[tokio::test] async fn invalid_qualified_table_references() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; for table_ref in &[ "nonexistentschema.aggregate_test_100", diff --git a/datafusion/tests/sql/explain.rs b/datafusion/tests/sql/explain.rs index 00842b5eb8ab..b85228016e50 100644 --- a/datafusion/tests/sql/explain.rs +++ b/datafusion/tests/sql/explain.rs @@ -18,7 +18,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ logical_plan::{LogicalPlan, LogicalPlanBuilder, PlanType}, - prelude::ExecutionContext, + prelude::SessionContext, }; #[test] @@ -39,7 +39,7 @@ fn optimize_explain() { } // now optimize the plan and expect to see more plans - let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap(); + let optimized_plan = SessionContext::new().optimize(&plan).unwrap(); if let LogicalPlan::Explain(e) = &optimized_plan { // should have more than one plan assert!( diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index 2051bdd1b80b..9c1936dfc5c6 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -21,9 +21,9 @@ use super::*; async fn explain_analyze_baseline_metrics() { // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE // and then validate the presence of baseline metrics for supported operators - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); - register_aggregate_csv_by_sql(&mut ctx).await; + let config = SessionConfig::new().with_target_partitions(3); + let ctx = SessionContext::with_config(config); + register_aggregate_csv_by_sql(&ctx).await; // a query with as many operators as we have metrics for let sql = "EXPLAIN ANALYZE \ SELECT count(*) as cnt FROM \ @@ -41,8 +41,8 @@ async fn explain_analyze_baseline_metrics() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(physical_plan.clone(), runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(physical_plan.clone(), task_ctx).await.unwrap(); let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() .to_string(); @@ -168,8 +168,8 @@ async fn explain_analyze_baseline_metrics() { async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; // Logical plan @@ -329,8 +329,8 @@ async fn csv_explain_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.expect(&msg); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(plan, task_ctx).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -342,10 +342,10 @@ async fn csv_explain_plans() { #[tokio::test] async fn csv_explain_verbose() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -365,8 +365,8 @@ async fn csv_explain_verbose() { async fn csv_explain_verbose_plans() { // This test verify the look of each plan in its full cycle plan creation - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; // Logical plan @@ -527,8 +527,8 @@ async fn csv_explain_verbose_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.expect(&msg); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(plan, task_ctx).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -545,15 +545,15 @@ async fn csv_explain_verbose_plans() { async fn explain_analyze_runs_optimizers() { // repro for https://github.com/apache/arrow-datafusion/issues/917 // where EXPLAIN ANALYZE was not correctly running optiimizer - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // This happens as an optimization pass where count(*) can be // answered using statistics only. let expected = "EmptyExec: produce_one_row=true"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let actual = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -561,7 +561,7 @@ async fn explain_analyze_runs_optimizers() { // EXPLAIN ANALYZE should work the same let sql = "EXPLAIN ANALYZE SELECT count(*) from alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let actual = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -570,12 +570,12 @@ async fn explain_analyze_runs_optimizers() { #[tokio::test] async fn tpch_explain_q10() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); - register_tpch_csv(&mut ctx, "customer").await?; - register_tpch_csv(&mut ctx, "orders").await?; - register_tpch_csv(&mut ctx, "lineitem").await?; - register_tpch_csv(&mut ctx, "nation").await?; + register_tpch_csv(&ctx, "customer").await?; + register_tpch_csv(&ctx, "orders").await?; + register_tpch_csv(&ctx, "lineitem").await?; + register_tpch_csv(&ctx, "nation").await?; let sql = "select c_custkey, @@ -633,9 +633,9 @@ order by #[tokio::test] async fn test_physical_plan_display_indent() { // Hard code target_partitions as it appears in the RepartitionExec output - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); - register_aggregate_csv(&mut ctx).await.unwrap(); + let config = SessionConfig::new().with_target_partitions(3); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ FROM aggregate_test_100 \ WHERE c12 < 10 \ @@ -679,10 +679,10 @@ async fn test_physical_plan_display_indent() { #[tokio::test] async fn test_physical_plan_display_indent_multi_children() { // Hard code target_partitions as it appears in the RepartitionExec output - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions(3); + let ctx = SessionContext::with_config(config); // ensure indenting works for nodes with multiple children - register_aggregate_csv(&mut ctx).await.unwrap(); + register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1 \ FROM (select c1 from aggregate_test_100) AS a \ JOIN\ @@ -731,10 +731,10 @@ async fn test_physical_plan_display_indent_multi_children() { async fn csv_explain() { // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, // then execute the physical plan and return the final explain results - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); // Note can't use `assert_batches_eq` as the plan needs to be @@ -758,7 +758,7 @@ async fn csv_explain() { // Also, expect same result with lowercase explain let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); assert_eq!(expected, actual); } @@ -766,10 +766,10 @@ async fn csv_explain() { #[tokio::test] async fn csv_explain_analyze() { // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN ANALYZE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -787,11 +787,11 @@ async fn csv_explain_analyze() { #[tokio::test] async fn csv_explain_analyze_verbose() { // This test uses the execute function to run an actual plan under EXPLAIN VERBOSE ANALYZE - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN ANALYZE VERBOSE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); diff --git a/datafusion/tests/sql/expr.rs b/datafusion/tests/sql/expr.rs index ef8a29c97caa..3d62ddb3a011 100644 --- a/datafusion/tests/sql/expr.rs +++ b/datafusion/tests/sql/expr.rs @@ -19,13 +19,13 @@ use super::*; #[tokio::test] async fn case_when() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE WHEN c1 = 'a' THEN 1 \ WHEN c1 = 'b' THEN 2 \ END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------------------------------------------------------------+", "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) END |", @@ -42,13 +42,13 @@ async fn case_when() -> Result<()> { #[tokio::test] async fn case_when_else() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE WHEN c1 = 'a' THEN 1 \ WHEN c1 = 'b' THEN 2 \ ELSE 999 END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------------------------------------------------------------------------------+", "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", @@ -65,13 +65,13 @@ async fn case_when_else() -> Result<()> { #[tokio::test] async fn case_when_with_base_expr() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE c1 WHEN 'a' THEN 1 \ WHEN 'b' THEN 2 \ END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------------------------------------------------+", "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) END |", @@ -88,13 +88,13 @@ async fn case_when_with_base_expr() -> Result<()> { #[tokio::test] async fn case_when_else_with_base_expr() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE c1 WHEN 'a' THEN 1 \ WHEN 'b' THEN 2 \ ELSE 999 END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------------------------------------------------------------------+", "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", @@ -124,10 +124,10 @@ async fn query_not() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT NOT c1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------+", "| NOT test.c1 |", @@ -143,12 +143,12 @@ async fn query_not() -> Result<()> { #[tokio::test] async fn csv_query_sum_cast() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; // c8 = i32; c9 = i64 let sql = "SELECT c8 + c9 FROM aggregate_test_100"; // check that the physical and logical schemas are equal - execute(&mut ctx, sql).await; + execute(&ctx, sql).await; } #[tokio::test] @@ -166,10 +166,10 @@ async fn query_is_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NULL FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| test.c1 IS NULL |", @@ -198,10 +198,10 @@ async fn query_is_not_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NOT NULL FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+", "| test.c1 IS NOT NULL |", @@ -219,10 +219,10 @@ async fn query_is_not_null() -> Result<()> { async fn query_without_from() -> Result<()> { // Test for SELECT without FROM. // Should evaluate expressions in project position. - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| Int64(1) |", @@ -233,7 +233,7 @@ async fn query_without_from() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT 1+2, 3/4, cos(0)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+---------------------+---------------+", "| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |", @@ -262,10 +262,10 @@ async fn query_scalar_minus_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT 4 - c1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------+", "| Int64(4) Minus test.c1 |", @@ -658,9 +658,9 @@ async fn test_cast_expressions() -> Result<()> { #[tokio::test] async fn test_random_expression() -> Result<()> { - let mut ctx = create_ctx()?; + let ctx = create_ctx()?; let sql = "SELECT random() r1"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let r1 = actual[0][0].parse::().unwrap(); assert!(0.0 <= r1); assert!(r1 < 1.0); @@ -669,9 +669,9 @@ async fn test_random_expression() -> Result<()> { #[tokio::test] async fn case_with_bool_type_result() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select case when 'cpu' != 'cpu' then true else false end"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------------------------------------------------------+", "| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |", @@ -685,8 +685,8 @@ async fn case_with_bool_type_result() -> Result<()> { #[tokio::test] async fn in_list_array() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT c1 IN ('a', 'c') AS utf8_in_true ,c1 IN ('x', 'y') AS utf8_in_false @@ -694,7 +694,7 @@ async fn in_list_array() -> Result<()> { ,c1 NOT IN ('a', 'c') AS utf8_not_in_false ,NULL IN ('a', 'c') AS utf8_in_null FROM aggregate_test_100 WHERE c12 < 0.05"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------+---------------+------------------+-------------------+--------------+", "| utf8_in_true | utf8_in_false | utf8_not_in_true | utf8_not_in_false | utf8_in_null |", @@ -791,11 +791,11 @@ async fn test_in_list_scalar() -> Result<()> { #[tokio::test] async fn csv_query_boolean_eq_neq() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for eq and neq let sql = "SELECT a, b, a = b as eq, b = true as eq_scalar, a != b as neq, a != true as neq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+------------+", @@ -817,11 +817,11 @@ async fn csv_query_boolean_eq_neq() { #[tokio::test] async fn csv_query_boolean_lt_lt_eq() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for < and <= let sql = "SELECT a, b, a < b as lt, b = true as lt_scalar, a <= b as lt_eq, a <= true as lt_eq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+--------------+", @@ -843,11 +843,11 @@ async fn csv_query_boolean_lt_lt_eq() { #[tokio::test] async fn csv_query_boolean_gt_gt_eq() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for > and >= let sql = "SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+--------------+", @@ -869,8 +869,8 @@ async fn csv_query_boolean_gt_gt_eq() { #[tokio::test] async fn csv_query_boolean_distinct_from() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for is distinct from and is not distinct from let sql = "SELECT a, b, \ a is distinct from b as df, \ @@ -878,7 +878,7 @@ async fn csv_query_boolean_distinct_from() { a is not distinct from b as ndf, \ a is not distinct from true as ndf_scalar \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+------------+", @@ -900,10 +900,10 @@ async fn csv_query_boolean_distinct_from() { #[tokio::test] async fn csv_query_nullif_divide_by_0() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = &actual[80..90]; // We just want to compare rows 80-89 let expected = vec![ vec!["258"], @@ -922,10 +922,10 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> { } #[tokio::test] async fn csv_count_star() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+-----+------------------------------+", "| COUNT(UInt8(1)) | c | COUNT(aggregate_test_100.c1) |", @@ -939,10 +939,10 @@ async fn csv_count_star() -> Result<()> { #[tokio::test] async fn csv_query_avg_sqrt() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.6706002946036462"]]; assert_float_eq(&expected, &actual); @@ -952,10 +952,10 @@ async fn csv_query_avg_sqrt() -> Result<()> { // this query used to deadlock due to the call udf(udf()) #[tokio::test] async fn csv_query_sqrt_sqrt() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT sqrt(sqrt(c12)) FROM aggregate_test_100 LIMIT 1"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; // sqrt(sqrt(c12=0.9294097332465232)) = 0.9818650561397431 let expected = vec![vec!["0.9818650561397431"]]; assert_float_eq(&expected, &actual); diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index cf2475792a4e..3b6150da13a8 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -20,16 +20,16 @@ use super::*; /// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) #[tokio::test] async fn sqrt_f32_vs_f64() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; // sqrt(f32)'s plan passes let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.6584407806396484"]]; assert_eq!(actual, expected); let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.6584408483418833"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -37,10 +37,10 @@ async fn sqrt_f32_vs_f64() -> Result<()> { #[tokio::test] async fn csv_query_cast() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------------------+", @@ -57,11 +57,11 @@ async fn csv_query_cast() -> Result<()> { #[tokio::test] async fn csv_query_cast_literal() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------+---------------------------+", @@ -93,10 +93,10 @@ async fn query_concat() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------------------------------+", "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |", @@ -129,10 +129,10 @@ async fn query_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![ vec!["[,0]"], vec!["[a,1]"], @@ -160,10 +160,10 @@ async fn query_count_distinct() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(DISTINCT c1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", "| COUNT(DISTINCT test.c1) |", diff --git a/datafusion/tests/sql/group_by.rs b/datafusion/tests/sql/group_by.rs index 38a0c2e44204..78323310a350 100644 --- a/datafusion/tests/sql/group_by.rs +++ b/datafusion/tests/sql/group_by.rs @@ -19,10 +19,10 @@ use super::*; #[tokio::test] async fn csv_query_group_by_int_min_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+-----------------------------+", "| c2 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", @@ -40,12 +40,12 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { #[tokio::test] async fn csv_query_group_by_float32() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT COUNT(*) as cnt, c1 FROM aggregate_simple GROUP BY c1 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+---------+", @@ -65,12 +65,12 @@ async fn csv_query_group_by_float32() -> Result<()> { #[tokio::test] async fn csv_query_group_by_float64() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT COUNT(*) as cnt, c2 FROM aggregate_simple GROUP BY c2 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+----------------+", @@ -90,12 +90,12 @@ async fn csv_query_group_by_float64() -> Result<()> { #[tokio::test] async fn csv_query_group_by_boolean() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT COUNT(*) as cnt, c3 FROM aggregate_simple GROUP BY c3 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+-------+", @@ -112,10 +112,10 @@ async fn csv_query_group_by_boolean() -> Result<()> { #[tokio::test] async fn csv_query_group_by_two_columns() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c2, MIN(c3) FROM aggregate_test_100 GROUP BY c1, c2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+----------------------------+", "| c1 | c2 | MIN(aggregate_test_100.c3) |", @@ -153,10 +153,10 @@ async fn csv_query_group_by_two_columns() -> Result<()> { #[tokio::test] async fn csv_query_group_by_and_having() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+------+", "| c1 | m |", @@ -171,14 +171,14 @@ async fn csv_query_group_by_and_having() -> Result<()> { #[tokio::test] async fn csv_query_group_by_and_having_and_where() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 WHERE c1 IN ('a', 'b') GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+------+", "| c1 | m |", @@ -192,10 +192,10 @@ async fn csv_query_group_by_and_having_and_where() -> Result<()> { #[tokio::test] async fn csv_query_having_without_group_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c2, c3 FROM aggregate_test_100 HAVING c2 >= 4 AND c3 > 90"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+-----+", "| c1 | c2 | c3 |", @@ -213,10 +213,10 @@ async fn csv_query_having_without_group_by() -> Result<()> { #[tokio::test] async fn csv_query_group_by_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+", "| c1 | AVG(aggregate_test_100.c12) |", @@ -234,10 +234,10 @@ async fn csv_query_group_by_avg() -> Result<()> { #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-------------------------------+", "| c1 | COUNT(aggregate_test_100.c12) |", @@ -255,10 +255,10 @@ async fn csv_query_group_by_int_count() -> Result<()> { #[tokio::test] async fn csv_query_group_with_aliased_aggregate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-------+", "| c1 | count |", @@ -276,10 +276,10 @@ async fn csv_query_group_with_aliased_aggregate() -> Result<()> { #[tokio::test] async fn csv_query_group_by_string_min_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+-----------------------------+", "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", @@ -312,11 +312,11 @@ async fn query_group_on_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // Note that the results also // include a row for NULL (c1=NULL, count = 1) @@ -371,11 +371,11 @@ async fn query_group_on_null_multi_col() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // Note that the results also include values for null // include a row for NULL (c1=NULL, count = 1) @@ -393,14 +393,14 @@ async fn query_group_on_null_multi_col() -> Result<()> { // Also run query with group columns reversed (results should be the same) let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn csv_group_by_date() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("date", DataType::Date32, false), Field::new("cnt", DataType::Int32, false), @@ -430,7 +430,7 @@ async fn csv_group_by_date() -> Result<()> { ctx.register_table("dates", Arc::new(table))?; let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", "| SUM(dates.cnt) |", diff --git a/datafusion/tests/sql/information_schema.rs b/datafusion/tests/sql/information_schema.rs index d93f0d7328d3..991bd01171aa 100644 --- a/datafusion/tests/sql/information_schema.rs +++ b/datafusion/tests/sql/information_schema.rs @@ -29,9 +29,9 @@ use super::*; #[tokio::test] async fn information_schema_tables_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let err = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap_err(); assert_eq!( @@ -42,11 +42,10 @@ async fn information_schema_tables_not_exist_by_default() { #[tokio::test] async fn information_schema_tables_no_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -63,15 +62,14 @@ async fn information_schema_tables_no_tables() { #[tokio::test] async fn information_schema_tables_tables_default_catalog() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); // Now, register an empty table ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -90,7 +88,7 @@ async fn information_schema_tables_tables_default_catalog() { ctx.register_table("t2", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -109,9 +107,9 @@ async fn information_schema_tables_tables_default_catalog() { #[tokio::test] async fn information_schema_tables_tables_with_multiple_catalogs() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); + let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema @@ -131,7 +129,7 @@ async fn information_schema_tables_tables_with_multiple_catalogs() { catalog.register_schema("my_other_schema", Arc::new(schema)); ctx.register_catalog("my_other_catalog", Arc::new(catalog)); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -176,14 +174,14 @@ async fn information_schema_tables_table_types() { _: &Option>, _: &[Expr], _: Option, + _: String, ) -> Result> { unimplemented!() } } - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) .unwrap(); @@ -192,7 +190,7 @@ async fn information_schema_tables_table_types() { ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) .unwrap(); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -212,28 +210,27 @@ async fn information_schema_tables_table_types() { #[tokio::test] async fn information_schema_show_tables_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias - let err = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap_err(); + let err = plan_and_collect(&ctx, "SHOW TABLES").await.unwrap_err(); assert_eq!(err.to_string(), "Error during planning: SHOW TABLES is not supported unless information_schema is enabled"); } #[tokio::test] async fn information_schema_show_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias - let result = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap(); + let result = plan_and_collect(&ctx, "SHOW TABLES").await.unwrap(); let expected = vec![ "+---------------+--------------------+------------+------------+", @@ -246,19 +243,19 @@ async fn information_schema_show_tables() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW tables").await.unwrap(); + let result = plan_and_collect(&ctx, "SHOW tables").await.unwrap(); assert_batches_sorted_eq!(expected, &result); } #[tokio::test] async fn information_schema_show_columns_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") + let err = plan_and_collect(&ctx, "SHOW COLUMNS FROM t") .await .unwrap_err(); @@ -267,7 +264,7 @@ async fn information_schema_show_columns_no_information_schema() { #[tokio::test] async fn information_schema_show_columns_like_where() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -275,12 +272,12 @@ async fn information_schema_show_columns_like_where() { let expected = "Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported"; - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t LIKE 'f'") + let err = plan_and_collect(&ctx, "SHOW COLUMNS FROM t LIKE 'f'") .await .unwrap_err(); assert_eq!(err.to_string(), expected); - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") + let err = plan_and_collect(&ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") .await .unwrap_err(); assert_eq!(err.to_string(), expected); @@ -288,16 +285,13 @@ async fn information_schema_show_columns_like_where() { #[tokio::test] async fn information_schema_show_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") - .await - .unwrap(); + let result = plan_and_collect(&ctx, "SHOW COLUMNS FROM t").await.unwrap(); let expected = vec![ "+---------------+--------------+------------+-------------+-----------+-------------+", @@ -308,13 +302,11 @@ async fn information_schema_show_columns() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW columns from t") - .await - .unwrap(); + let result = plan_and_collect(&ctx, "SHOW columns from t").await.unwrap(); assert_batches_sorted_eq!(expected, &result); // This isn't ideal but it is consistent behavior for `SELECT * from T` - let err = plan_and_collect(&mut ctx, "SHOW columns from T") + let err = plan_and_collect(&ctx, "SHOW columns from T") .await .unwrap_err(); assert_eq!( @@ -326,14 +318,13 @@ async fn information_schema_show_columns() { // test errors with WHERE and LIKE #[tokio::test] async fn information_schema_show_columns_full_extended() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") + let result = plan_and_collect(&ctx, "SHOW FULL COLUMNS FROM t") .await .unwrap(); let expected = vec![ @@ -345,7 +336,7 @@ async fn information_schema_show_columns_full_extended() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW EXTENDED COLUMNS FROM t") + let result = plan_and_collect(&ctx, "SHOW EXTENDED COLUMNS FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &result); @@ -353,14 +344,13 @@ async fn information_schema_show_columns_full_extended() { #[tokio::test] async fn information_schema_show_table_table_names() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") + let result = plan_and_collect(&ctx, "SHOW COLUMNS FROM public.t") .await .unwrap(); @@ -373,12 +363,12 @@ async fn information_schema_show_table_table_names() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t") + let result = plan_and_collect(&ctx, "SHOW columns from datafusion.public.t") .await .unwrap(); assert_batches_sorted_eq!(expected, &result); - let err = plan_and_collect(&mut ctx, "SHOW columns from t2") + let err = plan_and_collect(&ctx, "SHOW columns from t2") .await .unwrap_err(); assert_eq!( @@ -386,7 +376,7 @@ async fn information_schema_show_table_table_names() { "Error during planning: Unknown relation for SHOW COLUMNS: t2" ); - let err = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t2") + let err = plan_and_collect(&ctx, "SHOW columns from datafusion.public.t2") .await .unwrap_err(); assert_eq!( @@ -397,9 +387,9 @@ async fn information_schema_show_table_table_names() { #[tokio::test] async fn show_unsupported() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let ctx = SessionContext::new(); - let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") + let err = plan_and_collect(&ctx, "SHOW SOMETHING_UNKNOWN") .await .unwrap_err(); @@ -408,9 +398,9 @@ async fn show_unsupported() { #[tokio::test] async fn information_schema_columns_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + let err = plan_and_collect(&ctx, "SELECT * from information_schema.columns") .await .unwrap_err(); assert_eq!( @@ -456,9 +446,9 @@ fn table_with_many_types() -> Arc { #[tokio::test] async fn information_schema_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); + let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -472,7 +462,7 @@ async fn information_schema_columns() { catalog.register_schema("my_schema", Arc::new(schema)); ctx.register_catalog("my_catalog", Arc::new(catalog)); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.columns") .await .unwrap(); @@ -494,9 +484,6 @@ async fn information_schema_columns() { } /// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut ExecutionContext, - sql: &str, -) -> Result> { +async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } diff --git a/datafusion/tests/sql/intersection.rs b/datafusion/tests/sql/intersection.rs index d28dd8079fa9..fadeefee4136 100644 --- a/datafusion/tests/sql/intersection.rs +++ b/datafusion/tests/sql/intersection.rs @@ -23,8 +23,8 @@ async fn intersect_with_null_not_equal() { INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2"; let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -41,19 +41,19 @@ async fn intersect_with_null_equal() { "+-----+-----+", ]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } #[tokio::test] async fn test_intersect_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", @@ -70,11 +70,11 @@ async fn test_intersect_all() -> Result<()> { #[tokio::test] async fn test_intersect_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs index 04436ed460b1..d8f5d4ad8a2b 100644 --- a/datafusion/tests/sql/joins.rs +++ b/datafusion/tests/sql/joins.rs @@ -20,7 +20,7 @@ use datafusion::from_slice::FromSlice; #[tokio::test] async fn equijoin() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -35,11 +35,11 @@ async fn equijoin() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } - let mut ctx = create_join_context_qualified()?; + let ctx = create_join_context_qualified()?; let equivalent_sql = [ "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a", "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a", @@ -54,7 +54,7 @@ async fn equijoin() -> Result<()> { "+---+-----+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -62,7 +62,7 @@ async fn equijoin() -> Result<()> { #[tokio::test] async fn equijoin_multiple_condition_ordering() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id", @@ -79,7 +79,7 @@ async fn equijoin_multiple_condition_ordering() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -87,10 +87,10 @@ async fn equijoin_multiple_condition_ordering() -> Result<()> { #[tokio::test] async fn equijoin_and_other_condition() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -105,12 +105,12 @@ async fn equijoin_and_other_condition() -> Result<()> { #[tokio::test] async fn equijoin_left_and_condition_from_right() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; let res = ctx.create_logical_plan(sql); assert!(res.is_ok()); - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -128,12 +128,12 @@ async fn equijoin_left_and_condition_from_right() -> Result<()> { #[tokio::test] async fn equijoin_right_and_condition_from_left() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 ORDER BY t2_name"; let res = ctx.create_logical_plan(sql); assert!(res.is_ok()); - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -163,7 +163,7 @@ async fn equijoin_and_unsupported_condition() -> Result<()> { #[tokio::test] async fn left_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -179,7 +179,7 @@ async fn left_join() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -188,7 +188,7 @@ async fn left_join() -> Result<()> { #[tokio::test] async fn left_join_unbalanced() -> Result<()> { // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in - let mut ctx = create_join_context_unbalanced("t1_id", "t2_id")?; + let ctx = create_join_context_unbalanced("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -205,7 +205,7 @@ async fn left_join_unbalanced() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -216,7 +216,7 @@ async fn left_join_null_filter() -> Result<()> { // Since t2 is the non-preserved side of the join, we cannot push down a NULL filter. // Note that this is only true because IS NULL does not remove nulls. For filters that // remove nulls, we can rewrite the join as an inner join and then push down the filter. - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -229,7 +229,7 @@ async fn left_join_null_filter() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } @@ -237,7 +237,7 @@ async fn left_join_null_filter() -> Result<()> { #[tokio::test] async fn left_join_null_filter_on_join_column() -> Result<()> { // Again, since t2 is the non-preserved side of the join, we cannot push down a NULL filter. - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -249,14 +249,14 @@ async fn left_join_null_filter_on_join_column() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn left_join_not_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -268,14 +268,14 @@ async fn left_join_not_null_filter() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn left_join_not_null_filter_on_join_column() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -288,14 +288,14 @@ async fn left_join_not_null_filter_on_join_column() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -306,14 +306,14 @@ async fn right_join_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_null_filter_on_join_column() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -323,14 +323,14 @@ async fn right_join_null_filter_on_join_column() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_not_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -342,14 +342,14 @@ async fn right_join_not_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_not_null_filter_on_join_column() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -362,14 +362,14 @@ async fn right_join_not_null_filter_on_join_column() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn full_join_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id"; let expected = vec![ "+-------+---------+-------+", @@ -381,14 +381,14 @@ async fn full_join_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn full_join_not_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id"; let expected = vec![ "+-------+---------+-------+", @@ -402,14 +402,14 @@ async fn full_join_not_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id" @@ -425,7 +425,7 @@ async fn right_join() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -433,7 +433,7 @@ async fn right_join() -> Result<()> { #[tokio::test] async fn full_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -450,7 +450,7 @@ async fn full_join() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -459,7 +459,7 @@ async fn full_join() -> Result<()> { "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -468,9 +468,9 @@ async fn full_join() -> Result<()> { #[tokio::test] async fn left_join_using() -> Result<()> { - let mut ctx = create_join_context("id", "id")?; + let ctx = create_join_context("id", "id")?; let sql = "SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+---------+---------+", "| id | t1_name | t2_name |", @@ -487,7 +487,7 @@ async fn left_join_using() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id", @@ -502,7 +502,7 @@ async fn equijoin_implicit_syntax() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -510,14 +510,14 @@ async fn equijoin_implicit_syntax() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax_with_filter() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name \ FROM t1, t2 \ WHERE t1_id > 0 \ AND t1_id = t2_id \ AND t2_id < 99 \ ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -533,10 +533,10 @@ async fn equijoin_implicit_syntax_with_filter() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax_reversed() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -552,24 +552,24 @@ async fn equijoin_implicit_syntax_reversed() -> Result<()> { #[tokio::test] async fn cross_join() { - let mut ctx = create_join_context("t1_id", "t2_id").unwrap(); + let ctx = create_join_context("t1_id", "t2_id").unwrap(); let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4, actual.len()); let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4, actual.len()); let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4, actual.len()); - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -597,13 +597,13 @@ async fn cross_join() { // Two partitions (from UNION) on the left let sql = "SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4 * 2, actual.len()); // Two partitions (from UNION) on the right let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4 * 2, actual.len()); } @@ -611,12 +611,12 @@ async fn cross_join() { #[tokio::test] async fn cross_join_unbalanced() { // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in - let mut ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); + let ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); // the order of the values is not determinisitic, so we need to sort to check the values let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -648,7 +648,7 @@ async fn cross_join_unbalanced() { #[tokio::test] async fn test_join_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // register time table let timestamp_schema = Arc::new(Schema::new(vec![Field::new( @@ -673,7 +673,7 @@ async fn test_join_timestamp() -> Result<()> { JOIN (SELECT * FROM timestamp) as b \ ON a.time = b.time \ ORDER BY a.time"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+-------------------------------+", @@ -691,7 +691,7 @@ async fn test_join_timestamp() -> Result<()> { #[tokio::test] async fn test_join_float32() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // register population table let population_schema = Arc::new(Schema::new(vec![ @@ -714,7 +714,7 @@ async fn test_join_float32() -> Result<()> { JOIN (SELECT * FROM population) as b \ ON a.population = b.population \ ORDER BY a.population"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+------------+------+------------+", @@ -732,7 +732,7 @@ async fn test_join_float32() -> Result<()> { #[tokio::test] async fn test_join_float64() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // register population table let population_schema = Arc::new(Schema::new(vec![ @@ -755,7 +755,7 @@ async fn test_join_float64() -> Result<()> { JOIN (SELECT * FROM population) as b \ ON a.population = b.population \ ORDER BY a.population"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+------------+------+------------+", @@ -799,8 +799,8 @@ async fn inner_join_qualified_names() -> Result<()> { ]; for sql in equivalent_sql.iter() { - let mut ctx = create_join_context_qualified()?; - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified()?; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -813,8 +813,8 @@ async fn inner_join_nulls() { let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; // left and right shouldn't match anything assert_batches_eq!(expected, &actual); @@ -857,14 +857,14 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul .unwrap(); let cities = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("countries", Arc::new(countries))?; ctx.register_table("cities", Arc::new(cities))?; // city.id is not in the on constraint, but the output result will contain both city.id and // country.id let sql = "SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+-----------+---------+", "| id | id | city | country |", @@ -885,7 +885,7 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul #[tokio::test] async fn join_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let expected = vec![ @@ -899,7 +899,7 @@ async fn join_timestamp() -> Result<()> { ]; let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT * FROM t as t1 \ JOIN (SELECT * FROM t) as t2 \ ON t1.nanos = t2.nanos", @@ -909,7 +909,7 @@ async fn join_timestamp() -> Result<()> { assert_batches_sorted_eq!(expected, &results); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT * FROM t as t1 \ JOIN (SELECT * FROM t) as t2 \ ON t1.micros = t2.micros", @@ -919,7 +919,7 @@ async fn join_timestamp() -> Result<()> { assert_batches_sorted_eq!(expected, &results); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT * FROM t as t1 \ JOIN (SELECT * FROM t) as t2 \ ON t1.millis = t2.millis", diff --git a/datafusion/tests/sql/limit.rs b/datafusion/tests/sql/limit.rs index fd68e330bee1..fc2dc4c95645 100644 --- a/datafusion/tests/sql/limit.rs +++ b/datafusion/tests/sql/limit.rs @@ -19,10 +19,10 @@ use super::*; #[tokio::test] async fn csv_query_limit() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+----+", "| c1 |", "+----+", "| c |", "| d |", "+----+"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -30,10 +30,10 @@ async fn csv_query_limit() -> Result<()> { #[tokio::test] async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // println!("{}", pretty_format_batches(&a).unwrap()); let expected = vec![ "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", @@ -56,10 +56,10 @@ async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { #[tokio::test] async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", "| 4 |", "| 3 |", "| 3 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 3 |", @@ -81,10 +81,10 @@ async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { #[tokio::test] async fn csv_query_limit_zero() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 0"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["++", "++"]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index a548d619d635..e1ab5b3d297c 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -30,6 +30,7 @@ use datafusion::assert_batches_sorted_eq; use datafusion::assert_contains; use datafusion::assert_not_contains; use datafusion::datasource::TableProvider; +use datafusion::execution::context::TaskContext; use datafusion::from_slice::FromSlice; use datafusion::logical_plan::plan::{Aggregate, Projection}; use datafusion::logical_plan::LogicalPlan; @@ -45,7 +46,7 @@ use datafusion::{ error::{DataFusionError, Result}, physical_plan::ColumnarValue, }; -use datafusion::{execution::context::ExecutionContext, physical_plan::displayable}; +use datafusion::{execution::context::SessionContext, physical_plan::displayable}; /// A macro to assert that some particular line contains two substrings /// @@ -66,9 +67,9 @@ macro_rules! assert_metrics { macro_rules! test_expression { ($SQL:expr, $EXPECTED:expr) => { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = format!("SELECT {}", $SQL); - let actual = execute(&mut ctx, sql.as_str()).await; + let actual = execute(&ctx, sql.as_str()).await; assert_eq!(actual[0][0], $EXPECTED); }; } @@ -120,8 +121,8 @@ where } #[allow(clippy::unnecessary_wraps)] -fn create_ctx() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_ctx() -> Result { + let ctx = SessionContext::new(); // register a custom UDF ctx.register_udf(create_udf( @@ -150,8 +151,8 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } } -fn create_case_context() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_case_context() -> Result { + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); let data = RecordBatch::try_new( schema.clone(), @@ -167,11 +168,8 @@ fn create_case_context() -> Result { Ok(ctx) } -fn create_join_context( - column_left: &str, - column_right: &str, -) -> Result { - let mut ctx = ExecutionContext::new(); +fn create_join_context(column_left: &str, column_right: &str) -> Result { + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new(column_left, DataType::UInt32, true), @@ -214,8 +212,8 @@ fn create_join_context( Ok(ctx) } -fn create_join_context_qualified() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_join_context_qualified() -> Result { + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, true), @@ -256,8 +254,8 @@ fn create_join_context_qualified() -> Result { fn create_join_context_unbalanced( column_left: &str, column_right: &str, -) -> Result { - let mut ctx = ExecutionContext::new(); +) -> Result { + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new(column_left, DataType::UInt32, true), @@ -302,8 +300,8 @@ fn create_join_context_unbalanced( } // Create memory tables with nulls -fn create_join_context_with_nulls() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_join_context_with_nulls() -> Result { + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new("t1_id", DataType::UInt32, true), @@ -405,7 +403,7 @@ fn get_tpch_table_schema(table: &str) -> Schema { } } -async fn register_tpch_csv(ctx: &mut ExecutionContext, table: &str) -> Result<()> { +async fn register_tpch_csv(ctx: &SessionContext, table: &str) -> Result<()> { let schema = get_tpch_table_schema(table); ctx.register_csv( @@ -417,7 +415,7 @@ async fn register_tpch_csv(ctx: &mut ExecutionContext, table: &str) -> Result<() Ok(()) } -async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { +async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); // TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once @@ -459,7 +457,7 @@ async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { } /// Create table "t1" with two boolean columns "a" and "b" -async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> { +async fn register_boolean(ctx: &SessionContext) -> Result<()> { let a: BooleanArray = [ Some(true), Some(true), @@ -494,7 +492,7 @@ async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } -async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> { +async fn register_aggregate_simple_csv(ctx: &SessionContext) -> Result<()> { // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Float32, false), @@ -511,7 +509,7 @@ async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> Ok(()) } -async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { +async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( @@ -524,16 +522,13 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { } /// Execute SQL and return results as a RecordBatch -async fn plan_and_collect( - ctx: &mut ExecutionContext, - sql: &str, -) -> Result> { +async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } /// Execute query and return results as a Vec of RecordBatches -async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec { - let msg = format!("Creating logical plan for '{}'", sql); +async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec { + let msg = format!("Creating lexecute_to_batchesogical plan for '{}'", sql); let plan = ctx.create_logical_plan(sql).expect(&msg); let logical_schema = plan.schema(); @@ -545,8 +540,8 @@ async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec Vec Vec> { +async fn execute(ctx: &SessionContext, sql: &str) -> Vec> { result_vec(&execute_to_batches(ctx, sql).await) } @@ -601,7 +596,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { result } -async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) { +async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &SessionContext) { let df = ctx .sql( "CREATE EXTERNAL TABLE aggregate_simple ( @@ -623,7 +618,7 @@ async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionCo ); } -async fn register_alltypes_parquet(ctx: &mut ExecutionContext) { +async fn register_alltypes_parquet(ctx: &SessionContext) { let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", @@ -832,7 +827,7 @@ async fn nyc() -> Result<()> { Field::new("total_amount", DataType::Float64, true), ]); - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_csv( "tripdata", "file.csv", diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs index d23c81778951..2683f20e7fdd 100644 --- a/datafusion/tests/sql/order.rs +++ b/datafusion/tests/sql/order.rs @@ -19,11 +19,11 @@ use super::*; #[tokio::test] async fn test_sort_unprojected_col() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT id FROM alltypes_plain ORDER BY int_col, double_col"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| id |", "+----+", "| 4 |", "| 6 |", "| 2 |", "| 0 |", "| 5 |", "| 7 |", "| 3 |", "| 1 |", "+----+", @@ -34,10 +34,10 @@ async fn test_sort_unprojected_col() -> Result<()> { #[tokio::test] async fn test_order_by_agg_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------+", "| MIN(aggregate_test_100.c12) |", @@ -48,16 +48,16 @@ async fn test_order_by_agg_expr() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 0.1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn test_nulls_first_asc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", @@ -73,9 +73,9 @@ async fn test_nulls_first_asc() -> Result<()> { #[tokio::test] async fn test_nulls_first_desc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", @@ -91,9 +91,9 @@ async fn test_nulls_first_desc() -> Result<()> { #[tokio::test] async fn test_specific_nulls_last_desc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC NULLS LAST"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", @@ -109,9 +109,9 @@ async fn test_specific_nulls_last_desc() -> Result<()> { #[tokio::test] async fn test_specific_nulls_first_asc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num ASC NULLS FIRST"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", diff --git a/datafusion/tests/sql/parquet.rs b/datafusion/tests/sql/parquet.rs index 37912c8751c8..90ef36deecfb 100644 --- a/datafusion/tests/sql/parquet.rs +++ b/datafusion/tests/sql/parquet.rs @@ -24,12 +24,12 @@ use super::*; #[tokio::test] async fn parquet_query() { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------------------+", "| id | CAST(alltypes_plain.string_col AS Utf8) |", @@ -50,7 +50,7 @@ async fn parquet_query() { #[tokio::test] async fn parquet_single_nan_schema() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) .await @@ -59,8 +59,8 @@ async fn parquet_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(plan, task_ctx).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); assert_eq!(1, batch.num_columns()); @@ -70,7 +70,7 @@ async fn parquet_single_nan_schema() { #[tokio::test] #[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] async fn parquet_list_columns() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet( "list_columns", @@ -96,8 +96,8 @@ async fn parquet_list_columns() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect(plan, task_ctx).await.unwrap(); // int64_list utf8_list // 0 [1, 2, 3] [abc, efg, hij] @@ -212,7 +212,7 @@ async fn schema_merge_ignores_metadata() { // Read the parquet files into a dataframe to confirm results // (no errors) - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let df = ctx .read_parquet(table_dir.to_str().unwrap().to_string()) .await diff --git a/datafusion/tests/sql/partitioned_csv.rs b/datafusion/tests/sql/partitioned_csv.rs index 3394887ad0b8..3ba370969a1c 100644 --- a/datafusion/tests/sql/partitioned_csv.rs +++ b/datafusion/tests/sql/partitioned_csv.rs @@ -25,23 +25,20 @@ use arrow::{ }; use datafusion::{ error::Result, - prelude::{CsvReadOptions, ExecutionConfig, ExecutionContext}, + prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; use tempfile::TempDir; /// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut ExecutionContext, - sql: &str, -) -> Result> { +async fn plan_and_collect(ctx: SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } /// Execute SQL and return results pub async fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await + let ctx = create_ctx(&tmp_dir, partition_count).await?; + plan_and_collect(ctx, sql).await } /// Generate CSV partitions within the supplied directory @@ -77,9 +74,9 @@ fn populate_csv_partitions( pub async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, -) -> Result { - let mut ctx = - ExecutionContext::with_config(ExecutionConfig::new().with_target_partitions(8)); +) -> Result { + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::with_config(config); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; diff --git a/datafusion/tests/sql/predicates.rs b/datafusion/tests/sql/predicates.rs index f4e1f4f4deef..1369baa75f4c 100644 --- a/datafusion/tests/sql/predicates.rs +++ b/datafusion/tests/sql/predicates.rs @@ -19,10 +19,10 @@ use super::*; #[tokio::test] async fn csv_query_with_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+---------------------+", "| c1 | c12 |", @@ -37,10 +37,10 @@ async fn csv_query_with_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_negative_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c4 FROM aggregate_test_100 WHERE c3 < -55 AND -c4 > 30000"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+--------+", "| c1 | c4 |", @@ -55,10 +55,10 @@ async fn csv_query_with_negative_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_negated_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE NOT(c1 != 'a')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -72,10 +72,10 @@ async fn csv_query_with_negated_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_is_not_null_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NOT NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -89,10 +89,10 @@ async fn csv_query_with_is_not_null_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_is_null_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -106,12 +106,12 @@ async fn csv_query_with_is_null_predicate() -> Result<()> { #[tokio::test] async fn query_where_neg_num() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; // Negative numbers do not parse correctly as of Arrow 2.0.0 let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2 and c7 < 10"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-------+", "| c7 | c8 |", @@ -127,18 +127,18 @@ async fn query_where_neg_num() -> Result<()> { // Also check floating point neg numbers let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2.9 and c7 < 10"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn like() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; // check that the physical and logical schemas are equal - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------+", "| COUNT(aggregate_test_100.c1) |", @@ -152,10 +152,10 @@ async fn like() -> Result<()> { #[tokio::test] async fn csv_between_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 BETWEEN 0.995 AND 1.0"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c4 |", @@ -169,10 +169,10 @@ async fn csv_between_expr() -> Result<()> { #[tokio::test] async fn csv_between_expr_negated() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 NOT BETWEEN 0 AND 0.995"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c4 |", @@ -193,11 +193,11 @@ async fn like_on_strings() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -220,11 +220,11 @@ async fn like_on_string_dictionaries() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -247,11 +247,11 @@ async fn test_regexp_is_match() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 ~ 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -262,7 +262,7 @@ async fn test_regexp_is_match() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT * FROM test WHERE c1 ~* 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -274,7 +274,7 @@ async fn test_regexp_is_match() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT * FROM test WHERE c1 !~ 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -287,7 +287,7 @@ async fn test_regexp_is_match() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT * FROM test WHERE c1 !~* 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -313,8 +313,8 @@ async fn except_with_null_not_equal() { "+-----+-----+", ]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -325,19 +325,19 @@ async fn except_with_null_equal() { EXCEPT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2"; let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } #[tokio::test] async fn test_expect_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", @@ -354,11 +354,11 @@ async fn test_expect_all() -> Result<()> { #[tokio::test] async fn test_expect_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", diff --git a/datafusion/tests/sql/projection.rs b/datafusion/tests/sql/projection.rs index 0a956a9411eb..3960ecec2400 100644 --- a/datafusion/tests/sql/projection.rs +++ b/datafusion/tests/sql/projection.rs @@ -22,10 +22,10 @@ use super::*; #[tokio::test] async fn projection_same_fields() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| a |", "+---+", "| 2 |", "+---+"]; assert_batches_eq!(expected, &actual); @@ -35,13 +35,13 @@ async fn projection_same_fields() -> Result<()> { #[tokio::test] async fn projection_type_alias() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; // Query that aliases one column to the name of a different column // that also has a different type (c1 == float32, c3 == boolean) let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", @@ -58,10 +58,10 @@ async fn projection_type_alias() -> Result<()> { #[tokio::test] async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------+----+", "| AVG(aggregate_test_100.c12) | c1 |", @@ -139,7 +139,6 @@ async fn projection_on_table_scan() -> Result<()> { let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - let runtime = ctx.state.lock().runtime_env.clone(); let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -171,7 +170,8 @@ async fn projection_on_table_scan() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(physical_plan, task_ctx).await?; assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); Ok(()) @@ -218,7 +218,7 @@ async fn projection_on_memory_scan() -> Result<()> { .build()?; assert_fields_eq(&plan, vec!["b"]); - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let optimized_plan = ctx.optimize(&plan)?; match &optimized_plan { LogicalPlan::Projection(Projection { input, .. }) => match &**input { @@ -247,8 +247,8 @@ async fn projection_on_memory_scan() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - let runtime = ctx.state.lock().runtime_env.clone(); - let batches = collect(physical_plan, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let batches = collect(physical_plan, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(4, batches[0].num_rows()); diff --git a/datafusion/tests/sql/references.rs b/datafusion/tests/sql/references.rs index 779c6a336673..23f6058b8034 100644 --- a/datafusion/tests/sql/references.rs +++ b/datafusion/tests/sql/references.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn qualified_table_references() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; for table_ref in &[ "aggregate_test_100", @@ -28,7 +28,7 @@ async fn qualified_table_references() -> Result<()> { "datafusion.public.aggregate_test_100", ] { let sql = format!("SELECT COUNT(*) FROM {}", table_ref); - let actual = execute_to_batches(&mut ctx, &sql).await; + let actual = execute_to_batches(&ctx, &sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -43,7 +43,7 @@ async fn qualified_table_references() -> Result<()> { #[tokio::test] async fn qualified_table_references_and_fields() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let c1: StringArray = vec!["foofoo", "foobar", "foobaz"] .into_iter() @@ -73,7 +73,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // however, enclosing it in double quotes is ok let sql = r#"SELECT "f.c1" from test"#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------+", "| f.c1 |", @@ -86,12 +86,12 @@ async fn qualified_table_references_and_fields() -> Result<()> { assert_batches_eq!(expected, &actual); // Works fully qualified too let sql = r#"SELECT test."f.c1" from test"#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); // check that duplicated table name and column name are ok let sql = r#"SELECT "test.c2" as expr1, test."test.c2" as expr2 from test"#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+", "| expr1 | expr2 |", @@ -107,7 +107,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // datafusion should run the query, not that someone should write // this let sql = r#"SELECT "....", "...." as c3 from test order by "....""#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+----+", "| .... | c3 |", @@ -123,7 +123,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { #[tokio::test] async fn test_partial_qualified_name() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1.t1_id, t1_name FROM public.t1"; let expected = vec![ "+-------+---------+", @@ -135,7 +135,7 @@ async fn test_partial_qualified_name() -> Result<()> { "| 44 | d |", "+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 6ba190856a46..8add52ae34d2 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -24,12 +24,12 @@ use tempfile::TempDir; #[tokio::test] async fn all_where_empty() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT * FROM aggregate_test_100 WHERE 1=2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["++", "++"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -37,10 +37,10 @@ async fn all_where_empty() -> Result<()> { #[tokio::test] async fn select_values_list() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); { let sql = "VALUES (1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", "| column1 |", @@ -52,7 +52,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (-1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", "| column1 |", @@ -64,7 +64,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (2+1,2-1,2>1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+---------+", "| column1 | column2 | column3 |", @@ -86,7 +86,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1),(2)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", "| column1 |", @@ -104,7 +104,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,'a'),(2,'b')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -137,7 +137,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,'a'),(NULL,'b'),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -151,7 +151,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (NULL,'a'),(NULL,'b'),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -165,7 +165,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (NULL,'a'),(NULL,'b'),(NULL,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -179,7 +179,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,'a'),(2,NULL),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -193,7 +193,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,NULL),(2,NULL),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -207,7 +207,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", "| column1 | column2 | column3 | column4 | column5 | column6 | column7 | column8 | column9 | column10 | column11 | column12 | column13 | column14 | column15 | column16 |", @@ -219,7 +219,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "SELECT * FROM (VALUES (1,'a'),(2,NULL)) AS t(c1, c2)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+", "| c1 | c2 |", @@ -232,7 +232,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "EXPLAIN VALUES (1, 'a', -1, 1.1),(NULL, 'b', -3, 0.5)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------+-----------------------------------------------------------------------------------------------------------+", "| plan_type | plan |", @@ -249,14 +249,14 @@ async fn select_values_list() -> Result<()> { #[tokio::test] async fn select_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_simple order by c1"; - let results = execute_to_batches(&mut ctx, sql).await; + let results = execute_to_batches(&ctx, sql).await; let sql_all = "SELECT ALL c1 FROM aggregate_simple order by c1"; - let results_all = execute_to_batches(&mut ctx, sql_all).await; + let results_all = execute_to_batches(&ctx, sql_all).await; let expected = vec![ "+---------+", @@ -288,11 +288,11 @@ async fn select_all() -> Result<()> { #[tokio::test] async fn select_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT DISTINCT * FROM aggregate_simple"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let mut dedup = actual.clone(); @@ -305,11 +305,11 @@ async fn select_distinct() -> Result<()> { #[tokio::test] async fn select_distinct_simple_1() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT DISTINCT c1 FROM aggregate_simple order by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", @@ -327,11 +327,11 @@ async fn select_distinct_simple_1() { #[tokio::test] async fn select_distinct_simple_2() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT DISTINCT c1, c2 FROM aggregate_simple order by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+----------------+", @@ -349,11 +349,11 @@ async fn select_distinct_simple_2() { #[tokio::test] async fn select_distinct_simple_3() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT distinct c3 FROM aggregate_simple order by c3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", @@ -368,11 +368,11 @@ async fn select_distinct_simple_3() { #[tokio::test] async fn select_distinct_simple_4() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT distinct c1+c2 as a FROM aggregate_simple"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", @@ -390,7 +390,7 @@ async fn select_distinct_simple_4() { #[tokio::test] async fn select_distinct_from() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select 1 IS DISTINCT FROM CAST(NULL as INT) as a, @@ -400,7 +400,7 @@ async fn select_distinct_from() { NULL IS DISTINCT FROM NULL as e, NULL IS NOT DISTINCT FROM NULL as f "; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+-------+-------+------+-------+------+", "| a | b | c | d | e | f |", @@ -413,7 +413,7 @@ async fn select_distinct_from() { #[tokio::test] async fn select_distinct_from_utf8() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select 'x' IS DISTINCT FROM NULL as a, @@ -421,7 +421,7 @@ async fn select_distinct_from_utf8() { 'x' IS NOT DISTINCT FROM NULL as c, 'x' IS NOT DISTINCT FROM 'x' as d "; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+-------+-------+------+", "| a | b | c | d |", @@ -434,10 +434,10 @@ async fn select_distinct_from_utf8() { #[tokio::test] async fn csv_query_with_decimal_by_sql() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_simple_aggregate_csv_with_decimal_by_sql(&ctx).await; let sql = "SELECT c1 from aggregate_simple"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| c1 |", @@ -465,10 +465,10 @@ async fn csv_query_with_decimal_by_sql() -> Result<()> { #[tokio::test] async fn use_between_expression_in_select_query() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1 NOT BETWEEN 3 AND 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------------------+", "| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |", @@ -484,7 +484,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { ctx.register_table("test", Arc::new(table))?; let sql = "SELECT abs(c1) BETWEEN 0 AND LoG(c1 * 100 ) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // Expect field name to be correctly converted for expr, low and high. let expected = vec![ "+--------------------------------------------------------------------+", @@ -499,7 +499,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "EXPLAIN SELECT c1 BETWEEN 2 AND 3 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -515,7 +515,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { #[tokio::test] async fn query_get_indexed_field() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![Field::new( "some_list", DataType::List(Box::new(Field::new("item", DataType::Int64, true))), @@ -539,7 +539,7 @@ async fn query_get_indexed_field() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 7 |", "+----+", ]; @@ -549,7 +549,7 @@ async fn query_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_list": [[i64]] } let schema = Arc::new(Schema::new(vec![Field::new( @@ -585,7 +585,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| i0 |", @@ -597,7 +597,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { ]; assert_batches_eq!(expected, &actual); let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+", ]; @@ -607,7 +607,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field_on_struct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_struct": { "bar": [i64] } } let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)]; @@ -635,7 +635,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", "| l0 |", @@ -647,7 +647,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { ]; assert_batches_eq!(expected, &actual); let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+", ]; @@ -676,12 +676,12 @@ async fn query_on_string_dictionary() -> Result<()> { .unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; // Basic SELECT let sql = "SELECT d1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -695,7 +695,7 @@ async fn query_on_string_dictionary() -> Result<()> { // basic filtering let sql = "SELECT d1 FROM test WHERE d1 IS NOT NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -708,7 +708,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with constant let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -720,7 +720,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with another dictionary column let sql = "SELECT d1 FROM test WHERE d1 = d2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -732,7 +732,7 @@ async fn query_on_string_dictionary() -> Result<()> { // order comparison with another dictionary column let sql = "SELECT d1 FROM test WHERE d1 <= d2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -744,7 +744,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with a non dictionary column let sql = "SELECT d1 FROM test WHERE d1 = d3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -756,7 +756,7 @@ async fn query_on_string_dictionary() -> Result<()> { // filtering with constant let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -768,7 +768,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation let sql = "SELECT concat(d1, '-foo') FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------+", "| concat(test.d1,Utf8(\"-foo\")) |", @@ -782,7 +782,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation with two dictionaries let sql = "SELECT concat(d1, d2) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", "| concat(test.d1,test.d2) |", @@ -796,7 +796,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation let sql = "SELECT COUNT(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", "| COUNT(test.d1) |", @@ -808,7 +808,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation min let sql = "SELECT MIN(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------+", "| MIN(test.d1) |", @@ -820,7 +820,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation max let sql = "SELECT MAX(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------+", "| MAX(test.d1) |", @@ -832,7 +832,7 @@ async fn query_on_string_dictionary() -> Result<()> { // grouping let sql = "SELECT d1, COUNT(*) FROM test group by d1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-----------------+", "| d1 | COUNT(UInt8(1)) |", @@ -846,7 +846,7 @@ async fn query_on_string_dictionary() -> Result<()> { // window functions let sql = "SELECT d1, row_number() OVER (partition by d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+--------------+", "| d1 | ROW_NUMBER() |", @@ -865,11 +865,11 @@ async fn query_on_string_dictionary() -> Result<()> { async fn query_cte() -> Result<()> { // Test for SELECT without FROM. // Should evaluate expressions in project position. - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // simple with let sql = "WITH t AS (SELECT 1) SELECT * FROM t"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| Int64(1) |", @@ -882,19 +882,19 @@ async fn query_cte() -> Result<()> { // with + union let sql = "WITH t AS (SELECT 1 AS a), u AS (SELECT 2 AS a) SELECT * FROM t UNION ALL SELECT * FROM u"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "| 2 |", "+---+"]; assert_batches_eq!(expected, &actual); // with + join let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT 1 AS id2, 5 as x) SELECT x FROM t JOIN u ON (id1 = id2)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| x |", "+---+", "| 5 |", "+---+"]; assert_batches_eq!(expected, &actual); // backward reference let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT * FROM t) SELECT * from u"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+-----+", "| id1 |", "+-----+", "| 1 |", "+-----+"]; assert_batches_eq!(expected, &actual); @@ -903,8 +903,8 @@ async fn query_cte() -> Result<()> { #[tokio::test] async fn csv_select_nested() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT o1, o2, c3 FROM ( SELECT c1 AS o1, c2 + 1 AS o2, c3 @@ -915,7 +915,7 @@ async fn csv_select_nested() -> Result<()> { ORDER BY c2 ASC, c3 ASC ) AS a ) AS b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+------+", "| o1 | o2 | c3 |", @@ -945,8 +945,8 @@ async fn parallel_query_with_filter() -> Result<()> { let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect_partitioned(physical_plan, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let results = collect_partitioned(physical_plan, task_ctx).await?; // note that the order of partitions is not deterministic let mut num_rows = 0; @@ -991,11 +991,11 @@ async fn parallel_query_with_filter() -> Result<()> { #[tokio::test] async fn query_empty_table() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let empty_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))); ctx.register_table("test_tbl", empty_table).unwrap(); let sql = "SELECT * FROM test_tbl"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("Query empty table"); let expected = vec!["++", "++"]; diff --git a/datafusion/tests/sql/timestamp.rs b/datafusion/tests/sql/timestamp.rs index 42aa3f450163..47e41bf892c7 100644 --- a/datafusion/tests/sql/timestamp.rs +++ b/datafusion/tests/sql/timestamp.rs @@ -20,7 +20,7 @@ use datafusion::from_slice::FromSlice; #[tokio::test] async fn query_cast_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -35,7 +35,7 @@ async fn query_cast_timestamp_millis() -> Result<()> { ctx.register_table("t1", Arc::new(t1_table))?; let sql = "SELECT to_timestamp_millis(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------+", @@ -52,7 +52,7 @@ async fn query_cast_timestamp_millis() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_micros() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -67,7 +67,7 @@ async fn query_cast_timestamp_micros() -> Result<()> { ctx.register_table("t1", Arc::new(t1_table))?; let sql = "SELECT to_timestamp_micros(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------+", @@ -85,7 +85,7 @@ async fn query_cast_timestamp_micros() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_seconds() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -98,7 +98,7 @@ async fn query_cast_timestamp_seconds() -> Result<()> { ctx.register_table("t1", Arc::new(t1_table))?; let sql = "SELECT to_timestamp_seconds(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------+", @@ -116,12 +116,12 @@ async fn query_cast_timestamp_seconds() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_nanos_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; // Original column is nanos, convert to millis and check timestamp let sql = "SELECT to_timestamp_millis(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", @@ -135,7 +135,7 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT to_timestamp_micros(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", @@ -149,7 +149,7 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT to_timestamp_seconds(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------+", "| totimestampseconds(ts_data.ts) |", @@ -166,12 +166,12 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_seconds_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_secs", make_timestamp_table::()?)?; // Original column is seconds, convert to millis and check timestamp let sql = "SELECT to_timestamp_millis(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| totimestampmillis(ts_secs.ts) |", @@ -186,7 +186,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { // Original column is seconds, convert to micros and check timestamp let sql = "SELECT to_timestamp_micros(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| totimestampmicros(ts_secs.ts) |", @@ -200,7 +200,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { // to nanos let sql = "SELECT to_timestamp(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", "| totimestamp(ts_secs.ts) |", @@ -216,7 +216,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_micros_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table( "ts_micros", make_timestamp_table::()?, @@ -224,7 +224,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT to_timestamp_millis(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------+", "| totimestampmillis(ts_micros.ts) |", @@ -238,7 +238,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { // Original column is micros, convert to seconds and check timestamp let sql = "SELECT to_timestamp_seconds(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------------+", "| totimestampseconds(ts_micros.ts) |", @@ -252,7 +252,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { // Original column is micros, convert to nanos and check timestamp let sql = "SELECT to_timestamp(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+", "| totimestamp(ts_micros.ts) |", @@ -268,11 +268,11 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { #[tokio::test] async fn to_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", @@ -287,14 +287,14 @@ async fn to_timestamp() -> Result<()> { #[tokio::test] async fn to_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table( "ts_data", make_timestamp_table::()?, )?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_millis('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -308,14 +308,14 @@ async fn to_timestamp_millis() -> Result<()> { #[tokio::test] async fn to_timestamp_micros() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table( "ts_data", make_timestamp_table::()?, )?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_micros('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", @@ -330,11 +330,11 @@ async fn to_timestamp_micros() -> Result<()> { #[tokio::test] async fn to_timestamp_seconds() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_table::()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_seconds('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", @@ -349,11 +349,11 @@ async fn to_timestamp_seconds() -> Result<()> { #[tokio::test] async fn count_distinct_timestamps() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(DISTINCT(ts)) FROM ts_data"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+", @@ -369,8 +369,8 @@ async fn count_distinct_timestamps() -> Result<()> { #[tokio::test] async fn test_current_timestamp_expressions() -> Result<()> { let t1 = chrono::Utc::now().timestamp(); - let mut ctx = ExecutionContext::new(); - let actual = execute(&mut ctx, "SELECT NOW(), NOW() as t2").await; + let ctx = SessionContext::new(); + let actual = execute(&ctx, "SELECT NOW(), NOW() as t2").await; let res1 = actual[0][0].as_str(); let res2 = actual[0][1].as_str(); let t3 = chrono::Utc::now().timestamp(); @@ -387,7 +387,7 @@ async fn test_current_timestamp_expressions() -> Result<()> { #[tokio::test] async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let t1 = chrono::Utc::now().timestamp(); - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT NOW(), NOW() as t2"; let msg = format!("Creating logical plan for '{}'", sql); @@ -397,8 +397,8 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let plan = ctx.create_physical_plan(&plan).await.expect(&msg); let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().runtime_env.clone(); - let res = collect(plan, runtime).await.expect(&msg); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let res = collect(plan, task_ctx).await.expect(&msg); let actual = result_vec(&res); let res1 = actual[0][0].as_str(); @@ -416,7 +416,7 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { #[tokio::test] async fn timestamp_minmax() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_tz_table::(None)?; let table_b = make_timestamp_tz_table::(Some("UTC".to_owned()))?; @@ -424,7 +424,7 @@ async fn timestamp_minmax() -> Result<()> { ctx.register_table("table_b", table_b)?; let sql = "SELECT MIN(table_a.ts), MAX(table_b.ts) FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+----------------------------+", "| MIN(table_a.ts) | MAX(table_b.ts) |", @@ -440,7 +440,7 @@ async fn timestamp_minmax() -> Result<()> { #[tokio::test] async fn timestamp_coercion() -> Result<()> { { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_tz_table::(Some("UTC".to_owned()))?; let table_b = @@ -449,7 +449,7 @@ async fn timestamp_coercion() -> Result<()> { ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+-------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -469,14 +469,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -496,14 +496,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -523,14 +523,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+---------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -550,14 +550,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -577,14 +577,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -604,14 +604,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+---------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -631,14 +631,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+-------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -658,14 +658,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -685,14 +685,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+---------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -712,14 +712,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+-------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -739,14 +739,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -770,7 +770,7 @@ async fn timestamp_coercion() -> Result<()> { #[tokio::test] async fn group_by_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new( @@ -802,7 +802,7 @@ async fn group_by_timestamp_millis() -> Result<()> { let sql = "SELECT timestamp, SUM(count) FROM t1 GROUP BY timestamp ORDER BY timestamp ASC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+---------------+", "| timestamp | SUM(t1.count) |", diff --git a/datafusion/tests/sql/udf.rs b/datafusion/tests/sql/udf.rs index 6b714cb368b8..13cb5ab3faf6 100644 --- a/datafusion/tests/sql/udf.rs +++ b/datafusion/tests/sql/udf.rs @@ -27,10 +27,10 @@ use datafusion::{ /// physical plan have the same schema. #[tokio::test] async fn csv_query_custom_udf_with_cast() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.6584408483418833"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -51,7 +51,7 @@ async fn scalar_udf() -> Result<()> { ], )?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; ctx.register_table("t", Arc::new(provider))?; @@ -97,8 +97,8 @@ async fn scalar_udf() -> Result<()> { let plan = ctx.optimize(&plan)?; let plan = ctx.create_physical_plan(&plan).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(plan, runtime).await?; + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let result = collect(plan, task_ctx).await?; let expected = vec![ "+-----+-----+-----------------+", @@ -155,7 +155,7 @@ async fn simple_udaf() -> Result<()> { vec![Arc::new(Int32Array::from_slice(&[4, 5]))], )?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; @@ -172,7 +172,7 @@ async fn simple_udaf() -> Result<()> { ctx.register_udaf(my_avg); - let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT MY_AVG(a) FROM t").await?; let expected = vec![ "+-------------+", diff --git a/datafusion/tests/sql/unicode.rs b/datafusion/tests/sql/unicode.rs index 55747f2a9ac4..b9c9cbd1c5bc 100644 --- a/datafusion/tests/sql/unicode.rs +++ b/datafusion/tests/sql/unicode.rs @@ -116,10 +116,10 @@ async fn generic_query_length>>( let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT length(c1) FROM test"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; assert_eq!(expected, actual); Ok(()) diff --git a/datafusion/tests/sql/union.rs b/datafusion/tests/sql/union.rs index a1f81d24f456..4ebd4e88e3b0 100644 --- a/datafusion/tests/sql/union.rs +++ b/datafusion/tests/sql/union.rs @@ -19,9 +19,9 @@ use super::*; #[tokio::test] async fn union_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1 as x UNION ALL SELECT 2 as x"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "| 2 |", "+---+"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -29,20 +29,20 @@ async fn union_all() -> Result<()> { #[tokio::test] async fn csv_union_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 UNION ALL SELECT c1 FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(actual.len(), 200); Ok(()) } #[tokio::test] async fn union_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1 as x UNION SELECT 1 as x"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "+---+"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -50,10 +50,10 @@ async fn union_distinct() -> Result<()> { #[tokio::test] async fn union_all_with_aggregate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT SUM(d) FROM (SELECT 1 as c, 2 as d UNION ALL SELECT 1 as c, 3 AS d) as a"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| SUM(a.d) |", diff --git a/datafusion/tests/sql/window.rs b/datafusion/tests/sql/window.rs index 321ab320f5be..ba493e6fe559 100644 --- a/datafusion/tests/sql/window.rs +++ b/datafusion/tests/sql/window.rs @@ -20,8 +20,8 @@ use super::*; /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_empty_over() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ count(c5) over (), \ @@ -30,7 +30,7 @@ async fn csv_query_window_with_empty_over() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+------------------------------+----------------------------+----------------------------+", "| c9 | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) |", @@ -49,8 +49,8 @@ async fn csv_query_window_with_empty_over() -> Result<()> { /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_partition_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ sum(cast(c4 as Int)) over (partition by c3), \ @@ -61,7 +61,7 @@ async fn csv_query_window_with_partition_by() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", "| c9 | SUM(CAST(aggregate_test_100.c4 AS Int32)) | AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) | MIN(CAST(aggregate_test_100.c4 AS Int32)) |", @@ -79,8 +79,8 @@ async fn csv_query_window_with_partition_by() -> Result<()> { #[tokio::test] async fn csv_query_window_with_order_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ sum(c5) over (order by c9), \ @@ -94,7 +94,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", @@ -112,8 +112,8 @@ async fn csv_query_window_with_order_by() -> Result<()> { #[tokio::test] async fn csv_query_window_with_partition_by_order_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ sum(c5) over (partition by c4 order by c9), \ @@ -127,7 +127,7 @@ async fn csv_query_window_with_partition_by_order_by() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs index c5fba894e686..25e0499e7479 100644 --- a/datafusion/tests/statistics.rs +++ b/datafusion/tests/statistics.rs @@ -29,12 +29,12 @@ use datafusion::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, - prelude::ExecutionContext, scalar::ScalarValue, }; use async_trait::async_trait; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; +use datafusion::prelude::SessionContext; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -42,10 +42,12 @@ use datafusion::execution::runtime_env::RuntimeEnv; struct StatisticsValidation { stats: Statistics, schema: Arc, + /// Session id + session_id: String, } impl StatisticsValidation { - fn new(stats: Statistics, schema: SchemaRef) -> Self { + fn new(stats: Statistics, schema: SchemaRef, session_id: String) -> Self { assert!( stats .column_statistics @@ -54,7 +56,11 @@ impl StatisticsValidation { .unwrap_or(true), "if defined, the column statistics vector length should be the number of fields" ); - Self { stats, schema } + Self { + stats, + schema, + session_id, + } } } @@ -74,6 +80,7 @@ impl TableProvider for StatisticsValidation { filters: &[Expr], // limit is ignored because it is not mandatory for a `TableProvider` to honor it _limit: Option, + session_id: String, ) -> Result> { // Filters should not be pushed down as they are marked as unsupported by default. assert_eq!( @@ -102,6 +109,7 @@ impl TableProvider for StatisticsValidation { total_byte_size: None, }, projected_schema, + session_id, ))) } } @@ -144,7 +152,7 @@ impl ExecutionPlan for StatisticsValidation { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -169,12 +177,19 @@ impl ExecutionPlan for StatisticsValidation { } } } + + fn session_id(&self) -> String { + self.session_id.clone() + } } -fn init_ctx(stats: Statistics, schema: Schema) -> Result { - let mut ctx = ExecutionContext::new(); - let provider: Arc = - Arc::new(StatisticsValidation::new(stats, Arc::new(schema))); +fn init_ctx(stats: Statistics, schema: Schema) -> Result { + let ctx = SessionContext::new(); + let provider: Arc = Arc::new(StatisticsValidation::new( + stats, + Arc::new(schema), + ctx.session_id.clone(), + )); ctx.register_table("stats_table", provider)?; Ok(ctx) } @@ -210,7 +225,7 @@ fn fully_defined() -> (Statistics, Schema) { #[tokio::test] async fn sql_basic() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * from stats_table").await.unwrap(); @@ -228,7 +243,7 @@ async fn sql_basic() -> Result<()> { #[tokio::test] async fn sql_filter() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats, schema)?; + let ctx = init_ctx(stats, schema)?; let df = ctx .sql("SELECT * FROM stats_table WHERE c1 = 5") @@ -249,7 +264,7 @@ async fn sql_filter() -> Result<()> { #[tokio::test] async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); let physical_plan = ctx @@ -284,7 +299,7 @@ async fn sql_limit() -> Result<()> { #[tokio::test] async fn sql_window() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema)?; let df = ctx .sql("SELECT c2, sum(c1) over (partition by c2) FROM stats_table") diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index 17578047378a..f024c639f753 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -69,7 +69,6 @@ use arrow::{ }; use datafusion::{ error::{DataFusionError, Result}, - execution::context::ExecutionContextState, execution::context::QueryPlanner, logical_plan::{Expr, LogicalPlan, UserDefinedLogicalNode}, optimizer::{optimizer::OptimizerRule, utils::optimize_children}, @@ -79,21 +78,21 @@ use datafusion::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalPlanner, RecordBatchStream, SendableRecordBatchStream, Statistics, }, - prelude::{ExecutionConfig, ExecutionContext}, + prelude::{SessionConfig, SessionContext}, }; use fmt::Debug; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use async_trait::async_trait; -use datafusion::execution::context::ExecutionProps; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::{ExecutionProps, SessionState, TaskContext}; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::logical_plan::plan::{Extension, Sort}; use datafusion::logical_plan::{DFSchemaRef, Limit}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. -async fn exec_sql(ctx: &mut ExecutionContext, sql: &str) -> Result { +async fn exec_sql(ctx: &SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) @@ -102,26 +101,24 @@ async fn exec_sql(ctx: &mut ExecutionContext, sql: &str) -> Result { } /// Create a test table. -async fn setup_table(mut ctx: ExecutionContext) -> Result { +async fn setup_table(ctx: SessionContext) -> Result { let sql = "CREATE EXTERNAL TABLE sales(customer_id VARCHAR, revenue BIGINT) STORED AS CSV location 'tests/customer.csv'"; let expected = vec!["++", "++"]; - let s = exec_sql(&mut ctx, sql).await?; + let s = exec_sql(&ctx, sql).await?; let actual = s.lines().collect::>(); assert_eq!(expected, actual, "Creating table"); Ok(ctx) } -async fn setup_table_without_schemas( - mut ctx: ExecutionContext, -) -> Result { +async fn setup_table_without_schemas(ctx: SessionContext) -> Result { let sql = "CREATE EXTERNAL TABLE sales STORED AS CSV location 'tests/customer.csv'"; let expected = vec!["++", "++"]; - let s = exec_sql(&mut ctx, sql).await?; + let s = exec_sql(&ctx, sql).await?; let actual = s.lines().collect::>(); assert_eq!(expected, actual, "Creating table"); @@ -135,10 +132,7 @@ const QUERY: &str = // Run the query using the specified execution context and compare it // to the known result -async fn run_and_compare_query( - mut ctx: ExecutionContext, - description: &str, -) -> Result<()> { +async fn run_and_compare_query(ctx: &SessionContext, description: &str) -> Result<()> { let expected = vec![ "+-------------+---------+", "| customer_id | revenue |", @@ -149,7 +143,7 @@ async fn run_and_compare_query( "+-------------+---------+", ]; - let s = exec_sql(&mut ctx, QUERY).await?; + let s = exec_sql(ctx, QUERY).await?; let actual = s.lines().collect::>(); assert_eq!( @@ -166,7 +160,7 @@ async fn run_and_compare_query( // Run the query using the specified execution context and compare it // to the known result async fn run_and_compare_query_with_auto_schemas( - mut ctx: ExecutionContext, + ctx: &SessionContext, description: &str, ) -> Result<()> { let expected = vec![ @@ -179,7 +173,7 @@ async fn run_and_compare_query_with_auto_schemas( "+----------+----------+", ]; - let s = exec_sql(&mut ctx, QUERY1).await?; + let s = exec_sql(ctx, QUERY1).await?; let actual = s.lines().collect::>(); assert_eq!( @@ -196,15 +190,15 @@ async fn run_and_compare_query_with_auto_schemas( #[tokio::test] // Run the query using default planners and optimizer async fn normal_query_without_schemas() -> Result<()> { - let ctx = setup_table_without_schemas(ExecutionContext::new()).await?; - run_and_compare_query_with_auto_schemas(ctx, "Default context").await + let ctx = setup_table_without_schemas(SessionContext::new()).await?; + run_and_compare_query_with_auto_schemas(&ctx, "Default context").await } #[tokio::test] // Run the query using default planners and optimizer async fn normal_query() -> Result<()> { - let ctx = setup_table(ExecutionContext::new()).await?; - run_and_compare_query(ctx, "Default context").await + let ctx = setup_table(SessionContext::new()).await?; + run_and_compare_query(&ctx, "Default context").await } #[tokio::test] @@ -212,13 +206,13 @@ async fn normal_query() -> Result<()> { async fn topk_query() -> Result<()> { // Note the only difference is that the top let ctx = setup_table(make_topk_context()).await?; - run_and_compare_query(ctx, "Topk context").await + run_and_compare_query(&ctx, "Topk context").await } #[tokio::test] // Run EXPLAIN PLAN and show the plan was in fact rewritten async fn topk_plan() -> Result<()> { - let mut ctx = setup_table(make_topk_context()).await?; + let ctx = setup_table(make_topk_context()).await?; let expected = vec![ "| logical_plan after topk | TopK: k=3 |", @@ -227,7 +221,7 @@ async fn topk_plan() -> Result<()> { ].join("\n"); let explain_query = format!("EXPLAIN VERBOSE {}", QUERY); - let actual_output = exec_sql(&mut ctx, &explain_query).await?; + let actual_output = exec_sql(&ctx, &explain_query).await?; // normalize newlines (output on windows uses \r\n) let actual_output = actual_output.replace("\r\n", "\n"); @@ -247,13 +241,14 @@ async fn topk_plan() -> Result<()> { Ok(()) } -fn make_topk_context() -> ExecutionContext { - let config = ExecutionConfig::new() +fn make_topk_context() -> SessionContext { + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); + let state = SessionState::with_config(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) - .with_target_partitions(48) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - - ExecutionContext::with_config(config) + let ctx = SessionContext::with_state(state); + ctx } // ------ The implementation of the TopK code follows ----- @@ -267,7 +262,7 @@ impl QueryPlanner for TopKQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { // Teach the default physical planner how to plan TopK nodes. let physical_planner = @@ -276,7 +271,7 @@ impl QueryPlanner for TopKQueryPlanner { )]); // Delegate most work of physical planning to the default physical planner physical_planner - .create_physical_plan(logical_plan, ctx_state) + .create_physical_plan(logical_plan, session_state) .await } } @@ -386,7 +381,7 @@ impl ExtensionPlanner for TopKPlanner { node: &dyn UserDefinedLogicalNode, logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - _ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>> { Ok( if let Some(topk_node) = node.as_any().downcast_ref::() { @@ -396,6 +391,7 @@ impl ExtensionPlanner for TopKPlanner { Some(Arc::new(TopKExec { input: physical_inputs[0].clone(), k: topk_node.k, + session_id: session_state.session_id.clone(), })) } else { None @@ -410,6 +406,8 @@ struct TopKExec { input: Arc, /// The maxium number of values k: usize, + /// Session id + session_id: String, } impl Debug for TopKExec { @@ -457,6 +455,7 @@ impl ExecutionPlan for TopKExec { 1 => Ok(Arc::new(TopKExec { input: children[0].clone(), k: self.k, + session_id: self.session_id(), })), _ => Err(DataFusionError::Internal( "TopKExec wrong number of children".to_string(), @@ -468,7 +467,7 @@ impl ExecutionPlan for TopKExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -478,7 +477,7 @@ impl ExecutionPlan for TopKExec { } Ok(Box::pin(TopKReader { - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, k: self.k, done: false, state: BTreeMap::new(), @@ -502,6 +501,10 @@ impl ExecutionPlan for TopKExec { // better statistics inference could be provided Statistics::default() } + + fn session_id(&self) -> String { + self.session_id.clone() + } } // A very specialized TopK implementation diff --git a/docs/source/user-guide/distributed/clients/rust.md b/docs/source/user-guide/distributed/clients/rust.md index ccf19aa70e3c..78e026e63e07 100644 --- a/docs/source/user-guide/distributed/clients/rust.md +++ b/docs/source/user-guide/distributed/clients/rust.md @@ -30,7 +30,7 @@ let config = BallistaConfig::builder() .build()?; // connect to Ballista scheduler -let ctx = BallistaContext::remote("localhost", 50050, &config); +let ctx = BallistaContext::remote("localhost", 50050, &config).await?; ``` Here is a full example using the DataFrame API. @@ -43,7 +43,7 @@ async fn main() -> Result<()> { .build()?; // connect to Ballista scheduler - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; let testdata = datafusion::arrow::util::test_util::parquet_test_data(); @@ -72,7 +72,7 @@ async fn main() -> Result<()> { .build()?; // connect to Ballista scheduler - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; let testdata = datafusion::arrow::util::test_util::arrow_test_data(); diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 77930260e038..6be802400a77 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -36,7 +36,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // register the table - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; // create a plan to run a SQL query @@ -56,7 +56,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // create the dataframe - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; let df = df.filter(col("a").lt_eq(col("b")))?