From 014e5e90d623befd9f3e179b02864a6b8bcab568 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 9 Feb 2022 12:50:38 +0800 Subject: [PATCH] move accumulator and columnar value (#1762) --- datafusion-expr/Cargo.toml | 1 + datafusion-expr/src/expr.rs | 698 ++++++++++++++++++ datafusion-expr/src/expr_fn.rs | 32 + datafusion-expr/src/function.rs | 46 ++ datafusion-expr/src/lib.rs | 15 + datafusion-expr/src/literal.rs | 138 ++++ datafusion-expr/src/operator.rs | 43 ++ datafusion-expr/src/udaf.rs | 92 +++ datafusion-expr/src/udf.rs | 93 +++ datafusion/src/execution/dataframe_impl.rs | 4 +- datafusion/src/logical_plan/expr.rs | 811 +-------------------- datafusion/src/logical_plan/mod.rs | 3 +- datafusion/src/logical_plan/operators.rs | 42 -- datafusion/src/physical_plan/aggregates.rs | 11 +- datafusion/src/physical_plan/udaf.rs | 83 +-- datafusion/src/physical_plan/udf.rs | 85 +-- datafusion/src/sql/planner.rs | 3 +- 17 files changed, 1187 insertions(+), 1013 deletions(-) create mode 100644 datafusion-expr/src/expr.rs create mode 100644 datafusion-expr/src/expr_fn.rs create mode 100644 datafusion-expr/src/function.rs create mode 100644 datafusion-expr/src/literal.rs create mode 100644 datafusion-expr/src/udaf.rs create mode 100644 datafusion-expr/src/udf.rs diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml index 73a5fcd36152..a6dad528b6b7 100644 --- a/datafusion-expr/Cargo.toml +++ b/datafusion-expr/Cargo.toml @@ -38,3 +38,4 @@ path = "src/lib.rs" datafusion-common = { path = "../datafusion-common", version = "6.0.0" } arrow = { version = "8.0.0", features = ["prettyprint"] } sqlparser = "0.13" +ahash = { version = "0.7", default-features = false } diff --git a/datafusion-expr/src/expr.rs b/datafusion-expr/src/expr.rs new file mode 100644 index 000000000000..f26f1dfa9746 --- /dev/null +++ b/datafusion-expr/src/expr.rs @@ -0,0 +1,698 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregate_function; +use crate::built_in_function; +use crate::expr_fn::binary_expr; +use crate::window_frame; +use crate::window_function; +use crate::AggregateUDF; +use crate::Operator; +use crate::ScalarUDF; +use arrow::datatypes::DataType; +use datafusion_common::Column; +use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::fmt; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::ops::Not; +use std::sync::Arc; + +/// `Expr` is a central struct of DataFusion's query API, and +/// represent logical expressions such as `A + 1`, or `CAST(c1 AS +/// int)`. +/// +/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) +/// and nullability, and has functions for building up complex +/// expressions. +/// +/// # Examples +/// +/// ## Create an expression `c1` referring to column named "c1" +/// ``` +/// # use datafusion_common::Column; +/// # use datafusion_expr::{lit, col, Expr}; +/// let expr = col("c1"); +/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); +/// ``` +/// +/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together +/// ``` +/// # use datafusion_expr::{lit, col, Operator, Expr}; +/// let expr = col("c1") + col("c2"); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// assert_eq!(*right, col("c2")); +/// assert_eq!(op, Operator::Plus); +/// } +/// ``` +/// +/// ## Create expression `c1 = 42` to compare the value in column "c1" to the literal value `42` +/// ``` +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{lit, col, Operator, Expr}; +/// let expr = col("c1").eq(lit(42_i32)); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { .. } )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// let scalar = ScalarValue::Int32(Some(42)); +/// assert_eq!(*right, Expr::Literal(scalar)); +/// assert_eq!(op, Operator::Eq); +/// } +/// ``` +#[derive(Clone, PartialEq, Hash)] +pub enum Expr { + /// An expression with a specific name. + Alias(Box, String), + /// A named reference to a qualified filed in a schema. + Column(Column), + /// A named reference to a variable in a registry. + ScalarVariable(Vec), + /// A constant value. + Literal(ScalarValue), + /// A binary expression such as "age > 21" + BinaryExpr { + /// Left-hand side of the expression + left: Box, + /// The comparison operator + op: Operator, + /// Right-hand side of the expression + right: Box, + }, + /// Negation of an expression. The expression's type must be a boolean to make sense. + Not(Box), + /// Whether an expression is not Null. This expression is never null. + IsNotNull(Box), + /// Whether an expression is Null. This expression is never null. + IsNull(Box), + /// arithmetic negation of an expression, the operand must be of a signed numeric data type + Negative(Box), + /// Returns the field of a [`ListArray`] or [`StructArray`] by key + GetIndexedField { + /// the expression to take the field from + expr: Box, + /// The name of the field to take + key: ScalarValue, + }, + /// Whether an expression is between a given range. + Between { + /// The value to compare + expr: Box, + /// Whether the expression is negated + negated: bool, + /// The low end of the range + low: Box, + /// The high end of the range + high: Box, + }, + /// The CASE expression is similar to a series of nested if/else and there are two forms that + /// can be used. The first form consists of a series of boolean "when" expressions with + /// corresponding "then" expressions, and an optional "else" expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + /// + /// The second form uses a base expression and then a series of "when" clauses that match on a + /// literal value. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + Case { + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option>, + /// One or more when/then expressions + when_then_expr: Vec<(Box, Box)>, + /// Optional "else" expression + else_expr: Option>, + }, + /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + Cast { + /// The expression being cast + expr: Box, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// Casts the expression to a given type and will return a null value if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + TryCast { + /// The expression being cast + expr: Box, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// A sort expression, that can be used to sort values. + Sort { + /// The expression to sort on + expr: Box, + /// The direction of the sort + asc: bool, + /// Whether to put Nulls before all other data values + nulls_first: bool, + }, + /// Represents the call of a built-in scalar function with a set of arguments. + ScalarFunction { + /// The function + fun: built_in_function::BuiltinScalarFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Represents the call of a user-defined scalar function with arguments. + ScalarUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Represents the call of an aggregate built-in function with arguments. + AggregateFunction { + /// Name of the function + fun: aggregate_function::AggregateFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// Whether this is a DISTINCT aggregation or not + distinct: bool, + }, + /// Represents the call of a window function with arguments. + WindowFunction { + /// Name of the function + fun: window_function::WindowFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// List of partition by expressions + partition_by: Vec, + /// List of order by expressions + order_by: Vec, + /// Window frame + window_frame: Option, + }, + /// aggregate function + AggregateUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Returns whether the list contains the expr value. + InList { + /// The expression to compare + expr: Box, + /// A list of values to compare against + list: Vec, + /// Whether the expression is negated + negated: bool, + }, + /// Represents a reference to all fields in a schema. + Wildcard, +} + +/// Fixed seed for the hashing so that Ords are consistent across runs +const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); + +impl PartialOrd for Expr { + fn partial_cmp(&self, other: &Self) -> Option { + let mut hasher = SEED.build_hasher(); + self.hash(&mut hasher); + let s = hasher.finish(); + + let mut hasher = SEED.build_hasher(); + other.hash(&mut hasher); + let o = hasher.finish(); + + Some(s.cmp(&o)) + } +} + +impl Expr { + /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. + /// + /// This represents how a column with this expression is named when no alias is chosen + pub fn name(&self, input_schema: &DFSchema) -> Result { + create_name(self, input_schema) + } + + /// Return `self == other` + pub fn eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::Eq, other) + } + + /// Return `self != other` + pub fn not_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotEq, other) + } + + /// Return `self > other` + pub fn gt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Gt, other) + } + + /// Return `self >= other` + pub fn gt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::GtEq, other) + } + + /// Return `self < other` + pub fn lt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Lt, other) + } + + /// Return `self <= other` + pub fn lt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::LtEq, other) + } + + /// Return `self && other` + pub fn and(self, other: Expr) -> Expr { + binary_expr(self, Operator::And, other) + } + + /// Return `self || other` + pub fn or(self, other: Expr) -> Expr { + binary_expr(self, Operator::Or, other) + } + + /// Return `!self` + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Expr { + !self + } + + /// Calculate the modulus of two expressions. + /// Return `self % other` + pub fn modulus(self, other: Expr) -> Expr { + binary_expr(self, Operator::Modulo, other) + } + + /// Return `self LIKE other` + pub fn like(self, other: Expr) -> Expr { + binary_expr(self, Operator::Like, other) + } + + /// Return `self NOT LIKE other` + pub fn not_like(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotLike, other) + } + + /// Return `self AS name` alias expression + pub fn alias(self, name: &str) -> Expr { + Expr::Alias(Box::new(self), name.to_owned()) + } + + /// Return `self IN ` if `negated` is false, otherwise + /// return `self NOT IN `.a + pub fn in_list(self, list: Vec, negated: bool) -> Expr { + Expr::InList { + expr: Box::new(self), + list, + negated, + } + } + + /// Return `IsNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_null(self) -> Expr { + Expr::IsNull(Box::new(self)) + } + + /// Return `IsNotNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_not_null(self) -> Expr { + Expr::IsNotNull(Box::new(self)) + } + + /// Create a sort expression from an existing expression. + /// + /// ``` + /// # use datafusion_expr::col; + /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST + /// ``` + pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { + Expr::Sort { + expr: Box::new(self), + asc, + nulls_first, + } + } +} + +impl Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + Expr::Not(Box::new(self)) + } +} + +impl std::fmt::Display for Expr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => write!(f, "{} {} {}", left, op, right), + Expr::AggregateFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + /// Whether this is a DISTINCT aggregation or not + ref distinct, + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::ScalarFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + } => fmt_function(f, &fun.to_string(), false, args, true), + _ => write!(f, "{:?}", self), + } + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), + Expr::Column(c) => write!(f, "{}", c), + Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), + Expr::Literal(v) => write!(f, "{:?}", v), + Expr::Case { + expr, + when_then_expr, + else_expr, + .. + } => { + write!(f, "CASE ")?; + if let Some(e) = expr { + write!(f, "{:?} ", e)?; + } + for (w, t) in when_then_expr { + write!(f, "WHEN {:?} THEN {:?} ", w, t)?; + } + if let Some(e) = else_expr { + write!(f, "ELSE {:?} ", e)?; + } + write!(f, "END") + } + Expr::Cast { expr, data_type } => { + write!(f, "CAST({:?} AS {:?})", expr, data_type) + } + Expr::TryCast { expr, data_type } => { + write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) + } + Expr::Not(expr) => write!(f, "NOT {:?}", expr), + Expr::Negative(expr) => write!(f, "(- {:?})", expr), + Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), + Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), + Expr::BinaryExpr { left, op, right } => { + write!(f, "{:?} {} {:?}", left, op, right) + } + Expr::Sort { + expr, + asc, + nulls_first, + } => { + if *asc { + write!(f, "{:?} ASC", expr)?; + } else { + write!(f, "{:?} DESC", expr)?; + } + if *nulls_first { + write!(f, " NULLS FIRST") + } else { + write!(f, " NULLS LAST") + } + } + Expr::ScalarFunction { fun, args, .. } => { + fmt_function(f, &fun.to_string(), false, args, false) + } + Expr::ScalarUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + fmt_function(f, &fun.to_string(), false, args, false)?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY {:?}", partition_by)?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY {:?}", order_by)?; + } + if let Some(window_frame) = window_frame { + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + )?; + } + Ok(()) + } + Expr::AggregateFunction { + fun, + distinct, + ref args, + .. + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::AggregateUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + if *negated { + write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) + } else { + write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) + } + } + Expr::InList { + expr, + list, + negated, + } => { + if *negated { + write!(f, "{:?} NOT IN ({:?})", expr, list) + } else { + write!(f, "{:?} IN ({:?})", expr, list) + } + } + Expr::Wildcard => write!(f, "*"), + Expr::GetIndexedField { ref expr, key } => { + write!(f, "({:?})[{}]", expr, key) + } + } + } +} + +fn fmt_function( + f: &mut fmt::Formatter, + fun: &str, + distinct: bool, + args: &[Expr], + display: bool, +) -> fmt::Result { + let args: Vec = match display { + true => args.iter().map(|arg| format!("{}", arg)).collect(), + false => args.iter().map(|arg| format!("{:?}", arg)).collect(), + }; + + // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) +} + +fn create_function_name( + fun: &str, + distinct: bool, + args: &[Expr], + input_schema: &DFSchema, +) -> Result { + let names: Vec = args + .iter() + .map(|e| create_name(e, input_schema)) + .collect::>()?; + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) +} + +/// Returns a readable name of an expression based on the input schema. +/// This function recursively transverses the expression for names such as "CAST(a > 2)". +fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { + match e { + Expr::Alias(_, name) => Ok(name.clone()), + Expr::Column(c) => Ok(c.flat_name()), + Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = create_name(left, input_schema)?; + let right = create_name(right, input_schema)?; + Ok(format!("{} {} {}", left, op, right)) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut name = "CASE ".to_string(); + if let Some(e) = expr { + let e = create_name(e, input_schema)?; + name += &format!("{} ", e); + } + for (w, t) in when_then_expr { + let when = create_name(w, input_schema)?; + let then = create_name(t, input_schema)?; + name += &format!("WHEN {} THEN {} ", when, then); + } + if let Some(e) = else_expr { + let e = create_name(e, input_schema)?; + name += &format!("ELSE {} ", e); + } + name += "END"; + Ok(name) + } + Expr::Cast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("CAST({} AS {:?})", expr, data_type)) + } + Expr::TryCast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + } + Expr::Not(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("NOT {}", expr)) + } + Expr::Negative(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("(- {})", expr)) + } + Expr::IsNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NULL", expr)) + } + Expr::IsNotNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NOT NULL", expr)) + } + Expr::GetIndexedField { expr, key } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{}[{}]", expr, key)) + } + Expr::ScalarFunction { fun, args, .. } => { + create_function_name(&fun.to_string(), false, args, input_schema) + } + Expr::ScalarUDF { fun, args, .. } => { + create_function_name(&fun.name, false, args, input_schema) + } + Expr::WindowFunction { + fun, + args, + window_frame, + partition_by, + order_by, + } => { + let mut parts: Vec = vec![create_function_name( + &fun.to_string(), + false, + args, + input_schema, + )?]; + if !partition_by.is_empty() { + parts.push(format!("PARTITION BY {:?}", partition_by)); + } + if !order_by.is_empty() { + parts.push(format!("ORDER BY {:?}", order_by)); + } + if let Some(window_frame) = window_frame { + parts.push(format!("{}", window_frame)); + } + Ok(parts.join(" ")) + } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = create_name(expr, input_schema)?; + let list = list.iter().map(|expr| create_name(expr, input_schema)); + if *negated { + Ok(format!("{} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{} IN ({:?})", expr, list)) + } + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = create_name(expr, input_schema)?; + let low = create_name(low, input_schema)?; + let high = create_name(high, input_schema)?; + if *negated { + Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) + } else { + Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) + } + } + Expr::Sort { .. } => Err(DataFusionError::Internal( + "Create name does not support sort expression".to_string(), + )), + Expr::Wildcard => Err(DataFusionError::Internal( + "Create name does not support wildcard".to_string(), + )), + } +} diff --git a/datafusion-expr/src/expr_fn.rs b/datafusion-expr/src/expr_fn.rs new file mode 100644 index 000000000000..469a82d0ff24 --- /dev/null +++ b/datafusion-expr/src/expr_fn.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{Expr, Operator}; + +/// Create a column expression based on a qualified or unqualified column name +pub fn col(ident: &str) -> Expr { + Expr::Column(ident.into()) +} + +/// return a new expression l r +pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + } +} diff --git a/datafusion-expr/src/function.rs b/datafusion-expr/src/function.rs new file mode 100644 index 000000000000..2bacd6ae6227 --- /dev/null +++ b/datafusion-expr/src/function.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Accumulator; +use crate::ColumnarValue; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::sync::Arc; + +/// Scalar function +/// +/// The Fn param is the wrapped function but be aware that the function will +/// be passed with the slice / vec of columnar values (either scalar or array) +/// with the exception of zero param function, where a singular element vec +/// will be passed. In that case the single element is a null array to indicate +/// the batch's row count (so that the generative zero-argument function can know +/// the result array size). +pub type ScalarFunctionImplementation = + Arc Result + Send + Sync>; + +/// A function's return type +pub type ReturnTypeFunction = + Arc Result> + Send + Sync>; + +/// the implementation of an aggregate function +pub type AccumulatorFunctionImplementation = + Arc Result> + Send + Sync>; + +/// This signature corresponds to which types an aggregator serializes +/// its state, given its return datatype. +pub type StateTypeFunction = + Arc Result>> + Send + Sync>; diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index 2491fcf73ca9..709fa634d52d 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -19,8 +19,14 @@ mod accumulator; mod aggregate_function; mod built_in_function; mod columnar_value; +pub mod expr; +pub mod expr_fn; +mod function; +mod literal; mod operator; mod signature; +mod udaf; +mod udf; mod window_frame; mod window_function; @@ -28,7 +34,16 @@ pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::{ColumnarValue, NullColumnarValue}; +pub use expr::Expr; +pub use expr_fn::col; +pub use function::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation, + StateTypeFunction, +}; +pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; +pub use udaf::AggregateUDF; +pub use udf::ScalarUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/literal.rs b/datafusion-expr/src/literal.rs new file mode 100644 index 000000000000..02c75af69573 --- /dev/null +++ b/datafusion-expr/src/literal.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Expr; +use datafusion_common::ScalarValue; + +/// Create a literal expression +pub fn lit(n: T) -> Expr { + n.lit() +} + +/// Create a literal timestamp expression +pub fn lit_timestamp_nano(n: T) -> Expr { + n.lit_timestamp_nano() +} + +/// Trait for converting a type to a [`Literal`] literal expression. +pub trait Literal { + /// convert the value to a Literal expression + fn lit(&self) -> Expr; +} + +/// Trait for converting a type to a literal timestamp +pub trait TimestampLiteral { + fn lit_timestamp_nano(&self) -> Expr; +} + +impl Literal for &str { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + +impl Literal for String { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + +impl Literal for Vec { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + } +} + +impl Literal for &[u8] { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + } +} + +impl Literal for ScalarValue { + fn lit(&self) -> Expr { + Expr::Literal(self.clone()) + } +} + +macro_rules! make_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl Literal for $TYPE { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + } + } + }; +} + +macro_rules! make_timestamp_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl TimestampLiteral for $TYPE { + fn lit_timestamp_nano(&self) -> Expr { + Expr::Literal(ScalarValue::TimestampNanosecond( + Some((self.clone()).into()), + None, + )) + } + } + }; +} + +make_literal!(bool, Boolean, "literal expression containing a bool"); +make_literal!(f32, Float32, "literal expression containing an f32"); +make_literal!(f64, Float64, "literal expression containing an f64"); +make_literal!(i8, Int8, "literal expression containing an i8"); +make_literal!(i16, Int16, "literal expression containing an i16"); +make_literal!(i32, Int32, "literal expression containing an i32"); +make_literal!(i64, Int64, "literal expression containing an i64"); +make_literal!(u8, UInt8, "literal expression containing a u8"); +make_literal!(u16, UInt16, "literal expression containing a u16"); +make_literal!(u32, UInt32, "literal expression containing a u32"); +make_literal!(u64, UInt64, "literal expression containing a u64"); + +make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); +make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); +make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); +make_timestamp_literal!(i64, Int64, "literal expression containing an i64"); +make_timestamp_literal!(u8, UInt8, "literal expression containing a u8"); +make_timestamp_literal!(u16, UInt16, "literal expression containing a u16"); +make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); + +#[cfg(test)] +mod test { + use super::*; + use crate::expr_fn::col; + use datafusion_common::ScalarValue; + + #[test] + fn test_lit_timestamp_nano() { + let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 + let expected = + col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); + assert_eq!(expr, expected); + + let i: i64 = 10; + let expr = col("time").eq(lit_timestamp_nano(i)); + assert_eq!(expr, expected); + + let i: u32 = 10; + let expr = col("time").eq(lit_timestamp_nano(i)); + assert_eq!(expr, expected); + } +} diff --git a/datafusion-expr/src/operator.rs b/datafusion-expr/src/operator.rs index e6b7e35a0a5e..a1cad76cdd97 100644 --- a/datafusion-expr/src/operator.rs +++ b/datafusion-expr/src/operator.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::expr_fn::binary_expr; +use crate::Expr; use std::fmt; +use std::ops; /// Operators applied to expressions #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -95,3 +98,43 @@ impl fmt::Display for Operator { write!(f, "{}", display) } } + +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiply, rhs) + } +} + +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Divide, rhs) + } +} + +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} diff --git a/datafusion-expr/src/udaf.rs b/datafusion-expr/src/udaf.rs new file mode 100644 index 000000000000..a39d58b622f3 --- /dev/null +++ b/datafusion-expr/src/udaf.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains functions and structs supporting user-defined aggregate functions. + +use crate::Expr; +use crate::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction, +}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +/// Logical representation of a user-defined aggregate function (UDAF) +/// A UDAF is different from a UDF in that it is stateful across batches. +#[derive(Clone)] +pub struct AggregateUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + pub accumulator: AccumulatorFunctionImplementation, + /// the accumulator's state's description as a function of the return type + pub state_type: StateTypeFunction, +} + +impl Debug for AggregateUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for AggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl AggregateUDF { + /// Create a new AggregateUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + accumulator: &AccumulatorFunctionImplementation, + state_type: &StateTypeFunction, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + accumulator: accumulator.clone(), + state_type: state_type.clone(), + } + } + + /// creates a logical expression with a call of the UDAF + /// This utility allows using the UDAF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::AggregateUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion-expr/src/udf.rs b/datafusion-expr/src/udf.rs new file mode 100644 index 000000000000..79a17a4a2b4b --- /dev/null +++ b/datafusion-expr/src/udf.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDF support + +use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use std::fmt; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Logical representation of a UDF. +#[derive(Clone)] +pub struct ScalarUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + pub fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for ScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl ScalarUDF { + /// Create a new ScalarUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + fun: &ScalarFunctionImplementation, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + fun: fun.clone(), + } + } + + /// creates a logical expression with a call of the UDF + /// This utility allows using the UDF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::ScalarUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 3fcaa28af973..0e3cc61f3b5a 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -321,12 +321,12 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; - use crate::physical_plan::functions::ScalarFunctionImplementation; - use crate::physical_plan::functions::Volatility; use crate::physical_plan::{window_functions, ColumnarValue}; use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext}; use crate::{logical_plan::*, test_util}; use arrow::datatypes::DataType; + use datafusion_expr::ScalarFunctionImplementation; + use datafusion_expr::Volatility; #[tokio::test] async fn select_columns() -> Result<()> { diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index f19e9d8d6a35..de052983f770 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -21,379 +21,22 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; use crate::logical_plan::ExprSchemable; -use crate::logical_plan::{window_frames, DFField, DFSchema}; -use crate::physical_plan::functions::Volatility; -use crate::physical_plan::{aggregates, functions, udf::ScalarUDF, window_functions}; -use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; +use crate::logical_plan::{DFField, DFSchema}; +use crate::physical_plan::udaf::AggregateUDF; +use crate::physical_plan::{aggregates, functions, udf::ScalarUDF}; use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; -use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +pub use datafusion_expr::expr_fn::col; +use datafusion_expr::AccumulatorFunctionImplementation; +pub use datafusion_expr::Expr; +use datafusion_expr::StateTypeFunction; +pub use datafusion_expr::{lit, lit_timestamp_nano, Literal}; +use datafusion_expr::{ + ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, +}; use std::collections::HashSet; -use std::fmt; -use std::hash::{BuildHasher, Hash, Hasher}; -use std::ops::Not; use std::sync::Arc; -/// `Expr` is a central struct of DataFusion's query API, and -/// represent logical expressions such as `A + 1`, or `CAST(c1 AS -/// int)`. -/// -/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) -/// and nullability, and has functions for building up complex -/// expressions. -/// -/// # Examples -/// -/// ## Create an expression `c1` referring to column named "c1" -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); -/// ``` -/// -/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1") + col("c2"); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// assert_eq!(*right, col("c2")); -/// assert_eq!(op, Operator::Plus); -/// } -/// ``` -/// -/// ## Create expression `c1 = 42` to compare the value in coumn "c1" to the literal value `42` -/// ``` -/// # use datafusion::logical_plan::*; -/// # use datafusion::scalar::*; -/// let expr = col("c1").eq(lit(42)); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*right, Expr::Literal(scalar)); -/// assert_eq!(op, Operator::Eq); -/// } -/// ``` -#[derive(Clone, PartialEq, Hash)] -pub enum Expr { - /// An expression with a specific name. - Alias(Box, String), - /// A named reference to a qualified filed in a schema. - Column(Column), - /// A named reference to a variable in a registry. - ScalarVariable(Vec), - /// A constant value. - Literal(ScalarValue), - /// A binary expression such as "age > 21" - BinaryExpr { - /// Left-hand side of the expression - left: Box, - /// The comparison operator - op: Operator, - /// Right-hand side of the expression - right: Box, - }, - /// Negation of an expression. The expression's type must be a boolean to make sense. - Not(Box), - /// Whether an expression is not Null. This expression is never null. - IsNotNull(Box), - /// Whether an expression is Null. This expression is never null. - IsNull(Box), - /// arithmetic negation of an expression, the operand must be of a signed numeric data type - Negative(Box), - /// Returns the field of a [`ListArray`] or [`StructArray`] by key - GetIndexedField { - /// the expression to take the field from - expr: Box, - /// The name of the field to take - key: ScalarValue, - }, - /// Whether an expression is between a given range. - Between { - /// The value to compare - expr: Box, - /// Whether the expression is negated - negated: bool, - /// The low end of the range - low: Box, - /// The high end of the range - high: Box, - }, - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - Case { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec<(Box, Box)>, - /// Optional "else" expression - else_expr: Option>, - }, - /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - Cast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// Casts the expression to a given type and will return a null value if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - TryCast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// A sort expression, that can be used to sort values. - Sort { - /// The expression to sort on - expr: Box, - /// The direction of the sort - asc: bool, - /// Whether to put Nulls before all other data values - nulls_first: bool, - }, - /// Represents the call of a built-in scalar function with a set of arguments. - ScalarFunction { - /// The function - fun: functions::BuiltinScalarFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Represents the call of an aggregate built-in function with arguments. - AggregateFunction { - /// Name of the function - fun: aggregates::AggregateFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// Whether this is a DISTINCT aggregation or not - distinct: bool, - }, - /// Represents the call of a window function with arguments. - WindowFunction { - /// Name of the function - fun: window_functions::WindowFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// List of partition by expressions - partition_by: Vec, - /// List of order by expressions - order_by: Vec, - /// Window frame - window_frame: Option, - }, - /// aggregate function - AggregateUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Returns whether the list contains the expr value. - InList { - /// The expression to compare - expr: Box, - /// A list of values to compare against - list: Vec, - /// Whether the expression is negated - negated: bool, - }, - /// Represents a reference to all fields in a schema. - Wildcard, -} - -/// Fixed seed for the hashing so that Ords are consistent across runs -const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); - -impl PartialOrd for Expr { - fn partial_cmp(&self, other: &Self) -> Option { - let mut hasher = SEED.build_hasher(); - self.hash(&mut hasher); - let s = hasher.finish(); - - let mut hasher = SEED.build_hasher(); - other.hash(&mut hasher); - let o = hasher.finish(); - - Some(s.cmp(&o)) - } -} - -impl Expr { - /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. - /// - /// This represents how a column with this expression is named when no alias is chosen - pub fn name(&self, input_schema: &DFSchema) -> Result { - create_name(self, input_schema) - } - - /// Return `self == other` - pub fn eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::Eq, other) - } - - /// Return `self != other` - pub fn not_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotEq, other) - } - - /// Return `self > other` - pub fn gt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Gt, other) - } - - /// Return `self >= other` - pub fn gt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::GtEq, other) - } - - /// Return `self < other` - pub fn lt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Lt, other) - } - - /// Return `self <= other` - pub fn lt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::LtEq, other) - } - - /// Return `self && other` - pub fn and(self, other: Expr) -> Expr { - binary_expr(self, Operator::And, other) - } - - /// Return `self || other` - pub fn or(self, other: Expr) -> Expr { - binary_expr(self, Operator::Or, other) - } - - /// Return `!self` - #[allow(clippy::should_implement_trait)] - pub fn not(self) -> Expr { - !self - } - - /// Calculate the modulus of two expressions. - /// Return `self % other` - pub fn modulus(self, other: Expr) -> Expr { - binary_expr(self, Operator::Modulo, other) - } - - /// Return `self LIKE other` - pub fn like(self, other: Expr) -> Expr { - binary_expr(self, Operator::Like, other) - } - - /// Return `self NOT LIKE other` - pub fn not_like(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotLike, other) - } - - /// Return `self AS name` alias expression - pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), name.to_owned()) - } - - /// Return `self IN ` if `negated` is false, otherwise - /// return `self NOT IN `.a - pub fn in_list(self, list: Vec, negated: bool) -> Expr { - Expr::InList { - expr: Box::new(self), - list, - negated, - } - } - - /// Return `IsNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_null(self) -> Expr { - Expr::IsNull(Box::new(self)) - } - - /// Return `IsNotNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_not_null(self) -> Expr { - Expr::IsNotNull(Box::new(self)) - } - - /// Create a sort expression from an existing expression. - /// - /// ``` - /// # use datafusion::logical_plan::col; - /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST - /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort { - expr: Box::new(self), - asc, - nulls_first, - } - } -} - -impl Not for Expr { - type Output = Self; - - fn not(self) -> Self::Output { - Expr::Not(Box::new(self)) - } -} - -impl std::fmt::Display for Expr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => write!(f, "{} {} {}", left, op, right), - Expr::AggregateFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - /// Whether this is a DISTINCT aggregation or not - ref distinct, - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::ScalarFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - } => fmt_function(f, &fun.to_string(), false, args, true), - _ => write!(f, "{:?}", self), - } - } -} - /// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, @@ -484,15 +127,6 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { } } -/// return a new expression l r -pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { - Expr::BinaryExpr { - left: Box::new(l), - op, - right: Box::new(r), - } -} - /// return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr { @@ -525,11 +159,6 @@ pub fn or(left: Expr, right: Expr) -> Expr { } } -/// Create a column expression based on a qualified or unqualified column name -pub fn col(ident: &str) -> Expr { - Expr::Column(ident.into()) -} - /// Convert an expression into Column expression if it's already provided as input plan. /// /// For example, it rewrites: @@ -634,102 +263,6 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { } } -/// Trait for converting a type to a [`Literal`] literal expression. -pub trait Literal { - /// convert the value to a Literal expression - fn lit(&self) -> Expr; -} - -/// Trait for converting a type to a literal timestamp -pub trait TimestampLiteral { - fn lit_timestamp_nano(&self) -> Expr; -} - -impl Literal for &str { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) - } -} - -impl Literal for String { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) - } -} - -impl Literal for Vec { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) - } -} - -impl Literal for &[u8] { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) - } -} - -impl Literal for ScalarValue { - fn lit(&self) -> Expr { - Expr::Literal(self.clone()) - } -} - -macro_rules! make_literal { - ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { - #[doc = $DOC] - impl Literal for $TYPE { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) - } - } - }; -} - -macro_rules! make_timestamp_literal { - ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { - #[doc = $DOC] - impl TimestampLiteral for $TYPE { - fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), - None, - )) - } - } - }; -} - -make_literal!(bool, Boolean, "literal expression containing a bool"); -make_literal!(f32, Float32, "literal expression containing an f32"); -make_literal!(f64, Float64, "literal expression containing an f64"); -make_literal!(i8, Int8, "literal expression containing an i8"); -make_literal!(i16, Int16, "literal expression containing an i16"); -make_literal!(i32, Int32, "literal expression containing an i32"); -make_literal!(i64, Int64, "literal expression containing an i64"); -make_literal!(u8, UInt8, "literal expression containing a u8"); -make_literal!(u16, UInt16, "literal expression containing a u16"); -make_literal!(u32, UInt32, "literal expression containing a u32"); -make_literal!(u64, UInt64, "literal expression containing a u64"); - -make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); -make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); -make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); -make_timestamp_literal!(i64, Int64, "literal expression containing an i64"); -make_timestamp_literal!(u8, UInt8, "literal expression containing a u8"); -make_timestamp_literal!(u16, UInt16, "literal expression containing a u16"); -make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); - -/// Create a literal expression -pub fn lit(n: T) -> Expr { - n.lit() -} - -/// Create a literal timestamp expression -pub fn lit_timestamp_nano(n: T) -> Expr { - n.lit_timestamp_nano() -} - /// Concatenates the text representations of all the arguments. NULL arguments are ignored. pub fn concat(args: &[Expr]) -> Expr { Expr::ScalarFunction { @@ -934,311 +467,6 @@ pub fn create_udaf( ) } -fn fmt_function( - f: &mut fmt::Formatter, - fun: &str, - distinct: bool, - args: &[Expr], - display: bool, -) -> fmt::Result { - let args: Vec = match display { - true => args.iter().map(|arg| format!("{}", arg)).collect(), - false => args.iter().map(|arg| format!("{:?}", arg)).collect(), - }; - - // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) -} - -impl fmt::Debug for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(c) => write!(f, "{}", c), - Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{:?}", v), - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - write!(f, "CASE ")?; - if let Some(e) = expr { - write!(f, "{:?} ", e)?; - } - for (w, t) in when_then_expr { - write!(f, "WHEN {:?} THEN {:?} ", w, t)?; - } - if let Some(e) = else_expr { - write!(f, "ELSE {:?} ", e)?; - } - write!(f, "END") - } - Expr::Cast { expr, data_type } => { - write!(f, "CAST({:?} AS {:?})", expr, data_type) - } - Expr::TryCast { expr, data_type } => { - write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) - } - Expr::Not(expr) => write!(f, "NOT {:?}", expr), - Expr::Negative(expr) => write!(f, "(- {:?})", expr), - Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), - Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), - Expr::BinaryExpr { left, op, right } => { - write!(f, "{:?} {} {:?}", left, op, right) - } - Expr::Sort { - expr, - asc, - nulls_first, - } => { - if *asc { - write!(f, "{:?} ASC", expr)?; - } else { - write!(f, "{:?} DESC", expr)?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } - Expr::ScalarFunction { fun, args, .. } => { - fmt_function(f, &fun.to_string(), false, args, false) - } - Expr::ScalarUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - fmt_function(f, &fun.to_string(), false, args, false)?; - if !partition_by.is_empty() { - write!(f, " PARTITION BY {:?}", partition_by)?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY {:?}", order_by)?; - } - if let Some(window_frame) = window_frame { - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - )?; - } - Ok(()) - } - Expr::AggregateFunction { - fun, - distinct, - ref args, - .. - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::AggregateUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::Between { - expr, - negated, - low, - high, - } => { - if *negated { - write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) - } else { - write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) - } - } - Expr::InList { - expr, - list, - negated, - } => { - if *negated { - write!(f, "{:?} NOT IN ({:?})", expr, list) - } else { - write!(f, "{:?} IN ({:?})", expr, list) - } - } - Expr::Wildcard => write!(f, "*"), - Expr::GetIndexedField { ref expr, key } => { - write!(f, "({:?})[{}]", expr, key) - } - } - } -} - -fn create_function_name( - fun: &str, - distinct: bool, - args: &[Expr], - input_schema: &DFSchema, -) -> Result { - let names: Vec = args - .iter() - .map(|e| create_name(e, input_schema)) - .collect::>()?; - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) -} - -/// Returns a readable name of an expression based on the input schema. -/// This function recursively transverses the expression for names such as "CAST(a > 2)". -fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { - match e { - Expr::Alias(_, name) => Ok(name.clone()), - Expr::Column(c) => Ok(c.flat_name()), - Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{:?}", value)), - Expr::BinaryExpr { left, op, right } => { - let left = create_name(left, input_schema)?; - let right = create_name(right, input_schema)?; - Ok(format!("{} {} {}", left, op, right)) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let mut name = "CASE ".to_string(); - if let Some(e) = expr { - let e = create_name(e, input_schema)?; - name += &format!("{} ", e); - } - for (w, t) in when_then_expr { - let when = create_name(w, input_schema)?; - let then = create_name(t, input_schema)?; - name += &format!("WHEN {} THEN {} ", when, then); - } - if let Some(e) = else_expr { - let e = create_name(e, input_schema)?; - name += &format!("ELSE {} ", e); - } - name += "END"; - Ok(name) - } - Expr::Cast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("CAST({} AS {:?})", expr, data_type)) - } - Expr::TryCast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) - } - Expr::Not(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("NOT {}", expr)) - } - Expr::Negative(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("(- {})", expr)) - } - Expr::IsNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NULL", expr)) - } - Expr::IsNotNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NOT NULL", expr)) - } - Expr::GetIndexedField { expr, key } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{}[{}]", expr, key)) - } - Expr::ScalarFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), false, args, input_schema) - } - Expr::ScalarUDF { fun, args, .. } => { - create_function_name(&fun.name, false, args, input_schema) - } - Expr::WindowFunction { - fun, - args, - window_frame, - partition_by, - order_by, - } => { - let mut parts: Vec = vec![create_function_name( - &fun.to_string(), - false, - args, - input_schema, - )?]; - if !partition_by.is_empty() { - parts.push(format!("PARTITION BY {:?}", partition_by)); - } - if !order_by.is_empty() { - parts.push(format!("ORDER BY {:?}", order_by)); - } - if let Some(window_frame) = window_frame { - parts.push(format!("{}", window_frame)); - } - Ok(parts.join(" ")) - } - Expr::AggregateFunction { - fun, - distinct, - args, - .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), - Expr::AggregateUDF { fun, args } => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e, input_schema)?); - } - Ok(format!("{}({})", fun.name, names.join(","))) - } - Expr::InList { - expr, - list, - negated, - } => { - let expr = create_name(expr, input_schema)?; - let list = list.iter().map(|expr| create_name(expr, input_schema)); - if *negated { - Ok(format!("{} NOT IN ({:?})", expr, list)) - } else { - Ok(format!("{} IN ({:?})", expr, list)) - } - } - Expr::Between { - expr, - negated, - low, - high, - } => { - let expr = create_name(expr, input_schema)?; - let low = create_name(low, input_schema)?; - let high = create_name(high, input_schema)?; - if *negated { - Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) - } else { - Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) - } - } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create name does not support sort expression".to_string(), - )), - Expr::Wildcard => Err(DataFusionError::Internal( - "Create name does not support wildcard".to_string(), - )), - } -} - /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator, @@ -1265,6 +493,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { mod tests { use super::super::{col, lit, when}; use super::*; + use datafusion_expr::expr_fn::binary_expr; #[test] fn case_when_same_literal_then_types() -> Result<()> { @@ -1282,22 +511,6 @@ mod tests { assert!(maybe_expr.is_err()); } - #[test] - fn test_lit_timestamp_nano() { - let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 - let expected = - col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); - assert_eq!(expr, expected); - - let i: i64 = 10; - let expr = col("time").eq(lit_timestamp_nano(i)); - assert_eq!(expr, expected); - - let i: u32 = 10; - let expr = col("time").eq(lit_timestamp_nano(i)); - assert_eq!(expr, expected); - } - #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index f2ecb0f76278..24d6723210c7 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -37,11 +37,12 @@ pub mod window_frames; pub use builder::{ build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, }; +pub use datafusion_expr::expr_fn::binary_expr; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - avg, binary_expr, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, + avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index 813f7e0aac70..2f129284fa71 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -15,49 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::{binary_expr, Expr}; pub use datafusion_expr::Operator; -use std::ops; - -impl ops::Add for Expr { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - binary_expr(self, Operator::Plus, rhs) - } -} - -impl ops::Sub for Expr { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - binary_expr(self, Operator::Minus, rhs) - } -} - -impl ops::Mul for Expr { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - binary_expr(self, Operator::Multiply, rhs) - } -} - -impl ops::Div for Expr { - type Output = Self; - - fn div(self, rhs: Self) -> Self { - binary_expr(self, Operator::Divide, rhs) - } -} - -impl ops::Rem for Expr { - type Output = Self; - - fn rem(self, rhs: Self) -> Self { - binary_expr(self, Operator::Modulo, rhs) - } -} #[cfg(test)] mod tests { diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index a1531d4a7b83..10096504bcb4 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,7 +28,7 @@ use super::{ functions::{Signature, TypeSignature, Volatility}, - Accumulator, AggregateExpr, PhysicalExpr, + AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; @@ -40,15 +40,6 @@ use expressions::{ }; use std::sync::Arc; -/// the implementation of an aggregate function -pub type AccumulatorFunctionImplementation = - Arc Result> + Send + Sync>; - -/// This signature corresponds to which types an aggregator serializes -/// its state, given its return datatype. -pub type StateTypeFunction = - Arc Result>> + Send + Sync>; - pub use datafusion_expr::AggregateFunction; /// Returns the datatype of the aggregate function. diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 0de696d61172..71e7e0657596 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -17,7 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use fmt::{Debug, Formatter}; +use fmt::Debug; use std::any::Any; use std::fmt; @@ -26,85 +26,14 @@ use arrow::{ datatypes::{DataType, Schema}, }; -use crate::physical_plan::PhysicalExpr; -use crate::{error::Result, logical_plan::Expr}; - use super::{ - aggregates::AccumulatorFunctionImplementation, - aggregates::StateTypeFunction, - expressions::format_state_name, - functions::{ReturnTypeFunction, Signature}, - type_coercion::coerce, - Accumulator, AggregateExpr, + expressions::format_state_name, type_coercion::coerce, Accumulator, AggregateExpr, }; -use std::sync::Arc; - -/// Logical representation of a user-defined aggregate function (UDAF) -/// A UDAF is different from a UDF in that it is stateful across batches. -#[derive(Clone)] -pub struct AggregateUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - pub accumulator: AccumulatorFunctionImplementation, - /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, -} - -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl PartialEq for AggregateUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl AggregateUDF { - /// Create a new AggregateUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - accumulator: &AccumulatorFunctionImplementation, - state_type: &StateTypeFunction, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), - state_type: state_type.clone(), - } - } +use crate::error::Result; +use crate::physical_plan::PhysicalExpr; +pub use datafusion_expr::AggregateUDF; - /// creates a logical expression with a call of the UDAF - /// This utility allows using the UDAF without requiring access to the registry. - pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF { - fun: Arc::new(self.clone()), - args, - } - } -} +use std::sync::Arc; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 7355746a368b..58e66da48a7d 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -17,91 +17,16 @@ //! UDF support -use fmt::{Debug, Formatter}; -use std::fmt; - +use super::type_coercion::coerce; +use crate::error::Result; +use crate::physical_plan::functions::ScalarFunctionExpr; +use crate::physical_plan::PhysicalExpr; use arrow::datatypes::Schema; -use crate::error::Result; -use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; +pub use datafusion_expr::ScalarUDF; -use super::{ - functions::{ - ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, - }, - type_coercion::coerce, -}; use std::sync::Arc; -/// Logical representation of a UDF. -#[derive(Clone)] -pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - pub fun: ScalarFunctionImplementation, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl PartialEq for ScalarUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl ScalarUDF { - /// Create a new ScalarUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - fun: &ScalarFunctionImplementation, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), - } - } - - /// creates a logical expression with a call of the UDF - /// This utility allows using the UDF without requiring access to the registry. - pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF { - fun: Arc::new(self.clone()), - args, - } - } -} - /// Create a physical expression of the UDF. /// This function errors when `args`' can't be coerced to a valid argument type of the UDF. pub fn create_physical_expr( diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 682b92ba661f..2e417c75f3f0 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -2172,11 +2172,10 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result { #[cfg(test)] mod tests { - use functions::ScalarFunctionImplementation; - use crate::datasource::empty::EmptyTable; use crate::physical_plan::functions::Volatility; use crate::{logical_plan::create_udf, sql::parser::DFParser}; + use datafusion_expr::ScalarFunctionImplementation; use super::*;