From c4e97f49f545cbaf987751fe3c214ca213ebbefc Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Fri, 6 Nov 2020 16:29:36 -0500 Subject: [PATCH 1/3] [Rust] Only return dict_id/dict_is_ordered values for Dictionary types I had submitted this before but I think it got lost in a rebase somewhere. I think this is more correct and informative. --- rust/arrow/src/datatypes.rs | 45 ++++++++++++++----- rust/arrow/src/ipc/convert.rs | 8 +++- rust/arrow/src/ipc/writer.rs | 8 +++- .../src/bin/arrow-json-integration-test.rs | 15 +++++-- 4 files changed, 59 insertions(+), 17 deletions(-) diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index 0a26d2e5fd283..e95bc230f85e6 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -1196,16 +1196,22 @@ impl Field { self.nullable } - /// Returns the dictionary ID + /// Returns the dictionary ID, if this is a dictionary type #[inline] - pub const fn dict_id(&self) -> i64 { - self.dict_id + pub const fn dict_id(&self) -> Option { + match self.data_type { + DataType::Dictionary(_, _) => Some(self.dict_id), + _ => None, + } } - /// Indicates whether this `Field`'s dictionary is ordered + /// Returns whether this `Field`'s dictionary is ordered, if this is a dictionary type #[inline] - pub const fn dict_is_ordered(&self) -> bool { - self.dict_is_ordered + pub const fn dict_is_ordered(&self) -> Option { + match self.data_type { + DataType::Dictionary(_, _) => Some(self.dict_is_ordered), + _ => None, + } } /// Parse a `Field` definition from a JSON representation @@ -2512,7 +2518,8 @@ mod tests { last_name: Utf8, \ address: Struct([\ Field { name: \"street\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false }, \ - Field { name: \"zip\", data_type: UInt16, nullable: false, dict_id: 0, dict_is_ordered: false }])") + Field { name: \"zip\", data_type: UInt16, nullable: false, dict_id: 0, dict_is_ordered: false }]), \ + interests: Dictionary(Int32, Utf8)") } #[test] @@ -2520,18 +2527,29 @@ mod tests { let schema = person_schema(); // test schema accessors - assert_eq!(schema.fields().len(), 3); + assert_eq!(schema.fields().len(), 4); // test field accessors let first_name = &schema.fields()[0]; assert_eq!(first_name.name(), "first_name"); assert_eq!(first_name.data_type(), &DataType::Utf8); assert_eq!(first_name.is_nullable(), false); + assert_eq!(first_name.dict_id(), None); + assert_eq!(first_name.dict_is_ordered(), None); + + let interests = &schema.fields()[3]; + assert_eq!(interests.name(), "interests"); + assert_eq!( + interests.data_type(), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + ); + assert_eq!(interests.dict_id(), Some(123)); + assert_eq!(interests.dict_is_ordered(), Some(true)); } #[test] #[should_panic( - expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\"]" + expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]" )] fn schema_index_of() { let schema = person_schema(); @@ -2542,7 +2560,7 @@ mod tests { #[test] #[should_panic( - expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\"]" + expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]" )] fn schema_field_with_name() { let schema = person_schema(); @@ -2622,6 +2640,13 @@ mod tests { ]), false, ), + Field::new_dict( + "interests", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 123, + true, + ), ]) } diff --git a/rust/arrow/src/ipc/convert.rs b/rust/arrow/src/ipc/convert.rs index 5c5544297a121..c5d54bca7f213 100644 --- a/rust/arrow/src/ipc/convert.rs +++ b/rust/arrow/src/ipc/convert.rs @@ -291,8 +291,12 @@ pub(crate) fn build_field<'a>( let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() { Some(get_fb_dictionary( index_type, - field.dict_id(), - field.dict_is_ordered(), + field + .dict_id() + .expect("All Dictionary types have `dict_id`"), + field + .dict_is_ordered() + .expect("All Dictionary types have `dict_is_ordered`"), fbb, )) } else { diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index cf32eaa2096b5..d6a52a62c5d20 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -180,7 +180,9 @@ impl FileWriter { let column = batch.column(i); if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { - let dict_id = field.dict_id(); + let dict_id = field + .dict_id() + .expect("All Dictionary types have `dict_id`"); let dict_data = column.data(); let dict_values = &dict_data.child_data()[0]; @@ -317,7 +319,9 @@ impl StreamWriter { let column = batch.column(i); if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { - let dict_id = field.dict_id(); + let dict_id = field + .dict_id() + .expect("All Dictionary types have `dict_id`"); let dict_data = column.data(); let dict_values = &dict_data.child_data()[0]; diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index 72a113fcdc2bd..b1bec677cf163 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -489,7 +489,12 @@ fn array_from_json( Ok(Arc::new(array)) } DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id(); + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find dict_id for field {:?}", + field + )) + })?; // find dictionary let dictionary = dictionaries .ok_or_else(|| { @@ -539,8 +544,12 @@ fn dictionary_array_from_json( "key", dict_key.clone(), field.is_nullable(), - field.dict_id(), - field.dict_is_ordered(), + field + .dict_id() + .expect("Dictionary fields must have a dict_id value"), + field + .dict_is_ordered() + .expect("Dictionary fields must have a dict_is_ordered value"), ); let keys = array_from_json(&key_field, json_col, None)?; // note: not enough info on nullability of dictionary From 29ba56c6b335de6123c0700d1cbf04ca9748ae0d Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 12:48:17 -0500 Subject: [PATCH 2/3] [Rust] Enable looking up Schema Fields by dict_id --- rust/arrow/src/datatypes.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index e95bc230f85e6..c4b01091d8076 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -1648,6 +1648,15 @@ impl Schema { Ok(&self.fields[self.index_of(name)?]) } + /// Returns a vector of immutable references to all `Field` instances selected by + /// the dictionary ID they use + pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { + self.fields + .iter() + .filter(|f| f.dict_id() == Some(dict_id)) + .collect() + } + /// Find the index of the column with the given name pub fn index_of(&self, name: &str) -> Result { for i in 0..self.fields.len() { @@ -2575,6 +2584,20 @@ mod tests { schema.field_with_name("nickname").unwrap(); } + #[test] + fn schema_field_with_dict_id() { + let schema = person_schema(); + + let fields_dict_123: Vec<_> = schema + .fields_with_dict_id(123) + .iter() + .map(|f| f.name()) + .collect(); + assert_eq!(fields_dict_123, vec!["interests"]); + + assert!(schema.fields_with_dict_id(456).is_empty()); + } + #[test] fn schema_equality() { let schema1 = Schema::new(vec![ From 222dd64a122b62c1b064dd5260e8d2b657370a97 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 13:30:45 -0500 Subject: [PATCH 3/3] [Rust] Use Schema for dict info rather than also keeping ipc_schema This simplifies the code in `read_dictionary` and `StreamReader`; the `ipc_schema` gets parsed and converted into a `Schema` that `StreamReader` also holds onto, so use the `Schema` to find which fields use dictionaries rather than keeping both around. --- rust/arrow/src/ipc/reader.rs | 59 +++++++----------------------------- 1 file changed, 11 insertions(+), 48 deletions(-) diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index 76ad6b77cf3db..1b4119c9d964f 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -463,7 +463,6 @@ pub fn read_record_batch( fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, - ipc_schema: &ipc::Schema, schema: &Schema, dictionaries_by_field: &mut [Option], ) -> Result<()> { @@ -474,15 +473,15 @@ fn read_dictionary( } let id = batch.id(); - - // As the dictionary batch does not contain the type of the - // values array, we need to retrieve this from the schema. - let first_field = find_dictionary_field(ipc_schema, id).ok_or_else(|| { + let fields_using_this_dictionary = schema.fields_with_dict_id(id); + let first_field = fields_using_this_dictionary.first().ok_or_else(|| { ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) })?; + // As the dictionary batch does not contain the type of the + // values array, we need to retrieve this from the schema. // Get an array representing this dictionary's values. - let dictionary_values: ArrayRef = match schema.field(first_field).data_type() { + let dictionary_values: ArrayRef = match first_field.data_type() { DataType::Dictionary(_, ref value_type) => { // Make a fake schema for the dictionary batch. let schema = Schema { @@ -508,33 +507,16 @@ fn read_dictionary( // in the reader. Note that a dictionary batch may be shared between many fields. // We don't currently record the isOrdered field. This could be general // attributes of arrays. - let fields = ipc_schema.fields().unwrap(); - for (i, field) in fields.iter().enumerate() { - if let Some(dictionary) = field.dictionary() { - if dictionary.id() == id { - // Add (possibly multiple) array refs to the dictionaries array. - dictionaries_by_field[i] = Some(dictionary_values.clone()); - } + for (i, field) in schema.fields().iter().enumerate() { + if field.dict_id() == Some(id) { + // Add (possibly multiple) array refs to the dictionaries array. + dictionaries_by_field[i] = Some(dictionary_values.clone()); } } Ok(()) } -// Linear search for the first dictionary field with a dictionary id. -fn find_dictionary_field(ipc_schema: &ipc::Schema, id: i64) -> Option { - let fields = ipc_schema.fields().unwrap(); - for i in 0..fields.len() { - let field: ipc::Field = fields.get(i); - if let Some(dictionary) = field.dictionary() { - if dictionary.id() == id { - return Some(i); - } - } - } - None -} - /// Arrow File reader pub struct FileReader { /// Buffered file reader that supports reading and seeking @@ -639,13 +621,7 @@ impl FileReader { ))?; reader.read_exact(&mut buf)?; - read_dictionary( - &buf, - batch, - &ipc_schema, - &schema, - &mut dictionaries_by_field, - )?; + read_dictionary(&buf, batch, &schema, &mut dictionaries_by_field)?; } t => { return Err(ArrowError::IoError(format!( @@ -781,11 +757,6 @@ pub struct StreamReader { /// The schema that is read from the stream's first message schema: SchemaRef, - /// The bytes of the IPC schema that is read from the stream's first message - /// - /// This is kept in order to interpret dictionary data - ipc_schema: Vec, - /// Optional dictionaries for each schema field. /// /// Dictionaries may be appended to in the streaming format. @@ -833,7 +804,6 @@ impl StreamReader { Ok(Self { reader, schema: Arc::new(schema), - ipc_schema: meta_buffer, finished: false, dictionaries_by_field, }) @@ -918,15 +888,8 @@ impl StreamReader { let mut buf = vec![0; message.bodyLength() as usize]; self.reader.read_exact(&mut buf)?; - let ipc_schema = ipc::get_root_as_message(&self.ipc_schema).header_as_schema() - .ok_or_else(|| { - ArrowError::IoError( - "Unable to read schema from stored message header".to_string(), - ) - })?; - read_dictionary( - &buf, batch, &ipc_schema, &self.schema, &mut self.dictionaries_by_field + &buf, batch, &self.schema, &mut self.dictionaries_by_field )?; // read the next message until we encounter a RecordBatch