Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add safety check to CometBuffer #1050

Merged
merged 13 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
22 changes: 0 additions & 22 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,28 +172,6 @@ public void close() {

/** Returns a decoded {@link CometDecodedVector Comet vector}. */
public CometDecodedVector loadVector() {
// Only re-use Comet vector iff:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how much performance gain we got from reusing CometVector here. Note that it doesn't skip loading data into buffer in native side. It just reuses JVM CometVector.

This reuse conflicts safety check. Because in CometColumnarToRow, we close columnar vectors to remove buffer reference in JVM after accessing them. loadVector is called after that to load next columnar vector (i.e., CometVector). As it is closed, we cannot reuse it now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though the conditions for reuse are extremely well documented here, it is still confusing to keep track of memory while debugging. I seem to remember that there was some gain from the memory reuse, but if more recent benchmarks show no gain, should we remove this reuse (and simplify management of our buffer lifetimes)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed. As I explained above early, without removing it, we cannot close columnar vectors in CometColumnarToRow and it is required to pass safety check.

// 1. if we're not using dictionary encoding, since with dictionary encoding, the native
// side may fallback to plain encoding and the underlying memory address for the vector
// will change as result.
// 2. if the column type is of fixed width, in other words, string/binary are not supported
// since the native side may resize the vector and therefore change memory address.
// 3. if the last loaded vector contains null values: if values of last vector are all not
// null, Arrow C data API will skip loading the native validity buffer, therefore we
// should not re-use the vector in that case.
// 4. if the last loaded vector doesn't contain any null value, but the current vector also
// are all not null, which means we can also re-use the loaded vector.
// 5. if the new number of value is the same or smaller
if ((hadNull || currentNumNulls == 0)
&& currentVector != null
&& dictionary == null
&& currentVector.isFixedLength()
&& currentVector.numValues() >= currentNumValues) {
currentVector.setNumNulls(currentNumNulls);
currentVector.setNumValues(currentNumValues);
return currentVector;
}

LOG.debug("Reloading vector");

// Close the previous vector first to release struct memory allocated to import Arrow array &
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ public ConstantColumnReader(

public ConstantColumnReader(
DataType type, ColumnDescriptor descriptor, Object value, boolean useDecimal128) {
super(type, descriptor, useDecimal128);
super(type, descriptor, useDecimal128, true);
this.value = value;
}

ConstantColumnReader(
DataType type, ColumnDescriptor descriptor, int batchSize, boolean useDecimal128) {
super(type, descriptor, useDecimal128);
super(type, descriptor, useDecimal128, true);
this.batchSize = batchSize;
initNative();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ public class MetadataColumnReader extends AbstractColumnReader {
private ArrowArray array = null;
private ArrowSchema schema = null;

public MetadataColumnReader(DataType type, ColumnDescriptor descriptor, boolean useDecimal128) {
private boolean isConstant;

public MetadataColumnReader(
DataType type, ColumnDescriptor descriptor, boolean useDecimal128, boolean isConstant) {
// TODO: should we handle legacy dates & timestamps for metadata columns?
super(type, descriptor, useDecimal128, false);

this.isConstant = isConstant;
}

@Override
Expand All @@ -62,7 +67,7 @@ public void readBatch(int total) {

Native.currentBatch(nativeHandle, arrayAddr, schemaAddr);
FieldVector fieldVector = Data.importVector(allocator, array, schema, null);
vector = new CometPlainVector(fieldVector, useDecimal128);
vector = new CometPlainVector(fieldVector, useDecimal128, false, isConstant);
}

vector.setNumValues(total);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class RowIndexColumnReader extends MetadataColumnReader {
private long offset;

public RowIndexColumnReader(StructField field, int batchSize, long[] indices) {
super(field.dataType(), TypeUtil.convertToParquet(field), false);
super(field.dataType(), TypeUtil.convertToParquet(field), false, false);
this.indices = indices;
setBatchSize(batchSize);
}
Expand Down
16 changes: 16 additions & 0 deletions common/src/main/java/org/apache/comet/vector/CometPlainVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,18 @@ public class CometPlainVector extends CometDecodedVector {
private byte booleanByteCache;
private int booleanByteCacheIndex = -1;

private boolean isReused;

public CometPlainVector(ValueVector vector, boolean useDecimal128) {
this(vector, useDecimal128, false);
}

public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUuid) {
this(vector, useDecimal128, isUuid, false);
}

public CometPlainVector(
ValueVector vector, boolean useDecimal128, boolean isUuid, boolean isReused) {
super(vector, vector.getField(), useDecimal128, isUuid);
// NullType doesn't have data buffer.
if (vector instanceof NullVector) {
Expand All @@ -52,6 +59,15 @@ public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUui
}

isBaseFixedWidthVector = valueVector instanceof BaseFixedWidthVector;
this.isReused = isReused;
}

public boolean isReused() {
return isReused;
}

public void setReused(boolean isReused) {
this.isReused = isReused;
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion native/core/benches/parquet_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,6 @@ impl Iterator for TestColumnReader {
}
self.total_num_values_read += total;

Some(self.inner.current_batch())
Some(self.inner.current_batch().unwrap())
}
}
50 changes: 37 additions & 13 deletions native/core/src/common/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::common::bit;
use crate::execution::operators::ExecutionError;
use arrow::buffer::Buffer as ArrowBuffer;
use std::{
alloc::{handle_alloc_error, Layout},
Expand Down Expand Up @@ -43,6 +44,8 @@ pub struct CometBuffer {
capacity: usize,
/// Whether this buffer owns the data it points to.
owned: bool,
/// The allocation instance for this buffer.
allocation: Arc<CometBufferAllocation>,
}

unsafe impl Sync for CometBuffer {}
Expand All @@ -63,6 +66,7 @@ impl CometBuffer {
len: aligned_capacity,
capacity: aligned_capacity,
owned: true,
allocation: Arc::new(CometBufferAllocation::new()),
}
}
}
Expand All @@ -84,6 +88,7 @@ impl CometBuffer {
len,
capacity,
owned: false,
allocation: Arc::new(CometBufferAllocation::new()),
}
}

Expand Down Expand Up @@ -163,11 +168,28 @@ impl CometBuffer {
/// because of the iterator-style pattern, the content of the original mutable buffer will only
/// be updated once upstream operators fully consumed the previous output batch. For breaking
/// operators, they are responsible for copying content out of the buffers.
pub unsafe fn to_arrow(&self) -> ArrowBuffer {
pub unsafe fn to_arrow(&self) -> Result<ArrowBuffer, ExecutionError> {
let ptr = NonNull::new_unchecked(self.data.as_ptr());
// Uses a dummy `Arc::new(0)` as `Allocation` to ensure the memory region pointed by
// `ptr` won't be freed when the returned `ArrowBuffer` goes out of scope.
ArrowBuffer::from_custom_allocation(ptr, self.len, Arc::new(0))
self.check_reference()?;
Ok(ArrowBuffer::from_custom_allocation(
ptr,
self.len,
Arc::<CometBufferAllocation>::clone(&self.allocation),
))
}

/// Checks if this buffer is exclusively owned by Comet. If not, an error is returned.
/// We run this check when we want to update the buffer. If the buffer is also shared by
/// other components, e.g. one DataFusion operator stores the buffer, Comet cannot safely
/// modify the buffer.
pub fn check_reference(&self) -> Result<(), ExecutionError> {
viirya marked this conversation as resolved.
Show resolved Hide resolved
if Arc::strong_count(&self.allocation) > 1 {
Err(ExecutionError::GeneralError(
"Error on modifying a buffer which is not exclusively owned by Comet".to_string(),
))
} else {
Ok(())
}
}

/// Resets this buffer by filling all bytes with zeros.
Expand Down Expand Up @@ -242,13 +264,6 @@ impl PartialEq for CometBuffer {
}
}

impl From<&ArrowBuffer> for CometBuffer {
fn from(value: &ArrowBuffer) -> Self {
assert_eq!(value.len(), value.capacity());
CometBuffer::from_ptr(value.as_ptr(), value.len(), value.capacity())
}
}

impl std::ops::Deref for CometBuffer {
type Target = [u8];

Expand All @@ -264,6 +279,15 @@ impl std::ops::DerefMut for CometBuffer {
}
}

#[derive(Debug)]
struct CometBufferAllocation {}

impl CometBufferAllocation {
fn new() -> Self {
Self {}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -319,7 +343,7 @@ mod tests {
assert_eq!(b"aaaa bbbb cccc dddd", &buf.as_slice()[0..str.len()]);

unsafe {
let immutable_buf: ArrowBuffer = buf.to_arrow();
let immutable_buf: ArrowBuffer = buf.to_arrow().unwrap();
assert_eq!(64, immutable_buf.len());
assert_eq!(str, &immutable_buf.as_slice()[0..str.len()]);
}
Expand All @@ -335,7 +359,7 @@ mod tests {
assert_eq!(b"hello comet", &buf.as_slice()[0..11]);

unsafe {
let arrow_buf2 = buf.to_arrow();
let arrow_buf2 = buf.to_arrow().unwrap();
assert_eq!(arrow_buf, arrow_buf2);
}
}
Expand Down
5 changes: 4 additions & 1 deletion native/core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,10 @@ fn copy_array(array: &dyn Array) -> ArrayRef {
/// is a dictionary array, we will cast the dictionary array to primitive type
/// (i.e., unpack the dictionary array) and copy the primitive array. If the input
/// array is a primitive array, we simply copy the array.
fn copy_or_unpack_array(array: &Arc<dyn Array>, mode: &CopyMode) -> Result<ArrayRef, ArrowError> {
pub(crate) fn copy_or_unpack_array(
array: &Arc<dyn Array>,
mode: &CopyMode,
) -> Result<ArrayRef, ArrowError> {
match array.data_type() {
DataType::Dictionary(_, value_type) => {
let options = CastOptions::default();
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch(
try_unwrap_or_throw(&e, |_env| {
let ctx = get_context(handle)?;
let reader = &mut ctx.column_reader;
let data = reader.current_batch();
let data = reader.current_batch()?;
data.move_to_spark(array_addr, schema_addr)
.map_err(|e| e.into())
})
Expand Down
14 changes: 7 additions & 7 deletions native/core/src/parquet/mutable_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use arrow::{array::ArrayData, datatypes::DataType as ArrowDataType};

use crate::common::{bit, CometBuffer};
use crate::execution::operators::ExecutionError;

const DEFAULT_ARRAY_LEN: usize = 4;

Expand Down Expand Up @@ -192,7 +193,7 @@ impl ParquetMutableVector {
/// This method is highly unsafe since it calls `CometBuffer::to_arrow` which leaks raw
/// pointer to the memory region that are tracked by `CometBuffer`. Please see comments on
/// `to_arrow` buffer to understand the motivation.
pub fn get_array_data(&mut self) -> ArrayData {
pub fn get_array_data(&mut self) -> Result<ArrayData, ExecutionError> {
unsafe {
let data_type = if let Some(d) = &self.dictionary {
ArrowDataType::Dictionary(
Expand All @@ -204,20 +205,19 @@ impl ParquetMutableVector {
};
let mut builder = ArrayData::builder(data_type)
.len(self.num_values)
.add_buffer(self.value_buffer.to_arrow())
.null_bit_buffer(Some(self.validity_buffer.to_arrow()))
.add_buffer(self.value_buffer.to_arrow()?)
.null_bit_buffer(Some(self.validity_buffer.to_arrow()?))
.null_count(self.num_nulls);

if Self::is_binary_type(&self.arrow_type) && self.dictionary.is_none() {
let child = &mut self.children[0];
builder = builder.add_buffer(child.value_buffer.to_arrow());
builder = builder.add_buffer(child.value_buffer.to_arrow()?);
}

if let Some(d) = &mut self.dictionary {
builder = builder.add_child_data(d.get_array_data());
builder = builder.add_child_data(d.get_array_data()?);
}

builder.build_unchecked()
Ok(builder.build_unchecked())
}
}

Expand Down
5 changes: 3 additions & 2 deletions native/core/src/parquet/read/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use super::{
};

use crate::common::{bit, bit::log2};
use crate::execution::operators::ExecutionError;

/// Maximum number of decimal digits an i32 can represent
const DECIMAL_MAX_INT_DIGITS: i32 = 9;
Expand Down Expand Up @@ -601,7 +602,7 @@ impl ColumnReader {
}

#[inline]
pub fn current_batch(&mut self) -> ArrayData {
pub fn current_batch(&mut self) -> Result<ArrayData, ExecutionError> {
make_func_mut!(self, current_batch)
}

Expand Down Expand Up @@ -684,7 +685,7 @@ impl<T: DataType> TypedColumnReader<T> {
/// Note: the caller must make sure the returned Arrow vector is fully consumed before calling
/// `read_batch` again.
#[inline]
pub fn current_batch(&mut self) -> ArrayData {
pub fn current_batch(&mut self) -> Result<ArrayData, ExecutionError> {
self.vector.get_array_data()
}

Expand Down
10 changes: 10 additions & 0 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class CometExecIterator(
}

private var nextBatch: Option[ColumnarBatch] = None
private var prevBatch: ColumnarBatch = null
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false

Expand All @@ -95,6 +96,14 @@ class CometExecIterator(
return true
}

// Close previous batch if any.
// This is to guarantee safety at the native side before we overwrite the buffer memory
// shared across batches in the native side.
if (prevBatch != null) {
prevBatch.close()
prevBatch = null
}

nextBatch = getNextBatch()

if (nextBatch.isEmpty) {
Expand All @@ -117,6 +126,7 @@ class CometExecIterator(
}

currentBatch = nextBatch.get
prevBatch = currentBatch
nextBatch = None
currentBatch
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,17 @@ class CometSparkSessionExtensions
override def apply(plan: SparkPlan): SparkPlan = {
val eliminatedPlan = plan transformUp {
case ColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child
case c @ ColumnarToRowExec(child) =>
val op = CometColumnarToRowExec(child)
if (c.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
c.logicalLink.foreach(op.setLogicalLink)
}
op
case CometColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) =>
sparkToColumnar.child
case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child
// Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the
// shuffle takes row-based input.
Expand All @@ -1100,6 +1111,8 @@ class CometSparkSessionExtensions
eliminatedPlan match {
case ColumnarToRowExec(child: CometCollectLimitExec) =>
child
case CometColumnarToRowExec(child: CometCollectLimitExec) =>
child
case other =>
other
}
Expand Down
Loading
Loading