Skip to content

Commit

Permalink
Removed unnecessary cloning of scalar value when going from rust to p…
Browse files Browse the repository at this point in the history
…ython. Also removed the rust unit tests copied over from upstream repo that were failing due to #941 in pyo3
  • Loading branch information
timsaucer committed Jan 28, 2025
1 parent f0d25a2 commit d904d08
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 102 deletions.
4 changes: 2 additions & 2 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use std::convert::{From, Into};
use std::sync::Arc;
use window::PyWindowFrame;

use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::functions::core::expr_ext::FieldAccessor;
Expand All @@ -41,6 +40,7 @@ use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use crate::functions::add_builder_fns_to_window;
use crate::pyarrow_util::scalar_to_pyarrow;
use crate::sql::logical::PyLogicalPlan;

use self::alias::PyAlias;
Expand Down Expand Up @@ -355,7 +355,7 @@ impl PyExpr {
/// Extracts the Expr value into a PyObject that can be shared with Python
pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
match &self.expr {
Expr::Literal(scalar_value) => Ok(PyScalarValue(scalar_value.clone()).to_pyarrow(py)?),
Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py),
_ => Err(py_type_err(format!(
"Non Expr::Literal encountered in types: {:?}",
&self.expr
Expand Down
7 changes: 2 additions & 5 deletions src/pyarrow_filter_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ use pyo3::prelude::*;
use std::convert::TryFrom;
use std::result::Result;

use arrow::pyarrow::ToPyArrow;
use datafusion::common::{Column, ScalarValue};
use datafusion::logical_expr::{expr::InList, Between, BinaryExpr, Expr, Operator};

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionError;
use crate::pyarrow_util::scalar_to_pyarrow;

#[derive(Debug)]
#[repr(transparent)]
Expand Down Expand Up @@ -103,9 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
let op_module = Python::import_bound(py, "operator")?;
let pc_expr: Result<Bound<'_, PyAny>, PyDataFusionError> = match expr {
Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?),
Expr::Literal(scalar) => {
Ok(PyScalarValue(scalar.clone()).to_pyarrow(py)?.into_bound(py))
}
Expr::Literal(scalar) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let operator = operator_to_py(op, &op_module)?;
let left = PyArrowFilterExpression::try_from(left.as_ref())?.0;
Expand Down
102 changes: 7 additions & 95 deletions src/pyarrow_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::array::{Array, ArrayData};
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use datafusion::scalar::ScalarValue;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python};
use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python};

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionError;
Expand All @@ -45,105 +45,17 @@ impl FromPyArrow for PyScalarValue {
}
}

impl ToPyArrow for PyScalarValue {
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
let array = self.0.to_array().map_err(PyDataFusionError::from)?;
// convert to pyarrow array using C data interface
let pyarray = array.to_data().to_pyarrow(py)?;
let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;

Ok(pyscalar)
}
}

impl<'source> FromPyObject<'source> for PyScalarValue {
fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
Self::from_pyarrow_bound(value)
}
}

impl IntoPy<PyObject> for PyScalarValue {
fn into_py(self, py: Python) -> PyObject {
self.to_pyarrow(py).unwrap()
}
}

#[cfg(test)]
mod tests {
use pyo3::prepare_freethreaded_python;
use pyo3::py_run;
use pyo3::types::PyDict;

use super::*;

fn init_python() {
prepare_freethreaded_python();
Python::with_gil(|py| {
if py.run_bound("import pyarrow", None, None).is_err() {
let locals = PyDict::new_bound(py);
py.run_bound(
"import sys; executable = sys.executable; python_path = sys.path",
None,
Some(&locals),
)
.expect("Couldn't get python info");
let executable = locals.get_item("executable").unwrap();
let executable: String = executable.extract().unwrap();

let python_path = locals.get_item("python_path").unwrap();
let python_path: Vec<String> = python_path.extract().unwrap();
pub fn scalar_to_pyarrow(scalar: &ScalarValue, py: Python) -> PyResult<PyObject> {
let array = scalar.to_array().map_err(PyDataFusionError::from)?;
// convert to pyarrow array using C data interface
let pyarray = array.to_data().to_pyarrow(py)?;
let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;

panic!(
"pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\
HINT: try `pip install pyarrow`\n\
NOTE: On Mac OS, you must compile against a Framework Python \
(default in python.org installers and brew, but not pyenv)\n\
NOTE: On Mac OS, PYO3 might point to incorrect Python library \
path when using virtual environments. Try \
`export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n"
)
}
})
}

#[test]
fn test_roundtrip() {
init_python();

let example_scalars = vec![
ScalarValue::Boolean(Some(true)),
ScalarValue::Int32(Some(23)),
ScalarValue::Float64(Some(12.34)),
ScalarValue::from("Hello!"),
ScalarValue::Date32(Some(1234)),
];

Python::with_gil(|py| {
for scalar in example_scalars.into_iter() {
let scalar = PyScalarValue(scalar);
let result =
PyScalarValue::from_pyarrow_bound(scalar.to_pyarrow(py).unwrap().bind(py))
.unwrap();
assert_eq!(scalar, result);
}
});
}

#[test]
fn test_py_scalar() {
init_python();

// TODO: remove this attribute when bumping pyo3 to v0.23.0
// See: <https://github.com/PyO3/pyo3/blob/v0.23.0/guide/src/migration.md#gil-refs-feature-removed>
#[allow(unexpected_cfgs)]
Python::with_gil(|py| {
let scalar_float = PyScalarValue(ScalarValue::Float64(Some(12.34)));
let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap();
py_run!(py, py_float, "assert py_float == 12.34");

let scalar_string = PyScalarValue(ScalarValue::Utf8(Some("Hello!".to_string())));
let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap();
py_run!(py, py_string, "assert py_string == 'Hello!'");
});
}
Ok(pyscalar)
}

0 comments on commit d904d08

Please sign in to comment.