diff --git a/datafusion/core/tests/sql/cast.rs b/datafusion/core/tests/sql/cast.rs index 61bac0eb2c22..36213673fb6b 100644 --- a/datafusion/core/tests/sql/cast.rs +++ b/datafusion/core/tests/sql/cast.rs @@ -26,6 +26,13 @@ async fn execute_sql(sql: &str) -> Vec { execute_to_batches(&ctx, sql).await } +#[tokio::test] +async fn cast_from_subquery() -> Result<()> { + let actual = execute_sql("SELECT cast(c as varchar) FROM (SELECT 1 as c)").await; + assert_eq!(&DataType::Utf8, actual[0].schema().field(0).data_type()); + Ok(()) +} + #[tokio::test] async fn cast_tinyint() -> Result<()> { let actual = execute_sql("SELECT cast(10 as tinyint)").await; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 47939b73348d..2be9a6465df6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -25,7 +25,7 @@ use crate::logical_plan::{ Limit, Partitioning, Projection, Repartition, Sort, Subquery, SubqueryAlias, Union, Values, Window, }; -use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; +use crate::{Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -676,6 +676,10 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { Expr::Alias(inner_expr, name) => { Expr::Alias(Box::new(columnize_expr(*inner_expr, input_schema)), name) } + Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { + expr: Box::new(columnize_expr(*expr, input_schema)), + data_type, + }), Expr::ScalarSubquery(_) => e.clone(), _ => match e.display_name() { Ok(name) => match input_schema.field_with_unqualified_name(&name) { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b51ce83d5d01..1ecd110c0b38 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2974,6 +2974,17 @@ mod tests { ); } + #[test] + fn cast_from_subquery() { + quick_test( + "SELECT CAST (a AS FLOAT) FROM (SELECT 1 AS a)", + "Projection: CAST(a AS Float32)\ + \n Projection: a\ + \n Projection: Int64(1) AS a\ + \n EmptyRelation", + ); + } + #[test] fn cast_to_invalid_decimal_type() { // precision == 0