diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index f9d3a7809e..eec8f2b8f2 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -295,6 +295,14 @@ object CometConf extends ShimCometConf { .intConf .createWithDefault(1) + val COMET_SHUFFLE_ENABLE_FAST_ENCODING: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.enableFastEncoding") + .doc("Whether to enable Comet's faster proprietary encoding for shuffle blocks " + + "rather than using Arrow IPC.") + .internal() + .booleanConf + .createWithDefault(true) + val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.columnar.shuffle.async.enabled") .doc("Whether to enable asynchronous shuffle for Arrow-based shuffle.") diff --git a/native/core/benches/row_columnar.rs b/native/core/benches/row_columnar.rs index a62574111b..a052ab2b1a 100644 --- a/native/core/benches/row_columnar.rs +++ b/native/core/benches/row_columnar.rs @@ -79,6 +79,7 @@ fn benchmark(c: &mut Criterion) { 0, None, &CompressionCodec::Zstd(1), + true, ) .unwrap(); }); diff --git a/native/core/benches/shuffle_writer.rs b/native/core/benches/shuffle_writer.rs index 0d22c62cc2..9e47949c9e 100644 --- a/native/core/benches/shuffle_writer.rs +++ b/native/core/benches/shuffle_writer.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::builder::Int32Builder; +use arrow_array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; use arrow_array::{builder::StringBuilder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::shuffle::{write_ipc_compressed, CompressionCodec, ShuffleWriterExec}; +use comet::execution::shuffle::{CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::physical_plan::metrics::Time; use datafusion::{ @@ -31,67 +31,56 @@ use std::sync::Arc; use tokio::runtime::Runtime; fn criterion_benchmark(c: &mut Criterion) { + let batch = create_batch(8192, true); let mut group = c.benchmark_group("shuffle_writer"); - group.bench_function("shuffle_writer: encode (no compression))", |b| { - let batch = create_batch(8192, true); - let mut buffer = vec![]; - let ipc_time = Time::default(); - b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::None, &ipc_time) - }); - }); - group.bench_function("shuffle_writer: encode and compress (snappy)", |b| { - let batch = create_batch(8192, true); - let mut buffer = vec![]; - let ipc_time = Time::default(); - b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Snappy, &ipc_time) - }); - }); - group.bench_function("shuffle_writer: encode and compress (lz4)", |b| { - let batch = create_batch(8192, true); - let mut buffer = vec![]; - let ipc_time = Time::default(); - b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Lz4Frame, &ipc_time) - }); - }); - group.bench_function("shuffle_writer: encode and compress (zstd level 1)", |b| { - let batch = create_batch(8192, true); - let mut buffer = vec![]; - let ipc_time = Time::default(); - b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Zstd(1), &ipc_time) - }); - }); - group.bench_function("shuffle_writer: encode and compress (zstd level 6)", |b| { - let batch = create_batch(8192, true); - let mut buffer = vec![]; - let ipc_time = Time::default(); - b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Zstd(6), &ipc_time) - }); - }); - group.bench_function("shuffle_writer: end to end", |b| { - let ctx = SessionContext::new(); - let exec = create_shuffle_writer_exec(CompressionCodec::Zstd(1)); - b.iter(|| { - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx).unwrap(); - let rt = Runtime::new().unwrap(); - criterion::black_box(rt.block_on(collect(stream)).unwrap()); - }); - }); + for compression_codec in &[ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Snappy, + CompressionCodec::Zstd(1), + CompressionCodec::Zstd(6), + ] { + for enable_fast_encoding in [true, false] { + let name = format!("shuffle_writer: write encoded (enable_fast_encoding={enable_fast_encoding}, compression={compression_codec:?})"); + group.bench_function(name, |b| { + let mut buffer = vec![]; + let ipc_time = Time::default(); + let w = ShuffleBlockWriter::try_new( + &batch.schema(), + enable_fast_encoding, + compression_codec.clone(), + ) + .unwrap(); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + w.write_batch(&batch, &mut cursor, &ipc_time).unwrap(); + }); + }); + } + } + + for compression_codec in [ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Snappy, + CompressionCodec::Zstd(1), + CompressionCodec::Zstd(6), + ] { + group.bench_function( + format!("shuffle_writer: end to end (compression = {compression_codec:?}"), + |b| { + let ctx = SessionContext::new(); + let exec = create_shuffle_writer_exec(compression_codec.clone()); + b.iter(|| { + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).unwrap(); + let rt = Runtime::new().unwrap(); + rt.block_on(collect(stream)).unwrap(); + }); + }, + ); + } } fn create_shuffle_writer_exec(compression_codec: CompressionCodec) -> ShuffleWriterExec { @@ -104,6 +93,7 @@ fn create_shuffle_writer_exec(compression_codec: CompressionCodec) -> ShuffleWri compression_codec, "/tmp/data.out".to_string(), "/tmp/index.out".to_string(), + true, ) .unwrap() } @@ -121,11 +111,19 @@ fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { let schema = Arc::new(Schema::new(vec![ Field::new("c0", DataType::Int32, true), Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Date32, true), + Field::new("c3", DataType::Decimal128(11, 2), true), ])); let mut a = Int32Builder::new(); let mut b = StringBuilder::new(); + let mut c = Date32Builder::new(); + let mut d = Decimal128Builder::new() + .with_precision_and_scale(11, 2) + .unwrap(); for i in 0..num_rows { a.append_value(i as i32); + c.append_value(i as i32); + d.append_value((i * 1000000) as i128); if allow_nulls && i % 10 == 0 { b.append_null(); } else { @@ -134,7 +132,13 @@ fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { } let a = a.finish(); let b = b.finish(); - RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap() + let c = c.finish(); + let d = d.finish(); + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a), Arc::new(b), Arc::new(c), Arc::new(d)], + ) + .unwrap() } fn config() -> Criterion { diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index aaac7ec8ca..b3c33b7948 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -635,6 +635,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative current_checksum: jlong, compression_codec: jstring, compression_level: jint, + enable_fast_encoding: jboolean, ) -> jlongArray { try_unwrap_or_throw(&e, |mut env| unsafe { let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?; @@ -686,6 +687,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative checksum_algo, current_checksum, &compression_codec, + enable_fast_encoding != JNI_FALSE, )?; let checksum = if let Some(checksum) = checksum { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 294922f2f1..a1af0808d6 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1084,6 +1084,7 @@ impl PhysicalPlanner { codec, writer.output_data_file.clone(), writer.output_index_file.clone(), + writer.enable_fast_encoding, )?); Ok(( diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs new file mode 100644 index 0000000000..3c735434c6 --- /dev/null +++ b/native/core/src/execution/shuffle/codec.rs @@ -0,0 +1,708 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::parquet::data_type::AsBytes; +use arrow_array::cast::AsArray; +use arrow_array::types::Int32Type; +use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch, + RecordBatchOptions, StringArray, TimestampMicrosecondArray, +}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use datafusion_common::DataFusionError; +use std::io::Write; +use std::sync::Arc; + +pub fn fast_codec_supports_type(data_type: &DataType) -> bool { + match data_type { + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Utf8 + | DataType::Binary => true, + DataType::Decimal128(_, s) if *s >= 0 => true, + DataType::Dictionary(k, v) if **k == DataType::Int32 => fast_codec_supports_type(v), + _ => false, + } +} + +enum DataTypeId { + Boolean = 0, + Int8, + Int16, + Int32, + Int64, + Date32, + Timestamp, + TimestampNtz, + Decimal128, + Float32, + Float64, + Utf8, + Binary, + Dictionary, +} + +pub struct BatchWriter { + inner: W, +} + +impl BatchWriter { + pub fn new(inner: W) -> Self { + Self { inner } + } + + /// Encode the schema (just column names because data types can vary per batch) + pub fn write_partial_schema(&mut self, schema: &Schema) -> Result<(), DataFusionError> { + let schema_len = schema.fields().len(); + let mut null_bytes = Vec::with_capacity(schema_len); + self.inner.write_all(&schema_len.to_le_bytes())?; + for field in schema.fields() { + // field name + let field_name = field.name(); + self.inner.write_all(&field_name.len().to_le_bytes())?; + self.inner.write_all(field_name.as_str().as_bytes())?; + // nullable + null_bytes.push(field.is_nullable() as u8); + } + self.inner.write_all(null_bytes.as_bytes())?; + Ok(()) + } + + fn write_data_type(&mut self, data_type: &DataType) -> Result<(), DataFusionError> { + match data_type { + DataType::Boolean => { + self.inner.write_all(&[DataTypeId::Boolean as u8])?; + } + DataType::Int8 => { + self.inner.write_all(&[DataTypeId::Int8 as u8])?; + } + DataType::Int16 => { + self.inner.write_all(&[DataTypeId::Int16 as u8])?; + } + DataType::Int32 => { + self.inner.write_all(&[DataTypeId::Int32 as u8])?; + } + DataType::Int64 => { + self.inner.write_all(&[DataTypeId::Int64 as u8])?; + } + DataType::Float32 => { + self.inner.write_all(&[DataTypeId::Float32 as u8])?; + } + DataType::Float64 => { + self.inner.write_all(&[DataTypeId::Float64 as u8])?; + } + DataType::Date32 => { + self.inner.write_all(&[DataTypeId::Date32 as u8])?; + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + self.inner.write_all(&[DataTypeId::TimestampNtz as u8])?; + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + self.inner.write_all(&[DataTypeId::Timestamp as u8])?; + let tz_bytes = tz.as_bytes(); + self.inner.write_all(&tz_bytes.len().to_le_bytes())?; + self.inner.write_all(tz_bytes)?; + } + DataType::Utf8 => { + self.inner.write_all(&[DataTypeId::Utf8 as u8])?; + } + DataType::Binary => { + self.inner.write_all(&[DataTypeId::Binary as u8])?; + } + DataType::Decimal128(p, s) if *s >= 0 => { + self.inner + .write_all(&[DataTypeId::Decimal128 as u8, *p, *s as u8])?; + } + DataType::Dictionary(k, v) => { + self.inner.write_all(&[DataTypeId::Dictionary as u8])?; + self.write_data_type(k)?; + self.write_data_type(v)?; + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type in fast writer {other}" + ))) + } + } + Ok(()) + } + + pub fn write_all(&mut self, bytes: &[u8]) -> std::io::Result<()> { + self.inner.write_all(bytes) + } + + pub fn write_batch(&mut self, batch: &RecordBatch) -> Result<(), DataFusionError> { + self.write_all(&batch.num_rows().to_le_bytes())?; + for i in 0..batch.num_columns() { + self.write_array(batch.column(i))?; + } + Ok(()) + } + + fn write_array(&mut self, col: &dyn Array) -> Result<(), DataFusionError> { + // data type + self.write_data_type(col.data_type())?; + // array contents + match col.data_type() { + DataType::Boolean => { + let arr = col.as_any().downcast_ref::().unwrap(); + // boolean array is the only type we write the array length because it cannot + // be determined from the data buffer size (length is in bits rather than bytes) + self.write_all(&arr.len().to_le_bytes())?; + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int8 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int16 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int32 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int64 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Float32 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Float64 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Date32 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let arr = col + .as_any() + .downcast_ref::() + .unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Decimal128(_, _) => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Utf8 => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values())?; + // write offset buffer + let offsets = arr.offsets(); + let scalar_buffer = offsets.inner(); + self.write_buffer(scalar_buffer.inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Binary => { + let arr = col.as_any().downcast_ref::().unwrap(); + // write data buffer + self.write_buffer(arr.values())?; + // write offset buffer + let offsets = arr.offsets(); + let scalar_buffer = offsets.inner(); + let buffer = scalar_buffer.inner(); + self.write_buffer(buffer)?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Dictionary(k, _) if **k == DataType::Int32 => { + let arr = col + .as_any() + .downcast_ref::>() + .unwrap(); + self.write_array(arr.keys())?; + self.write_array(arr.values())?; + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type {other}" + ))) + } + } + Ok(()) + } + + fn write_null_buffer( + &mut self, + null_buffer: Option<&NullBuffer>, + ) -> Result<(), DataFusionError> { + if let Some(nulls) = null_buffer { + let buffer = nulls.inner(); + // write null buffer length in bits + self.write_all(&buffer.len().to_le_bytes())?; + // write null buffer + let buffer = buffer.inner(); + self.write_buffer(buffer)?; + } else { + self.inner.write_all(&0_usize.to_le_bytes())?; + } + Ok(()) + } + + fn write_buffer(&mut self, buffer: &Buffer) -> std::io::Result<()> { + // write buffer length + self.inner.write_all(&buffer.len().to_le_bytes())?; + // write buffer data + self.inner.write_all(buffer.as_slice()) + } + + pub fn inner(self) -> W { + self.inner + } +} + +pub struct BatchReader<'a> { + input: &'a [u8], + offset: usize, + /// buffer for reading usize + length: [u8; 8], +} + +impl<'a> BatchReader<'a> { + pub fn new(input: &'a [u8]) -> Self { + Self { + input, + offset: 0, + length: [0; 8], + } + } + + pub fn read_batch(&mut self) -> Result { + let mut length = [0; 8]; + length.copy_from_slice(&self.input[0..8]); + self.offset += 8; + let schema_len = usize::from_le_bytes(length); + + let mut field_names: Vec = Vec::with_capacity(schema_len); + let mut nullable: Vec = Vec::with_capacity(schema_len); + for _ in 0..schema_len { + field_names.push(self.read_string()); + } + for _ in 0..schema_len { + nullable.push(self.read_bool()); + } + + length.copy_from_slice(&self.input[self.offset..self.offset + 8]); + self.offset += 8; + let num_rows = usize::from_le_bytes(length); + + let mut fields: Vec> = Vec::with_capacity(schema_len); + let mut arrays = Vec::with_capacity(schema_len); + for (name, nullable) in field_names.into_iter().zip(&nullable) { + let array = self.read_array()?; + let field = Arc::new(Field::new(name, array.data_type().clone(), *nullable)); + arrays.push(array); + fields.push(field); + } + let schema = Arc::new(Schema::new(fields)); + Ok(RecordBatch::try_new_with_options( + schema, + arrays, + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + ) + .unwrap()) + } + + fn read_array(&mut self) -> Result { + // read data type + let data_type = self.read_data_type()?; + Ok(match data_type { + DataType::Boolean => { + // read array length (number of bits) + let mut length = [0; 8]; + length.copy_from_slice(&self.input[self.offset..self.offset + 8]); + self.offset += 8; + let array_len = usize::from_le_bytes(length); + let buffer = self.read_buffer(); + let data_buffer = BooleanBuffer::new(buffer, 0, array_len); + let null_buffer = self.read_null_buffer(); + Arc::new(BooleanArray::new(data_buffer, null_buffer)) + } + DataType::Int8 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Int8Array::try_new(data_buffer, null_buffer)?) + } + DataType::Int16 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Int16Array::try_new(data_buffer, null_buffer)?) + } + DataType::Int32 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Int32Array::try_new(data_buffer, null_buffer)?) + } + DataType::Int64 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Int64Array::try_new(data_buffer, null_buffer)?) + } + DataType::Float32 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Float32Array::try_new(data_buffer, null_buffer)?) + } + DataType::Float64 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Float64Array::try_new(data_buffer, null_buffer)?) + } + DataType::Date32 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Date32Array::try_new(data_buffer, null_buffer)?) + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(TimestampMicrosecondArray::try_new( + data_buffer, + null_buffer, + )?) + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new( + TimestampMicrosecondArray::try_new(data_buffer, null_buffer)?.with_timezone(tz), + ) + } + DataType::Decimal128(p, s) => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new( + Decimal128Array::try_new(data_buffer, null_buffer)? + .with_precision_and_scale(p, s)?, + ) + } + DataType::Utf8 => { + let buffer = self.read_buffer(); + let offset_buffer = self.read_offset_buffer(); + let null_buffer = self.read_null_buffer(); + Arc::new(StringArray::try_new(offset_buffer, buffer, null_buffer)?) + } + DataType::Binary => { + let buffer = self.read_buffer(); + let offset_buffer = self.read_offset_buffer(); + let null_buffer = self.read_null_buffer(); + Arc::new(BinaryArray::try_new(offset_buffer, buffer, null_buffer)?) + } + DataType::Dictionary(k, _) if *k == DataType::Int32 => { + let k = self.read_array()?; + let v = self.read_array()?; + Arc::new(DictionaryArray::try_new( + k.as_primitive::().to_owned(), + v, + )?) + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type in fast reader {other}" + ))) + } + }) + } + + fn read_data_type(&mut self) -> Result { + let type_id = self.input[self.offset] as i32; + let data_type = match type_id { + x if x == DataTypeId::Boolean as i32 => DataType::Boolean, + x if x == DataTypeId::Int8 as i32 => DataType::Int8, + x if x == DataTypeId::Int16 as i32 => DataType::Int16, + x if x == DataTypeId::Int32 as i32 => DataType::Int32, + x if x == DataTypeId::Int64 as i32 => DataType::Int64, + x if x == DataTypeId::Float32 as i32 => DataType::Float32, + x if x == DataTypeId::Float64 as i32 => DataType::Float64, + x if x == DataTypeId::Date32 as i32 => DataType::Date32, + x if x == DataTypeId::TimestampNtz as i32 => { + DataType::Timestamp(TimeUnit::Microsecond, None) + } + x if x == DataTypeId::Timestamp as i32 => { + self.offset += 1; + let tz = self.read_string(); + let tz: Arc = Arc::from(tz.into_boxed_str()); + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) + } + x if x == DataTypeId::Utf8 as i32 => DataType::Utf8, + x if x == DataTypeId::Binary as i32 => DataType::Binary, + x if x == DataTypeId::Dictionary as i32 => { + self.offset += 1; + DataType::Dictionary( + Box::new(self.read_data_type()?), + Box::new(self.read_data_type()?), + ) + } + x if x == DataTypeId::Decimal128 as i32 => DataType::Decimal128( + self.input[self.offset + 1], + self.input[self.offset + 2] as i8, + ), + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type {other}" + ))) + } + }; + match data_type { + DataType::Dictionary(_, _) | DataType::Timestamp(_, Some(_)) => { + // no need to increment + } + DataType::Decimal128(_, _) => self.offset += 3, + _ => self.offset += 1, + } + Ok(data_type) + } + + fn read_bool(&mut self) -> bool { + let value = self.input[self.offset] != 0; + self.offset += 1; + value + } + + fn read_string(&mut self) -> String { + // read field name length + self.length + .copy_from_slice(&self.input[self.offset..self.offset + 8]); + let field_name_len = usize::from_le_bytes(self.length); + self.offset += 8; + + // read field name + let field_name_bytes = &self.input[self.offset..self.offset + field_name_len]; + let str = unsafe { String::from_utf8_unchecked(field_name_bytes.into()) }; + self.offset += field_name_len; + str + } + + fn read_offset_buffer(&mut self) -> OffsetBuffer { + let offset_buffer = self.read_buffer(); + let offset_buffer: ScalarBuffer = ScalarBuffer::from(offset_buffer); + OffsetBuffer::new(offset_buffer) + } + + fn read_buffer(&mut self) -> Buffer { + // read data buffer length + let mut length = [0; 8]; + length.copy_from_slice(&self.input[self.offset..self.offset + 8]); + let buffer_len = usize::from_le_bytes(length); + self.offset += 8; + + // read data buffer + let buffer = Buffer::from(&self.input[self.offset..self.offset + buffer_len]); + self.offset += buffer_len; + buffer + } + + fn read_null_buffer(&mut self) -> Option { + // read null buffer length in bits + let mut length = [0; 8]; + length.copy_from_slice(&self.input[self.offset..self.offset + 8]); + let length_bits = usize::from_le_bytes(length); + self.offset += 8; + if length_bits == 0 { + return None; + } + + // read buffer length in bytes + length.copy_from_slice(&self.input[self.offset..self.offset + 8]); + let null_buffer_length = usize::from_le_bytes(length); + self.offset += 8; + + let null_buffer = if null_buffer_length != 0 { + let null_buffer = &self.input[self.offset..self.offset + null_buffer_length]; + Some(NullBuffer::new(BooleanBuffer::new( + Buffer::from(null_buffer), + 0, + length_bits, + ))) + } else { + None + }; + self.offset += null_buffer_length; + null_buffer + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::builder::*; + use std::sync::Arc; + + #[test] + fn roundtrip() { + let batch = create_batch(8192, true); + let buffer = Vec::new(); + let mut writer = BatchWriter::new(buffer); + writer.write_partial_schema(&batch.schema()).unwrap(); + writer.write_batch(&batch).unwrap(); + let buffer = writer.inner(); + + let mut reader = BatchReader::new(&buffer); + let batch2 = reader.read_batch().unwrap(); + assert_eq!(batch, batch2); + } + + fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("bool", DataType::Boolean, true), + Field::new("int8", DataType::Int8, true), + Field::new("int16", DataType::Int16, true), + Field::new("int32", DataType::Int32, true), + Field::new("int64", DataType::Int64, true), + Field::new("float32", DataType::Float32, true), + Field::new("float64", DataType::Float64, true), + Field::new("binary", DataType::Binary, true), + Field::new("utf8", DataType::Utf8, true), + Field::new( + "utf8_dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("date32", DataType::Date32, true), + Field::new("decimal128", DataType::Decimal128(11, 2), true), + Field::new( + "timestamp_ntz", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + ])); + let mut col_bool = BooleanBuilder::with_capacity(num_rows); + let mut col_i8 = Int8Builder::new(); + let mut col_i16 = Int16Builder::new(); + let mut col_i32 = Int32Builder::new(); + let mut col_i64 = Int64Builder::new(); + let mut col_f32 = Float32Builder::new(); + let mut col_f64 = Float64Builder::new(); + let mut col_binary = BinaryBuilder::new(); + let mut col_utf8 = StringBuilder::new(); + let mut col_utf8_dict: StringDictionaryBuilder = StringDictionaryBuilder::new(); + let mut col_date32 = Date32Builder::new(); + let mut col_decimal128 = Decimal128Builder::new() + .with_precision_and_scale(11, 2) + .unwrap(); + let mut col_timestamp_ntz = TimestampMicrosecondBuilder::with_capacity(num_rows); + let mut col_timestamp = + TimestampMicrosecondBuilder::with_capacity(num_rows).with_timezone("UTC"); + for i in 0..num_rows { + col_i8.append_value(i as i8); + col_i16.append_value(i as i16); + col_i32.append_value(i as i32); + col_i64.append_value(i as i64); + col_f32.append_value(i as f32 * 1.23_f32); + col_f64.append_value(i as f64 * 1.23_f64); + col_date32.append_value(i as i32); + col_decimal128.append_value((i * 1000000) as i128); + col_binary.append_value(format!("{i}").as_bytes()); + if allow_nulls && i % 10 == 0 { + col_utf8.append_null(); + col_utf8_dict.append_null(); + col_bool.append_null(); + col_timestamp_ntz.append_null(); + col_timestamp.append_null(); + } else { + // test for dictionary-encoded strings + col_utf8.append_value(format!("this is string {i}")); + col_utf8_dict.append_value("this string is repeated a lot"); + col_bool.append_value(i % 2 == 0); + col_timestamp_ntz.append_value((i * 100000000) as i64); + col_timestamp.append_value((i * 100000000) as i64); + } + } + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(col_bool.finish()), + Arc::new(col_i8.finish()), + Arc::new(col_i16.finish()), + Arc::new(col_i32.finish()), + Arc::new(col_i64.finish()), + Arc::new(col_f32.finish()), + Arc::new(col_f64.finish()), + Arc::new(col_binary.finish()), + Arc::new(col_utf8.finish()), + Arc::new(col_utf8_dict.finish()), + Arc::new(col_date32.finish()), + Arc::new(col_decimal128.finish()), + Arc::new(col_timestamp_ntz.finish()), + Arc::new(col_timestamp.finish()), + ], + ) + .unwrap() + } +} diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index 178aff1fad..716034a610 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. +mod codec; mod list; mod map; pub mod row; mod shuffle_writer; +pub use codec::BatchWriter; + pub use shuffle_writer::{ - read_ipc_compressed, write_ipc_compressed, CompressionCodec, ShuffleWriterExec, + read_ipc_compressed, CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec, }; diff --git a/native/core/src/execution/shuffle/row.rs b/native/core/src/execution/shuffle/row.rs index 9037bd7943..54a9bb31fb 100644 --- a/native/core/src/execution/shuffle/row.rs +++ b/native/core/src/execution/shuffle/row.rs @@ -23,7 +23,7 @@ use crate::{ shuffle::{ list::{append_list_element, SparkUnsafeArray}, map::{append_map_elements, get_map_key_value_dt, SparkUnsafeMap}, - shuffle_writer::{write_ipc_compressed, Checksum}, + shuffle_writer::{Checksum, ShuffleBlockWriter}, }, utils::bytes_to_i128, }, @@ -3298,6 +3298,7 @@ pub fn process_sorted_row_partition( // inside the loop within the method across batches. initial_checksum: Option, codec: &CompressionCodec, + enable_fast_encoding: bool, ) -> Result<(i64, Option), CometError> { // TODO: We can tune this parameter automatically based on row size and cache size. let row_step = 10; @@ -3360,7 +3361,12 @@ pub fn process_sorted_row_partition( // we do not collect metrics in Native_writeSortedFileNative let ipc_time = Time::default(); - written += write_ipc_compressed(&batch, &mut cursor, codec, &ipc_time)?; + let block_writer = ShuffleBlockWriter::try_new( + batch.schema().as_ref(), + enable_fast_encoding, + codec.clone(), + )?; + written += block_writer.write_batch(&batch, &mut cursor, &ipc_time)?; if let Some(checksum) = &mut current_checksum { checksum.update(&mut cursor)?; diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index e183276d18..70e832a739 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -17,6 +17,8 @@ //! Defines the External shuffle repartition plan. +use crate::execution::shuffle::codec::{fast_codec_supports_type, BatchReader}; +use crate::execution::shuffle::BatchWriter; use crate::{ common::bit::ceil, errors::{CometError, CometResult}, @@ -91,15 +93,23 @@ pub struct ShuffleWriterExec { output_index_file: String, /// Metrics metrics: ExecutionPlanMetricsSet, + /// Cache for expensive-to-compute plan properties cache: PlanProperties, + /// The compression codec to use when compressing shuffle blocks codec: CompressionCodec, + /// When true, Comet will use a fast proprietary encoding rather than using Arrow IPC + enable_fast_encoding: bool, } impl DisplayAs for ShuffleWriterExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "ShuffleWriterExec: partitioning={:?}", self.partitioning) + write!( + f, + "ShuffleWriterExec: partitioning={:?}, fast_encoding={}, compression={:?}", + self.partitioning, self.enable_fast_encoding, self.codec + ) } } } @@ -132,6 +142,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.codec.clone(), self.output_data_file.clone(), self.output_index_file.clone(), + self.enable_fast_encoding, )?)), _ => panic!("ShuffleWriterExec wrong number of children"), } @@ -157,6 +168,7 @@ impl ExecutionPlan for ShuffleWriterExec { metrics, context, self.codec.clone(), + self.enable_fast_encoding, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -189,6 +201,7 @@ impl ShuffleWriterExec { codec: CompressionCodec, output_data_file: String, output_index_file: String, + enable_fast_encoding: bool, ) -> Result { let cache = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&input.schema())), @@ -205,6 +218,7 @@ impl ShuffleWriterExec { output_index_file, cache, codec, + enable_fast_encoding, }) } } @@ -225,22 +239,25 @@ struct PartitionBuffer { batch_size: usize, /// Memory reservation for this partition buffer. reservation: MemoryReservation, - codec: CompressionCodec, + /// Writer that performs encoding and compression + shuffle_block_writer: ShuffleBlockWriter, } impl PartitionBuffer { - fn new( + fn try_new( schema: SchemaRef, batch_size: usize, partition_id: usize, runtime: &Arc, codec: CompressionCodec, - ) -> Self { + enable_fast_encoding: bool, + ) -> Result { let reservation = MemoryConsumer::new(format!("PartitionBuffer[{}]", partition_id)) .with_can_spill(true) .register(&runtime.memory_pool); - - Self { + let shuffle_block_writer = + ShuffleBlockWriter::try_new(schema.as_ref(), enable_fast_encoding, codec)?; + Ok(Self { schema, frozen: vec![], active: vec![], @@ -248,8 +265,8 @@ impl PartitionBuffer { num_active_rows: 0, batch_size, reservation, - codec, - } + shuffle_block_writer, + }) } /// Initializes active builders if necessary. @@ -353,12 +370,8 @@ impl PartitionBuffer { let frozen_capacity_old = self.frozen.capacity(); let mut cursor = Cursor::new(&mut self.frozen); cursor.seek(SeekFrom::End(0))?; - write_ipc_compressed( - &frozen_batch, - &mut cursor, - &self.codec, - &metrics.encode_time, - )?; + self.shuffle_block_writer + .write_batch(&frozen_batch, &mut cursor, &metrics.encode_time)?; mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize; Ok(mem_diff) @@ -699,7 +712,7 @@ impl ShuffleRepartitionerMetrics { impl ShuffleRepartitioner { #[allow(clippy::too_many_arguments)] - pub fn new( + pub fn try_new( partition_id: usize, output_data_file: String, output_index_file: String, @@ -709,7 +722,8 @@ impl ShuffleRepartitioner { runtime: Arc, batch_size: usize, codec: CompressionCodec, - ) -> Self { + enable_fast_encoding: bool, + ) -> Result { let num_output_partitions = partitioning.partition_count(); let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{}]", partition_id)) .with_can_spill(true) @@ -725,21 +739,22 @@ impl ShuffleRepartitioner { partition_ids.set_len(batch_size); } - Self { + Ok(Self { output_data_file, output_index_file, schema: Arc::clone(&schema), buffered_partitions: (0..num_output_partitions) .map(|partition_id| { - PartitionBuffer::new( + PartitionBuffer::try_new( Arc::clone(&schema), batch_size, partition_id, &runtime, codec.clone(), + enable_fast_encoding, ) }) - .collect::>(), + .collect::>>()?, spills: Mutex::new(vec![]), partitioning, num_output_partitions, @@ -749,7 +764,7 @@ impl ShuffleRepartitioner { hashes_buf, partition_ids, batch_size, - } + }) } /// Shuffles rows in input batch into corresponding partition buffer. @@ -1172,9 +1187,10 @@ async fn external_shuffle( metrics: ShuffleRepartitionerMetrics, context: Arc, codec: CompressionCodec, + enable_fast_encoding: bool, ) -> Result { let schema = input.schema(); - let mut repartitioner = ShuffleRepartitioner::new( + let mut repartitioner = ShuffleRepartitioner::try_new( partition_id, output_data_file, output_index_file, @@ -1184,7 +1200,8 @@ async fn external_shuffle( context.runtime_env(), context.session_config().batch_size(), codec, - ); + enable_fast_encoding, + )?; while let Some(batch) = input.next().await { // Block on the repartitioner to insert the batch and shuffle the rows @@ -1570,110 +1587,244 @@ pub enum CompressionCodec { Snappy, } -/// Writes given record batch as Arrow IPC bytes into given writer. -/// Returns number of bytes written. -pub fn write_ipc_compressed( - batch: &RecordBatch, - output: &mut W, - codec: &CompressionCodec, - ipc_time: &Time, -) -> Result { - if batch.num_rows() == 0 { - return Ok(0); - } +pub struct ShuffleBlockWriter { + fast_encoding: bool, + codec: CompressionCodec, + encoded_schema: Vec, + header_bytes: Vec, +} - let mut timer = ipc_time.timer(); - let start_pos = output.stream_position()?; +impl ShuffleBlockWriter { + pub fn try_new( + schema: &Schema, + enable_fast_encoding: bool, + codec: CompressionCodec, + ) -> Result { + let mut encoded_schema = vec![]; + + let enable_fast_encoding = enable_fast_encoding + && schema + .fields() + .iter() + .all(|f| fast_codec_supports_type(f.data_type())); + + // encode the schema once and then reuse the encoded bytes for each batch + if enable_fast_encoding { + let mut w = BatchWriter::new(&mut encoded_schema); + w.write_partial_schema(schema)?; + } - // seek past ipc_length placeholder - output.seek_relative(8)?; + let header_bytes = Vec::with_capacity(24); + let mut cursor = Cursor::new(header_bytes); - // write number of columns because JVM side needs to know how many addresses to allocate - let field_count = batch.schema().fields().len(); - output.write_all(&field_count.to_le_bytes())?; + // leave space for compressed message length + cursor.seek_relative(8)?; - let output = match codec { - CompressionCodec::None => { - output.write_all(b"NONE")?; - let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - arrow_writer.into_inner()? - } - CompressionCodec::Snappy => { - output.write_all(b"SNAP")?; - let mut wtr = snap::write::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.into_inner() - .map_err(|e| DataFusionError::Execution(format!("lz4 compression error: {}", e)))? - } - CompressionCodec::Lz4Frame => { - output.write_all(b"LZ4_")?; - let mut wtr = lz4_flex::frame::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.finish() - .map_err(|e| DataFusionError::Execution(format!("lz4 compression error: {}", e)))? - } - CompressionCodec::Zstd(level) => { - output.write_all(b"ZSTD")?; - let encoder = zstd::Encoder::new(output, *level)?; - let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - let zstd_encoder = arrow_writer.into_inner()?; - zstd_encoder.finish()? + // write number of columns because JVM side needs to know how many addresses to allocate + let field_count = schema.fields().len(); + cursor.write_all(&field_count.to_le_bytes())?; + + // write compression codec to header + let codec_header = match &codec { + CompressionCodec::Snappy => b"SNAP", + CompressionCodec::Lz4Frame => b"LZ4_", + CompressionCodec::Zstd(_) => b"ZSTD", + CompressionCodec::None => b"NONE", + }; + cursor.write_all(codec_header)?; + + // write encoding scheme + if enable_fast_encoding { + cursor.write_all(b"FAST")?; + } else { + cursor.write_all(b"AIPC")?; } - }; - // fill ipc length - let end_pos = output.stream_position()?; - let ipc_length = end_pos - start_pos - 8; - let max_size = i32::MAX as u64; - if ipc_length > max_size { - return Err(DataFusionError::Execution(format!( - "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ - Try reducing batch size or increasing compression level" - ))); + let header_bytes = cursor.into_inner(); + + Ok(Self { + fast_encoding: enable_fast_encoding, + codec, + encoded_schema, + header_bytes, + }) } - // fill ipc length - output.seek(SeekFrom::Start(start_pos))?; - output.write_all(&ipc_length.to_le_bytes()[..])?; - output.seek(SeekFrom::Start(end_pos))?; + /// Writes given record batch as Arrow IPC bytes into given writer. + /// Returns number of bytes written. + pub fn write_batch( + &self, + batch: &RecordBatch, + output: &mut W, + ipc_time: &Time, + ) -> Result { + if batch.num_rows() == 0 { + return Ok(0); + } + + let mut timer = ipc_time.timer(); + let start_pos = output.stream_position()?; - timer.stop(); + // write header + output.write_all(&self.header_bytes)?; - Ok((end_pos - start_pos) as usize) + let output = if self.fast_encoding { + match &self.codec { + CompressionCodec::None => { + let mut fast_writer = BatchWriter::new(&mut *output); + fast_writer.write_all(&self.encoded_schema)?; + fast_writer.write_batch(batch)?; + output + } + CompressionCodec::Lz4Frame => { + let mut wtr = lz4_flex::frame::FrameEncoder::new(output); + let mut fast_writer = BatchWriter::new(&mut wtr); + fast_writer.write_all(&self.encoded_schema)?; + fast_writer.write_batch(batch)?; + wtr.finish().map_err(|e| { + DataFusionError::Execution(format!("lz4 compression error: {}", e)) + })? + } + CompressionCodec::Zstd(level) => { + let mut encoder = zstd::Encoder::new(output, *level)?; + let mut fast_writer = BatchWriter::new(&mut encoder); + fast_writer.write_all(&self.encoded_schema)?; + fast_writer.write_batch(batch)?; + encoder.finish()? + } + CompressionCodec::Snappy => { + let mut encoder = snap::write::FrameEncoder::new(output); + let mut fast_writer = BatchWriter::new(&mut encoder); + fast_writer.write_all(&self.encoded_schema)?; + fast_writer.write_batch(batch)?; + encoder.into_inner().map_err(|e| { + DataFusionError::Execution(format!("snappy compression error: {}", e)) + })? + } + } + } else { + match &self.codec { + CompressionCodec::None => { + let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + arrow_writer.into_inner()? + } + CompressionCodec::Lz4Frame => { + let mut wtr = lz4_flex::frame::FrameEncoder::new(output); + let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + wtr.finish().map_err(|e| { + DataFusionError::Execution(format!("lz4 compression error: {}", e)) + })? + } + + CompressionCodec::Zstd(level) => { + let encoder = zstd::Encoder::new(output, *level)?; + let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + let zstd_encoder = arrow_writer.into_inner()?; + zstd_encoder.finish()? + } + + CompressionCodec::Snappy => { + let mut wtr = snap::write::FrameEncoder::new(output); + let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + wtr.into_inner().map_err(|e| { + DataFusionError::Execution(format!("snappy compression error: {}", e)) + })? + } + } + }; + + // fill ipc length + let end_pos = output.stream_position()?; + let ipc_length = end_pos - start_pos - 8; + let max_size = i32::MAX as u64; + if ipc_length > max_size { + return Err(DataFusionError::Execution(format!( + "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ + Try reducing batch size or increasing compression level" + ))); + } + + // fill ipc length + output.seek(SeekFrom::Start(start_pos))?; + output.write_all(&ipc_length.to_le_bytes())?; + output.seek(SeekFrom::Start(end_pos))?; + + timer.stop(); + + Ok((end_pos - start_pos) as usize) + } } pub fn read_ipc_compressed(bytes: &[u8]) -> Result { + let fast_encoding = match &bytes[4..8] { + b"AIPC" => false, + b"FAST" => true, + other => { + return Err(DataFusionError::Internal(format!( + "invalid encoding schema: {other:?}" + ))) + } + }; match &bytes[0..4] { b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let mut reader = StreamReader::try_new(decoder, None)?; - reader.next().unwrap().map_err(|e| e.into()) + let mut decoder = snap::read::FrameDecoder::new(&bytes[8..]); + if fast_encoding { + // TODO avoid reading bytes into interim buffer + let mut buffer = vec![]; + decoder.read_to_end(&mut buffer)?; + let mut reader = BatchReader::new(&buffer); + reader.read_batch() + } else { + let mut reader = StreamReader::try_new(decoder, None)?; + reader.next().unwrap().map_err(|e| e.into()) + } } b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let mut reader = StreamReader::try_new(decoder, None)?; - reader.next().unwrap().map_err(|e| e.into()) + let mut decoder = lz4_flex::frame::FrameDecoder::new(&bytes[8..]); + if fast_encoding { + // TODO avoid reading bytes into interim buffer + let mut buffer = vec![]; + decoder.read_to_end(&mut buffer)?; + let mut reader = BatchReader::new(&buffer); + reader.read_batch() + } else { + let mut reader = StreamReader::try_new(decoder, None)?; + reader.next().unwrap().map_err(|e| e.into()) + } } b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let mut reader = StreamReader::try_new(decoder, None)?; - reader.next().unwrap().map_err(|e| e.into()) + let mut decoder = zstd::Decoder::new(&bytes[8..])?; + if fast_encoding { + // TODO avoid reading bytes into interim buffer + let mut buffer = vec![]; + decoder.read_to_end(&mut buffer)?; + let mut reader = BatchReader::new(&buffer); + reader.read_batch() + } else { + let mut reader = StreamReader::try_new(decoder, None)?; + reader.next().unwrap().map_err(|e| e.into()) + } } b"NONE" => { - let mut reader = StreamReader::try_new(&bytes[4..], None)?; - reader.next().unwrap().map_err(|e| e.into()) + if fast_encoding { + let mut reader = BatchReader::new(&bytes[8..]); + reader.read_batch() + } else { + let mut reader = StreamReader::try_new(&bytes[8..], None)?; + reader.next().unwrap().map_err(|e| e.into()) + } } - _ => Err(DataFusionError::Execution( - "Failed to decode batch: invalid compression codec".to_string(), - )), + other => Err(DataFusionError::Execution(format!( + "Failed to decode batch: invalid compression codec: {other:?}" + ))), } } @@ -1728,21 +1879,30 @@ mod test { #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn roundtrip_ipc() { let batch = create_batch(8192); - for codec in &[ - CompressionCodec::None, - CompressionCodec::Zstd(1), - CompressionCodec::Snappy, - CompressionCodec::Lz4Frame, - ] { - let mut output = vec![]; - let mut cursor = Cursor::new(&mut output); - let length = - write_ipc_compressed(&batch, &mut cursor, codec, &Time::default()).unwrap(); - assert_eq!(length, output.len()); - - let ipc_without_length_prefix = &output[16..]; - let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); - assert_eq!(batch, batch2); + for fast_encoding in [true, false] { + for codec in &[ + CompressionCodec::None, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + CompressionCodec::Lz4Frame, + ] { + let mut output = vec![]; + let mut cursor = Cursor::new(&mut output); + let writer = ShuffleBlockWriter::try_new( + batch.schema().as_ref(), + fast_encoding, + codec.clone(), + ) + .unwrap(); + let length = writer + .write_batch(&batch, &mut cursor, &Time::default()) + .unwrap(); + assert_eq!(length, output.len()); + + let ipc_without_length_prefix = &output[16..]; + let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + assert_eq!(batch, batch2); + } } } @@ -1819,6 +1979,7 @@ mod test { CompressionCodec::Zstd(1), "/tmp/data.out".to_string(), "/tmp/index.out".to_string(), + true, ) .unwrap(); diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index cab511faff..4180751477 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -23,6 +23,7 @@ // The clippy throws an error if the reference clone not wrapped into `Arc::clone` // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] +extern crate core; use jni::{ objects::{JClass, JString}, diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index a3480086c7..c6f310d65c 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -95,6 +95,7 @@ message ShuffleWriter { string output_index_file = 4; CompressionCodec codec = 5; int32 compression_level = 6; + bool enable_fast_encoding = 7; } enum AggregateMode { diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index a4f09b4158..f3e814ab04 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -182,6 +182,7 @@ protected long doSpilling( long start = System.nanoTime(); int batchSize = (int) CometConf.COMET_COLUMNAR_SHUFFLE_BATCH_SIZE().get(); + boolean enableFastEncoding = (boolean) CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING().get(); long[] results = nativeLib.writeSortedFileNative( addresses, @@ -194,7 +195,8 @@ protected long doSpilling( checksumAlgo, currentChecksum, compressionCodec, - compressionLevel); + compressionLevel, + enableFastEncoding); long written = results[0]; checksum = results[1]; diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index dbcab15b4f..7a8c061a25 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -139,7 +139,8 @@ class Native extends NativeBase { checksumAlgo: Int, currentChecksum: Long, compressionCodec: String, - compressionLevel: Int): Array[Long] + compressionLevel: Int, + enableFastEncoding: Boolean): Array[Long] // scalastyle:on /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 041411b3f0..7c27fb1969 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -555,6 +555,8 @@ class CometShuffleWriteProcessor( val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() shuffleWriterBuilder.setOutputDataFile(dataFile) shuffleWriterBuilder.setOutputIndexFile(indexFile) + shuffleWriterBuilder.setEnableFastEncoding( + CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get()) if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index 2839c9bd8c..b8c1669494 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -144,6 +144,8 @@ case class NativeBatchDecoderIterator( } var dataBuf = threadLocalDataBuf.get() if (dataBuf.capacity() < bytesToRead) { + // it is unlikely that we would overflow here since it would + // require a 1GB compressed shuffle block but we check anyway val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt dataBuf = ByteBuffer.allocateDirect(newCapacity) threadLocalDataBuf.set(dataBuf) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 8170230bc6..615ca591a3 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -108,7 +108,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometShuffles = collect(df2.queryExecution.executedPlan) { case _: CometShuffleExchangeExec => true } - if (shuffleMode == "jvm") { + if (shuffleMode == "jvm" || shuffleMode == "auto") { assert(cometShuffles.length == 1) } else { // we fall back to Spark for shuffle because we do not support