diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 7c8b94e5b358c..3677ea8af5c8d 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -20,8 +20,11 @@ use crate::logical_plan::{AsLogicalPlan, LogicalExtensionCodec}; use crate::{from_proto::parse_expr, protobuf}; use arrow::datatypes::SchemaRef; use datafusion::datasource::TableProvider; +use datafusion::physical_plan::functions::make_scalar_function; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Expr, Extension, LogicalPlan}; +use datafusion_expr::{ + create_udaf, create_udf, Expr, Extension, LogicalPlan, Volatility, +}; use prost::{ bytes::{Bytes, BytesMut}, Message, @@ -83,7 +86,45 @@ impl Serializeable for Expr { DataFusionError::Plan(format!("Error encoding protobuf as bytes: {}", e)) })?; - Ok(buffer.into()) + let bytes: Bytes = buffer.into(); + + // the produced byte stream may lead to "recursion limit" errors, see + // https://github.com/apache/arrow-datafusion/issues/3968 + // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to + // deserialize the data here and check for errors. + // + // Need to provide some placeholder registry because the stream may contain UDFs + struct PlaceHolderRegistry; + + impl FunctionRegistry for PlaceHolderRegistry { + fn udfs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } + + fn udf(&self, name: &str) -> Result> { + Ok(Arc::new(create_udf( + name, + vec![], + Arc::new(arrow::datatypes::DataType::Null), + Volatility::Immutable, + make_scalar_function(|_| unimplemented!()), + ))) + } + + fn udaf(&self, name: &str) -> Result> { + Ok(Arc::new(create_udaf( + name, + arrow::datatypes::DataType::Null, + Arc::new(arrow::datatypes::DataType::Null), + Volatility::Immutable, + Arc::new(|_| unimplemented!()), + Arc::new(vec![]), + ))) + } + } + Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; + + Ok(bytes) } fn from_bytes_with_registry( @@ -212,7 +253,7 @@ mod test { use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::SessionContext; - use datafusion_expr::{create_udf, lit, Volatility}; + use datafusion_expr::{col, create_udf, lit, Volatility}; use std::sync::Arc; #[test] @@ -280,6 +321,37 @@ mod test { Expr::from_bytes(&bytes).unwrap(); } + #[test] + fn roundtrip_deeply_nested() { + // we need more stack space so this doesn't overflow in dev builds + std::thread::Builder::new().stack_size(10_000_000).spawn(|| { + // don't know what "too much" is, so let's slowly try to increase complexity + let n_max = 100; + + for n in 1..n_max { + println!("testing: {n}"); + + let expr_base = col("a").lt(lit(5i32)); + let expr = (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone())); + + // Convert it to an opaque form + let bytes = match expr.to_bytes() { + Ok(bytes) => bytes, + Err(_) => { + // found expression that is too deeply nested + return; + } + }; + + // Decode bytes from somewhere (over network, etc. + let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well"); + assert_eq!(expr, decoded_expr); + } + + panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}") + }).expect("spawning thread").join().expect("joining thread"); + } + /// return a `SessionContext` with a `dummy` function registered as a UDF fn context_with_udf() -> SessionContext { let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);