diff --git a/native/core/benches/parquet_read.rs b/native/core/benches/parquet_read.rs index 1f8178cd2..13f21612f 100644 --- a/native/core/benches/parquet_read.rs +++ b/native/core/benches/parquet_read.rs @@ -213,6 +213,6 @@ impl Iterator for TestColumnReader { } self.total_num_values_read += total; - Some(self.inner.current_batch()) + Some(self.inner.current_batch().unwrap()) } } diff --git a/native/core/src/common/buffer.rs b/native/core/src/common/buffer.rs index f24038a95..291082d10 100644 --- a/native/core/src/common/buffer.rs +++ b/native/core/src/common/buffer.rs @@ -16,6 +16,7 @@ // under the License. use crate::common::bit; +use crate::execution::operators::ExecutionError; use arrow::buffer::Buffer as ArrowBuffer; use std::{ alloc::{handle_alloc_error, Layout}, @@ -43,6 +44,8 @@ pub struct CometBuffer { capacity: usize, /// Whether this buffer owns the data it points to. owned: bool, + /// The allocation instance for this buffer. + allocation: Arc, } unsafe impl Sync for CometBuffer {} @@ -63,6 +66,7 @@ impl CometBuffer { len: aligned_capacity, capacity: aligned_capacity, owned: true, + allocation: Arc::new(CometBufferAllocation::new()), } } } @@ -84,6 +88,7 @@ impl CometBuffer { len, capacity, owned: false, + allocation: Arc::new(CometBufferAllocation::new()), } } @@ -163,11 +168,28 @@ impl CometBuffer { /// because of the iterator-style pattern, the content of the original mutable buffer will only /// be updated once upstream operators fully consumed the previous output batch. For breaking /// operators, they are responsible for copying content out of the buffers. - pub unsafe fn to_arrow(&self) -> ArrowBuffer { + pub unsafe fn to_arrow(&self) -> Result { let ptr = NonNull::new_unchecked(self.data.as_ptr()); - // Uses a dummy `Arc::new(0)` as `Allocation` to ensure the memory region pointed by - // `ptr` won't be freed when the returned `ArrowBuffer` goes out of scope. - ArrowBuffer::from_custom_allocation(ptr, self.len, Arc::new(0)) + self.check_reference()?; + Ok(ArrowBuffer::from_custom_allocation( + ptr, + self.len, + Arc::::clone(&self.allocation), + )) + } + + /// Checks if this buffer is exclusively owned by Comet. If not, an error is returned. + /// We run this check when we want to update the buffer. If the buffer is also shared by + /// other components, e.g. one DataFusion operator stores the buffer, Comet cannot safely + /// modify the buffer. + pub fn check_reference(&self) -> Result<(), ExecutionError> { + if Arc::strong_count(&self.allocation) > 1 { + Err(ExecutionError::GeneralError( + "Error on modifying a buffer which is not exclusively owned by Comet".to_string(), + )) + } else { + Ok(()) + } } /// Resets this buffer by filling all bytes with zeros. @@ -242,13 +264,6 @@ impl PartialEq for CometBuffer { } } -impl From<&ArrowBuffer> for CometBuffer { - fn from(value: &ArrowBuffer) -> Self { - assert_eq!(value.len(), value.capacity()); - CometBuffer::from_ptr(value.as_ptr(), value.len(), value.capacity()) - } -} - impl std::ops::Deref for CometBuffer { type Target = [u8]; @@ -264,6 +279,15 @@ impl std::ops::DerefMut for CometBuffer { } } +#[derive(Debug)] +struct CometBufferAllocation {} + +impl CometBufferAllocation { + fn new() -> Self { + Self {} + } +} + #[cfg(test)] mod tests { use super::*; @@ -319,7 +343,7 @@ mod tests { assert_eq!(b"aaaa bbbb cccc dddd", &buf.as_slice()[0..str.len()]); unsafe { - let immutable_buf: ArrowBuffer = buf.to_arrow(); + let immutable_buf: ArrowBuffer = buf.to_arrow().unwrap(); assert_eq!(64, immutable_buf.len()); assert_eq!(str, &immutable_buf.as_slice()[0..str.len()]); } @@ -335,7 +359,7 @@ mod tests { assert_eq!(b"hello comet", &buf.as_slice()[0..11]); unsafe { - let arrow_buf2 = buf.to_arrow(); + let arrow_buf2 = buf.to_arrow().unwrap(); assert_eq!(arrow_buf, arrow_buf2); } } diff --git a/native/core/src/execution/operators/copy.rs b/native/core/src/execution/operators/copy.rs index 8eeda8a5a..3a3a47717 100644 --- a/native/core/src/execution/operators/copy.rs +++ b/native/core/src/execution/operators/copy.rs @@ -258,7 +258,10 @@ fn copy_array(array: &dyn Array) -> ArrayRef { /// is a dictionary array, we will cast the dictionary array to primitive type /// (i.e., unpack the dictionary array) and copy the primitive array. If the input /// array is a primitive array, we simply copy the array. -fn copy_or_unpack_array(array: &Arc, mode: &CopyMode) -> Result { +pub(crate) fn copy_or_unpack_array( + array: &Arc, + mode: &CopyMode, +) -> Result { match array.data_type() { DataType::Dictionary(_, value_type) => { let options = CastOptions::default(); diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index 455f19929..da8a3ca15 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -543,7 +543,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch( try_unwrap_or_throw(&e, |_env| { let ctx = get_context(handle)?; let reader = &mut ctx.column_reader; - let data = reader.current_batch(); + let data = reader.current_batch()?; data.move_to_spark(array_addr, schema_addr) .map_err(|e| e.into()) }) diff --git a/native/core/src/parquet/mutable_vector.rs b/native/core/src/parquet/mutable_vector.rs index 7f30d7d87..bacb2c7a3 100644 --- a/native/core/src/parquet/mutable_vector.rs +++ b/native/core/src/parquet/mutable_vector.rs @@ -18,6 +18,7 @@ use arrow::{array::ArrayData, datatypes::DataType as ArrowDataType}; use crate::common::{bit, CometBuffer}; +use crate::execution::operators::ExecutionError; const DEFAULT_ARRAY_LEN: usize = 4; @@ -192,7 +193,7 @@ impl ParquetMutableVector { /// This method is highly unsafe since it calls `CometBuffer::to_arrow` which leaks raw /// pointer to the memory region that are tracked by `CometBuffer`. Please see comments on /// `to_arrow` buffer to understand the motivation. - pub fn get_array_data(&mut self) -> ArrayData { + pub fn get_array_data(&mut self) -> Result { unsafe { let data_type = if let Some(d) = &self.dictionary { ArrowDataType::Dictionary( @@ -204,20 +205,19 @@ impl ParquetMutableVector { }; let mut builder = ArrayData::builder(data_type) .len(self.num_values) - .add_buffer(self.value_buffer.to_arrow()) - .null_bit_buffer(Some(self.validity_buffer.to_arrow())) + .add_buffer(self.value_buffer.to_arrow()?) + .null_bit_buffer(Some(self.validity_buffer.to_arrow()?)) .null_count(self.num_nulls); if Self::is_binary_type(&self.arrow_type) && self.dictionary.is_none() { let child = &mut self.children[0]; - builder = builder.add_buffer(child.value_buffer.to_arrow()); + builder = builder.add_buffer(child.value_buffer.to_arrow()?); } if let Some(d) = &mut self.dictionary { - builder = builder.add_child_data(d.get_array_data()); + builder = builder.add_child_data(d.get_array_data()?); } - - builder.build_unchecked() + Ok(builder.build_unchecked()) } } diff --git a/native/core/src/parquet/read/column.rs b/native/core/src/parquet/read/column.rs index 73f8df956..3dc19db62 100644 --- a/native/core/src/parquet/read/column.rs +++ b/native/core/src/parquet/read/column.rs @@ -39,6 +39,7 @@ use super::{ }; use crate::common::{bit, bit::log2}; +use crate::execution::operators::ExecutionError; /// Maximum number of decimal digits an i32 can represent const DECIMAL_MAX_INT_DIGITS: i32 = 9; @@ -601,7 +602,7 @@ impl ColumnReader { } #[inline] - pub fn current_batch(&mut self) -> ArrayData { + pub fn current_batch(&mut self) -> Result { make_func_mut!(self, current_batch) } @@ -684,7 +685,7 @@ impl TypedColumnReader { /// Note: the caller must make sure the returned Arrow vector is fully consumed before calling /// `read_batch` again. #[inline] - pub fn current_batch(&mut self) -> ArrayData { + pub fn current_batch(&mut self) -> Result { self.vector.get_array_data() } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 99007d0c9..3ace67301 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -67,8 +67,10 @@ class CometExecSuite extends CometTestBase { test("TopK operator should return correct results on dictionary column with nulls") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTable("test_data") { + val data = (0 to 8000) + .flatMap(_ => Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B"))) val tableDF = spark.sparkContext - .parallelize(Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B")), 3) + .parallelize(data, 3) .toDF("c1", "c2", "c3") tableDF .coalesce(1)