Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-10791: [Rust] StreamReader, read_dictionary duplicating schema info #8820

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 58 additions & 10 deletions rust/arrow/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1196,16 +1196,22 @@ impl Field {
self.nullable
}

/// Returns the dictionary ID
/// Returns the dictionary ID, if this is a dictionary type
#[inline]
pub const fn dict_id(&self) -> i64 {
self.dict_id
pub const fn dict_id(&self) -> Option<i64> {
match self.data_type {
DataType::Dictionary(_, _) => Some(self.dict_id),
_ => None,
}
}

/// Indicates whether this `Field`'s dictionary is ordered
/// Returns whether this `Field`'s dictionary is ordered, if this is a dictionary type
#[inline]
pub const fn dict_is_ordered(&self) -> bool {
self.dict_is_ordered
pub const fn dict_is_ordered(&self) -> Option<bool> {
match self.data_type {
DataType::Dictionary(_, _) => Some(self.dict_is_ordered),
_ => None,
}
}

/// Parse a `Field` definition from a JSON representation
Expand Down Expand Up @@ -1642,6 +1648,15 @@ impl Schema {
Ok(&self.fields[self.index_of(name)?])
}

/// Returns a vector of immutable references to all `Field` instances selected by
/// the dictionary ID they use
pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> {
self.fields
.iter()
.filter(|f| f.dict_id() == Some(dict_id))
.collect()
}

/// Find the index of the column with the given name
pub fn index_of(&self, name: &str) -> Result<usize> {
for i in 0..self.fields.len() {
Expand Down Expand Up @@ -2512,26 +2527,38 @@ mod tests {
last_name: Utf8, \
address: Struct([\
Field { name: \"street\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false }, \
Field { name: \"zip\", data_type: UInt16, nullable: false, dict_id: 0, dict_is_ordered: false }])")
Field { name: \"zip\", data_type: UInt16, nullable: false, dict_id: 0, dict_is_ordered: false }]), \
interests: Dictionary(Int32, Utf8)")
}

#[test]
fn schema_field_accessors() {
let schema = person_schema();

// test schema accessors
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.fields().len(), 4);

// test field accessors
let first_name = &schema.fields()[0];
assert_eq!(first_name.name(), "first_name");
assert_eq!(first_name.data_type(), &DataType::Utf8);
assert_eq!(first_name.is_nullable(), false);
assert_eq!(first_name.dict_id(), None);
assert_eq!(first_name.dict_is_ordered(), None);

let interests = &schema.fields()[3];
assert_eq!(interests.name(), "interests");
assert_eq!(
interests.data_type(),
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))
);
assert_eq!(interests.dict_id(), Some(123));
assert_eq!(interests.dict_is_ordered(), Some(true));
}

#[test]
#[should_panic(
expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\"]"
expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]"
)]
fn schema_index_of() {
let schema = person_schema();
Expand All @@ -2542,7 +2569,7 @@ mod tests {

#[test]
#[should_panic(
expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\"]"
expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]"
)]
fn schema_field_with_name() {
let schema = person_schema();
Expand All @@ -2557,6 +2584,20 @@ mod tests {
schema.field_with_name("nickname").unwrap();
}

#[test]
fn schema_field_with_dict_id() {
let schema = person_schema();

let fields_dict_123: Vec<_> = schema
.fields_with_dict_id(123)
.iter()
.map(|f| f.name())
.collect();
assert_eq!(fields_dict_123, vec!["interests"]);

assert!(schema.fields_with_dict_id(456).is_empty());
}

#[test]
fn schema_equality() {
let schema1 = Schema::new(vec![
Expand Down Expand Up @@ -2622,6 +2663,13 @@ mod tests {
]),
false,
),
Field::new_dict(
"interests",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
123,
true,
),
])
}

Expand Down
8 changes: 6 additions & 2 deletions rust/arrow/src/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,12 @@ pub(crate) fn build_field<'a>(
let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() {
Some(get_fb_dictionary(
index_type,
field.dict_id(),
field.dict_is_ordered(),
field
.dict_id()
.expect("All Dictionary types have `dict_id`"),
field
.dict_is_ordered()
.expect("All Dictionary types have `dict_is_ordered`"),
fbb,
))
} else {
Expand Down
59 changes: 11 additions & 48 deletions rust/arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ pub fn read_record_batch(
fn read_dictionary(
buf: &[u8],
batch: ipc::DictionaryBatch,
ipc_schema: &ipc::Schema,
schema: &Schema,
dictionaries_by_field: &mut [Option<ArrayRef>],
) -> Result<()> {
Expand All @@ -474,15 +473,15 @@ fn read_dictionary(
}

let id = batch.id();

// As the dictionary batch does not contain the type of the
// values array, we need to retrieve this from the schema.
let first_field = find_dictionary_field(ipc_schema, id).ok_or_else(|| {
let fields_using_this_dictionary = schema.fields_with_dict_id(id);
let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string())
})?;

// As the dictionary batch does not contain the type of the
// values array, we need to retrieve this from the schema.
// Get an array representing this dictionary's values.
let dictionary_values: ArrayRef = match schema.field(first_field).data_type() {
let dictionary_values: ArrayRef = match first_field.data_type() {
DataType::Dictionary(_, ref value_type) => {
// Make a fake schema for the dictionary batch.
let schema = Schema {
Expand All @@ -508,33 +507,16 @@ fn read_dictionary(
// in the reader. Note that a dictionary batch may be shared between many fields.
// We don't currently record the isOrdered field. This could be general
// attributes of arrays.
let fields = ipc_schema.fields().unwrap();
for (i, field) in fields.iter().enumerate() {
if let Some(dictionary) = field.dictionary() {
if dictionary.id() == id {
// Add (possibly multiple) array refs to the dictionaries array.
dictionaries_by_field[i] = Some(dictionary_values.clone());
}
for (i, field) in schema.fields().iter().enumerate() {
if field.dict_id() == Some(id) {
// Add (possibly multiple) array refs to the dictionaries array.
dictionaries_by_field[i] = Some(dictionary_values.clone());
}
}

Ok(())
}

// Linear search for the first dictionary field with a dictionary id.
fn find_dictionary_field(ipc_schema: &ipc::Schema, id: i64) -> Option<usize> {
let fields = ipc_schema.fields().unwrap();
for i in 0..fields.len() {
let field: ipc::Field = fields.get(i);
if let Some(dictionary) = field.dictionary() {
if dictionary.id() == id {
return Some(i);
}
}
}
None
}

/// Arrow File reader
pub struct FileReader<R: Read + Seek> {
/// Buffered file reader that supports reading and seeking
Expand Down Expand Up @@ -639,13 +621,7 @@ impl<R: Read + Seek> FileReader<R> {
))?;
reader.read_exact(&mut buf)?;

read_dictionary(
&buf,
batch,
&ipc_schema,
&schema,
&mut dictionaries_by_field,
)?;
read_dictionary(&buf, batch, &schema, &mut dictionaries_by_field)?;
}
t => {
return Err(ArrowError::IoError(format!(
Expand Down Expand Up @@ -781,11 +757,6 @@ pub struct StreamReader<R: Read> {
/// The schema that is read from the stream's first message
schema: SchemaRef,

/// The bytes of the IPC schema that is read from the stream's first message
///
/// This is kept in order to interpret dictionary data
ipc_schema: Vec<u8>,

/// Optional dictionaries for each schema field.
///
/// Dictionaries may be appended to in the streaming format.
Expand Down Expand Up @@ -833,7 +804,6 @@ impl<R: Read> StreamReader<R> {
Ok(Self {
reader,
schema: Arc::new(schema),
ipc_schema: meta_buffer,
finished: false,
dictionaries_by_field,
})
Expand Down Expand Up @@ -918,15 +888,8 @@ impl<R: Read> StreamReader<R> {
let mut buf = vec![0; message.bodyLength() as usize];
self.reader.read_exact(&mut buf)?;

let ipc_schema = ipc::get_root_as_message(&self.ipc_schema).header_as_schema()
.ok_or_else(|| {
ArrowError::IoError(
"Unable to read schema from stored message header".to_string(),
)
})?;

read_dictionary(
&buf, batch, &ipc_schema, &self.schema, &mut self.dictionaries_by_field
&buf, batch, &self.schema, &mut self.dictionaries_by_field
)?;

// read the next message until we encounter a RecordBatch
Expand Down
8 changes: 6 additions & 2 deletions rust/arrow/src/ipc/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ impl<W: Write> FileWriter<W> {
let column = batch.column(i);

if let DataType::Dictionary(_key_type, _value_type) = column.data_type() {
let dict_id = field.dict_id();
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
let dict_data = column.data();
let dict_values = &dict_data.child_data()[0];

Expand Down Expand Up @@ -317,7 +319,9 @@ impl<W: Write> StreamWriter<W> {
let column = batch.column(i);

if let DataType::Dictionary(_key_type, _value_type) = column.data_type() {
let dict_id = field.dict_id();
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
let dict_data = column.data();
let dict_values = &dict_data.child_data()[0];

Expand Down
15 changes: 12 additions & 3 deletions rust/integration-testing/src/bin/arrow-json-integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,12 @@ fn array_from_json(
Ok(Arc::new(array))
}
DataType::Dictionary(key_type, value_type) => {
let dict_id = field.dict_id();
let dict_id = field.dict_id().ok_or_else(|| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologise, I do remember this change, but I'm not sure of what happened; I presume while I was rebasing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! I remember I sent you a rebased PR to your branch that was... complicated... so I'm not surprised :) Thanks for the merge!

ArrowError::JsonError(format!(
"Unable to find dict_id for field {:?}",
field
))
})?;
// find dictionary
let dictionary = dictionaries
.ok_or_else(|| {
Expand Down Expand Up @@ -539,8 +544,12 @@ fn dictionary_array_from_json(
"key",
dict_key.clone(),
field.is_nullable(),
field.dict_id(),
field.dict_is_ordered(),
field
.dict_id()
.expect("Dictionary fields must have a dict_id value"),
field
.dict_is_ordered()
.expect("Dictionary fields must have a dict_is_ordered value"),
);
let keys = array_from_json(&key_field, json_col, None)?;
// note: not enough info on nullability of dictionary
Expand Down