Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rand expression support #1199

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
// query plan, we need to defer stream initialization to first time execution.
if exec_context.root_op.is_none() {
let start = Instant::now();
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx))
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is interesting. Is there any reason the partition is not used in Comet native physical planner? this is def used in DF physical plan during plan node execution https://github.com/apache/datafusion/blob/main/datafusion/physical-plan/src/execution_plan.rs#L371

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spark partition index is erased when a native DF plan is sent for the execution for some reason : https://github.com/apache/datafusion-comet/blob/main/native/core/src/execution/jni_api.rs#L496

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that I would like to see improved. We currently use partition 0 for each native plan rather than the real partition id.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andygrove Can i do it as a part of this PR or it would be better to create a separate one?

.with_exec_id(exec_context_id);
let (scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
Expand Down
11 changes: 10 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ use datafusion_comet_proto::{
},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
};
use datafusion_comet_spark_expr::rand::RandExpr;
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation,
Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField,
Expand Down Expand Up @@ -128,6 +129,7 @@ pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
pub struct PhysicalPlanner {
// The execution context id of this planner.
exec_context_id: i64,
partition: i32,
execution_props: ExecutionProps,
session_ctx: Arc<SessionContext>,
}
Expand All @@ -138,17 +140,19 @@ impl Default for PhysicalPlanner {
let execution_props = ExecutionProps::new();
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
partition: 0,
execution_props,
session_ctx,
}
}
}

impl PhysicalPlanner {
pub fn new(session_ctx: Arc<SessionContext>) -> Self {
pub fn new(session_ctx: Arc<SessionContext>, partition: i32) -> Self {
let execution_props = ExecutionProps::new();
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
partition,
execution_props,
session_ctx,
}
Expand All @@ -157,6 +161,7 @@ impl PhysicalPlanner {
pub fn with_exec_id(self, exec_context_id: i64) -> Self {
Self {
exec_context_id,
partition: self.partition,
execution_props: self.execution_props,
session_ctx: Arc::clone(&self.session_ctx),
}
Expand Down Expand Up @@ -735,6 +740,10 @@ impl PhysicalPlanner {
));
Ok(array_has_expr)
}
ExprStruct::Rand(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(RandExpr::new(child, self.partition)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ message Expr {
BinaryExpr array_append = 58;
ArrayInsert array_insert = 59;
BinaryExpr array_contains = 60;
UnaryExpr rand = 61;
}
}

Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub use normalize_nan::NormalizeNaNAndZero;
mod variance;
pub use variance::Variance;
mod comet_scalar_funcs;
pub mod rand;

pub use cast::{spark_cast, Cast, SparkCastOptions};
pub use comet_scalar_funcs::create_comet_physical_fun;
pub use error::{SparkError, SparkResult};
Expand Down
272 changes: 272 additions & 0 deletions native/spark-expr/src/rand.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::spark_hash::spark_compatible_murmur3_hash;
use arrow_array::builder::Float64Builder;
use arrow_array::{Float64Array, RecordBatch};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use std::any::Any;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};

/// Adoption of the XOR-shift algorithm used in Apache Spark.
/// See: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala

/// Normalization multiplier used in mapping from a random i64 value to the f64 interval [0.0, 1.0).
/// Corresponds to the java implementation: https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/util/Random.java#L302)
/// Due to the lack of hexadecimal float literals support in rust, the scientific notation is used instead.
const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;

/// Spark-compatible initial seed which is actually a part of the scala standard library murmurhash3 implementation.
/// The references:
/// https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala#L63
/// https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/hashing/MurmurHash3.scala#L331
const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
akupchinskiy marked this conversation as resolved.
Show resolved Hide resolved

#[derive(Debug, Clone)]
struct XorShiftRandom {
seed: i64,
}

impl XorShiftRandom {
fn from_init_seed(init_seed: i64) -> Self {
XorShiftRandom {
seed: Self::init_seed(init_seed),
}
}

fn from_stored_seed(stored_seed: i64) -> Self {
XorShiftRandom { seed: stored_seed }
}

fn next(&mut self, bits: u8) -> i32 {
let mut next_seed = self.seed ^ (self.seed << 21);
next_seed ^= ((next_seed as u64) >> 35) as i64;
next_seed ^= next_seed << 4;
self.seed = next_seed;
(next_seed & ((1i64 << bits) - 1)) as i32
}

pub fn next_f64(&mut self) -> f64 {
let a = self.next(26) as i64;
let b = self.next(27) as i64;
((a << 27) + b) as f64 * DOUBLE_UNIT
}

fn init_seed(init: i64) -> i64 {
let bytes_repr = init.to_be_bytes();
let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64)
}
}

#[derive(Debug)]
pub struct RandExpr {
seed: Arc<dyn PhysicalExpr>,
init_seed_shift: i32,
state_holder: Arc<Mutex<Option<i64>>>,
}

impl RandExpr {
pub fn new(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Self {
Self {
seed,
init_seed_shift,
state_holder: Arc::new(Mutex::new(None::<i64>)),
}
}

fn extract_init_state(seed: ScalarValue) -> Result<i64> {
if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? {
Ok(seed_opt.unwrap_or(0))
} else {
Err(DataFusionError::Internal(
"unexpected execution branch".to_string(),
))
}
}
fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result<ColumnarValue> {
let mut seed_state = self.state_holder.lock().unwrap();
let mut rnd = if seed_state.is_none() {
let init_seed = RandExpr::extract_init_state(seed)?;
let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64);
*seed_state = Some(init_seed);
XorShiftRandom::from_init_seed(init_seed)
} else {
let stored_seed = seed_state.unwrap();
XorShiftRandom::from_stored_seed(stored_seed)
};

let mut arr_builder = Float64Builder::with_capacity(num_rows);
std::iter::repeat_with(|| rnd.next_f64())
.take(num_rows)
.for_each(|v| arr_builder.append_value(v));
let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
*seed_state = Some(rnd.seed);
Ok(ColumnarValue::Array(array_ref))
}
}

impl Display for RandExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RAND({})", self.seed)
}
}

impl PartialEq<dyn Any> for RandExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.seed.eq(&x.seed))
.unwrap_or(false)
}
}

impl PhysicalExpr for RandExpr {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::Float64)
}

fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(false)
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
match self.seed.evaluate(batch)? {
ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()),
ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!(
"Only literal seeds are supported for {}",
self
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.seed]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(RandExpr::new(
Arc::clone(&children[0]),
self.init_seed_shift,
)))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.children().hash(&mut s);
}
}

pub fn rand(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(RandExpr::new(seed, init_seed_shift)))
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::{array::StringArray, compute::concat, datatypes::*};
use arrow_array::{Array, BooleanArray, Float64Array, Int64Array};
use datafusion_common::cast::as_float64_array;
use datafusion_physical_expr::expressions::lit;

const SPARK_SEED_42_FIRST_5: [f64; 5] = [
0.619189370225301,
0.5096018842446481,
0.8325259388871524,
0.26322809041172357,
0.6702867696264135,
];

#[test]
fn test_rand_single_batch() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
let rand_expr = rand(lit(42), 0)?;
let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_float64_array(&result)?;
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
assert_eq!(result, expected);
Ok(())
}

#[test]
fn test_rand_multi_batch() -> Result<()> {
let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
let first_batch_data = Int64Array::from(vec![Some(42), None]);
let second_batch_schema = first_batch_schema.clone();
let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
let rand_expr = rand(lit(42), 0)?;
let first_batch = RecordBatch::try_new(
Arc::new(first_batch_schema),
vec![Arc::new(first_batch_data)],
)?;
let first_batch_result = rand_expr
.evaluate(&first_batch)?
.into_array(first_batch.num_rows())?;
let second_batch = RecordBatch::try_new(
Arc::new(second_batch_schema),
vec![Arc::new(second_batch_data)],
)?;
let second_batch_result = rand_expr
.evaluate(&second_batch)?
.into_array(second_batch.num_rows())?;
let result_arrays: Vec<&dyn Array> = vec![
as_float64_array(&first_batch_result)?,
as_float64_array(&second_batch_result)?,
];
let result_arrays = &concat(&result_arrays)?;
let final_result = as_float64_array(result_arrays)?;
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
assert_eq!(final_result, expected);
Ok(())
}

#[test]
fn test_overflow_shift_seed() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
let data = BooleanArray::from(vec![Some(true), Some(false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?;
let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?;
let first_expr_result = max_seed_and_shift_expr
.evaluate(&batch)?
.into_array(batch.num_rows())?;
let first_expr_result = as_float64_array(&first_expr_result)?;
let second_expr_result = min_seed_no_shift_expr
.evaluate(&batch)?
.into_array(batch.num_rows())?;
let second_expr_result = as_float64_array(&second_expr_result)?;
assert_eq!(first_expr_result, second_expr_result);
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expr.children(1),
inputs,
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))

case Rand(child, _) =>
createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setRand(unaryExpr))

case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
20 changes: 20 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2529,4 +2529,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}
}

test("rand expression with random parameters") {
val partitionsNumber = Random.nextInt(10) + 1
val rowsNumber = Random.nextInt(500)
val seed = Random.nextLong()
// use this value to have both single-batch and multi-batch partitions
val cometBatchSize = math.max(1, math.ceil(rowsNumber.toDouble / partitionsNumber).toInt)
withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) {
withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df =>
val dfWithRandParameters = df.repartition(partitionsNumber).withColumn("rnd", rand(seed))
checkSparkAnswer(dfWithRandParameters)
val dfWithOverflowSeed =
df.repartition(partitionsNumber).withColumn("rnd", rand(Long.MaxValue))
checkSparkAnswer(dfWithOverflowSeed)
val dfWithNullSeed =
df.repartition(partitionsNumber).selectExpr("_1", "rand(null) as rnd")
checkSparkAnswer(dfWithNullSeed)
}
}
}
}