Skip to content

Commit

Permalink
Add tests that arrow IPC data is validated (#7096)
Browse files Browse the repository at this point in the history
* Add tests for validating IPC data read/written

Add tests for invalid arrays

* Test with file decoder too

* consolidate test

* Add test for test_validation_of_invalid_primitive_array

* Rework ArrayData validation to return error rather than panic

* Revert "Rework ArrayData validation to return error rather than panic"

This reverts commit 0b88bbc.

* Revert "Add test for test_validation_of_invalid_primitive_array"

This reverts commit 8d885a1.
  • Loading branch information
alamb authored Feb 12, 2025
1 parent 2bce568 commit 78c9df9
Showing 1 changed file with 167 additions and 22 deletions.
189 changes: 167 additions & 22 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1484,10 +1484,11 @@ mod tests {

use super::*;

use crate::root_as_message;
use crate::convert::fb_to_schema;
use crate::{root_as_footer, root_as_message};
use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
use arrow_array::types::*;
use arrow_buffer::NullBuffer;
use arrow_buffer::{NullBuffer, OffsetBuffer};
use arrow_data::ArrayDataBuilder;

fn create_test_projection_schema() -> Schema {
Expand Down Expand Up @@ -1724,27 +1725,73 @@ mod tests {
});
}

fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
/// Write the record batch to an in-memory buffer in IPC File format
fn write_ipc(rb: &RecordBatch) -> Vec<u8> {
let mut buf = Vec::new();
let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
writer.write(rb).unwrap();
writer.finish().unwrap();
drop(writer);
buf
}

let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
reader.next().unwrap().unwrap()
/// Return the first record batch read from the IPC File buffer
fn read_ipc(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?;
reader.next().unwrap()
}

fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
let buf = write_ipc(rb);
read_ipc(&buf).unwrap()
}

/// Return the first record batch read from the IPC File buffer
/// using the FileDecoder API
fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
let buffer = Buffer::from_vec(buf);
let trailer_start = buffer.len() - 10;
let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start])
.map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?;

let schema = fb_to_schema(footer.schema().unwrap());

let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
// Read dictionaries
for block in footer.dictionaries().iter().flatten() {
let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
let data = buffer.slice_with_length(block.offset() as _, block_len);
decoder.read_dictionary(block, &data)?
}

// Read record batch
let batches = footer.recordBatches().unwrap();
assert_eq!(batches.len(), 1); // Only wrote a single batch

let block = batches.get(0);
let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
let data = buffer.slice_with_length(block.offset() as _, block_len);
Ok(decoder.read_record_batch(block, &data)?.unwrap())
}

/// Write the record batch to an in-memory buffer in IPC Stream format
fn write_stream(rb: &RecordBatch) -> Vec<u8> {
let mut buf = Vec::new();
let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
writer.write(rb).unwrap();
writer.finish().unwrap();
drop(writer);
buf
}

/// Return the first record batch read from the IPC Stream buffer
fn read_stream(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?;
reader.next().unwrap()
}

let mut reader =
crate::reader::StreamReader::try_new(std::io::Cursor::new(buf), None).unwrap();
reader.next().unwrap().unwrap()
fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
let buf = write_stream(rb);
read_stream(&buf).unwrap()
}

#[test]
Expand Down Expand Up @@ -2403,17 +2450,10 @@ mod tests {
.build_unchecked(),
)
};

let batch = RecordBatch::try_new(schema.clone(), vec![invalid_struct_arr]).unwrap();

let mut buf = Vec::new();
let mut writer = crate::writer::FileWriter::try_new(&mut buf, schema.as_ref()).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();

let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
let err = reader.next().unwrap().unwrap_err();
assert!(matches!(err, ArrowError::InvalidArgumentError(_)));
expect_ipc_validation_error(
Arc::new(invalid_struct_arr),
"Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3",
);
}

#[test]
Expand Down Expand Up @@ -2472,4 +2512,109 @@ mod tests {
assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
});
}

#[test]
fn test_validation_of_invalid_list_array() {
// ListArray with invalid offsets
let array = unsafe {
let values = Int32Array::from(vec![1, 2, 3]);
let bad_offsets = ScalarBuffer::<i32>::from(vec![0, 2, 4, 2]); // offsets can't go backwards
let offsets = OffsetBuffer::new_unchecked(bad_offsets); // INVALID array created
let field = Field::new_list_field(DataType::Int32, true);
let nulls = None;
ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls)
};

expect_ipc_validation_error(
Arc::new(array),
"Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2"
);
}

#[test]
fn test_validation_of_invalid_string_array() {
let valid: &[u8] = b" ";
let mut invalid = vec![];
invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
// data is not valid utf8 we can not construct a correct StringArray
// safely, so purposely create an invalid StringArray
let array = unsafe {
StringArray::new_unchecked(
binary_array.offsets().clone(),
binary_array.values().clone(),
binary_array.nulls().cloned(),
)
};
expect_ipc_validation_error(
Arc::new(array),
"Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38"
);
}

#[test]
fn test_validation_of_invalid_string_view_array() {
let valid: &[u8] = b" ";
let mut invalid = vec![];
invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
let binary_view_array =
BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
// data is not valid utf8 we can not construct a correct StringArray
// safely, so purposely create an invalid StringArray
let array = unsafe {
StringViewArray::new_unchecked(
binary_view_array.views().clone(),
binary_view_array.data_buffers().to_vec(),
binary_view_array.nulls().cloned(),
)
};
expect_ipc_validation_error(
Arc::new(array),
"Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38"
);
}

/// return an invalid dictionary array (key is larger than values)
/// ListArray with invalid offsets
#[test]
fn test_validation_of_invalid_dictionary_array() {
let array = unsafe {
let values = StringArray::from_iter_values(["a", "b", "c"]);
let keys = Int32Array::from(vec![1, 200]); // keys are not valid for values
DictionaryArray::new_unchecked(keys, Arc::new(values))
};

expect_ipc_validation_error(
Arc::new(array),
"Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])",
);
}

/// Invalid Utf-8 sequence in the first character
/// <https://stackoverflow.com/questions/1301402/example-invalid-utf8-string>
const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];

/// Expect an error when reading the record batch using IPC or IPC Streams
fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) {
let rb = RecordBatch::try_from_iter([("a", array)]).unwrap();

// IPC Stream format
let buf = write_stream(&rb); // write is ok
let err = read_stream(&buf).unwrap_err();
assert_eq!(err.to_string(), expected_err);

// IPC File format
let buf = write_ipc(&rb); // write is ok
let err = read_ipc(&buf).unwrap_err();
assert_eq!(err.to_string(), expected_err);

// TODO verify there is no error when validation is disabled
// see https://github.com/apache/arrow-rs/issues/3287

// IPC Format with FileDecoder
let err = read_ipc_with_decoder(buf).unwrap_err();
assert_eq!(err.to_string(), expected_err);
}
}

0 comments on commit 78c9df9

Please sign in to comment.