From 70111656e14f8ba226e27768c7741d2388c69bcd Mon Sep 17 00:00:00 2001 From: David Blajda Date: Mon, 2 Oct 2023 16:37:58 -0400 Subject: [PATCH 1/5] fix str repr for predicate + parsing expr with funcs --- rust/src/delta_datafusion/expr.rs | 483 ++++++++++++++++++ .../mod.rs} | 2 + rust/src/operations/delete.rs | 4 +- rust/src/operations/merge.rs | 17 +- rust/src/operations/mod.rs | 12 +- .../transaction/conflict_checker.rs | 5 +- rust/src/operations/transaction/state.rs | 47 +- rust/src/operations/update.rs | 6 +- 8 files changed, 547 insertions(+), 29 deletions(-) create mode 100644 rust/src/delta_datafusion/expr.rs rename rust/src/{delta_datafusion.rs => delta_datafusion/mod.rs} (99%) diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs new file mode 100644 index 0000000000..f7d52bc49c --- /dev/null +++ b/rust/src/delta_datafusion/expr.rs @@ -0,0 +1,483 @@ +//! Utility functions for Datafusion's Expressions +//! + +use std::fmt::{self, Display, Formatter, Write}; + +use datafusion_common::ScalarValue; +use datafusion_expr::{ + expr::{InList, ScalarUDF}, + Between, BinaryExpr, Cast, Expr, Like, TryCast, +}; +use sqlparser::ast::escape_quoted_string; + +struct SqlFormat<'a> { + expr: &'a Expr, +} + +macro_rules! expr_vec_fmt { + ( $ARRAY:expr ) => {{ + $ARRAY + .iter() + .map(|e| format!("{}", SqlFormat { expr: e })) + .collect::>() + .join(", ") + }}; +} + +struct BinaryExprFormat<'a> { + expr: &'a BinaryExpr, +} + +impl<'a> Display for BinaryExprFormat<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Put parentheses around child binary expressions so that we can see the difference + // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed, + // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are + // equivalent and the parentheses are not necessary. + + fn write_child(f: &mut Formatter<'_>, expr: &Expr, precedence: u8) -> fmt::Result { + match expr { + Expr::BinaryExpr(child) => { + let p = child.op.precedence(); + if p == 0 || p < precedence { + write!(f, "({})", BinaryExprFormat { expr: child })?; + } else { + write!(f, "{}", BinaryExprFormat { expr: child })?; + } + } + _ => write!(f, "{}", SqlFormat { expr })?, + } + Ok(()) + } + + let precedence = self.expr.op.precedence(); + write_child(f, self.expr.left.as_ref(), precedence)?; + write!(f, " {} ", self.expr.op)?; + write_child(f, self.expr.right.as_ref(), precedence) + } +} + +impl<'a> Display for SqlFormat<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.expr { + Expr::Column(c) => write!(f, "{c}"), + Expr::Literal(v) => write!(f, "{}", ScalarValueFormat { scalar: v }), + Expr::Case(case) => { + write!(f, "CASE ")?; + if let Some(e) = &case.expr { + write!(f, "{} ", SqlFormat { expr: e })?; + } + for (w, t) in &case.when_then_expr { + write!( + f, + "WHEN {} THEN {} ", + SqlFormat { expr: w }, + SqlFormat { expr: t } + )?; + } + if let Some(e) = &case.else_expr { + write!(f, "ELSE {} ", SqlFormat { expr: e })?; + } + write!(f, "END") + } + Expr::Cast(Cast { expr, data_type }) => { + write!(f, "CAST({} AS {data_type:?})", SqlFormat { expr }) + } + Expr::TryCast(TryCast { expr, data_type }) => { + write!(f, "TRY_CAST({expr} AS {data_type:?})") + } + Expr::Not(expr) => write!(f, "NOT {}", SqlFormat { expr }), + Expr::Negative(expr) => write!(f, "(- {})", SqlFormat { expr }), + Expr::IsNull(expr) => write!(f, "{} IS NULL", SqlFormat { expr }), + Expr::IsNotNull(expr) => write!(f, "{} IS NOT NULL", SqlFormat { expr }), + Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SqlFormat { expr }), + Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SqlFormat { expr }), + Expr::IsUnknown(expr) => write!(f, "{} IS UNKNOWN", SqlFormat { expr }), + Expr::IsNotTrue(expr) => write!(f, "{} IS NOT TRUE", SqlFormat { expr }), + Expr::IsNotFalse(expr) => write!(f, "{} IS NOT FALSE", SqlFormat { expr }), + Expr::IsNotUnknown(expr) => write!(f, "{} IS NOT UNKNOWN", SqlFormat { expr }), + Expr::BinaryExpr(expr) => write!(f, "{}", BinaryExprFormat { expr }), + Expr::ScalarFunction(func) => { + fmt_function(f, &func.fun.to_string(), false, &func.args, true) + } + Expr::ScalarUDF(ScalarUDF { fun, args }) => { + fmt_function(f, &fun.name, false, args, true) + } + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + if *negated { + write!( + f, + "{} NOT BETWEEN {} AND {}", + SqlFormat { expr }, + SqlFormat { expr: low }, + SqlFormat { expr: high } + ) + } else { + write!( + f, + "{} BETWEEN {} AND {}", + SqlFormat { expr }, + SqlFormat { expr: low }, + SqlFormat { expr: high } + ) + } + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + write!(f, "{}", SqlFormat { expr })?; + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; + if *negated { + write!(f, " NOT")?; + } + if let Some(char) = escape_char { + write!( + f, + " {op_name} {} ESCAPE '{char}'", + SqlFormat { expr: pattern } + ) + } else { + write!(f, " {op_name} {}", SqlFormat { expr: pattern }) + } + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }) => { + write!(f, "{expr}")?; + if *negated { + write!(f, " NOT")?; + } + if let Some(char) = escape_char { + write!(f, " SIMILAR TO {pattern} ESCAPE '{char}'") + } else { + write!(f, " SIMILAR TO {pattern}") + } + } + Expr::InList(InList { + expr, + list, + negated, + }) => { + if *negated { + write!(f, "{expr} NOT IN ({})", expr_vec_fmt!(list)) + } else { + write!(f, "{expr} IN ({})", expr_vec_fmt!(list)) + } + } + _ => Err(fmt::Error), + } + } +} + +/// Format an `Expr` to a parsable SQL expression +pub fn fmt_expr_to_sql(expr: &Expr) -> Result { + let mut s = String::new(); + write!(&mut s, "{}", SqlFormat { expr })?; + Ok(s) +} + +fn fmt_function( + f: &mut fmt::Formatter, + fun: &str, + distinct: bool, + args: &[Expr], + display: bool, +) -> fmt::Result { + let args: Vec = match display { + true => args + .iter() + .map(|arg| format!("{}", SqlFormat { expr: arg })) + .collect(), + false => todo!("fmt function"), + }; + + // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{e}"), + None => write!($F, "NULL"), + } + }}; +} + +struct ScalarValueFormat<'a> { + scalar: &'a ScalarValue, +} + +impl<'a> fmt::Display for ScalarValueFormat<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.scalar { + ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float32(e) => format_option!(f, e)?, + ScalarValue::Float64(e) => format_option!(f, e)?, + ScalarValue::Int8(e) => format_option!(f, e)?, + ScalarValue::Int16(e) => format_option!(f, e)?, + ScalarValue::Int32(e) => format_option!(f, e)?, + ScalarValue::Int64(e) => format_option!(f, e)?, + ScalarValue::UInt8(e) => format_option!(f, e)?, + ScalarValue::UInt16(e) => format_option!(f, e)?, + ScalarValue::UInt32(e) => format_option!(f, e)?, + ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::Utf8(e) | ScalarValue::LargeUtf8(e) => match e { + Some(e) => write!(f, "'{}'", escape_quoted_string(e, '\''))?, + None => write!(f, "NULL")?, + }, + ScalarValue::Binary(e) + | ScalarValue::FixedSizeBinary(_, e) + | ScalarValue::LargeBinary(e) => match e { + Some(l) => write!( + f, + "decode('{}', 'hex')", + l.iter() + .map(|v| format!("{v:02x}")) + .collect::>() + .join("") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::Null => write!(f, "NULL")?, + _ => return Err(fmt::Error) + }; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use datafusion::prelude::SessionContext; + use datafusion_common::ScalarValue; + use datafusion_expr::{col, decode, lit, substring, Expr}; + + use crate::{DeltaOps, DeltaTable, Schema, SchemaDataType, SchemaField}; + + use super::fmt_expr_to_sql; + + struct ParseTest { + expr: Expr, + expected: String, + override_expected_expr: Option, + } + + macro_rules! simple { + ( $EXPR:expr, $STR:expr ) => {{ + ParseTest { + expr: $EXPR, + expected: $STR, + override_expected_expr: None, + } + }}; + } + + async fn setup_table() -> DeltaTable { + let schema = Schema::new(vec![ + SchemaField::new( + "id".to_string(), + SchemaDataType::primitive("string".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "value".to_string(), + SchemaDataType::primitive("integer".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "value2".to_string(), + SchemaDataType::primitive("integer".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "modified".to_string(), + SchemaDataType::primitive("string".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "active".to_string(), + SchemaDataType::primitive("boolean".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "money".to_string(), + SchemaDataType::primitive("decimal(12,2)".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "_date".to_string(), + SchemaDataType::primitive("date".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "_timestamp".to_string(), + SchemaDataType::primitive("timestamp".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "_binary".to_string(), + SchemaDataType::primitive("binary".to_string()), + true, + HashMap::new(), + ), + ]); + + let table = DeltaOps::new_in_memory() + .create() + .with_columns(schema.get_fields().clone()) + .await + .unwrap(); + assert_eq!(table.version(), 0); + table + } + + #[tokio::test] + async fn test_expr_sql() { + let table = setup_table().await; + + // String expression that we output must be parsable for conflict resolution. + let tests = vec![ + // TODO: sql parser will default to i64 by default where just lit(3) is be a i32 + simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), + simple!(col("active").is_true(), "active IS TRUE".to_string()), + simple!(col("active"), "active".to_string()), + simple!(col("active").eq(lit(true)), "active = true".to_string()), + simple!(col("active").is_null(), "active IS NULL".to_string()), + simple!( + col("modified").eq(lit("2021-02-03")), + "modified = '2021-02-03'".to_string() + ), + simple!( + col("modified").eq(lit("'validate ' escapi\\ng'")), + "modified = '''validate '' escapi\\ng'''".to_string() + ), + simple!(col("money").gt(lit(0.10)), "money > 0.1".to_string()), + ParseTest { + expr: col("_binary").eq(lit(ScalarValue::Binary(Some(vec![0xAA, 0x00, 0xFF])))), + expected: "_binary = decode('aa00ff', 'hex')".to_string(), + override_expected_expr: Some(col("_binary").eq(decode(lit("aa00ff"), lit("hex")))), + }, + simple!( + col("value").between(lit(20_i64), lit(30_i64)), + "value BETWEEN 20 AND 30".to_string() + ), + simple!( + col("value").not_between(lit(20_i64), lit(30_i64)), + "value NOT BETWEEN 20 AND 30".to_string() + ), + simple!( + col("modified").like(lit("abc%")), + "modified LIKE 'abc%'".to_string() + ), + simple!( + col("modified").not_like(lit("abc%")), + "modified NOT LIKE 'abc%'".to_string() + ), + simple!( + (((col("value") * lit(2_i64) + col("value2")) / lit(3_i64)) - col("value")) + .gt(lit(0_i64)), + "(value * 2 + value2) / 3 - value > 0".to_string() + ), + simple!( + col("modified").in_list(vec![lit("a"), lit("c")], false), + "modified IN ('a', 'c')".to_string() + ), + simple!( + col("modified").in_list(vec![lit("a"), lit("c")], true), + "modified NOT IN ('a', 'c')".to_string() + ), + // Validate order of operations is maintained + simple!( + col("modified") + .eq(lit("value")) + .and(col("value").eq(lit(1_i64))) + .or(col("modified") + .eq(lit("value2")) + .and(col("value").gt(lit(1_i64)))), + "modified = 'value' AND value = 1 OR modified = 'value2' AND value > 1".to_string() + ), + simple!( + col("modified") + .eq(lit("value")) + .or(col("value").eq(lit(1_i64))) + .and( + col("modified") + .eq(lit("value2")) + .or(col("value").gt(lit(1_i64))), + ), + "(modified = 'value' OR value = 1) AND (modified = 'value2' OR value > 1)" + .to_string() + ), + // TODO: Need to refactor parse_predicate for this... + simple!( + substring(col("modified"), lit(0_i64), lit(4_i64)).eq(lit("2021")), + "substr(modified, 0, 4) = '2021'".to_string() + ), + ]; + + let session = SessionContext::new(); + + for test in tests { + let actual = fmt_expr_to_sql(&test.expr).unwrap(); + assert_eq!(test.expected, actual); + + let actual_expr = table + .state + .parse_predicate_expression(actual, &session.state()) + .unwrap(); + + match test.override_expected_expr { + None => assert_eq!(test.expr, actual_expr), + Some(expr) => assert_eq!(expr, actual_expr) + } + } + + + let unsupported_types = vec! [ + /* TODO: Determine proper way to display decimal, date, and datetime values in an sql expression*/ + simple! ( + col("money").gt(lit(ScalarValue::Decimal128(Some(100), 12, 2))), + "money > 0.1".to_string() + ), + simple! ( + col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), None))), + "".to_string() + ), + simple! ( + col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), Some("UTC".into())))), + "".to_string() + ), + ]; + + for test in unsupported_types { + assert!(fmt_expr_to_sql(&test.expr).is_err()); + } + } +} diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion/mod.rs similarity index 99% rename from rust/src/delta_datafusion.rs rename to rust/src/delta_datafusion/mod.rs index e542413cfd..166996dddd 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion/mod.rs @@ -76,6 +76,8 @@ use crate::{open_table, open_table_with_storage_options, DeltaTable, Invariant, const PATH_COLUMN: &str = "__delta_rs_path"; +pub mod expr; + impl From for DataFusionError { fn from(err: DeltaTableError) -> Self { match err { diff --git a/rust/src/operations/delete.rs b/rust/src/operations/delete.rs index d7f908680d..78a5b6e3cb 100644 --- a/rust/src/operations/delete.rs +++ b/rust/src/operations/delete.rs @@ -298,7 +298,9 @@ impl std::future::IntoFuture for DeleteBuilder { let predicate = match this.predicate { Some(predicate) => match predicate { Expression::DataFusion(expr) => Some(expr), - Expression::String(s) => Some(this.snapshot.parse_predicate_expression(s)?), + Expression::String(s) => { + Some(this.snapshot.parse_predicate_expression(s, &state)?) + } }, None => None, }; diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index d088fbd3b7..55e41a712a 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -171,6 +171,7 @@ impl MergeBuilder { let builder = builder(UpdateBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, builder.updates, OperationType::Update, @@ -204,6 +205,7 @@ impl MergeBuilder { let builder = builder(DeleteBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, HashMap::default(), OperationType::Delete, @@ -240,6 +242,7 @@ impl MergeBuilder { let builder = builder(InsertBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, builder.set, OperationType::Insert, @@ -278,6 +281,7 @@ impl MergeBuilder { let builder = builder(UpdateBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, builder.updates, OperationType::Update, @@ -311,6 +315,7 @@ impl MergeBuilder { let builder = builder(DeleteBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, HashMap::default(), OperationType::Delete, @@ -448,15 +453,21 @@ struct MergeOperation { impl MergeOperation { pub fn try_new( snapshot: &DeltaTableState, + state: &Option<&SessionState>, predicate: Option, operations: HashMap, r#type: OperationType, ) -> DeltaResult { - let predicate = maybe_into_expr(predicate, snapshot)?; + let context = SessionContext::new(); + let mut s = &context.state(); + if let Some(df_state) = state { + s = df_state; + } + let predicate = maybe_into_expr(predicate, snapshot, s)?; let mut _operations = HashMap::new(); for (column, expr) in operations { - _operations.insert(column, into_expr(expr, snapshot)?); + _operations.insert(column, into_expr(expr, snapshot, s)?); } Ok(MergeOperation { @@ -518,7 +529,7 @@ async fn execute( let predicate = match predicate { Expression::DataFusion(expr) => expr, - Expression::String(s) => snapshot.parse_predicate_expression(s)?, + Expression::String(s) => snapshot.parse_predicate_expression(s, &state)?, }; let schema = snapshot.input_schema()?; diff --git a/rust/src/operations/mod.rs b/rust/src/operations/mod.rs index 7b6cb27ace..c07b81438b 100644 --- a/rust/src/operations/mod.rs +++ b/rust/src/operations/mod.rs @@ -205,6 +205,7 @@ mod datafusion_utils { use arrow_schema::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::Result as DataFusionResult; + use datafusion::execution::context::SessionState; use datafusion::physical_plan::DisplayAs; use datafusion::physical_plan::{ metrics::{ExecutionPlanMetricsSet, MetricsSet}, @@ -240,19 +241,24 @@ mod datafusion_utils { } } - pub(crate) fn into_expr(expr: Expression, snapshot: &DeltaTableState) -> DeltaResult { + pub(crate) fn into_expr( + expr: Expression, + snapshot: &DeltaTableState, + df_state: &SessionState, + ) -> DeltaResult { match expr { Expression::DataFusion(expr) => Ok(expr), - Expression::String(s) => snapshot.parse_predicate_expression(s), + Expression::String(s) => snapshot.parse_predicate_expression(s, df_state), } } pub(crate) fn maybe_into_expr( expr: Option, snapshot: &DeltaTableState, + df_state: &SessionState, ) -> DeltaResult> { Ok(match expr { - Some(predicate) => Some(into_expr(predicate, snapshot)?), + Some(predicate) => Some(into_expr(predicate, snapshot, df_state)?), None => None, }) } diff --git a/rust/src/operations/transaction/conflict_checker.rs b/rust/src/operations/transaction/conflict_checker.rs index d75e401def..d7a9d3fb86 100644 --- a/rust/src/operations/transaction/conflict_checker.rs +++ b/rust/src/operations/transaction/conflict_checker.rs @@ -114,8 +114,11 @@ impl<'a> TransactionInfo<'a> { actions: &'a Vec, read_whole_table: bool, ) -> DeltaResult { + use datafusion::prelude::SessionContext; + + let session = SessionContext::new(); let read_predicates = read_predicates - .map(|pred| read_snapshot.parse_predicate_expression(pred)) + .map(|pred| read_snapshot.parse_predicate_expression(pred, &session.state())) .transpose()?; Ok(Self { txn_id: "".into(), diff --git a/rust/src/operations/transaction/state.rs b/rust/src/operations/transaction/state.rs index 6fe1d65aee..5924609fb7 100644 --- a/rust/src/operations/transaction/state.rs +++ b/rust/src/operations/transaction/state.rs @@ -5,6 +5,7 @@ use arrow::datatypes::{ DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; use datafusion::datasource::physical_plan::wrap_partition_type_in_dict; +use datafusion::execution::context::SessionState; use datafusion::optimizer::utils::conjunction; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use datafusion_common::config::ConfigOptions; @@ -104,7 +105,11 @@ impl DeltaTableState { } /// Parse an expression string into a datafusion [`Expr`] - pub fn parse_predicate_expression(&self, expr: impl AsRef) -> DeltaResult { + pub fn parse_predicate_expression( + &self, + expr: impl AsRef, + df_state: &SessionState, + ) -> DeltaResult { let dialect = &GenericDialect {}; let mut tokenizer = Tokenizer::new(dialect, expr.as_ref()); let tokens = tokenizer @@ -121,7 +126,7 @@ impl DeltaTableState { // TODO should we add the table name as qualifier when available? let df_schema = DFSchema::try_from_qualified_schema("", self.arrow_schema()?.as_ref())?; - let context_provider = DummyContextProvider::default(); + let context_provider = DeltaContextProvider { state: df_state }; let sql_to_rel = SqlToRel::new(&context_provider); Ok(sql_to_rel.sql_to_expr(sql, &df_schema, &mut Default::default())?) @@ -342,34 +347,33 @@ impl PruningStatistics for DeltaTableState { } } -#[derive(Default)] -struct DummyContextProvider { - options: ConfigOptions, +pub(crate) struct DeltaContextProvider<'a> { + state: &'a SessionState, } -impl ContextProvider for DummyContextProvider { +impl<'a> ContextProvider for DeltaContextProvider<'a> { fn get_table_provider(&self, _name: TableReference) -> DFResult> { unimplemented!() } - fn get_function_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() } - fn get_variable_type(&self, _: &[String]) -> Option { + fn get_variable_type(&self, _var: &[String]) -> Option { unimplemented!() } fn options(&self) -> &ConfigOptions { - &self.options + self.state.config_options() } - fn get_window_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() } } @@ -377,24 +381,29 @@ impl ContextProvider for DummyContextProvider { mod tests { use super::*; use crate::operations::transaction::test_utils::{create_add_action, init_table_actions}; + use datafusion::prelude::SessionContext; use datafusion_expr::{col, lit}; #[test] fn test_parse_predicate_expression() { - let state = DeltaTableState::from_actions(init_table_actions(), 0).unwrap(); + let snapshot = DeltaTableState::from_actions(init_table_actions(), 0).unwrap(); + let session = SessionContext::new(); + let state = session.state(); // parses simple expression - let parsed = state.parse_predicate_expression("value > 10").unwrap(); + let parsed = snapshot + .parse_predicate_expression("value > 10", &state) + .unwrap(); let expected = col("value").gt(lit::(10)); assert_eq!(parsed, expected); // fails for unknown column - let parsed = state.parse_predicate_expression("non_existent > 10"); + let parsed = snapshot.parse_predicate_expression("non_existent > 10", &state); assert!(parsed.is_err()); // parses complex expression - let parsed = state - .parse_predicate_expression("value > 10 OR value <= 0") + let parsed = snapshot + .parse_predicate_expression("value > 10 OR value <= 0", &state) .unwrap(); let expected = col("value") .gt(lit::(10)) diff --git a/rust/src/operations/update.rs b/rust/src/operations/update.rs index b030bc5644..4408cb8b65 100644 --- a/rust/src/operations/update.rs +++ b/rust/src/operations/update.rs @@ -194,7 +194,7 @@ async fn execute( let predicate = match predicate { Some(predicate) => match predicate { Expression::DataFusion(expr) => Some(expr), - Expression::String(s) => Some(snapshot.parse_predicate_expression(s)?), + Expression::String(s) => Some(snapshot.parse_predicate_expression(s, &state)?), }, None => None, }; @@ -203,7 +203,9 @@ async fn execute( .into_iter() .map(|(key, expr)| match expr { Expression::DataFusion(e) => Ok((key, e)), - Expression::String(s) => snapshot.parse_predicate_expression(s).map(|e| (key, e)), + Expression::String(s) => snapshot + .parse_predicate_expression(s, &state) + .map(|e| (key, e)), }) .collect::, _>>()?; From 134419ef0f9918ef00e72f7a8f9a21638c8f6948 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Mon, 2 Oct 2023 20:16:02 -0400 Subject: [PATCH 2/5] remove uses of canonical name --- rust/src/delta_datafusion/expr.rs | 91 ++++++++++++++----------------- rust/src/operations/delete.rs | 11 +++- rust/src/operations/merge.rs | 14 +++-- rust/src/operations/update.rs | 14 ++++- 4 files changed, 70 insertions(+), 60 deletions(-) diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs index f7d52bc49c..c815c37a94 100644 --- a/rust/src/delta_datafusion/expr.rs +++ b/rust/src/delta_datafusion/expr.rs @@ -10,6 +10,8 @@ use datafusion_expr::{ }; use sqlparser::ast::escape_quoted_string; +use crate::DeltaTableError; + struct SqlFormat<'a> { expr: &'a Expr, } @@ -65,18 +67,18 @@ impl<'a> Display for SqlFormat<'a> { Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { - write!(f, "{} ", SqlFormat { expr: e })?; + write!(f, "{} ", SqlFormat { expr: &e })?; } for (w, t) in &case.when_then_expr { write!( f, "WHEN {} THEN {} ", - SqlFormat { expr: w }, - SqlFormat { expr: t } + SqlFormat { expr: &w }, + SqlFormat { expr: &t } )?; } if let Some(e) = &case.else_expr { - write!(f, "ELSE {} ", SqlFormat { expr: e })?; + write!(f, "ELSE {} ", SqlFormat { expr: &e })?; } write!(f, "END") } @@ -97,12 +99,8 @@ impl<'a> Display for SqlFormat<'a> { Expr::IsNotFalse(expr) => write!(f, "{} IS NOT FALSE", SqlFormat { expr }), Expr::IsNotUnknown(expr) => write!(f, "{} IS NOT UNKNOWN", SqlFormat { expr }), Expr::BinaryExpr(expr) => write!(f, "{}", BinaryExprFormat { expr }), - Expr::ScalarFunction(func) => { - fmt_function(f, &func.fun.to_string(), false, &func.args, true) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, &fun.name, false, args, true) - } + Expr::ScalarFunction(func) => fmt_function(f, &func.fun.to_string(), false, &func.args), + Expr::ScalarUDF(ScalarUDF { fun, args }) => fmt_function(f, &fun.name, false, args), Expr::Between(Between { expr, negated, @@ -113,7 +111,7 @@ impl<'a> Display for SqlFormat<'a> { write!( f, "{} NOT BETWEEN {} AND {}", - SqlFormat { expr }, + SqlFormat { expr: expr }, SqlFormat { expr: low }, SqlFormat { expr: high } ) @@ -121,7 +119,7 @@ impl<'a> Display for SqlFormat<'a> { write!( f, "{} BETWEEN {} AND {}", - SqlFormat { expr }, + SqlFormat { expr: expr }, SqlFormat { expr: low }, SqlFormat { expr: high } ) @@ -183,28 +181,20 @@ impl<'a> Display for SqlFormat<'a> { } /// Format an `Expr` to a parsable SQL expression -pub fn fmt_expr_to_sql(expr: &Expr) -> Result { +pub fn fmt_expr_to_sql(expr: &Expr) -> Result { let mut s = String::new(); - write!(&mut s, "{}", SqlFormat { expr })?; + write!(&mut s, "{}", SqlFormat { expr }).map_err(|_| { + DeltaTableError::Generic("Unable to convert expression to string".to_owned()) + })?; Ok(s) } -fn fmt_function( - f: &mut fmt::Formatter, - fun: &str, - distinct: bool, - args: &[Expr], - display: bool, -) -> fmt::Result { - let args: Vec = match display { - true => args - .iter() - .map(|arg| format!("{}", SqlFormat { expr: arg })) - .collect(), - false => todo!("fmt function"), - }; +fn fmt_function(f: &mut fmt::Formatter, fun: &str, distinct: bool, args: &[Expr]) -> fmt::Result { + let args: Vec = args + .iter() + .map(|arg| format!("{}", SqlFormat { expr: &arg })) + .collect(); - // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); let distinct_str = match distinct { true => "DISTINCT ", false => "", @@ -257,7 +247,7 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> { None => write!(f, "NULL")?, }, ScalarValue::Null => write!(f, "NULL")?, - _ => return Err(fmt::Error) + _ => return Err(fmt::Error), }; Ok(()) } @@ -364,8 +354,7 @@ mod test { // String expression that we output must be parsable for conflict resolution. let tests = vec![ - // TODO: sql parser will default to i64 by default where just lit(3) is be a i32 - simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), + simple!(col("value").eq(lit(3 as i64)), "value = 3".to_string()), simple!(col("active").is_true(), "active IS TRUE".to_string()), simple!(col("active"), "active".to_string()), simple!(col("active").eq(lit(true)), "active = true".to_string()), @@ -385,11 +374,11 @@ mod test { override_expected_expr: Some(col("_binary").eq(decode(lit("aa00ff"), lit("hex")))), }, simple!( - col("value").between(lit(20_i64), lit(30_i64)), + col("value").between(lit(20 as i64), lit(30 as i64)), "value BETWEEN 20 AND 30".to_string() ), simple!( - col("value").not_between(lit(20_i64), lit(30_i64)), + col("value").not_between(lit(20 as i64), lit(30 as i64)), "value NOT BETWEEN 20 AND 30".to_string() ), simple!( @@ -401,8 +390,8 @@ mod test { "modified NOT LIKE 'abc%'".to_string() ), simple!( - (((col("value") * lit(2_i64) + col("value2")) / lit(3_i64)) - col("value")) - .gt(lit(0_i64)), + (((col("value") * lit(2 as i64) + col("value2")) / lit(3 as i64)) - col("value")) + .gt(lit(0 as i64)), "(value * 2 + value2) / 3 - value > 0".to_string() ), simple!( @@ -417,27 +406,27 @@ mod test { simple!( col("modified") .eq(lit("value")) - .and(col("value").eq(lit(1_i64))) + .and(col("value").eq(lit(1 as i64))) .or(col("modified") .eq(lit("value2")) - .and(col("value").gt(lit(1_i64)))), + .and(col("value").gt(lit(1 as i64)))), "modified = 'value' AND value = 1 OR modified = 'value2' AND value > 1".to_string() ), simple!( col("modified") .eq(lit("value")) - .or(col("value").eq(lit(1_i64))) + .or(col("value").eq(lit(1 as i64))) .and( col("modified") .eq(lit("value2")) - .or(col("value").gt(lit(1_i64))), + .or(col("value").gt(lit(1 as i64))), ), "(modified = 'value' OR value = 1) AND (modified = 'value2' OR value > 1)" .to_string() ), - // TODO: Need to refactor parse_predicate for this... + // Validate functions are correctly parsed simple!( - substring(col("modified"), lit(0_i64), lit(4_i64)).eq(lit("2021")), + substring(col("modified"), lit(0 as i64), lit(4 as i64)).eq(lit("2021")), "substr(modified, 0, 4) = '2021'".to_string() ), ]; @@ -455,23 +444,25 @@ mod test { match test.override_expected_expr { None => assert_eq!(test.expr, actual_expr), - Some(expr) => assert_eq!(expr, actual_expr) + Some(expr) => assert_eq!(expr, actual_expr), } } - - let unsupported_types = vec! [ - /* TODO: Determine proper way to display decimal, date, and datetime values in an sql expression*/ - simple! ( + let unsupported_types = vec![ + /* TODO: Determine proper way to display decimal values in an sql expression*/ + simple!( col("money").gt(lit(ScalarValue::Decimal128(Some(100), 12, 2))), "money > 0.1".to_string() ), - simple! ( + simple!( col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), None))), "".to_string() ), - simple! ( - col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), Some("UTC".into())))), + simple!( + col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond( + Some(100), + Some("UTC".into()) + ))), "".to_string() ), ]; diff --git a/rust/src/operations/delete.rs b/rust/src/operations/delete.rs index 78a5b6e3cb..f07c92e442 100644 --- a/rust/src/operations/delete.rs +++ b/rust/src/operations/delete.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::protocol::{Action, Add, Remove}; use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::physical_expr::create_physical_expr; @@ -263,7 +264,7 @@ async fn execute( // Do not make a commit when there are zero updates to the state if !actions.is_empty() { let operation = DeltaOperation::Delete { - predicate: Some(predicate.canonical_name()), + predicate: Some(fmt_expr_to_sql(&predicate)?), }; version = commit( object_store.as_ref(), @@ -337,6 +338,7 @@ mod tests { use arrow::record_batch::RecordBatch; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::*; + use serde_json::json; use std::sync::Arc; async fn setup_table(partitions: Option>) -> DeltaTable { @@ -458,7 +460,7 @@ mod tests { assert_eq!(table.version(), 2); assert_eq!(table.get_file_uris().count(), 2); - let (table, metrics) = DeltaOps(table) + let (mut table, metrics) = DeltaOps(table) .delete() .with_predicate(col("value").eq(lit(1))) .await @@ -472,6 +474,11 @@ mod tests { assert_eq!(metrics.num_deleted_rows, Some(1)); assert_eq!(metrics.num_copied_rows, Some(3)); + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["predicate"], json!("value = 1")); + let expected = vec![ "+----+-------+------------+", "| id | value | modified |", diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 55e41a712a..3a18543756 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -62,6 +62,7 @@ use serde_json::{Map, Value}; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::commit; +use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{parquet_scan_from_actions, register_store}; use crate::operations::datafusion_utils::MetricObserverExec; use crate::operations::write::write_execution_plan; @@ -686,7 +687,10 @@ async fn execute( }; let action_type = action_type.to_string(); - let predicate = op.predicate.map(|expr| expr.display_name().unwrap()); + let predicate = op + .predicate + .map(|expr| fmt_expr_to_sql(&expr)) + .transpose()?; predicates.push(MergePredicate { action_type, @@ -1046,7 +1050,7 @@ async fn execute( // Do not make a commit when there are zero updates to the state if !actions.is_empty() { let operation = DeltaOperation::Merge { - predicate: Some(predicate.canonical_name()), + predicate: Some(fmt_expr_to_sql(&predicate)?), matched_predicates: match_operations, not_matched_predicates: not_match_target_operations, not_matched_by_source_predicates: not_match_source_operations, @@ -1236,7 +1240,7 @@ mod tests { // Todo: Expected this predicate to actually be 'value = 1'. Predicate should contain a valid sql expression assert_eq!( parameters["notMatchedBySourcePredicates"], - json!(r#"[{"actionType":"update","predicate":"value = Int32(1)"}]"#) + json!(r#"[{"actionType":"update","predicate":"value = 1"}]"#) ); let expected = vec![ @@ -1458,7 +1462,7 @@ mod tests { assert_eq!(parameters["predicate"], json!("id = source.id")); assert_eq!( parameters["matchedPredicates"], - json!(r#"[{"actionType":"delete","predicate":"source.value <= Int32(10)"}]"#) + json!(r#"[{"actionType":"delete","predicate":"source.value <= 10"}]"#) ); let expected = vec![ @@ -1590,7 +1594,7 @@ mod tests { assert_eq!(parameters["predicate"], json!("id = source.id")); assert_eq!( parameters["notMatchedBySourcePredicates"], - json!(r#"[{"actionType":"delete","predicate":"modified > Utf8(\"2021-02-01\")"}]"#) + json!(r#"[{"actionType":"delete","predicate":"modified > '2021-02-01'"}]"#) ); let expected = vec![ diff --git a/rust/src/operations/update.rs b/rust/src/operations/update.rs index 4408cb8b65..3891c04fd9 100644 --- a/rust/src/operations/update.rs +++ b/rust/src/operations/update.rs @@ -43,7 +43,9 @@ use parquet::file::properties::WriterProperties; use serde_json::{Map, Value}; use crate::{ - delta_datafusion::{find_files, parquet_scan_from_actions, register_store}, + delta_datafusion::{ + expr::fmt_expr_to_sql, find_files, parquet_scan_from_actions, register_store, + }, protocol::{Action, DeltaOperation, Remove}, storage::{DeltaObjectStore, ObjectStoreRef}, table::state::DeltaTableState, @@ -418,7 +420,7 @@ async fn execute( metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64; let operation = DeltaOperation::Update { - predicate: Some(predicate.canonical_name()), + predicate: Some(fmt_expr_to_sql(&predicate)?), }; version = commit( object_store.as_ref(), @@ -483,6 +485,7 @@ mod tests { use arrow_array::Int32Array; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::*; + use serde_json::json; use std::sync::Arc; async fn setup_table(partitions: Option>) -> DeltaTable { @@ -605,7 +608,7 @@ mod tests { assert_eq!(table.version(), 1); assert_eq!(table.get_file_uris().count(), 1); - let (table, metrics) = DeltaOps(table) + let (mut table, metrics) = DeltaOps(table) .update() .with_predicate(col("modified").eq(lit("2021-02-03"))) .with_update("modified", lit("2023-05-14")) @@ -619,6 +622,11 @@ mod tests { assert_eq!(metrics.num_updated_rows, 2); assert_eq!(metrics.num_copied_rows, 2); + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["predicate"], json!("modified = '2021-02-03'")); + let expected = vec![ "+----+-------+------------+", "| id | value | modified |", From a516523596bcb454488768fa981b248a7cef6714 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Mon, 2 Oct 2023 20:17:51 -0400 Subject: [PATCH 3/5] clippy --- rust/src/delta_datafusion/expr.rs | 34 +++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs index c815c37a94..7d62a5413b 100644 --- a/rust/src/delta_datafusion/expr.rs +++ b/rust/src/delta_datafusion/expr.rs @@ -67,18 +67,18 @@ impl<'a> Display for SqlFormat<'a> { Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { - write!(f, "{} ", SqlFormat { expr: &e })?; + write!(f, "{} ", SqlFormat { expr: e })?; } for (w, t) in &case.when_then_expr { write!( f, "WHEN {} THEN {} ", - SqlFormat { expr: &w }, - SqlFormat { expr: &t } + SqlFormat { expr: w }, + SqlFormat { expr: t } )?; } if let Some(e) = &case.else_expr { - write!(f, "ELSE {} ", SqlFormat { expr: &e })?; + write!(f, "ELSE {} ", SqlFormat { expr: e })?; } write!(f, "END") } @@ -111,7 +111,7 @@ impl<'a> Display for SqlFormat<'a> { write!( f, "{} NOT BETWEEN {} AND {}", - SqlFormat { expr: expr }, + SqlFormat { expr }, SqlFormat { expr: low }, SqlFormat { expr: high } ) @@ -119,7 +119,7 @@ impl<'a> Display for SqlFormat<'a> { write!( f, "{} BETWEEN {} AND {}", - SqlFormat { expr: expr }, + SqlFormat { expr }, SqlFormat { expr: low }, SqlFormat { expr: high } ) @@ -192,7 +192,7 @@ pub fn fmt_expr_to_sql(expr: &Expr) -> Result { fn fmt_function(f: &mut fmt::Formatter, fun: &str, distinct: bool, args: &[Expr]) -> fmt::Result { let args: Vec = args .iter() - .map(|arg| format!("{}", SqlFormat { expr: &arg })) + .map(|arg| format!("{}", SqlFormat { expr: arg })) .collect(); let distinct_str = match distinct { @@ -354,7 +354,7 @@ mod test { // String expression that we output must be parsable for conflict resolution. let tests = vec![ - simple!(col("value").eq(lit(3 as i64)), "value = 3".to_string()), + simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), simple!(col("active").is_true(), "active IS TRUE".to_string()), simple!(col("active"), "active".to_string()), simple!(col("active").eq(lit(true)), "active = true".to_string()), @@ -374,11 +374,11 @@ mod test { override_expected_expr: Some(col("_binary").eq(decode(lit("aa00ff"), lit("hex")))), }, simple!( - col("value").between(lit(20 as i64), lit(30 as i64)), + col("value").between(lit(20_i64), lit(30_i64)), "value BETWEEN 20 AND 30".to_string() ), simple!( - col("value").not_between(lit(20 as i64), lit(30 as i64)), + col("value").not_between(lit(20_i64), lit(30_i64)), "value NOT BETWEEN 20 AND 30".to_string() ), simple!( @@ -390,8 +390,8 @@ mod test { "modified NOT LIKE 'abc%'".to_string() ), simple!( - (((col("value") * lit(2 as i64) + col("value2")) / lit(3 as i64)) - col("value")) - .gt(lit(0 as i64)), + (((col("value") * lit(2_i64) + col("value2")) / lit(3_i64)) - col("value")) + .gt(lit(0_i64)), "(value * 2 + value2) / 3 - value > 0".to_string() ), simple!( @@ -406,27 +406,27 @@ mod test { simple!( col("modified") .eq(lit("value")) - .and(col("value").eq(lit(1 as i64))) + .and(col("value").eq(lit(1_i64))) .or(col("modified") .eq(lit("value2")) - .and(col("value").gt(lit(1 as i64)))), + .and(col("value").gt(lit(1_i64)))), "modified = 'value' AND value = 1 OR modified = 'value2' AND value > 1".to_string() ), simple!( col("modified") .eq(lit("value")) - .or(col("value").eq(lit(1 as i64))) + .or(col("value").eq(lit(1_i64))) .and( col("modified") .eq(lit("value2")) - .or(col("value").gt(lit(1 as i64))), + .or(col("value").gt(lit(1_i64))), ), "(modified = 'value' OR value = 1) AND (modified = 'value2' OR value > 1)" .to_string() ), // Validate functions are correctly parsed simple!( - substring(col("modified"), lit(0 as i64), lit(4 as i64)).eq(lit("2021")), + substring(col("modified"), lit(0_i64), lit(4_i64)).eq(lit("2021")), "substr(modified, 0, 4) = '2021'".to_string() ), ]; From 7bb306c4e079a34fab4070583a07e35dacc22dea Mon Sep 17 00:00:00 2001 From: David Blajda Date: Mon, 2 Oct 2023 21:51:39 -0400 Subject: [PATCH 4/5] add license + remove cast support --- rust/src/delta_datafusion/expr.rs | 47 ++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs index 7d62a5413b..d2777ddf51 100644 --- a/rust/src/delta_datafusion/expr.rs +++ b/rust/src/delta_datafusion/expr.rs @@ -1,12 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + //! Utility functions for Datafusion's Expressions -//! use std::fmt::{self, Display, Formatter, Write}; use datafusion_common::ScalarValue; use datafusion_expr::{ expr::{InList, ScalarUDF}, - Between, BinaryExpr, Cast, Expr, Like, TryCast, + Between, BinaryExpr, Expr, Like, }; use sqlparser::ast::escape_quoted_string; @@ -82,12 +98,6 @@ impl<'a> Display for SqlFormat<'a> { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({} AS {data_type:?})", SqlFormat { expr }) - } - Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type:?})") - } Expr::Not(expr) => write!(f, "NOT {}", SqlFormat { expr }), Expr::Negative(expr) => write!(f, "(- {})", SqlFormat { expr }), Expr::IsNull(expr) => write!(f, "{} IS NULL", SqlFormat { expr }), @@ -258,8 +268,8 @@ mod test { use std::collections::HashMap; use datafusion::prelude::SessionContext; - use datafusion_common::ScalarValue; - use datafusion_expr::{col, decode, lit, substring, Expr}; + use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_expr::{col, decode, lit, substring, Expr, ExprSchemable}; use crate::{DeltaOps, DeltaTable, Schema, SchemaDataType, SchemaField}; @@ -465,6 +475,23 @@ mod test { ))), "".to_string() ), + simple!( + col("value") + .cast_to::( + &arrow_schema::DataType::Utf8, + &table + .state + .input_schema() + .unwrap() + .as_ref() + .to_owned() + .try_into() + .unwrap() + ) + .unwrap() + .eq(lit("1")), + "CAST(value as STRING) = '1'".to_string() + ), ]; for test in unsupported_types { From fee32946cfbb043283870e82d72d6c656355ab18 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Mon, 2 Oct 2023 23:20:43 -0400 Subject: [PATCH 5/5] remove todo + comment on which code is used from datafusion --- rust/src/delta_datafusion/expr.rs | 4 ++++ rust/src/operations/merge.rs | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs index d2777ddf51..d60fe6666c 100644 --- a/rust/src/delta_datafusion/expr.rs +++ b/rust/src/delta_datafusion/expr.rs @@ -14,6 +14,10 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// +// This product includes software from the Datafusion project (Apache 2.0) +// https://github.com/apache/arrow-datafusion +// Display functions and required macros were pulled from https://github.com/apache/arrow-datafusion/blob/ddb95497e2792015d5a5998eec79aac8d37df1eb/datafusion/expr/src/expr.rs //! Utility functions for Datafusion's Expressions diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 3a18543756..d52dd26819 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -1237,7 +1237,6 @@ mod tests { parameters["notMatchedPredicates"], json!(r#"[{"actionType":"insert"}]"#) ); - // Todo: Expected this predicate to actually be 'value = 1'. Predicate should contain a valid sql expression assert_eq!( parameters["notMatchedBySourcePredicates"], json!(r#"[{"actionType":"update","predicate":"value = 1"}]"#)