Skip to content

Commit

Permalink
fix: check recursion limit in Expr::to_bytes (apache#3970)
Browse files Browse the repository at this point in the history
Install a DF-specific workaround until
tokio-rs/prost#736 is implemented.

Fixes apache#3968.
  • Loading branch information
crepererum authored and jimexist committed Oct 31, 2022
1 parent 7cf1f32 commit 882b8f9
Showing 1 changed file with 75 additions and 3 deletions.
78 changes: 75 additions & 3 deletions datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ use crate::logical_plan::{AsLogicalPlan, LogicalExtensionCodec};
use crate::{from_proto::parse_expr, protobuf};
use arrow::datatypes::SchemaRef;
use datafusion::datasource::TableProvider;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Expr, Extension, LogicalPlan};
use datafusion_expr::{
create_udaf, create_udf, Expr, Extension, LogicalPlan, Volatility,
};
use prost::{
bytes::{Bytes, BytesMut},
Message,
Expand Down Expand Up @@ -83,7 +86,45 @@ impl Serializeable for Expr {
DataFusionError::Plan(format!("Error encoding protobuf as bytes: {}", e))
})?;

Ok(buffer.into())
let bytes: Bytes = buffer.into();

// the produced byte stream may lead to "recursion limit" errors, see
// https://github.com/apache/arrow-datafusion/issues/3968
// Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to
// deserialize the data here and check for errors.
//
// Need to provide some placeholder registry because the stream may contain UDFs
struct PlaceHolderRegistry;

impl FunctionRegistry for PlaceHolderRegistry {
fn udfs(&self) -> std::collections::HashSet<String> {
std::collections::HashSet::default()
}

fn udf(&self, name: &str) -> Result<Arc<datafusion_expr::ScalarUDF>> {
Ok(Arc::new(create_udf(
name,
vec![],
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
make_scalar_function(|_| unimplemented!()),
)))
}

fn udaf(&self, name: &str) -> Result<Arc<datafusion_expr::AggregateUDF>> {
Ok(Arc::new(create_udaf(
name,
arrow::datatypes::DataType::Null,
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
Arc::new(|_| unimplemented!()),
Arc::new(vec![]),
)))
}
}
Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?;

Ok(bytes)
}

fn from_bytes_with_registry(
Expand Down Expand Up @@ -212,7 +253,7 @@ mod test {
use arrow::{array::ArrayRef, datatypes::DataType};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::SessionContext;
use datafusion_expr::{create_udf, lit, Volatility};
use datafusion_expr::{col, create_udf, lit, Volatility};
use std::sync::Arc;

#[test]
Expand Down Expand Up @@ -280,6 +321,37 @@ mod test {
Expr::from_bytes(&bytes).unwrap();
}

#[test]
fn roundtrip_deeply_nested() {
// we need more stack space so this doesn't overflow in dev builds
std::thread::Builder::new().stack_size(10_000_000).spawn(|| {
// don't know what "too much" is, so let's slowly try to increase complexity
let n_max = 100;

for n in 1..n_max {
println!("testing: {n}");

let expr_base = col("a").lt(lit(5i32));
let expr = (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));

// Convert it to an opaque form
let bytes = match expr.to_bytes() {
Ok(bytes) => bytes,
Err(_) => {
// found expression that is too deeply nested
return;
}
};

// Decode bytes from somewhere (over network, etc.
let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well");
assert_eq!(expr, decoded_expr);
}

panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}")
}).expect("spawning thread").join().expect("joining thread");
}

/// return a `SessionContext` with a `dummy` function registered as a UDF
fn context_with_udf() -> SessionContext {
let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);
Expand Down

0 comments on commit 882b8f9

Please sign in to comment.