diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 63bf72cf226c..300c75137d8a 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -392,20 +392,59 @@ impl PartialOrd for Expr { } } +/// Provides schema information needed by [Expr] methods such as +/// [Expr::nullable] and [Expr::data_type]. +/// +/// Note that this trait is implemented for &[DFSchema] which is +/// widely used in the DataFusion codebase. +pub trait ExprSchema { + /// Is this column reference nullable? + fn nullable(&self, col: &Column) -> Result; + + /// What is the datatype of this column? + fn data_type(&self, col: &Column) -> Result<&DataType>; +} + +// Implement `ExprSchema` for `Arc` +impl> ExprSchema for P { + fn nullable(&self, col: &Column) -> Result { + self.as_ref().nullable(col) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + self.as_ref().data_type(col) + } +} + +impl ExprSchema for DFSchema { + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } +} + impl Expr { - /// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema]. + /// Returns the [arrow::datatypes::DataType] of the expression + /// based on [ExprSchema] + /// + /// Note: [DFSchema] implements [ExprSchema]. /// /// # Errors /// - /// This function errors when it is not possible to compute its [arrow::datatypes::DataType]. - /// This happens when e.g. the expression refers to a column that does not exist in the schema, or when - /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`). - pub fn get_type(&self, schema: &DFSchema) -> Result { + /// This function errors when it is not possible to compute its + /// [arrow::datatypes::DataType]. This happens when e.g. the + /// expression refers to a column that does not exist in the + /// schema, or when the expression is incorrectly typed + /// (e.g. `[utf8] + [bool]`). + pub fn get_type(&self, schema: &S) -> Result { match self { Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { expr.get_type(schema) } - Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), + Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -472,13 +511,16 @@ impl Expr { } } - /// Returns the nullability of the expression based on [arrow::datatypes::Schema]. + /// Returns the nullability of the expression based on [ExprSchema]. + /// + /// Note: [DFSchema] implements [ExprSchema]. /// /// # Errors /// - /// This function errors when it is not possible to compute its nullability. - /// This happens when the expression refers to a column that does not exist in the schema. - pub fn nullable(&self, input_schema: &DFSchema) -> Result { + /// This function errors when it is not possible to compute its + /// nullability. This happens when the expression refers to a + /// column that does not exist in the schema. + pub fn nullable(&self, input_schema: &S) -> Result { match self { Expr::Alias(expr, _) | Expr::Not(expr) @@ -486,7 +528,7 @@ impl Expr { | Expr::Sort { expr, .. } | Expr::Between { expr, .. } | Expr::InList { expr, .. } => expr.nullable(input_schema), - Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), + Expr::Column(c) => input_schema.nullable(c), Expr::Literal(value) => Ok(value.is_null()), Expr::Case { when_then_expr, @@ -561,7 +603,11 @@ impl Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + pub fn cast_to( + self, + cast_to_type: &DataType, + schema: &S, + ) -> Result { // TODO(kszucs): most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? @@ -2557,4 +2603,57 @@ mod tests { combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]); assert_eq!(result, Some(and(and(filter1, filter2), filter3))); } + + #[test] + fn expr_schema_nullability() { + let expr = col("foo").eq(lit(1)); + assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); + assert!(expr + .nullable(&MockExprSchema::new().with_nullable(true)) + .unwrap()); + } + + #[test] + fn expr_schema_data_type() { + let expr = col("foo"); + assert_eq!( + DataType::Utf8, + expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) + .unwrap() + ); + } + + struct MockExprSchema { + nullable: bool, + data_type: DataType, + } + + impl MockExprSchema { + fn new() -> Self { + Self { + nullable: false, + data_type: DataType::Null, + } + } + + fn with_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } + + fn with_data_type(mut self, data_type: DataType) -> Self { + self.data_type = data_type; + self + } + } + + impl ExprSchema for MockExprSchema { + fn nullable(&self, _col: &Column) -> Result { + Ok(self.nullable) + } + + fn data_type(&self, _col: &Column) -> Result<&DataType> { + Ok(&self.data_type) + } + } } diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 25714514d78a..22521a1bd1fb 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -46,8 +46,8 @@ pub use expr::{ rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, - Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, - SimplifyInfo, + Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion, + RewriteRecursion, SimplifyInfo, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator;