diff --git a/src/common/src/catalog/schema.rs b/src/common/src/catalog/schema.rs index e942bfcf4d479..3a10c9743cf33 100644 --- a/src/common/src/catalog/schema.rs +++ b/src/common/src/catalog/schema.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::ops::Index; +use arrow_schema::{DataType as ArrowDataType, Schema as ArrowSchema}; use itertools::Itertools; use risingwave_pb::plan_common::{PbColumnDesc, PbField}; @@ -21,7 +23,6 @@ use super::ColumnDesc; use crate::array::ArrayBuilderImpl; use crate::types::{DataType, StructType}; use crate::util::iter_util::ZipEqFast; - /// The field in the schema of the executor's return data #[derive(Clone, PartialEq, Eq, Hash)] pub struct Field { @@ -197,6 +198,32 @@ impl Schema { true } } + + /// Check if the schema can convert to arrow schema. + pub fn same_as_arrow_schema(&self, arrow_schema: &ArrowSchema) -> bool { + if self.fields.len() != arrow_schema.fields().len() { + return false; + } + let mut schema_fields = HashMap::new(); + self.fields.iter().for_each(|field| { + let res = schema_fields.insert(&field.name, &field.data_type); + // This assert is to make sure there is no duplicate field name in the schema. + assert!(res.is_none()) + }); + + arrow_schema.fields().iter().all(|arrow_field| { + schema_fields + .get(arrow_field.name()) + .and_then(|data_type| { + if let Ok(data_type) = TryInto::::try_into(*data_type) && data_type == *arrow_field.data_type() { + Some(()) + } else { + None + } + }) + .is_some() + }) + } } impl Field { @@ -328,3 +355,38 @@ pub mod test_utils { decimal_n::<3>() } } + +#[cfg(test)] +mod test { + #[test] + fn test_same_as_arrow_schema() { + use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; + + use super::*; + let risingwave_schema = Schema::new(vec![ + Field::with_name(DataType::Int32, "a"), + Field::with_name(DataType::Int32, "b"), + Field::with_name(DataType::Int32, "c"), + ]); + let arrow_schema = ArrowSchema::new(vec![ + ArrowField::new("a", ArrowDataType::Int32, false), + ArrowField::new("b", ArrowDataType::Int32, false), + ArrowField::new("c", ArrowDataType::Int32, false), + ]); + assert!(risingwave_schema.same_as_arrow_schema(&arrow_schema)); + + let risingwave_schema = Schema::new(vec![ + Field::with_name(DataType::Int32, "d"), + Field::with_name(DataType::Int32, "c"), + Field::with_name(DataType::Int32, "a"), + Field::with_name(DataType::Int32, "b"), + ]); + let arrow_schema = ArrowSchema::new(vec![ + ArrowField::new("a", ArrowDataType::Int32, false), + ArrowField::new("b", ArrowDataType::Int32, false), + ArrowField::new("d", ArrowDataType::Int32, false), + ArrowField::new("c", ArrowDataType::Int32, false), + ]); + assert!(risingwave_schema.same_as_arrow_schema(&arrow_schema)); + } +}