Skip to content

Commit

Permalink
chore: Add safety check to CometBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 4, 2024
1 parent ac4223c commit 5ccbb7d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 26 deletions.
2 changes: 1 addition & 1 deletion native/core/benches/parquet_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,6 @@ impl Iterator for TestColumnReader {
}
self.total_num_values_read += total;

Some(self.inner.current_batch())
Some(self.inner.current_batch().unwrap())
}
}
50 changes: 37 additions & 13 deletions native/core/src/common/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::common::bit;
use crate::execution::operators::ExecutionError;
use arrow::buffer::Buffer as ArrowBuffer;
use std::{
alloc::{handle_alloc_error, Layout},
Expand Down Expand Up @@ -43,6 +44,8 @@ pub struct CometBuffer {
capacity: usize,
/// Whether this buffer owns the data it points to.
owned: bool,
/// The allocation instance for this buffer.
allocation: Arc<CometBufferAllocation>,
}

unsafe impl Sync for CometBuffer {}
Expand All @@ -63,6 +66,7 @@ impl CometBuffer {
len: aligned_capacity,
capacity: aligned_capacity,
owned: true,
allocation: Arc::new(CometBufferAllocation::new()),
}
}
}
Expand All @@ -84,6 +88,7 @@ impl CometBuffer {
len,
capacity,
owned: false,
allocation: Arc::new(CometBufferAllocation::new()),
}
}

Expand Down Expand Up @@ -163,11 +168,28 @@ impl CometBuffer {
/// because of the iterator-style pattern, the content of the original mutable buffer will only
/// be updated once upstream operators fully consumed the previous output batch. For breaking
/// operators, they are responsible for copying content out of the buffers.
pub unsafe fn to_arrow(&self) -> ArrowBuffer {
pub unsafe fn to_arrow(&self) -> Result<ArrowBuffer, ExecutionError> {
let ptr = NonNull::new_unchecked(self.data.as_ptr());
// Uses a dummy `Arc::new(0)` as `Allocation` to ensure the memory region pointed by
// `ptr` won't be freed when the returned `ArrowBuffer` goes out of scope.
ArrowBuffer::from_custom_allocation(ptr, self.len, Arc::new(0))
self.check_reference()?;
Ok(ArrowBuffer::from_custom_allocation(
ptr,
self.len,
self.allocation.clone(),
))
}

/// Checks if this buffer is exclusively owned by Comet. If not, an error is returned.
/// We run this check when we want to update the buffer. If the buffer is also shared by
/// other components, e.g. one DataFusion operator stores the buffer, Comet cannot safely
/// modify the buffer.
pub fn check_reference(&self) -> Result<(), ExecutionError> {
if Arc::strong_count(&self.allocation) > 1 {
Err(ExecutionError::GeneralError(
"Error on modifying a buffer which is not exclusively owned by Comet".to_string(),
))
} else {
Ok(())
}
}

/// Resets this buffer by filling all bytes with zeros.
Expand Down Expand Up @@ -242,13 +264,6 @@ impl PartialEq for CometBuffer {
}
}

impl From<&ArrowBuffer> for CometBuffer {
fn from(value: &ArrowBuffer) -> Self {
assert_eq!(value.len(), value.capacity());
CometBuffer::from_ptr(value.as_ptr(), value.len(), value.capacity())
}
}

impl std::ops::Deref for CometBuffer {
type Target = [u8];

Expand All @@ -264,6 +279,15 @@ impl std::ops::DerefMut for CometBuffer {
}
}

#[derive(Debug)]
struct CometBufferAllocation {}

impl CometBufferAllocation {
fn new() -> Self {
Self {}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -319,7 +343,7 @@ mod tests {
assert_eq!(b"aaaa bbbb cccc dddd", &buf.as_slice()[0..str.len()]);

unsafe {
let immutable_buf: ArrowBuffer = buf.to_arrow();
let immutable_buf: ArrowBuffer = buf.to_arrow().unwrap();
assert_eq!(64, immutable_buf.len());
assert_eq!(str, &immutable_buf.as_slice()[0..str.len()]);
}
Expand All @@ -335,7 +359,7 @@ mod tests {
assert_eq!(b"hello comet", &buf.as_slice()[0..11]);

unsafe {
let arrow_buf2 = buf.to_arrow();
let arrow_buf2 = buf.to_arrow().unwrap();
assert_eq!(arrow_buf, arrow_buf2);
}
}
Expand Down
5 changes: 4 additions & 1 deletion native/core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,10 @@ fn copy_array(array: &dyn Array) -> ArrayRef {
/// is a dictionary array, we will cast the dictionary array to primitive type
/// (i.e., unpack the dictionary array) and copy the primitive array. If the input
/// array is a primitive array, we simply copy the array.
fn copy_or_unpack_array(array: &Arc<dyn Array>, mode: &CopyMode) -> Result<ArrayRef, ArrowError> {
pub(crate) fn copy_or_unpack_array(
array: &Arc<dyn Array>,
mode: &CopyMode,
) -> Result<ArrayRef, ArrowError> {
match array.data_type() {
DataType::Dictionary(_, value_type) => {
let options = CastOptions::default();
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch(
try_unwrap_or_throw(&e, |_env| {
let ctx = get_context(handle)?;
let reader = &mut ctx.column_reader;
let data = reader.current_batch();
let data = reader.current_batch()?;
data.move_to_spark(array_addr, schema_addr)
.map_err(|e| e.into())
})
Expand Down
14 changes: 7 additions & 7 deletions native/core/src/parquet/mutable_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use arrow::{array::ArrayData, datatypes::DataType as ArrowDataType};

use crate::common::{bit, CometBuffer};
use crate::execution::operators::ExecutionError;

const DEFAULT_ARRAY_LEN: usize = 4;

Expand Down Expand Up @@ -192,7 +193,7 @@ impl ParquetMutableVector {
/// This method is highly unsafe since it calls `CometBuffer::to_arrow` which leaks raw
/// pointer to the memory region that are tracked by `CometBuffer`. Please see comments on
/// `to_arrow` buffer to understand the motivation.
pub fn get_array_data(&mut self) -> ArrayData {
pub fn get_array_data(&mut self) -> Result<ArrayData, ExecutionError> {
unsafe {
let data_type = if let Some(d) = &self.dictionary {
ArrowDataType::Dictionary(
Expand All @@ -204,20 +205,19 @@ impl ParquetMutableVector {
};
let mut builder = ArrayData::builder(data_type)
.len(self.num_values)
.add_buffer(self.value_buffer.to_arrow())
.null_bit_buffer(Some(self.validity_buffer.to_arrow()))
.add_buffer(self.value_buffer.to_arrow()?)
.null_bit_buffer(Some(self.validity_buffer.to_arrow()?))
.null_count(self.num_nulls);

if Self::is_binary_type(&self.arrow_type) && self.dictionary.is_none() {
let child = &mut self.children[0];
builder = builder.add_buffer(child.value_buffer.to_arrow());
builder = builder.add_buffer(child.value_buffer.to_arrow()?);
}

if let Some(d) = &mut self.dictionary {
builder = builder.add_child_data(d.get_array_data());
builder = builder.add_child_data(d.get_array_data()?);
}

builder.build_unchecked()
Ok(builder.build_unchecked())
}
}

Expand Down
5 changes: 3 additions & 2 deletions native/core/src/parquet/read/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use super::{
};

use crate::common::{bit, bit::log2};
use crate::execution::operators::ExecutionError;

/// Maximum number of decimal digits an i32 can represent
const DECIMAL_MAX_INT_DIGITS: i32 = 9;
Expand Down Expand Up @@ -601,7 +602,7 @@ impl ColumnReader {
}

#[inline]
pub fn current_batch(&mut self) -> ArrayData {
pub fn current_batch(&mut self) -> Result<ArrayData, ExecutionError> {
make_func_mut!(self, current_batch)
}

Expand Down Expand Up @@ -684,7 +685,7 @@ impl<T: DataType> TypedColumnReader<T> {
/// Note: the caller must make sure the returned Arrow vector is fully consumed before calling
/// `read_batch` again.
#[inline]
pub fn current_batch(&mut self) -> ArrayData {
pub fn current_batch(&mut self) -> Result<ArrayData, ExecutionError> {
self.vector.get_array_data()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ class CometExecSuite extends CometTestBase {
test("TopK operator should return correct results on dictionary column with nulls") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
withTable("test_data") {
val data = (0 to 8000)
.flatMap(_ => Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B")))
val tableDF = spark.sparkContext
.parallelize(Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B")), 3)
.parallelize(data, 3)
.toDF("c1", "c2", "c3")
tableDF
.coalesce(1)
Expand Down

0 comments on commit 5ccbb7d

Please sign in to comment.