Skip to content

Commit e3709ea

Browse files
fix: Optimize read_side_padding (apache#772)
## Which issue does this PR close? ## Rationale for this change This PR improves read_side_padding that is used for CHAR() schema ## What changes are included in this PR? Optimized spark_read_side_padding ## How are these changes tested? Added tests
1 parent 607ee7d commit e3709ea

File tree

2 files changed

+32
-31
lines changed

2 files changed

+32
-31
lines changed

Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ chrono-tz = { workspace = true }
4141
num = { workspace = true }
4242
regex = { workspace = true }
4343
thiserror = { workspace = true }
44-
unicode-segmentation = "1.11.0"
4544

4645
[dev-dependencies]
4746
arrow-data = {workspace = true}

src/scalar_funcs.rs

+32-30
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::{cmp::min, sync::Arc};
19-
2018
use arrow::{
2119
array::{
22-
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
23-
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
20+
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
21+
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
2422
},
2523
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
2624
};
25+
use arrow_array::builder::GenericStringBuilder;
2726
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
2827
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
2928
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
@@ -35,7 +34,8 @@ use num::{
3534
integer::{div_ceil, div_floor},
3635
BigInt, Signed, ToPrimitive,
3736
};
38-
use unicode_segmentation::UnicodeSegmentation;
37+
use std::fmt::Write;
38+
use std::{cmp::min, sync::Arc};
3939

4040
mod unhex;
4141
pub use unhex::spark_unhex;
@@ -387,52 +387,54 @@ pub fn spark_round(
387387
}
388388

389389
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
390-
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
390+
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
391391
match args {
392392
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
393-
match args[0].data_type() {
394-
DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
395-
DataType::LargeUtf8 => spark_rpad_internal::<i64>(array, *length),
393+
match array.data_type() {
394+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
395+
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
396396
// TODO: handle Dictionary types
397397
other => Err(DataFusionError::Internal(format!(
398-
"Unsupported data type {other:?} for function rpad",
398+
"Unsupported data type {other:?} for function read_side_padding",
399399
))),
400400
}
401401
}
402402
other => Err(DataFusionError::Internal(format!(
403-
"Unsupported arguments {other:?} for function rpad",
403+
"Unsupported arguments {other:?} for function read_side_padding",
404404
))),
405405
}
406406
}
407407

408-
fn spark_rpad_internal<T: OffsetSizeTrait>(
408+
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
409409
array: &ArrayRef,
410410
length: i32,
411411
) -> Result<ColumnarValue, DataFusionError> {
412412
let string_array = as_generic_string_array::<T>(array)?;
413+
let length = 0.max(length) as usize;
414+
let space_string = " ".repeat(length);
415+
416+
let mut builder =
417+
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
413418

414-
let result = string_array
415-
.iter()
416-
.map(|string| match string {
419+
for string in string_array.iter() {
420+
match string {
417421
Some(string) => {
418-
let length = if length < 0 { 0 } else { length as usize };
419-
if length == 0 {
420-
Ok(Some("".to_string()))
422+
// It looks Spark's UTF8String is closer to chars rather than graphemes
423+
// https://stackoverflow.com/a/46290728
424+
let char_len = string.chars().count();
425+
if length <= char_len {
426+
builder.append_value(string);
421427
} else {
422-
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
423-
if length < graphemes.len() {
424-
Ok(Some(string.to_string()))
425-
} else {
426-
let mut s = string.to_string();
427-
s.push_str(" ".repeat(length - graphemes.len()).as_str());
428-
Ok(Some(s))
429-
}
428+
// write_str updates only the value buffer, not null nor offset buffer
429+
// This is convenient for concatenating str(s)
430+
builder.write_str(string)?;
431+
builder.append_value(&space_string[char_len..]);
430432
}
431433
}
432-
_ => Ok(None),
433-
})
434-
.collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
435-
Ok(ColumnarValue::Array(Arc::new(result)))
434+
_ => builder.append_null(),
435+
}
436+
}
437+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
436438
}
437439

438440
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).

0 commit comments

Comments
 (0)