From 1f0e514643c6fc484a2e03a511d9321644296ed2 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 28 Dec 2022 13:11:05 +0800 Subject: [PATCH 1/3] consider union schema Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/core/src/physical_plan/planner.rs | 11 +++++-- datafusion/core/src/physical_plan/union.rs | 33 ++++++++++++++++++++ datafusion/core/tests/sql/union.rs | 17 ++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 768c42978936..7158277f4e35 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -773,12 +773,19 @@ impl DefaultPhysicalPlanner { )?; Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) } - LogicalPlan::Union(Union { inputs, .. }) => { + LogicalPlan::Union(Union { inputs, schema }) => { let physical_plans = futures::stream::iter(inputs) .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; - Ok(Arc::new(UnionExec::new(physical_plans))) + if schema.fields().len() < physical_plans[0].schema().fields().len() { + // `schema` could be a subset of the child schema. For example + // for query "select count(*) from (select a from t union all select a from t)" + // `schema` is empty but child schema contains one field `a`. + Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) + } else { + Ok(Arc::new(UnionExec::new(physical_plans))) + } } LogicalPlan::Repartition(Repartition { input, diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index af57c9ef9cc2..8d17b14bdf1c 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -30,6 +30,7 @@ use arrow::{ datatypes::{Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_common::{DFSchemaRef, DataFusionError}; use futures::{Stream, StreamExt}; use itertools::Itertools; use log::debug; @@ -63,6 +64,38 @@ pub struct UnionExec { } impl UnionExec { + /// Create a new UnionExec with specified schema. + /// The `schema` should always be a subset of the schema of `inputs`, + /// otherwise, an error will be returned. + pub fn try_new_with_schema( + inputs: Vec>, + schema: DFSchemaRef, + ) -> Result { + let mut exec = Self::new(inputs); + let exec_schema = exec.schema(); + let fields = schema + .fields() + .iter() + .map(|dff| { + exec_schema + .field_with_name(dff.name()) + .cloned() + .map_err(|_| { + DataFusionError::Internal(format!( + "Cannot find the field {:?} in child schema", + dff.name() + )) + }) + }) + .collect::>>()?; + let schema = Arc::new(Schema::new_with_metadata( + fields, + exec.schema().metadata().clone(), + )); + exec.schema = schema; + Ok(exec) + } + /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let fields: Vec = (0..inputs[0].schema().fields().len()) diff --git a/datafusion/core/tests/sql/union.rs b/datafusion/core/tests/sql/union.rs index 29856a37b1a9..7547a01d6afe 100644 --- a/datafusion/core/tests/sql/union.rs +++ b/datafusion/core/tests/sql/union.rs @@ -80,6 +80,23 @@ async fn union_all_with_aggregate() -> Result<()> { Ok(()) } +#[tokio::test] +async fn union_all_with_count() -> Result<()> { + let ctx = SessionContext::new(); + let sql = + "SELECT COUNT(*) FROM (SELECT 1 as c, 2 as d UNION ALL SELECT 1 as c, 3 AS d)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 2 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn union_schemas() -> Result<()> { let ctx = From 3b4122c0dc284d8d40c243bf8634e00c06ef84f5 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 28 Dec 2022 13:51:24 +0800 Subject: [PATCH 2/3] update test Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/core/tests/sql/union.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/union.rs b/datafusion/core/tests/sql/union.rs index 7547a01d6afe..7e977de9bd74 100644 --- a/datafusion/core/tests/sql/union.rs +++ b/datafusion/core/tests/sql/union.rs @@ -83,8 +83,9 @@ async fn union_all_with_aggregate() -> Result<()> { #[tokio::test] async fn union_all_with_count() -> Result<()> { let ctx = SessionContext::new(); + execute_to_batches(&ctx, "CREATE table t as SELECT 1 as a").await; let sql = - "SELECT COUNT(*) FROM (SELECT 1 as c, 2 as d UNION ALL SELECT 1 as c, 3 AS d)"; + "SELECT COUNT(*) FROM (SELECT a from t UNION ALL SELECT a from t)"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", From 445d0627ca39fb818301b125a4d39ac1b0f7af20 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 28 Dec 2022 14:05:25 +0800 Subject: [PATCH 3/3] fmt Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/core/tests/sql/union.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/core/tests/sql/union.rs b/datafusion/core/tests/sql/union.rs index 7e977de9bd74..ac0e39f4d479 100644 --- a/datafusion/core/tests/sql/union.rs +++ b/datafusion/core/tests/sql/union.rs @@ -84,8 +84,7 @@ async fn union_all_with_aggregate() -> Result<()> { async fn union_all_with_count() -> Result<()> { let ctx = SessionContext::new(); execute_to_batches(&ctx, "CREATE table t as SELECT 1 as a").await; - let sql = - "SELECT COUNT(*) FROM (SELECT a from t UNION ALL SELECT a from t)"; + let sql = "SELECT COUNT(*) FROM (SELECT a from t UNION ALL SELECT a from t)"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+",