Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement bloom_filter_agg #987

Merged
merged 28 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e662814
Add test that invokes bloom_filter_agg.
mbutrovich Sep 25, 2024
20f6e67
QueryPlanSerde support for BloomFilterAgg.
mbutrovich Sep 25, 2024
1ec31a2
Add bloom_filter_agg based on sample UDAF. planner instantiates it no…
mbutrovich Sep 27, 2024
3965dc4
Partial work on Accumulator. Need to finish merge_batch and state.
mbutrovich Sep 27, 2024
62e656c
BloomFilterAgg state, merge_state, and evaluate. Need more tests.
mbutrovich Sep 30, 2024
33ef47d
Matches Spark behavior. Need to clean up the code quite a bit, and do…
mbutrovich Sep 30, 2024
2040c76
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Sep 30, 2024
599a8f9
Remove old comment.
mbutrovich Sep 30, 2024
a2a8cf3
Clippy. Increase bloom filter size back to Spark's default.
mbutrovich Sep 30, 2024
22aedd9
API cleanup.
mbutrovich Sep 30, 2024
bf22902
API cleanup.
mbutrovich Oct 1, 2024
4b7000c
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 2, 2024
88adc75
Add BloomFilterAgg benchmark to CometExecBenchmark
mbutrovich Oct 2, 2024
a21e0e3
Docs.
mbutrovich Oct 2, 2024
5c5d0f9
API cleanup, fix merge_bits to update cardinality.
mbutrovich Oct 2, 2024
cd107e3
Refactor merge_bits to update bit_count with the bit merging.
mbutrovich Oct 2, 2024
4f06098
Remove benchmark results file.
mbutrovich Oct 2, 2024
79f6468
Docs.
mbutrovich Oct 2, 2024
57fe742
Add native side benchmarks.
mbutrovich Oct 2, 2024
ec64e4c
Adjust benchmark parameters to match Spark defaults.
mbutrovich Oct 2, 2024
7a81f35
Address review feedback.
mbutrovich Oct 2, 2024
013513e
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 3, 2024
3347923
Add assertion to merge_batch.
mbutrovich Oct 4, 2024
5c82f24
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 9, 2024
c39ff1d
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 13, 2024
1ed99e3
Address some review feedback.
mbutrovich Oct 17, 2024
d41a9d2
Only generate native BloomFilterAgg if child has LongType.
mbutrovich Oct 18, 2024
6d13890
Add TODO with GitHub issue link.
mbutrovich Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,7 @@ harness = false
[[bench]]
name = "aggregate"
harness = false

[[bench]]
name = "bloom_filter_agg"
harness = false
162 changes: 162 additions & 0 deletions native/core/benches/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder};

use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::builder::Int64Builder;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::SchemaRef;
use comet::execution::datafusion::expressions::bloom_filter_agg::BloomFilterAgg;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::ScalarValue;
use datafusion_execution::TaskContext;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{Column, Literal};
use futures::StreamExt;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Runtime;

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("bloom_filter_agg");
let num_rows = 8192;
let batch = create_record_batch(num_rows);
let mut batches = Vec::new();
for _ in 0..10 {
batches.push(batch.clone());
}
let partitions = &[batches];
let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
// spark.sql.optimizer.runtime.bloomFilter.expectedNumItems
let num_items_sv = ScalarValue::Int64(Some(1000000_i64));
let num_items: Arc<dyn PhysicalExpr> = Arc::new(Literal::new(num_items_sv));
//spark.sql.optimizer.runtime.bloomFilter.numBits
let num_bits_sv = ScalarValue::Int64(Some(8388608_i64));
let num_bits: Arc<dyn PhysicalExpr> = Arc::new(Literal::new(num_bits_sv));

let rt = Runtime::new().unwrap();

for agg_mode in [
("partial_agg", AggregateMode::Partial),
("single_agg", AggregateMode::Single),
] {
group.bench_function(agg_mode.0, |b| {
let comet_bloom_filter_agg =
Arc::new(AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&c0),
Arc::clone(&num_items),
Arc::clone(&num_bits),
"bloom_filter_agg",
DataType::Binary,
)));
b.to_async(&rt).iter(|| {
black_box(agg_test(
partitions,
c0.clone(),
comet_bloom_filter_agg.clone(),
"bloom_filter_agg",
agg_mode.1,
))
})
});
}

group.finish();
}

async fn agg_test(
partitions: &[Vec<RecordBatch>],
c0: Arc<dyn PhysicalExpr>,
aggregate_udf: Arc<AggregateUDF>,
alias: &str,
mode: AggregateMode,
) {
let schema = &partitions[0][0].schema();
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap());
let aggregate = create_aggregate(scan, c0.clone(), schema, aggregate_udf, alias, mode);
let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch.unwrap();
}
}

fn create_aggregate(
scan: Arc<dyn ExecutionPlan>,
c0: Arc<dyn PhysicalExpr>,
schema: &SchemaRef,
aggregate_udf: Arc<AggregateUDF>,
alias: &str,
mode: AggregateMode,
) -> Arc<AggregateExec> {
let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c0.clone()])
.schema(schema.clone())
.alias(alias)
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.unwrap();

Arc::new(
AggregateExec::try_new(
mode,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr],
vec![None],
scan,
Arc::clone(schema),
)
.unwrap(),
)
}

fn create_record_batch(num_rows: usize) -> RecordBatch {
let mut int64_builder = Int64Builder::with_capacity(num_rows);
for i in 0..num_rows {
int64_builder.append_value(i as i64);
}
let int64_array = Arc::new(int64_builder.finish());

let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];

// int64 column
fields.push(Field::new("c0", DataType::Int64, false));
columns.push(int64_array);

let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}

fn config() -> Criterion {
Criterion::default()
.measurement_time(Duration::from_millis(500))
.warm_up_time(Duration::from_millis(500))
}

criterion_group! {
name = benches;
config = config();
targets = criterion_benchmark
}
criterion_main!(benches);
143 changes: 143 additions & 0 deletions native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Field;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use std::{any::Any, sync::Arc};

use crate::execution::datafusion::util::spark_bloom_filter;
use crate::execution::datafusion::util::spark_bloom_filter::SparkBloomFilter;
use arrow::array::ArrayRef;
use arrow_array::BinaryArray;
use datafusion::error::Result;
use datafusion::physical_expr::PhysicalExpr;
use datafusion_common::{downcast_value, DataFusionError, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDFImpl, Signature,
};
use datafusion_physical_expr::expressions::Literal;

#[derive(Debug, Clone)]
pub struct BloomFilterAgg {
name: String,
signature: Signature,
expr: Arc<dyn PhysicalExpr>,
num_items: i32,
num_bits: i32,
}

fn extract_i32_from_literal(expr: Arc<dyn PhysicalExpr>) -> i32 {
mbutrovich marked this conversation as resolved.
Show resolved Hide resolved
match expr.as_any().downcast_ref::<Literal>().unwrap().value() {
ScalarValue::Int64(scalar_value) => scalar_value.unwrap() as i32,
_ => {
unreachable!()
}
}
}

impl BloomFilterAgg {
pub fn new(
expr: Arc<dyn PhysicalExpr>,
num_items: Arc<dyn PhysicalExpr>,
num_bits: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
assert!(matches!(data_type, DataType::Binary));
Self {
name: name.into(),
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
expr,
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
}
}
}

impl AggregateUDFImpl for BloomFilterAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"bloom_filter_agg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SparkBloomFilter::from((
spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits),
self.num_bits,
))))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new("bits", DataType::Binary, false)])
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}
}

impl Accumulator for SparkBloomFilter {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Int64(Some(value)) = v {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only supports Int64? Spark BloomFilterAggregate supports Byte, Short, Int, Long and String. If Comet BloomFilterAggregate only support Int64 for now. We need to fallback to Spark for other cases in QueryPlanSerde.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was going off of their docs which say it only supports Long.

In their implementation, however, if looks like they can cast the fixed width types directly to Long
https://github.com/apache/spark/blob/b078c0d6e2adf7eb0ee7d4742a6c52864440226e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala#L238

and for strings their bloom filter implementation has a putBinary method that we don't currently support. The casts should be easy. I'll look at what putBinary on our bloom filter implementation will take.

Copy link
Contributor Author

@mbutrovich mbutrovich Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what happened. 3.4 only supports Long, which was the Spark source I was working off of. 3.5 added support for other types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified it to only generate a native BloomFilterAgg if the child has LongType. I'll open an issue to support more types in the future.

self.put_long(value);
} else {
unreachable!()
}
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Binary(Some(self.spark_serialization())))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
// There might be a more efficient way to do this by transmuting since calling state() on an
// Accumulator is considered destructive.
let state_sv = ScalarValue::Binary(Some(self.state_as_bytes()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to avoid the copy, which may be too ugly , would be to store bloom filter data as an Option<>

So instead of

pub struct SparkBloomFilter {
    bits: SparkBitArray,
    num_hash_functions: u32,
}

Something like

pub struct SparkBloomFilter {
    bits: Option<SparkBitArray>
    num_hash_functions: u32,
}

And then you could basically use Option::take to take the value and leave a None in its place

let Some(bits) = self.bits.take() else {
  return Err(invalid state)
};

// do whatever you want now you have the owned `bits`

Ok(vec![state_sv])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let state_sv = downcast_value!(states[0], BinaryArray);
self.merge_filter(state_sv.value_data());
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn evaluate_bloom_filter(
let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
match bloom_filter_bytes {
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
Ok(v.map(|v| SparkBloomFilter::from(v.as_bytes())))
}
_ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"),
}
Expand Down
1 change: 1 addition & 0 deletions native/core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
use crate::errors::CometError;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_agg;
pub mod bloom_filter_might_contain;
pub mod comet_scalar_funcs;
pub mod correlation;
Expand Down
17 changes: 17 additions & 0 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
avg::Avg,
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_agg::BloomFilterAgg,
bloom_filter_might_contain::BloomFilterMightContain,
checkoverflow::CheckOverflow,
correlation::Correlation,
Expand Down Expand Up @@ -1611,6 +1612,22 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func)
}
AggExprStruct::BloomFilterAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let num_items =
self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?;
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&child),
Arc::clone(&num_items),
Arc::clone(&num_bits),
"bloom_filter_agg",
datatype,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
}
}

Expand Down
Loading
Loading