Skip to content

Commit

Permalink
Test DataFusion 45.0.0 with Sail (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
shehabgamin authored Feb 8, 2025
1 parent 39a1b5a commit a159393
Show file tree
Hide file tree
Showing 19 changed files with 489 additions and 336 deletions.
480 changes: 258 additions & 222 deletions Cargo.lock

Large diffs are not rendered by default.

45 changes: 35 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ryu = "1.0.18"
either = "1.12.0"
num-bigint = "0.4.6"
num-traits = "0.2.19"
serde_arrow = { version = "0.12.3", features = ["arrow-53"] }
serde_arrow = { version = "0.12.3", features = ["arrow-54"] }
mimalloc = { version = "0.1.43", default-features = false }
rand = "0.8.5"
rand_chacha = "0.3.1"
Expand Down Expand Up @@ -86,16 +86,16 @@ chumsky = { version = "=1.0.0-alpha.7", default-features = false, features = ["p
# The versions of the following dependencies are managed manually.
######

datafusion = { version = "44.0.0", features = ["serde", "pyarrow", "avro"] }
datafusion-common = { version = "44.0.0", features = ["object_store", "pyarrow", "avro"] }
datafusion-expr = { version = "44.0.0" }
datafusion-expr-common = { version = "44.0.0" }
datafusion-proto = { version = "44.0.0" }
datafusion-functions-nested = { version = "44.0.0" }
datafusion-functions-json = { git = "https://github.com/lakehq/datafusion-functions-json.git", rev = "7bcca26" }
datafusion = { version = "45.0.0", features = ["serde", "pyarrow", "avro"] }
datafusion-common = { version = "45.0.0", features = ["object_store", "pyarrow", "avro"] }
datafusion-expr = { version = "45.0.0" }
datafusion-expr-common = { version = "45.0.0" }
datafusion-proto = { version = "45.0.0" }
datafusion-functions-nested = { version = "45.0.0" }
datafusion-functions-json = { git = "https://github.com/lakehq/datafusion-functions-json.git", rev = "453183d" }
# auto-initialize: Changes [`Python::with_gil`] to automatically initialize the Python interpreter if needed.
pyo3 = { version = "0.22.0", features = ["auto-initialize", "serde"] }
arrow-flight = { version = "53.3.0" }
pyo3 = { version = "0.23.4", features = ["auto-initialize", "serde"] }
arrow-flight = { version = "54.1.0" }
# The `object_store` version must match the one used in DataFusion.
object_store = { version = "0.11.0", features = ["aws", "gcp", "azure", "http"] }
# We use a patched latest version of sqlparser. The version may be different from the one used in DataFusion.
Expand All @@ -109,6 +109,31 @@ sqlparser = { git = "https://github.com/lakehq/sqlparser-rs.git", rev = "9ade53d
[patch.crates-io]
# Override dependencies to use our forked versions.
# You can use `path = "..."` to temporarily point to your local copy of the crates to speed up local development.
datafusion = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-catalog = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-common-runtime = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-doc = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-expr-common = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
#datafusion-ffi = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions-aggregate = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions-aggregate-common = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions-table = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions-window = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-functions-window-common = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-macros = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-optimizer = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-physical-optimizer = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-proto = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-proto-common = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }
datafusion-sql = { git = "https://github.com/apache/datafusion.git", rev = "a9fb58c" }

[profile.release]
# https://doc.rust-lang.org/cargo/reference/profiles.html#release
Expand Down
8 changes: 7 additions & 1 deletion crates/sail-cli/src/spark/shell.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ffi::CString;
use std::net::Ipv4Addr;

use pyo3::prelude::PyAnyMethods;
Expand Down Expand Up @@ -32,7 +33,12 @@ pub fn run_pyspark_shell() -> Result<(), Box<dyn std::error::Error>> {
})?;
runtime.spawn(server_task);
Python::with_gil(|py| -> PyResult<_> {
let shell = PyModule::from_code_bound(py, SHELL_SOURCE_CODE, "shell.py", "shell")?;
let shell = PyModule::from_code(
py,
CString::new(SHELL_SOURCE_CODE)?.as_c_str(),
CString::new("shell.py")?.as_c_str(),
CString::new("shell")?.as_c_str(),
)?;
shell
.getattr("run_pyspark_shell")?
.call((server_port,), None)?;
Expand Down
17 changes: 16 additions & 1 deletion crates/sail-execution/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@ use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
use datafusion::common::parsers::CompressionTypeVariant;
use datafusion::common::{plan_datafusion_err, plan_err, JoinSide, Result};
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
#[allow(deprecated)]
use datafusion::datasource::physical_plan::{ArrowExec, NdJsonExec};
use datafusion::datasource::physical_plan::{ArrowSource, JsonSource};
use datafusion::execution::FunctionRegistry;
use datafusion::functions::string::overlay::OverlayFunc;
use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl};
use datafusion::physical_expr::LexOrdering;
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use datafusion::physical_plan::joins::SortMergeJoinExec;
#[allow(deprecated)]
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::recursive_query::RecursiveQueryExec;
use datafusion::physical_plan::sorts::partial_sort::PartialSortExec;
#[allow(deprecated)]
use datafusion::physical_plan::values::ValuesExec;
use datafusion::physical_plan::work_table::WorkTableExec;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
Expand Down Expand Up @@ -226,6 +230,7 @@ impl PhysicalExtensionCodec for RemoteExecutionCodec {
let sort_information =
self.try_decode_lex_orderings(&sort_information, registry, &schema)?;
Ok(Arc::new(
#[allow(deprecated)]
MemoryExec::try_new(&partitions, Arc::new(schema), projection)?
.with_show_sizes(show_sizes)
.try_with_sort_information(sort_information)?,
Expand All @@ -234,6 +239,7 @@ impl PhysicalExtensionCodec for RemoteExecutionCodec {
NodeKind::Values(gen::ValuesExecNode { data, schema }) => {
let schema = self.try_decode_schema(&schema)?;
let data = read_record_batches(&data)?;
#[allow(deprecated)]
Ok(Arc::new(ValuesExec::try_new_from_batches(
Arc::new(schema),
data,
Expand All @@ -247,9 +253,11 @@ impl PhysicalExtensionCodec for RemoteExecutionCodec {
&self.try_decode_message(&base_config)?,
registry,
self,
Arc::new(JsonSource::new()), // TODO: Look into configuring this if needed
)?;
let file_compression_type: FileCompressionType =
self.try_decode_file_compression_type(file_compression_type)?;
#[allow(deprecated)]
Ok(Arc::new(NdJsonExec::new(
base_config,
file_compression_type,
Expand All @@ -260,7 +268,9 @@ impl PhysicalExtensionCodec for RemoteExecutionCodec {
&self.try_decode_message(&base_config)?,
registry,
self,
Arc::new(ArrowSource::default()), // TODO: Look into configuring this if needed
)?;
#[allow(deprecated)]
Ok(Arc::new(ArrowExec::new(base_config)))
}
NodeKind::WorkTable(gen::WorkTableExecNode { name, schema }) => {
Expand Down Expand Up @@ -334,7 +344,11 @@ impl PhysicalExtensionCodec for RemoteExecutionCodec {
})
})
.collect::<Result<Vec<_>>>()?;
Some(JoinFilter::new(expression, column_indices, schema))
Some(JoinFilter::new(
expression,
column_indices,
Arc::new(schema),
))
} else {
None
};
Expand Down Expand Up @@ -364,6 +378,7 @@ impl PhysicalExtensionCodec for RemoteExecutionCodec {
}

fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
#[allow(deprecated)]
let node_kind = if let Some(range) = node.as_any().downcast_ref::<RangeExec>() {
let schema = self.try_encode_schema(range.schema().as_ref())?;
NodeKind::Range(gen::RangeExecNode {
Expand Down
11 changes: 7 additions & 4 deletions crates/sail-plan/src/extension/function/spark_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ use datafusion::arrow::array::{
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion_common::utils::SingleRowListArrayBuilder;
use datafusion_common::{internal_err, plan_err, ExprSchema, Result};
use datafusion_common::{internal_err, plan_err, Result};
use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, TypeSignature, Volatility};
use datafusion_expr::{
ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};

use crate::extension::function::functions_nested_utils::make_scalar_function;

Expand Down Expand Up @@ -75,8 +77,9 @@ impl ScalarUDFImpl for SparkArray {
}
}

fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
false
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
let return_type = self.return_type(args.arg_types)?;
Ok(ReturnInfo::new_non_nullable(return_type))
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
Expand Down
13 changes: 10 additions & 3 deletions crates/sail-plan/src/extension/function/spark_concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::functions::string::concat::ConcatFunc;
use datafusion_common::utils::list_ndims;
use datafusion_common::{plan_err, ExprSchema, Result};
use datafusion_expr::type_coercion::binary::get_wider_type;
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable, ScalarUDFImpl, Signature, Volatility};
use datafusion_functions_nested::concat::ArrayConcat;

Expand Down Expand Up @@ -41,7 +40,7 @@ impl ScalarUDFImpl for SparkConcat {
&self.signature
}

/// [Credit]: <https://github.com/apache/datafusion/blob/7b2284c8a0b49234e9607bfef10d73ef788d9458/datafusion/functions-nested/src/concat.rs#L274-L301>
/// [Credit]: <https://github.com/apache/datafusion/blob/7ccc6d7c55ae9dbcb7dee031f394bf11a03000ba/datafusion/functions-nested/src/concat.rs#L276-L310>
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types
.iter()
Expand All @@ -56,7 +55,15 @@ impl ScalarUDFImpl for SparkConcat {
let dims = list_ndims(arg_type);
expr_type = match max_dims.cmp(&dims) {
Ordering::Greater => expr_type,
Ordering::Equal => get_wider_type(&expr_type, arg_type)?,
Ordering::Equal => {
if expr_type == DataType::Null {
arg_type.clone()
} else if !expr_type.equals_datatype(arg_type) {
return plan_err!("It is not possible to concatenate arrays of different types. Expected: {expr_type}, got: {arg_type}");
} else {
expr_type
}
}
Ordering::Less => {
max_dims = dims;
arg_type.clone()
Expand Down
13 changes: 11 additions & 2 deletions crates/sail-plan/src/extension/source/rename.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,22 @@ impl RenameTableProvider {

fn to_inner_expr(&self, expr: &Expr) -> Result<Expr> {
let rewrite = |e: Expr| -> Result<Transformed<Expr>> {
if let Expr::Column(Column { name, relation }) = e {
if let Expr::Column(Column {
name,
relation,
spans,
}) = e
{
let name = self
.names
.get(&name)
.ok_or_else(|| plan_datafusion_err!("column {name} not found"))?
.clone();
Ok(Transformed::yes(Expr::Column(Column { name, relation })))
Ok(Transformed::yes(Expr::Column(Column {
name,
relation,
spans,
})))
} else {
Ok(Transformed::no(e))
}
Expand Down
15 changes: 9 additions & 6 deletions crates/sail-python-udf/src/cereal/pyspark_udf.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::PyAnyMethods;
use pyo3::types::PyModule;
use pyo3::{intern, PyObject, Python, ToPyObject};
use pyo3::{intern, Bound, IntoPyObject, PyAny, Python};
use sail_common::spec;

use crate::cereal::{check_python_udf_version, should_write_config};
Expand All @@ -11,24 +11,27 @@ use crate::error::{PyUdfError, PyUdfResult};
pub struct PySparkUdfPayload;

impl PySparkUdfPayload {
pub fn load(py: Python, data: &[u8]) -> PyUdfResult<PyObject> {
pub fn load<'py>(py: Python<'py>, data: &[u8]) -> PyUdfResult<Bound<'py, PyAny>> {
let (eval_type, v) = data
.split_at_checked(size_of::<i32>())
.ok_or_else(|| PyUdfError::invalid("missing eval_type"))?;
let eval_type = eval_type
.try_into()
.map_err(|e| PyValueError::new_err(format!("eval_type bytes: {e}")))?;
let eval_type = i32::from_be_bytes(eval_type);
let infile = PyModule::import_bound(py, intern!(py, "io"))?
let infile = PyModule::import(py, intern!(py, "io"))?
.getattr(intern!(py, "BytesIO"))?
.call1((v,))?;
let serializer = PyModule::import_bound(py, intern!(py, "pyspark.serializers"))?
let serializer = PyModule::import(py, intern!(py, "pyspark.serializers"))?
.getattr(intern!(py, "CPickleSerializer"))?
.call0()?;
let tuple = PyModule::import_bound(py, intern!(py, "pyspark.worker"))?
let tuple = PyModule::import(py, intern!(py, "pyspark.worker"))?
.getattr(intern!(py, "read_udfs"))?
.call1((serializer, infile, eval_type))?;
Ok(tuple.get_item(0)?.to_object(py))
tuple
.get_item(0)?
.into_pyobject(py)
.map_err(|e| PyUdfError::PythonError(e.into()))
}

pub fn build(
Expand Down
17 changes: 10 additions & 7 deletions crates/sail-python-udf/src/cereal/pyspark_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use datafusion::arrow::pyarrow::ToPyArrow;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::PyAnyMethods;
use pyo3::types::PyModule;
use pyo3::{intern, PyObject, PyResult, Python, ToPyObject};
use pyo3::{intern, Bound, IntoPyObject, PyAny, PyResult, Python};
use sail_common::spec;

use crate::cereal::{check_python_udf_version, should_write_config};
Expand All @@ -13,24 +13,27 @@ use crate::error::{PyUdfError, PyUdfResult};
pub struct PySparkUdtfPayload;

impl PySparkUdtfPayload {
pub fn load(py: Python, v: &[u8]) -> PyUdfResult<PyObject> {
pub fn load<'py>(py: Python<'py>, v: &[u8]) -> PyUdfResult<Bound<'py, PyAny>> {
let (eval_type, v) = v
.split_at_checked(size_of::<i32>())
.ok_or_else(|| PyUdfError::invalid("missing eval_type"))?;
let eval_type = eval_type
.try_into()
.map_err(|e| PyValueError::new_err(format!("eval_type bytes: {e}")))?;
let eval_type = i32::from_be_bytes(eval_type);
let infile = PyModule::import_bound(py, intern!(py, "io"))?
let infile = PyModule::import(py, intern!(py, "io"))?
.getattr(intern!(py, "BytesIO"))?
.call1((v,))?;
let serializer = PyModule::import_bound(py, intern!(py, "pyspark.serializers"))?
let serializer = PyModule::import(py, intern!(py, "pyspark.serializers"))?
.getattr(intern!(py, "CPickleSerializer"))?
.call0()?;
let tuple = PyModule::import_bound(py, intern!(py, "pyspark.worker"))?
let tuple = PyModule::import(py, intern!(py, "pyspark.worker"))?
.getattr(intern!(py, "read_udtf"))?
.call1((serializer, infile, eval_type))?;
Ok(tuple.get_item(0)?.to_object(py))
tuple
.get_item(0)?
.into_pyobject(py)
.map_err(|e| PyUdfError::PythonError(e.into()))
}

pub fn build(
Expand Down Expand Up @@ -70,7 +73,7 @@ impl PySparkUdtfPayload {

let type_string = Python::with_gil(|py| -> PyResult<String> {
let return_type = return_type.to_pyarrow(py)?.clone_ref(py).into_bound(py);
PyModule::import_bound(py, intern!(py, "pyspark.sql.pandas.types"))?
PyModule::import(py, intern!(py, "pyspark.sql.pandas.types"))?
.getattr(intern!(py, "from_arrow_type"))?
.call1((return_type,))?
.getattr(intern!(py, "json"))?
Expand Down
Loading

0 comments on commit a159393

Please sign in to comment.