Skip to content

Commit

Permalink
Add RowParser (#3174)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored Nov 24, 2022
1 parent cea5146 commit 1d22fe3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 21 deletions.
97 changes: 82 additions & 15 deletions arrow/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,12 @@ impl RowConverter {
})
.collect::<Result<Vec<_>>>()?;

let mut rows = new_empty_rows(columns, &dictionaries, Arc::clone(&self.fields));
let config = RowConfig {
fields: Arc::clone(&self.fields),
// Don't need to validate UTF-8 as came from arrow array
validate_utf8: false,
};
let mut rows = new_empty_rows(columns, &dictionaries, config);

for ((column, field), dictionary) in
columns.iter().zip(self.fields.iter()).zip(dictionaries)
Expand Down Expand Up @@ -465,14 +470,15 @@ impl RowConverter {
where
I: IntoIterator<Item = Row<'a>>,
{
let mut validate_utf8 = false;
let mut rows: Vec<_> = rows
.into_iter()
.map(|row| {
assert!(
Arc::ptr_eq(row.fields, &self.fields),
Arc::ptr_eq(&row.config.fields, &self.fields),
"rows were not produced by this RowConverter"
);

validate_utf8 |= row.config.validate_utf8;
row.data
})
.collect();
Expand All @@ -484,11 +490,18 @@ impl RowConverter {
// SAFETY
// We have validated that the rows came from this [`RowConverter`]
// and therefore must be valid
unsafe { decode_column(field, &mut rows, interner.as_deref()) }
unsafe {
decode_column(field, &mut rows, interner.as_deref(), validate_utf8)
}
})
.collect()
}

/// Returns a [`RowParser`] that can be used to parse [`Row`] from bytes
pub fn parser(&self) -> RowParser {
RowParser::new(Arc::clone(&self.fields))
}

/// Returns the size of this instance in bytes
///
/// Includes the size of `Self`.
Expand All @@ -505,6 +518,43 @@ impl RowConverter {
}
}

/// A [`RowParser`] can be created from a [`RowConverter`] and used to parse bytes to [`Row`]
#[derive(Debug)]
pub struct RowParser {
config: RowConfig,
}

impl RowParser {
fn new(fields: Arc<[SortField]>) -> Self {
Self {
config: RowConfig {
fields,
validate_utf8: true,
},
}
}

/// Creates a [`Row`] from the provided `bytes`.
///
/// `bytes` must be a [`Row`] produced by the [`RowConverter`] associated with
/// this [`RowParser`], otherwise subsequent operations with the produced [`Row`] may panic
pub fn parse<'a>(&'a self, bytes: &'a [u8]) -> Row<'a> {
Row {
data: bytes,
config: &self.config,
}
}
}

/// The config of a given set of [`Row`]
#[derive(Debug, Clone)]
struct RowConfig {
/// The schema for these rows
fields: Arc<[SortField]>,
/// Whether to run UTF-8 validation when converting to arrow arrays
validate_utf8: bool,
}

/// A row-oriented representation of arrow data, that is normalized for comparison.
///
/// See the [module level documentation](self) and [`RowConverter`] for more details.
Expand All @@ -514,8 +564,8 @@ pub struct Rows {
buffer: Box<[u8]>,
/// Row `i` has data `&buffer[offsets[i]..offsets[i+1]]`
offsets: Box<[usize]>,
/// The schema for these rows
fields: Arc<[SortField]>,
/// The config for these rows
config: RowConfig,
}

impl Rows {
Expand All @@ -524,7 +574,7 @@ impl Rows {
let start = self.offsets[row];
Row {
data: &self.buffer[start..end],
fields: &self.fields,
config: &self.config,
}
}

Expand Down Expand Up @@ -614,15 +664,15 @@ impl<'a> DoubleEndedIterator for RowsIter<'a> {
#[derive(Debug, Copy, Clone)]
pub struct Row<'a> {
data: &'a [u8],
fields: &'a Arc<[SortField]>,
config: &'a RowConfig,
}

impl<'a> Row<'a> {
/// Create owned version of the row to detach it from the shared [`Rows`].
pub fn owned(&self) -> OwnedRow {
OwnedRow {
data: self.data.to_vec(),
fields: Arc::clone(self.fields),
config: self.config.clone(),
}
}
}
Expand Down Expand Up @@ -672,7 +722,7 @@ impl<'a> AsRef<[u8]> for Row<'a> {
#[derive(Debug, Clone)]
pub struct OwnedRow {
data: Vec<u8>,
fields: Arc<[SortField]>,
config: RowConfig,
}

impl OwnedRow {
Expand All @@ -682,7 +732,7 @@ impl OwnedRow {
pub fn row(&self) -> Row<'_> {
Row {
data: &self.data,
fields: &self.fields,
config: &self.config,
}
}
}
Expand Down Expand Up @@ -739,7 +789,7 @@ fn null_sentinel(options: SortOptions) -> u8 {
fn new_empty_rows(
cols: &[ArrayRef],
dictionaries: &[Option<Vec<Option<&[u8]>>>],
fields: Arc<[SortField]>,
config: RowConfig,
) -> Rows {
use fixed::FixedLengthEncoding;

Expand Down Expand Up @@ -816,7 +866,7 @@ fn new_empty_rows(
Rows {
buffer: buffer.into(),
offsets: offsets.into(),
fields,
config,
}
}

Expand Down Expand Up @@ -872,6 +922,7 @@ unsafe fn decode_column(
field: &SortField,
rows: &mut [&[u8]],
interner: Option<&OrderPreservingInterner>,
validate_utf8: bool,
) -> Result<ArrayRef> {
let options = field.options;
let data_type = field.data_type.clone();
Expand All @@ -881,8 +932,8 @@ unsafe fn decode_column(
DataType::Boolean => Arc::new(decode_bool(rows, options)),
DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)),
DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options)),
DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options)),
DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
DataType::Dictionary(k, v) => match k.as_ref() {
DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
interner.unwrap(),
Expand Down Expand Up @@ -1373,6 +1424,22 @@ mod tests {
assert!(rows.row(3) < rows.row(0));
}

#[test]
#[should_panic(expected = "Invalid UTF8 sequence at string")]
fn test_invalid_utf8() {
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();
let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _;
let rows = converter.convert_columns(&[array]).unwrap();
let binary_row = rows.row(0);

let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap();
let parser = converter.parser();
let utf8_row = parser.parse(binary_row.as_ref());

converter.convert_rows(std::iter::once(utf8_row)).unwrap();
}

#[test]
#[should_panic(expected = "rows were not produced by this RowConverter")]
fn test_different_converter() {
Expand Down
14 changes: 8 additions & 6 deletions arrow/src/row/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,18 @@ pub fn decode_binary<I: OffsetSizeTrait>(
pub unsafe fn decode_string<I: OffsetSizeTrait>(
rows: &mut [&[u8]],
options: SortOptions,
validate_utf8: bool,
) -> GenericStringArray<I> {
let d = match I::IS_LARGE {
true => DataType::LargeUtf8,
false => DataType::Utf8,
};
let decoded = decode_binary::<I>(rows, options);

if validate_utf8 {
return GenericStringArray::from(decoded);
}

let builder = decode_binary::<I>(rows, options)
let builder = decoded
.into_data()
.into_builder()
.data_type(d);
.data_type(GenericStringArray::<I>::DATA_TYPE);

// SAFETY:
// Row data must have come from a valid UTF-8 array
Expand Down

0 comments on commit 1d22fe3

Please sign in to comment.