From a7244f565e9a0ed63ef16368b1bae231ac1033c4 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 31 May 2024 14:11:30 +0200 Subject: [PATCH 01/21] Add `ExtensionType` for `uuid` and map to parquet logical type --- arrow-schema/src/datatype.rs | 76 +++++++++++++++++++++++++++++++++ arrow-schema/src/field.rs | 33 +++++++++++++- parquet/src/arrow/schema/mod.rs | 6 ++- 3 files changed, 113 insertions(+), 2 deletions(-) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 449d363db671..6561f5b8855b 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -672,6 +672,82 @@ impl DataType { } } +/// Canonical extension types. +/// +/// The Arrow columnar format allows defining extension types so as to extend +/// standard Arrow data types with custom semantics. Often these semantics will +/// be specific to a system or application. However, it is beneficial to share +/// the definitions of well-known extension types so as to improve +/// interoperability between different systems integrating Arrow columnar data. +/// +/// https://arrow.apache.org/docs/format/CanonicalExtensions.html +#[non_exhaustive] +#[derive(Debug, Clone)] +pub enum ExtensionType { + /// Extension name: `arrow.uuid`. + /// + /// The storage type of the extension is `FixedSizeBinary` with a length of + /// 16 bytes. + /// + /// Note: + /// A specific UUID version is not required or guaranteed. This extension + /// represents UUIDs as FixedSizeBinary(16) with big-endian notation and + /// does not interpret the bytes in any way. + Uuid, +} + +impl fmt::Display for ExtensionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +impl ExtensionType { + /// The metadata key for the string name identifying the custom data type. + pub const NAME_KEY: &'static str = "ARROW:extension:name"; + + /// The metadata key for a serialized representation of the ExtensionType + /// necessary to reconstruct the custom type. + pub const METADATA_KEY: &'static str = "ARROW:extension:metadata"; + + /// Returns the name of this extension type. + pub fn name(&self) -> &'static str { + match self { + ExtensionType::Uuid => "arrow.uuid", + } + } + + /// Returns the metadata of this extension type. + pub fn metadata(&self) -> Option { + match self { + ExtensionType::Uuid => None, + } + } + + /// Returns `true` iff the given [`DataType`] can be used as storage type + /// for this extension type. + pub(crate) fn supports_storage_type(&self, data_type: &DataType) -> bool { + match self { + ExtensionType::Uuid => matches!(data_type, DataType::FixedSizeBinary(16)), + } + } + + /// Extract an [`ExtensionType`] from the given [`Field`]. + /// + /// This function returns `None` if the extension type is not supported or + /// recognized. + pub(crate) fn try_from_field(field: &Field) -> Option { + let metadata = field.metadata().get(ExtensionType::METADATA_KEY); + field + .metadata() + .get(ExtensionType::NAME_KEY) + .and_then(|name| match name.as_str() { + "arrow.uuid" if metadata.is_none() => Some(ExtensionType::Uuid), + _ => None, + }) + } +} + /// The maximum precision for [DataType::Decimal128] values pub const DECIMAL128_MAX_PRECISION: u8 = 38; diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index b84a2568ed8a..02a87264a25a 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::datatype::DataType; use crate::schema::SchemaBuilder; -use crate::{Fields, UnionFields, UnionMode}; +use crate::{ExtensionType, Fields, UnionFields, UnionMode}; /// A reference counted [`Field`] pub type FieldRef = Arc; @@ -337,6 +337,37 @@ impl Field { self } + /// Returns the canonical [`ExtensionType`] of this [`Field`], if set. + pub fn extension_type(&self) -> Option { + ExtensionType::try_from_field(self) + } + + /// Updates the metadata of this [`Field`] with the [`ExtensionType::name`] + /// and [`ExtensionType::metadata`] of the given [`ExtensionType`]. + /// + /// # Panics + /// + /// This function panics when the datatype of this field is not a valid + /// storage type for the given extension type. + pub fn with_extension_type(mut self, extension_type: ExtensionType) -> Self { + if extension_type.supports_storage_type(&self.data_type) { + self.metadata.insert( + ExtensionType::NAME_KEY.to_owned(), + extension_type.name().to_owned(), + ); + if let Some(metadata) = extension_type.metadata() { + self.metadata + .insert(ExtensionType::METADATA_KEY.to_owned(), metadata); + } + self + } else { + panic!( + "{extension_type} does not support {} as storage type", + self.data_type + ); + } + } + /// Indicates whether this [`Field`] supports null values. #[inline] pub const fn is_nullable(&self) -> bool { diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 8c583eebac5b..61e14a40a23d 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -29,7 +29,7 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_ipc::writer; -use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; +use arrow_schema::{DataType, ExtensionType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, Type as PhysicalType, @@ -468,6 +468,10 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .with_length(*length) + .with_logical_type(match field.extension_type() { + Some(ExtensionType::Uuid) => Some(LogicalType::Uuid), + _ => None, + }) .build() } DataType::BinaryView => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) From 6b2e7aa8fd759ac8ff86681e4a71886fe9928679 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 31 May 2024 15:13:44 +0200 Subject: [PATCH 02/21] Fix docs --- arrow-schema/src/datatype.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 6561f5b8855b..7e1dafbb7e04 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -680,7 +680,7 @@ impl DataType { /// the definitions of well-known extension types so as to improve /// interoperability between different systems integrating Arrow columnar data. /// -/// https://arrow.apache.org/docs/format/CanonicalExtensions.html +/// #[non_exhaustive] #[derive(Debug, Clone)] pub enum ExtensionType { From bdeab9f47e925a9c06233e502cf743458f158b93 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Thu, 26 Sep 2024 21:10:35 +0200 Subject: [PATCH 03/21] Use an `ExtensionType` trait instead --- arrow-schema/Cargo.toml | 2 +- arrow-schema/src/datatype.rs | 311 +++++++++++++++++++++++++++----- arrow-schema/src/field.rs | 70 +++++-- parquet/src/arrow/schema/mod.rs | 28 ++- 4 files changed, 338 insertions(+), 73 deletions(-) diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index 628d4a683cac..711543a18677 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -36,6 +36,7 @@ bench = false [dependencies] serde = { version = "1.0", default-features = false, features = ["derive", "std", "rc"], optional = true } bitflags = { version = "2.0.0", default-features = false, optional = true } +serde_json = "1.0" [features] # Enable ffi support @@ -45,5 +46,4 @@ ffi = ["bitflags"] features = ["ffi"] [dev-dependencies] -serde_json = "1.0" bincode = { version = "1.3.3", default-features = false } diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 7e1dafbb7e04..fb14fe0695ab 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -672,6 +672,76 @@ impl DataType { } } +/// The metadata key for the string name identifying the custom data type. +pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; + +/// The metadata key for a serialized representation of the ExtensionType +/// necessary to reconstruct the custom type. +pub const EXTENSION_TYPE_METADATA_KEY: &str = "ARROW:extension:metadata"; + +/// Extension types. +/// +/// +pub trait ExtensionType: Sized { + /// The name of this extension type. + const NAME: &'static str; + + /// The supported storage types of this extension type. + fn storage_types(&self) -> &[DataType]; + + /// The metadata type of this extension type. + type Metadata; + + /// Returns a reference to the metadata of this extension type, or `None` + /// if this extension type has no metadata. + fn metadata(&self) -> Option<&Self::Metadata>; + + /// Returns the serialized representation of the metadata of this extension + /// type, or `None` if this extension type has no metadata. + fn into_serialized_metadata(&self) -> Option; + + /// Deserialize this extension type from the serialized representation of the + /// metadata of this extension. An extension type that has no metadata should + /// expect `None` for for the serialized metadata. + fn from_serialized_metadata(serialized_metadata: Option<&str>) -> Option; +} + +pub(crate) trait ExtensionTypeExt: ExtensionType { + /// Returns `true` if the given data type is supported by this extension + /// type. + fn supports(&self, data_type: &DataType) -> bool { + self.storage_types().contains(data_type) + } + + /// Try to extract this extension type from the given [`Field`]. + /// + /// This function returns `None` if extension type + /// - information is missing + /// - name does not match + /// - metadata deserialization failed + /// - does not support the data type of this field + fn try_from_field(field: &Field) -> Option { + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .and_then(|name| { + (name == ::NAME) + .then(|| { + Self::from_serialized_metadata( + field + .metadata() + .get(EXTENSION_TYPE_METADATA_KEY) + .map(String::as_str), + ) + }) + .flatten() + }) + .filter(|extension_type| extension_type.supports(field.data_type())) + } +} + +impl ExtensionTypeExt for T where T: ExtensionType {} + /// Canonical extension types. /// /// The Arrow columnar format allows defining extension types so as to extend @@ -679,11 +749,90 @@ impl DataType { /// be specific to a system or application. However, it is beneficial to share /// the definitions of well-known extension types so as to improve /// interoperability between different systems integrating Arrow columnar data. -/// -/// -#[non_exhaustive] -#[derive(Debug, Clone)] -pub enum ExtensionType { +pub mod canonical_extension_types { + use serde_json::{Map, Value}; + + use super::{DataType, ExtensionType}; + + /// Canonical extension types. + #[non_exhaustive] + #[derive(Debug, Clone, PartialEq)] + pub enum CanonicalExtensionTypes { + /// The extension type for 'JSON'. + Json(Json), + /// The extension type for `UUID`. + Uuid(Uuid), + } + + impl From for CanonicalExtensionTypes { + fn from(value: Json) -> Self { + CanonicalExtensionTypes::Json(value) + } + } + + impl From for CanonicalExtensionTypes { + fn from(value: Uuid) -> Self { + CanonicalExtensionTypes::Uuid(value) + } + } + + /// The extension type for `JSON`. + /// + /// Extension name: `arrow.json`. + /// + /// The storage type of this extension is `String` or `LargeString` or + /// `StringView`. Only UTF-8 encoded JSON as specified in [rfc8259](https://datatracker.ietf.org/doc/html/rfc8259) + /// is supported. + /// + /// This type does not have any parameters. + /// + /// Metadata is either an empty string or a JSON string with an empty + /// object. In the future, additional fields may be added, but they are not + /// required to interpret the array. + /// + /// + #[derive(Debug, Clone, PartialEq)] + pub struct Json(Value); + + impl Default for Json { + fn default() -> Self { + Self(Value::String("".to_owned())) + } + } + + impl ExtensionType for Json { + const NAME: &'static str = "arrow.json"; + + type Metadata = Value; + + fn storage_types(&self) -> &[DataType] { + &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] + } + + fn metadata(&self) -> Option<&Self::Metadata> { + Some(&self.0) + } + + fn into_serialized_metadata(&self) -> Option { + Some(self.0.to_string()) + } + + fn from_serialized_metadata(serialized_metadata: Option<&str>) -> Option { + serialized_metadata.and_then(|metadata| match metadata { + // Empty string + r#""""# => Some(Default::default()), + // Empty object + value => value + .parse::() + .ok() + .filter(|value| value.as_object().is_some_and(Map::is_empty)) + .map(Self), + }) + } + } + + /// The extension type for `UUID`. + /// /// Extension name: `arrow.uuid`. /// /// The storage type of the extension is `FixedSizeBinary` with a length of @@ -691,60 +840,128 @@ pub enum ExtensionType { /// /// Note: /// A specific UUID version is not required or guaranteed. This extension - /// represents UUIDs as FixedSizeBinary(16) with big-endian notation and + /// represents UUIDs as `FixedSizeBinary(16)` with big-endian notation and /// does not interpret the bytes in any way. - Uuid, -} + /// + /// + #[derive(Debug, Default, Clone, Copy, PartialEq)] + pub struct Uuid; -impl fmt::Display for ExtensionType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.name()) - } -} + impl ExtensionType for Uuid { + const NAME: &'static str = "arrow.uuid"; -impl ExtensionType { - /// The metadata key for the string name identifying the custom data type. - pub const NAME_KEY: &'static str = "ARROW:extension:name"; + type Metadata = (); - /// The metadata key for a serialized representation of the ExtensionType - /// necessary to reconstruct the custom type. - pub const METADATA_KEY: &'static str = "ARROW:extension:metadata"; + fn storage_types(&self) -> &[DataType] { + &[DataType::FixedSizeBinary(16)] + } - /// Returns the name of this extension type. - pub fn name(&self) -> &'static str { - match self { - ExtensionType::Uuid => "arrow.uuid", + fn metadata(&self) -> Option<&Self::Metadata> { + None } - } - /// Returns the metadata of this extension type. - pub fn metadata(&self) -> Option { - match self { - ExtensionType::Uuid => None, + fn into_serialized_metadata(&self) -> Option { + None } - } - /// Returns `true` iff the given [`DataType`] can be used as storage type - /// for this extension type. - pub(crate) fn supports_storage_type(&self, data_type: &DataType) -> bool { - match self { - ExtensionType::Uuid => matches!(data_type, DataType::FixedSizeBinary(16)), + fn from_serialized_metadata(serialized_metadata: Option<&str>) -> Option { + serialized_metadata.is_none().then_some(Self) } } - /// Extract an [`ExtensionType`] from the given [`Field`]. - /// - /// This function returns `None` if the extension type is not supported or - /// recognized. - pub(crate) fn try_from_field(field: &Field) -> Option { - let metadata = field.metadata().get(ExtensionType::METADATA_KEY); - field - .metadata() - .get(ExtensionType::NAME_KEY) - .and_then(|name| match name.as_str() { - "arrow.uuid" if metadata.is_none() => Some(ExtensionType::Uuid), - _ => None, - }) + #[cfg(test)] + mod tests { + use std::collections::HashMap; + + use serde_json::Map; + + use crate::{ArrowError, Field, EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}; + + use super::*; + + #[test] + fn json() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::Utf8, false); + field.try_with_extension_type(Json::default())?; + assert_eq!( + field.metadata().get(EXTENSION_TYPE_METADATA_KEY), + Some(&r#""""#.to_owned()) + ); + assert!(field.extension_type::().is_some()); + + let mut field = Field::new("", DataType::LargeUtf8, false); + field.try_with_extension_type(Json(serde_json::Value::Object(Map::default())))?; + assert_eq!( + field.metadata().get(EXTENSION_TYPE_METADATA_KEY), + Some(&"{}".to_owned()) + ); + assert!(field.extension_type::().is_some()); + + let mut field = Field::new("", DataType::Utf8View, false); + field.try_with_extension_type(Json::default())?; + assert!(field.extension_type::().is_some()); + assert_eq!( + field.canonical_extension_type(), + Some(CanonicalExtensionTypes::Json(Json::default())) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "expected Utf8 or LargeUtf8 or Utf8View, found Boolean")] + fn json_bad_type() { + Field::new("", DataType::Boolean, false).with_extension_type(Json::default()); + } + + #[test] + fn json_bad_metadata() { + let field = Field::new("", DataType::Utf8, false).with_metadata(HashMap::from_iter([ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned()), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), "1234".to_owned()), + ])); + // This returns `None` now because this metadata is invalid. + assert!(field.extension_type::().is_none()); + } + + #[test] + fn json_missing_metadata() { + let field = Field::new("", DataType::LargeUtf8, false).with_metadata( + HashMap::from_iter([(EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned())]), + ); + // This returns `None` now because the metadata is missing. + assert!(field.extension_type::().is_none()); + } + + #[test] + fn uuid() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::FixedSizeBinary(16), false); + field.try_with_extension_type(Uuid)?; + assert!(field.extension_type::().is_some()); + assert_eq!( + field.canonical_extension_type(), + Some(CanonicalExtensionTypes::Uuid(Uuid)) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "expected FixedSizeBinary(16), found FixedSizeBinary(8)")] + fn uuid_bad_type() { + Field::new("", DataType::FixedSizeBinary(8), false).with_extension_type(Uuid); + } + + #[test] + fn uuid_with_metadata() { + // Add metadata that's not expected for uuid. + let field = Field::new("", DataType::FixedSizeBinary(16), false) + .with_metadata(HashMap::from_iter([( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + "".to_owned(), + )])) + .with_extension_type(Uuid); + // This returns `None` now because `Uuid` expects no metadata. + assert!(field.extension_type::().is_none()); + } } } diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 02a87264a25a..04cc3665223b 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::canonical_extension_types::{CanonicalExtensionTypes, Json, Uuid}; use crate::error::ArrowError; use std::cmp::Ordering; use std::collections::HashMap; @@ -23,7 +24,10 @@ use std::sync::Arc; use crate::datatype::DataType; use crate::schema::SchemaBuilder; -use crate::{ExtensionType, Fields, UnionFields, UnionMode}; +use crate::{ + ExtensionType, ExtensionTypeExt, Fields, UnionFields, UnionMode, EXTENSION_TYPE_METADATA_KEY, + EXTENSION_TYPE_NAME_KEY, +}; /// A reference counted [`Field`] pub type FieldRef = Arc; @@ -337,37 +341,63 @@ impl Field { self } - /// Returns the canonical [`ExtensionType`] of this [`Field`], if set. - pub fn extension_type(&self) -> Option { - ExtensionType::try_from_field(self) + /// Returns the given [`ExtensionType`] of this [`Field`], if set. + /// Returns `None` if this field does not have this extension type. + pub fn extension_type(&self) -> Option { + E::try_from_field(self) } - /// Updates the metadata of this [`Field`] with the [`ExtensionType::name`] + /// Returns the [`CanonicalExtensionTypes`] of this [`Field`], if set. + pub fn canonical_extension_type(&self) -> Option { + Json::try_from_field(self) + .map(Into::into) + .or(Uuid::try_from_field(self).map(Into::into)) + } + + /// Updates the metadata of this [`Field`] with the [`ExtensionType::NAME`] /// and [`ExtensionType::metadata`] of the given [`ExtensionType`]. /// - /// # Panics + /// # Error /// - /// This function panics when the datatype of this field is not a valid - /// storage type for the given extension type. - pub fn with_extension_type(mut self, extension_type: ExtensionType) -> Self { - if extension_type.supports_storage_type(&self.data_type) { - self.metadata.insert( - ExtensionType::NAME_KEY.to_owned(), - extension_type.name().to_owned(), - ); - if let Some(metadata) = extension_type.metadata() { + /// This functions returns an error if the datatype of this field does not + /// match the storage type of the given extension type. + pub fn try_with_extension_type( + &mut self, + extension_type: E, + ) -> Result<(), ArrowError> { + if extension_type.supports(&self.data_type) { + // Insert the name + self.metadata + .insert(EXTENSION_TYPE_NAME_KEY.to_owned(), E::NAME.to_owned()); + // Insert the metadata, if any + if let Some(metadata) = extension_type.into_serialized_metadata() { self.metadata - .insert(ExtensionType::METADATA_KEY.to_owned(), metadata); + .insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata); } - self + Ok(()) } else { - panic!( - "{extension_type} does not support {} as storage type", + Err(ArrowError::InvalidArgumentError(format!( + "storage type of extension type {} does not match field data type, expected {}, found {}", + ::NAME, + extension_type.storage_types().iter().map(ToString::to_string).collect::>().join(" or "), self.data_type - ); + ))) } } + /// Updates the metadata of this [`Field`] with the [`ExtensionType::NAME`] + /// and [`ExtensionType::metadata`] of the given [`ExtensionType`]. + /// + /// # Panics + /// + /// This functions panics if the datatype of this field does match the + /// storage type of the given extension type. + pub fn with_extension_type(mut self, extension_type: E) -> Self { + self.try_with_extension_type(extension_type) + .unwrap_or_else(|e| panic!("{e}")); + self + } + /// Indicates whether this [`Field`] supports null values. #[inline] pub const fn is_nullable(&self) -> bool { diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 61e14a40a23d..3a871aaba9ef 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -23,13 +23,14 @@ //! //! The interfaces for converting arrow schema to parquet schema is coming. +use arrow_schema::canonical_extension_types::Uuid; use base64::prelude::BASE64_STANDARD; use base64::Engine; use std::collections::HashMap; use std::sync::Arc; use arrow_ipc::writer; -use arrow_schema::{DataType, ExtensionType, Field, Fields, Schema, TimeUnit}; +use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, Type as PhysicalType, @@ -468,10 +469,8 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .with_length(*length) - .with_logical_type(match field.extension_type() { - Some(ExtensionType::Uuid) => Some(LogicalType::Uuid), - _ => None, - }) + // If set, map arrow uuid extension type to parquet uuid logical type. + .with_logical_type(field.extension_type::().map(|_| LogicalType::Uuid)) .build() } DataType::BinaryView => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) @@ -1913,4 +1912,23 @@ mod tests { fn test_get_arrow_schema_from_metadata() { assert!(get_arrow_schema_from_metadata("").is_err()); } + + #[test] + fn arrow_uuid_to_parquet_uuid() -> Result<()> { + let arrow_schema = Schema::new(vec![Field::new( + "uuid", + DataType::FixedSizeBinary(16), + false, + ) + .with_extension_type(Uuid)]); + + let parquet_schema = arrow_to_parquet_schema(&arrow_schema)?; + + assert_eq!( + parquet_schema.column(0).logical_type(), + Some(LogicalType::Uuid) + ); + + Ok(()) + } } From 84286535639daf2d21275f81cc0d46c37619952f Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Thu, 26 Sep 2024 21:24:17 +0200 Subject: [PATCH 04/21] Fix clippy warnings --- arrow-schema/src/datatype.rs | 10 +++++----- arrow-schema/src/field.rs | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 2bf9663ef512..b7a326f605f3 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -790,7 +790,7 @@ pub trait ExtensionType: Sized { /// Returns the serialized representation of the metadata of this extension /// type, or `None` if this extension type has no metadata. - fn into_serialized_metadata(&self) -> Option; + fn serialized_metadata(&self) -> Option; /// Deserialize this extension type from the serialized representation of the /// metadata of this extension. An extension type that has no metadata should @@ -842,7 +842,7 @@ impl ExtensionTypeExt for T where T: ExtensionType {} /// the definitions of well-known extension types so as to improve /// interoperability between different systems integrating Arrow columnar data. pub mod canonical_extension_types { - use serde_json::{Map, Value}; + use serde_json::Value; use super::{DataType, ExtensionType}; @@ -905,7 +905,7 @@ pub mod canonical_extension_types { Some(&self.0) } - fn into_serialized_metadata(&self) -> Option { + fn serialized_metadata(&self) -> Option { Some(self.0.to_string()) } @@ -917,7 +917,7 @@ pub mod canonical_extension_types { value => value .parse::() .ok() - .filter(|value| value.as_object().is_some_and(Map::is_empty)) + .filter(|value| matches!(value.as_object(), Some(map) if map.is_empty())) .map(Self), }) } @@ -952,7 +952,7 @@ pub mod canonical_extension_types { None } - fn into_serialized_metadata(&self) -> Option { + fn serialized_metadata(&self) -> Option { None } diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index bca3257b4bd6..f16e2f9bbc05 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -370,7 +370,7 @@ impl Field { self.metadata .insert(EXTENSION_TYPE_NAME_KEY.to_owned(), E::NAME.to_owned()); // Insert the metadata, if any - if let Some(metadata) = extension_type.into_serialized_metadata() { + if let Some(metadata) = extension_type.serialized_metadata() { self.metadata .insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata); } From 7896455d4cbd5f084c190d60ac51c0dfc7ccc99b Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Thu, 26 Sep 2024 23:06:33 +0200 Subject: [PATCH 05/21] Add type annotation to fix build --- arrow-select/src/dictionary.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index 2a532600b6cc..c363b99920a7 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -315,7 +315,7 @@ mod tests { assert_eq!(merged.values.as_ref(), &expected); assert_eq!(merged.key_mappings.len(), 2); assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]); - assert_eq!(&merged.key_mappings[1], &[]); + assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]); } #[test] From e35630a86911d5cb67656af8feaadc39cdbd0b52 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 17 Jan 2025 23:48:38 +0100 Subject: [PATCH 06/21] Update `ExtensionType` trait to support more canonical extension types --- arrow-schema/Cargo.toml | 10 +- arrow-schema/src/datatype.rs | 293 ------------------ arrow-schema/src/extension/canonical/bool8.rs | 72 +++++ .../extension/canonical/fixed_shape_tensor.rs | 257 +++++++++++++++ arrow-schema/src/extension/canonical/json.rs | 171 ++++++++++ arrow-schema/src/extension/canonical/mod.rs | 101 ++++++ .../src/extension/canonical/opaque.rs | 131 ++++++++ arrow-schema/src/extension/canonical/uuid.rs | 118 +++++++ .../canonical/variable_shape_tensor.rs | 186 +++++++++++ arrow-schema/src/extension/mod.rs | 250 +++++++++++++++ arrow-schema/src/field.rs | 192 +++++++++--- arrow-schema/src/lib.rs | 1 + parquet/Cargo.toml | 2 + parquet/src/arrow/schema/mod.rs | 16 +- 14 files changed, 1467 insertions(+), 333 deletions(-) create mode 100644 arrow-schema/src/extension/canonical/bool8.rs create mode 100644 arrow-schema/src/extension/canonical/fixed_shape_tensor.rs create mode 100644 arrow-schema/src/extension/canonical/json.rs create mode 100644 arrow-schema/src/extension/canonical/mod.rs create mode 100644 arrow-schema/src/extension/canonical/opaque.rs create mode 100644 arrow-schema/src/extension/canonical/uuid.rs create mode 100644 arrow-schema/src/extension/canonical/variable_shape_tensor.rs create mode 100644 arrow-schema/src/extension/mod.rs diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index 711543a18677..a90f086bd732 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -34,13 +34,19 @@ path = "src/lib.rs" bench = false [dependencies] -serde = { version = "1.0", default-features = false, features = ["derive", "std", "rc"], optional = true } +serde = { version = "1.0", default-features = false, features = [ + "derive", + "std", + "rc", +], optional = true } bitflags = { version = "2.0.0", default-features = false, optional = true } -serde_json = "1.0" +serde_json = { version = "1.0", optional = true } [features] +canonical-extension-types = ["dep:serde", "dep:serde_json"] # Enable ffi support ffi = ["bitflags"] +serde = ["dep:serde"] [package.metadata.docs.rs] features = ["ffi"] diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index b7a326f605f3..ff5832dfa68c 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -764,299 +764,6 @@ impl DataType { } } -/// The metadata key for the string name identifying the custom data type. -pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; - -/// The metadata key for a serialized representation of the ExtensionType -/// necessary to reconstruct the custom type. -pub const EXTENSION_TYPE_METADATA_KEY: &str = "ARROW:extension:metadata"; - -/// Extension types. -/// -/// -pub trait ExtensionType: Sized { - /// The name of this extension type. - const NAME: &'static str; - - /// The supported storage types of this extension type. - fn storage_types(&self) -> &[DataType]; - - /// The metadata type of this extension type. - type Metadata; - - /// Returns a reference to the metadata of this extension type, or `None` - /// if this extension type has no metadata. - fn metadata(&self) -> Option<&Self::Metadata>; - - /// Returns the serialized representation of the metadata of this extension - /// type, or `None` if this extension type has no metadata. - fn serialized_metadata(&self) -> Option; - - /// Deserialize this extension type from the serialized representation of the - /// metadata of this extension. An extension type that has no metadata should - /// expect `None` for for the serialized metadata. - fn from_serialized_metadata(serialized_metadata: Option<&str>) -> Option; -} - -pub(crate) trait ExtensionTypeExt: ExtensionType { - /// Returns `true` if the given data type is supported by this extension - /// type. - fn supports(&self, data_type: &DataType) -> bool { - self.storage_types().contains(data_type) - } - - /// Try to extract this extension type from the given [`Field`]. - /// - /// This function returns `None` if extension type - /// - information is missing - /// - name does not match - /// - metadata deserialization failed - /// - does not support the data type of this field - fn try_from_field(field: &Field) -> Option { - field - .metadata() - .get(EXTENSION_TYPE_NAME_KEY) - .and_then(|name| { - (name == ::NAME) - .then(|| { - Self::from_serialized_metadata( - field - .metadata() - .get(EXTENSION_TYPE_METADATA_KEY) - .map(String::as_str), - ) - }) - .flatten() - }) - .filter(|extension_type| extension_type.supports(field.data_type())) - } -} - -impl ExtensionTypeExt for T where T: ExtensionType {} - -/// Canonical extension types. -/// -/// The Arrow columnar format allows defining extension types so as to extend -/// standard Arrow data types with custom semantics. Often these semantics will -/// be specific to a system or application. However, it is beneficial to share -/// the definitions of well-known extension types so as to improve -/// interoperability between different systems integrating Arrow columnar data. -pub mod canonical_extension_types { - use serde_json::Value; - - use super::{DataType, ExtensionType}; - - /// Canonical extension types. - #[non_exhaustive] - #[derive(Debug, Clone, PartialEq)] - pub enum CanonicalExtensionTypes { - /// The extension type for 'JSON'. - Json(Json), - /// The extension type for `UUID`. - Uuid(Uuid), - } - - impl From for CanonicalExtensionTypes { - fn from(value: Json) -> Self { - CanonicalExtensionTypes::Json(value) - } - } - - impl From for CanonicalExtensionTypes { - fn from(value: Uuid) -> Self { - CanonicalExtensionTypes::Uuid(value) - } - } - - /// The extension type for `JSON`. - /// - /// Extension name: `arrow.json`. - /// - /// The storage type of this extension is `String` or `LargeString` or - /// `StringView`. Only UTF-8 encoded JSON as specified in [rfc8259](https://datatracker.ietf.org/doc/html/rfc8259) - /// is supported. - /// - /// This type does not have any parameters. - /// - /// Metadata is either an empty string or a JSON string with an empty - /// object. In the future, additional fields may be added, but they are not - /// required to interpret the array. - /// - /// - #[derive(Debug, Clone, PartialEq)] - pub struct Json(Value); - - impl Default for Json { - fn default() -> Self { - Self(Value::String("".to_owned())) - } - } - - impl ExtensionType for Json { - const NAME: &'static str = "arrow.json"; - - type Metadata = Value; - - fn storage_types(&self) -> &[DataType] { - &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] - } - - fn metadata(&self) -> Option<&Self::Metadata> { - Some(&self.0) - } - - fn serialized_metadata(&self) -> Option { - Some(self.0.to_string()) - } - - fn from_serialized_metadata(serialized_metadata: Option<&str>) -> Option { - serialized_metadata.and_then(|metadata| match metadata { - // Empty string - r#""""# => Some(Default::default()), - // Empty object - value => value - .parse::() - .ok() - .filter(|value| matches!(value.as_object(), Some(map) if map.is_empty())) - .map(Self), - }) - } - } - - /// The extension type for `UUID`. - /// - /// Extension name: `arrow.uuid`. - /// - /// The storage type of the extension is `FixedSizeBinary` with a length of - /// 16 bytes. - /// - /// Note: - /// A specific UUID version is not required or guaranteed. This extension - /// represents UUIDs as `FixedSizeBinary(16)` with big-endian notation and - /// does not interpret the bytes in any way. - /// - /// - #[derive(Debug, Default, Clone, Copy, PartialEq)] - pub struct Uuid; - - impl ExtensionType for Uuid { - const NAME: &'static str = "arrow.uuid"; - - type Metadata = (); - - fn storage_types(&self) -> &[DataType] { - &[DataType::FixedSizeBinary(16)] - } - - fn metadata(&self) -> Option<&Self::Metadata> { - None - } - - fn serialized_metadata(&self) -> Option { - None - } - - fn from_serialized_metadata(serialized_metadata: Option<&str>) -> Option { - serialized_metadata.is_none().then_some(Self) - } - } - - #[cfg(test)] - mod tests { - use std::collections::HashMap; - - use serde_json::Map; - - use crate::{ArrowError, Field, EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}; - - use super::*; - - #[test] - fn json() -> Result<(), ArrowError> { - let mut field = Field::new("", DataType::Utf8, false); - field.try_with_extension_type(Json::default())?; - assert_eq!( - field.metadata().get(EXTENSION_TYPE_METADATA_KEY), - Some(&r#""""#.to_owned()) - ); - assert!(field.extension_type::().is_some()); - - let mut field = Field::new("", DataType::LargeUtf8, false); - field.try_with_extension_type(Json(serde_json::Value::Object(Map::default())))?; - assert_eq!( - field.metadata().get(EXTENSION_TYPE_METADATA_KEY), - Some(&"{}".to_owned()) - ); - assert!(field.extension_type::().is_some()); - - let mut field = Field::new("", DataType::Utf8View, false); - field.try_with_extension_type(Json::default())?; - assert!(field.extension_type::().is_some()); - assert_eq!( - field.canonical_extension_type(), - Some(CanonicalExtensionTypes::Json(Json::default())) - ); - Ok(()) - } - - #[test] - #[should_panic(expected = "expected Utf8 or LargeUtf8 or Utf8View, found Boolean")] - fn json_bad_type() { - Field::new("", DataType::Boolean, false).with_extension_type(Json::default()); - } - - #[test] - fn json_bad_metadata() { - let field = Field::new("", DataType::Utf8, false).with_metadata(HashMap::from_iter([ - (EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned()), - (EXTENSION_TYPE_METADATA_KEY.to_owned(), "1234".to_owned()), - ])); - // This returns `None` now because this metadata is invalid. - assert!(field.extension_type::().is_none()); - } - - #[test] - fn json_missing_metadata() { - let field = Field::new("", DataType::LargeUtf8, false).with_metadata( - HashMap::from_iter([(EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned())]), - ); - // This returns `None` now because the metadata is missing. - assert!(field.extension_type::().is_none()); - } - - #[test] - fn uuid() -> Result<(), ArrowError> { - let mut field = Field::new("", DataType::FixedSizeBinary(16), false); - field.try_with_extension_type(Uuid)?; - assert!(field.extension_type::().is_some()); - assert_eq!( - field.canonical_extension_type(), - Some(CanonicalExtensionTypes::Uuid(Uuid)) - ); - Ok(()) - } - - #[test] - #[should_panic(expected = "expected FixedSizeBinary(16), found FixedSizeBinary(8)")] - fn uuid_bad_type() { - Field::new("", DataType::FixedSizeBinary(8), false).with_extension_type(Uuid); - } - - #[test] - fn uuid_with_metadata() { - // Add metadata that's not expected for uuid. - let field = Field::new("", DataType::FixedSizeBinary(16), false) - .with_metadata(HashMap::from_iter([( - EXTENSION_TYPE_METADATA_KEY.to_owned(), - "".to_owned(), - )])) - .with_extension_type(Uuid); - // This returns `None` now because `Uuid` expects no metadata. - assert!(field.extension_type::().is_none()); - } - } -} - /// The maximum precision for [DataType::Decimal128] values pub const DECIMAL128_MAX_PRECISION: u8 = 38; diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs new file mode 100644 index 000000000000..0272752f35a8 --- /dev/null +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! 8-bit Boolean +//! +//! + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `8-bit Boolean`. +/// +/// Extension name: `arrow.bool8`. +/// +/// The storage type of the extension is `Int8` where: +/// - false is denoted by the value 0. +/// - true can be specified using any non-zero value. Preferably 1. +/// +/// +#[derive(Debug, Default, Clone, Copy, PartialEq)] +pub struct Bool8; + +impl ExtensionType for Bool8 { + const NAME: &str = "arrow.bool8"; + + type Metadata = &'static str; + + fn metadata(&self) -> &Self::Metadata { + &"" + } + + fn serialize_metadata(&self) -> Option { + Some(String::default()) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + const ERR: &str = "Bool8 extension type expects an empty string as metadata"; + metadata.map_or_else( + || Err(ArrowError::InvalidArgumentError(ERR.to_owned())), + |value| match value { + "" => Ok(""), + _ => Err(ArrowError::InvalidArgumentError(ERR.to_owned())), + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::Int8 => Ok(()), + data_type => Err(ArrowError::InvalidArgumentError(format!( + "Bool8 data type mismatch, expected Int8, found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { + Self.supports_data_type(data_type).map(|_| Self) + } +} diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs new file mode 100644 index 000000000000..abbfe3f6978f --- /dev/null +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! FixedShapeTensor +//! +//! + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for fixed shape tensor. +/// +/// Extension name: `arrow.fixed_shape_tensor`. +/// +/// The storage type of the extension: `FixedSizeList` where: +/// - `value_type` is the data type of individual tensor elements. +/// - `list_size` is the product of all the elements in tensor shape. +/// +/// Extension type parameters: +/// - `value_type`: the Arrow data type of individual tensor elements. +/// - `shape`: the physical shape of the contained tensors as an array. +/// +/// Optional parameters describing the logical layout: +/// - `dim_names`: explicit names to tensor dimensions as an array. The +/// length of it should be equal to the shape length and equal to the +/// number of dimensions. +/// `dim_names` can be used if the dimensions have +/// well-known names and they map to the physical layout (row-major). +/// - `permutation`: indices of the desired ordering of the original +/// dimensions, defined as an array. +/// The indices contain a permutation of the values `[0, 1, .., N-1]` +/// where `N` is the number of dimensions. The permutation indicates +/// which dimension of the logical layout corresponds to which dimension +/// of the physical tensor (the i-th dimension of the logical view +/// corresponds to the dimension with number `permutations[i]` of the +/// physical tensor). +/// Permutation can be useful in case the logical order of the tensor is +/// a permutation of the physical order (row-major). +/// When logical and physical layout are equal, the permutation will +/// always be `([0, 1, .., N-1])` and can therefore be left out. +/// +/// Description of the serialization: +/// The metadata must be a valid JSON object including shape of the +/// contained tensors as an array with key `shape` plus optional +/// dimension names with keys `dim_names` and ordering of the +/// dimensions with key `permutation`. +/// Example: `{ "shape": [2, 5]}` +/// Example with `dim_names` metadata for NCHW ordered data: +/// `{ "shape": [100, 200, 500], "dim_names": ["C", "H", "W"]}` +/// Example of permuted 3-dimensional tensor: +/// `{ "shape": [100, 200, 500], "permutation": [2, 0, 1]}` +/// +/// This is the physical layout shape and the shape of the logical layout +/// would in this case be `[500, 100, 200]`. +/// +/// +#[derive(Debug, Clone, PartialEq)] +pub struct FixedShapeTensor { + /// The data type of individual tensor elements. + value_type: DataType, + + /// The metadata of this extension type. + metadata: FixedShapeTensorMetadata, +} + +impl FixedShapeTensor { + /// Returns a new fixed shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names or permutations are + /// invalid. + pub fn try_new( + _value_type: DataType, + _shape: impl IntoIterator, + _dimension_names: Option>>, + _permutations: Option>>, + ) -> Result { + todo!() + } + + /// Returns the value type of the individual tensor elements. + pub fn value_type(&self) -> &DataType { + &self.value_type + } + + /// Returns the product of all the elements in tensor shape. + pub fn list_size(&self) -> usize { + self.metadata.list_size() + } + + /// Returns the number of dimensions in this fixed shape tensor. + pub fn dimensions(&self) -> usize { + self.metadata.dimensions() + } + + /// Returns the names of the dimensions in this fixed shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.metadata.dimension_names() + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.metadata.permutations() + } +} + +/// Extension type metadata for [`FixedShapeTensor`]. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct FixedShapeTensorMetadata { + /// The physical shape of the contained tensors. + shape: Vec, + + /// Explicit names to tensor dimensions. + dim_names: Option>, + + /// Indices of the desired ordering of the original dimensions. + permutations: Option>, +} + +impl FixedShapeTensorMetadata { + /// Returns the product of all the elements in tensor shape. + pub fn list_size(&self) -> usize { + self.shape.iter().product() + } + + /// Returns the number of dimensions in this fixed shape tensor. + pub fn dimensions(&self) -> usize { + self.shape.len() + } + + /// Returns the names of the dimensions in this fixed shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.dim_names.as_ref().map(AsRef::as_ref) + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.permutations.as_ref().map(AsRef::as_ref) + } +} + +impl ExtensionType for FixedShapeTensor { + const NAME: &str = "arrow.fixed_shape_tensor"; + + type Metadata = FixedShapeTensorMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.metadata + } + + fn serialize_metadata(&self) -> Option { + Some(serde_json::to_string(&self.metadata).expect("metadata serialization")) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || { + Err(ArrowError::InvalidArgumentError( + "FixedShapeTensor extension types requires metadata".to_owned(), + )) + }, + |value| { + serde_json::from_str(value).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor metadata deserialization failed: {e}" + )) + }) + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + let expected = DataType::new_fixed_size_list( + self.value_type.clone(), + i32::try_from(self.list_size()).expect("overflow"), + false, + ); + data_type + .equals_datatype(&expected) + .then_some(()) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}" + )) + }) + } + + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { + match data_type { + DataType::FixedSizeList(field, list_size) if !field.is_nullable() => { + // Make sure the shape matches + let expected_size = i32::try_from(metadata.list_size()).expect("overflow"); + if *list_size != expected_size { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)" + ))); + } + // Make sure the dim names size is correct, if set. + if let Some(dim_names_size) = metadata.dimension_names().map(<[_]>::len) { + let expected_size = metadata.dimensions(); + if dim_names_size != expected_size { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor dimension names size mismatch, expected {expected_size}, found {dim_names_size}" + ))); + } + } + // Make sure the permutations are correct, if set. + if let Some(permutations) = metadata.permutations() { + let expected_size = metadata.dimensions(); + if permutations.len() != expected_size { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor permutations size mismatch, expected {expected_size}, found {}", + permutations.len() + ))); + } + // Check if the permutations are valid. + let mut permutations = permutations.to_vec(); + permutations.sort_unstable(); + let dimensions = metadata.dimensions(); + if (0..dimensions).zip(permutations).any(|(a, b)| a != b) { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" + ))); + } + } + + Ok(Self { + value_type: field.data_type().clone(), + metadata, + }) + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}" + ))), + } + } +} diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs new file mode 100644 index 000000000000..8e89d611c5ff --- /dev/null +++ b/arrow-schema/src/extension/canonical/json.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! JSON +//! +//! + +use serde_json::Value; + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `JSON`. +/// +/// Extension name: `arrow.json`. +/// +/// The storage type of this extension is `String` or `LargeString` or +/// `StringView`. Only UTF-8 encoded JSON as specified in [rfc8259](https://datatracker.ietf.org/doc/html/rfc8259) +/// is supported. +/// +/// This type does not have any parameters. +/// +/// Metadata is either an empty string or a JSON string with an empty +/// object. In the future, additional fields may be added, but they are not +/// required to interpret the array. +/// +/// +#[derive(Debug, Clone, Default, PartialEq)] +pub struct Json(JsonMetadata); + +/// Extension type metadata for [`Json`]. +#[derive(Debug, Clone, PartialEq)] +pub struct JsonMetadata(Value); + +impl Default for JsonMetadata { + fn default() -> Self { + Self(Value::String(Default::default())) + } +} + +impl ExtensionType for Json { + const NAME: &str = "arrow.json"; + + type Metadata = JsonMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.0 + } + + fn serialize_metadata(&self) -> Option { + Some(self.metadata().0.to_string()) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + const ERR: &str = "Json extension type metadata is either an empty string or a JSON string with an empty object"; + metadata + .map_or_else( + || Err(ArrowError::InvalidArgumentError(ERR.to_owned())), + |metadata| match metadata { + r#""""# => Ok(Value::String(Default::default())), + value => value + .parse::() + .ok() + .filter(|value| matches!(value.as_object(), Some(map) if map.is_empty())) + .ok_or_else(|| ArrowError::InvalidArgumentError(ERR.to_owned())), + }, + ) + .map(JsonMetadata) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(()), + data_type => Err(ArrowError::InvalidArgumentError(format!( + "Json data type mismatch, expected one of Utf8, LargeUtf8, Utf8View, found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { + let json = Self(metadata); + json.supports_data_type(data_type)?; + Ok(json) + } +} + +#[cfg(test)] +mod tests { + use serde_json::Map; + + use crate::{ + extension::{CanonicalExtensionType, EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn json() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::Utf8, false); + field.try_with_extension_type(Json::default())?; + assert_eq!( + field.metadata().get(EXTENSION_TYPE_METADATA_KEY), + Some(&r#""""#.to_owned()) + ); + assert!(field.try_extension_type::().is_ok()); + + let mut field = Field::new("", DataType::LargeUtf8, false); + field.try_with_extension_type(Json(JsonMetadata(serde_json::Value::Object( + Map::default(), + ))))?; + assert_eq!( + field.metadata().get(EXTENSION_TYPE_METADATA_KEY), + Some(&"{}".to_owned()) + ); + assert!(field.try_extension_type::().is_ok()); + + let mut field = Field::new("", DataType::Utf8View, false); + field.try_with_extension_type(Json::default())?; + assert!(field.try_extension_type::().is_ok()); + assert_eq!( + field.try_canonical_extension_type().unwrap(), + CanonicalExtensionType::Json(Json::default()) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "expected one of Utf8, LargeUtf8, Utf8View, found Null")] + fn json_bad_type() { + Field::new("", DataType::Null, false).with_extension_type(Json::default()); + } + + #[test] + fn json_bad_metadata() { + let field = Field::new("", DataType::Utf8, false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned()), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), "1234".to_owned()), + ] + .into_iter() + .collect(), + ); + // This returns `None` now because this metadata is invalid. + assert!(field.try_extension_type::().is_err()); + } + + #[test] + fn json_missing_metadata() { + let field = Field::new("", DataType::LargeUtf8, false).with_metadata( + [(EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned())] + .into_iter() + .collect(), + ); + // This returns `None` now because the metadata is missing. + assert!(field.try_extension_type::().is_err()); + } +} diff --git a/arrow-schema/src/extension/canonical/mod.rs b/arrow-schema/src/extension/canonical/mod.rs new file mode 100644 index 000000000000..778da9ef2c46 --- /dev/null +++ b/arrow-schema/src/extension/canonical/mod.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Canonical extension types. +//! +//! The Arrow columnar format allows defining extension types so as to extend +//! standard Arrow data types with custom semantics. Often these semantics will +//! be specific to a system or application. However, it is beneficial to share +//! the definitions of well-known extension types so as to improve +//! interoperability between different systems integrating Arrow columnar data. +//! +//! + +mod bool8; +pub use bool8::Bool8; +mod fixed_shape_tensor; +pub use fixed_shape_tensor::{FixedShapeTensor, FixedShapeTensorMetadata}; +mod json; +pub use json::{Json, JsonMetadata}; +mod opaque; +pub use opaque::{Opaque, OpaqueMetadata}; +mod uuid; +pub use uuid::Uuid; +mod variable_shape_tensor; +pub use variable_shape_tensor::{VariableShapeTensor, VariableShapeTensorMetadata}; + +/// Canonical extension types. +/// +/// +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq)] +pub enum CanonicalExtensionType { + /// The extension type for `FixedShapeTensor`. + /// + /// + FixedShapeTensor(FixedShapeTensor), + + /// The extension type for `VariableShapeTensor`. + /// + /// + VariableShapeTensor(VariableShapeTensor), + + /// The extension type for 'JSON'. + /// + /// + Json(Json), + + /// The extension type for `UUID`. + /// + /// + Uuid(Uuid), + + /// The extension type for `Opaque`. + /// + /// + Opaque(Opaque), +} + +impl From for CanonicalExtensionType { + fn from(value: FixedShapeTensor) -> Self { + CanonicalExtensionType::FixedShapeTensor(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: VariableShapeTensor) -> Self { + CanonicalExtensionType::VariableShapeTensor(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Json) -> Self { + CanonicalExtensionType::Json(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Uuid) -> Self { + CanonicalExtensionType::Uuid(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Opaque) -> Self { + CanonicalExtensionType::Opaque(value) + } +} diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs new file mode 100644 index 000000000000..283ac1e486d6 --- /dev/null +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Opaque +//! +//! + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `Opaque`. +/// +/// Extension name: `arrow.opaque`. +/// +/// Opaque represents a type that an Arrow-based system received from an +/// external (often non-Arrow) system, but that it cannot interpret. In this +/// case, it can pass on Opaque to its clients to at least show that a field +/// exists and preserve metadata about the type from the other system. +/// +/// The storage type of this extension is any type. If there is no underlying +/// data, the storage type should be Null. +#[derive(Debug, Clone, PartialEq)] +pub struct Opaque(OpaqueMetadata); + +impl Opaque { + /// Returns a new `Opaque` extension type. + pub fn new(type_name: impl Into, vendor_name: impl Into) -> Self { + Self(OpaqueMetadata::new(type_name, vendor_name)) + } + + /// Returns the name of the unknown type in the external system. + pub fn type_name(&self) -> &str { + &self.0.type_name() + } + + /// Returns the name of the external system. + pub fn vendor_name(&self) -> &str { + &self.0.vendor_name() + } +} + +impl From for Opaque { + fn from(value: OpaqueMetadata) -> Self { + Self(value) + } +} + +/// Extension type metadata for [`Opaque`]. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct OpaqueMetadata { + /// Name of the unknown type in the external system. + type_name: String, + + /// Name of the external system. + vendor_name: String, +} + +impl OpaqueMetadata { + /// Returns a new `OpaqueMetadata`. + pub fn new(type_name: impl Into, vendor_name: impl Into) -> Self { + OpaqueMetadata { + type_name: type_name.into(), + vendor_name: vendor_name.into(), + } + } + + /// Returns the name of the unknown type in the external system. + pub fn type_name(&self) -> &str { + &self.type_name + } + + /// Returns the name of the external system. + pub fn vendor_name(&self) -> &str { + &self.vendor_name + } +} + +impl ExtensionType for Opaque { + const NAME: &str = "arrow.opaque"; + + type Metadata = OpaqueMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.0 + } + + fn serialize_metadata(&self) -> Option { + Some(serde_json::to_string(self.metadata()).expect("metadata serialization")) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || { + Err(ArrowError::InvalidArgumentError( + "Opaque extension types requires metadata".to_owned(), + )) + }, + |value| { + serde_json::from_str(value).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Opaque metadata deserialization failed: {e}" + )) + }) + }, + ) + } + + fn supports_data_type(&self, _data_type: &DataType) -> Result<(), ArrowError> { + // Any type + Ok(()) + } + + fn try_new(_data_type: &DataType, metadata: Self::Metadata) -> Result { + Ok(Self::from(metadata)) + } +} diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs new file mode 100644 index 000000000000..206856265ae5 --- /dev/null +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UUID +//! +//! + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `UUID`. +/// +/// Extension name: `arrow.uuid`. +/// +/// The storage type of the extension is `FixedSizeBinary` with a length of +/// 16 bytes. +/// +/// Note: +/// A specific UUID version is not required or guaranteed. This extension +/// represents UUIDs as `FixedSizeBinary(16)` with big-endian notation and +/// does not interpret the bytes in any way. +/// +/// +#[derive(Debug, Default, Clone, Copy, PartialEq)] +pub struct Uuid; + +impl ExtensionType for Uuid { + const NAME: &str = "arrow.uuid"; + + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || Ok(()), + |_| { + Err(ArrowError::InvalidArgumentError( + "Uuid extension type expects no metadata".to_owned(), + )) + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::FixedSizeBinary(16) => Ok(()), + data_type => Err(ArrowError::InvalidArgumentError(format!( + "Uuid data type mismatch, expected FixedSizeBinary(16), found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { + Self.supports_data_type(data_type).map(|_| Self) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + extension::{CanonicalExtensionType, EXTENSION_TYPE_METADATA_KEY}, + Field, + }; + + use super::*; + + #[test] + fn uuid() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::FixedSizeBinary(16), false); + field.try_with_extension_type(Uuid)?; + assert!(field.try_extension_type::().is_ok()); + assert_eq!( + field.try_canonical_extension_type().unwrap(), + CanonicalExtensionType::Uuid(Uuid) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "expected FixedSizeBinary(16), found FixedSizeBinary(8)")] + fn uuid_bad_type() { + Field::new("", DataType::FixedSizeBinary(8), false).with_extension_type(Uuid); + } + + #[test] + fn uuid_with_metadata() { + // Add metadata that's not expected for uuid. + let field = Field::new("", DataType::FixedSizeBinary(16), false) + .with_extension_type(Uuid) + .with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned())] + .into_iter() + .collect(), + ); + // This returns an error now because `Uuid` expects no metadata. + assert!(field.try_extension_type::().is_err()); + } +} diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs new file mode 100644 index 000000000000..8730d6765715 --- /dev/null +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! VariableShapeTensor +//! +//! + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType, Field}; + +/// The extension type for `VariableShapeTensor`. +/// +/// +/// +/// +#[derive(Debug, Clone, PartialEq)] +pub struct VariableShapeTensor { + /// The data type of individual tensor elements. + value_type: DataType, + + /// The number of dimensions of the tensor. + dimensions: usize, + + /// The metadata of this extension type. + metadata: VariableShapeTensorMetadata, +} + +impl VariableShapeTensor { + /// Returns a new variable shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names or permutations are + /// invalid. + pub fn try_new( + _value_type: DataType, + _dimensions: usize, + _dimension_names: Option>>, + _permutations: Option>>, + ) -> Result { + todo!() + } + + /// Returns the value type of the individual tensor elements. + pub fn value_type(&self) -> &DataType { + &self.value_type + } + + /// Returns the number of dimensions in this variable shape tensor. + pub fn dimensions(&self) -> usize { + self.dimensions + } + + /// Returns the names of the dimensions in this variable shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.metadata.dimension_names() + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.metadata.permutations() + } +} + +/// Extension type metadata for [`VariableShapeTensor`]. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct VariableShapeTensorMetadata { + /// Explicit names to tensor dimensions. + dim_names: Option>, + + /// Indices of the desired ordering of the original dimensions. + permutations: Option>, +} + +impl VariableShapeTensorMetadata { + /// Returns the names of the dimensions in this variable shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.dim_names.as_ref().map(AsRef::as_ref) + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.permutations.as_ref().map(AsRef::as_ref) + } +} + +impl ExtensionType for VariableShapeTensor { + const NAME: &str = "arrow.variable_shape_tensor"; + + type Metadata = VariableShapeTensorMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.metadata + } + + fn serialize_metadata(&self) -> Option { + Some(serde_json::to_string(self.metadata()).expect("metadata serialization")) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || { + Err(ArrowError::InvalidArgumentError( + "VariableShapeTensor extension types requires metadata".to_owned(), + )) + }, + |value| { + serde_json::from_str(value).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor metadata deserialization failed: {e}" + )) + }) + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + let expected = DataType::Struct( + [ + Field::new_list( + "data", + Field::new_list_field(self.value_type.clone(), false), + false, + ), + Field::new( + "shape", + DataType::new_fixed_size_list( + DataType::Int32, + i32::try_from(self.dimensions()).expect("overflow"), + false, + ), + false, + ), + ] + .into_iter() + .map(Arc::new) + .collect(), + ); + data_type + .equals_datatype(&expected) + .then_some(()) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}" + )) + }) + } + + fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { + match data_type { + DataType::Struct(fields) + if fields.len() == 2 + && matches!(fields.find("data"), Some((0, _))) + && matches!(fields.find("shape"), Some((1, _))) => + { + let _data_field = &fields[0]; + let _shape_field = &fields[1]; + todo!() + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}" + ))), + } + } +} diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs new file mode 100644 index 000000000000..583334579229 --- /dev/null +++ b/arrow-schema/src/extension/mod.rs @@ -0,0 +1,250 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Extension types. + +#[cfg(feature = "canonical-extension-types")] +mod canonical; +#[cfg(feature = "canonical-extension-types")] +pub use canonical::*; + +use crate::{ArrowError, DataType}; + +/// The metadata key for the string name identifying an [`ExtensionType`]. +pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; + +/// The metadata key for a serialized representation of the [`ExtensionType`] +/// necessary to reconstruct the custom type. +pub const EXTENSION_TYPE_METADATA_KEY: &str = "ARROW:extension:metadata"; + +/// Extension types. +/// +/// User-defined “extension” types can be defined setting certain key value +/// pairs in the [`Field`] metadata structure. These extension keys are: +/// - [`EXTENSION_TYPE_NAME_KEY`] +/// - [`EXTENSION_TYPE_METADATA_KEY`] +/// +/// Canonical extension types support in this crate requires the +/// `canonical-extension-types` feature. +/// +/// Extension types may or may not use the [`EXTENSION_TYPE_METADATA_KEY`] +/// field. +/// +/// # Example +/// +/// The example below demonstrates how to implement this trait for a `Uuid` +/// type. Note this is not the canonical extension type for `Uuid`, which does +/// not include information about the `Uuid` version. +/// +/// ``` +/// # use arrow_schema::ArrowError; +/// # fn main() -> Result<(), ArrowError> { +/// use arrow_schema::{DataType, extension::ExtensionType, Field}; +/// use std::{fmt, str::FromStr}; +/// +/// /// The different Uuid versions. +/// #[derive(Clone, Copy, Debug, PartialEq)] +/// enum UuidVersion { +/// V1, +/// V2, +/// V3, +/// V4, +/// V5, +/// V6, +/// V7, +/// V8, +/// } +/// +/// // We'll use `Display` to serialize. +/// impl fmt::Display for UuidVersion { +/// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +/// write!( +/// f, +/// "{}", +/// match self { +/// Self::V1 => "V1", +/// Self::V2 => "V2", +/// Self::V3 => "V3", +/// Self::V4 => "V4", +/// Self::V5 => "V5", +/// Self::V6 => "V6", +/// Self::V7 => "V7", +/// Self::V8 => "V8", +/// } +/// ) +/// } +/// } +/// +/// // And `FromStr` to deserialize. +/// impl FromStr for UuidVersion { +/// type Err = ArrowError; +/// +/// fn from_str(s: &str) -> Result { +/// match s { +/// "V1" => Ok(Self::V1), +/// "V2" => Ok(Self::V2), +/// "V3" => Ok(Self::V3), +/// "V4" => Ok(Self::V4), +/// "V5" => Ok(Self::V5), +/// "V6" => Ok(Self::V6), +/// "V7" => Ok(Self::V7), +/// "V8" => Ok(Self::V8), +/// _ => Err(ArrowError::ParseError("Invalid UuidVersion".to_owned())), +/// } +/// } +/// } +/// +/// /// This is the extension type, not the container for Uuid values. It +/// /// stores the Uuid version (this is the metadata of this extension type). +/// #[derive(Clone, Copy, Debug, PartialEq)] +/// struct Uuid(UuidVersion); +/// +/// impl ExtensionType for Uuid { +/// // We use a namespace as suggested by the specification. +/// const NAME: &str = "myorg.example.uuid"; +/// +/// // The metadata type is the Uuid version. +/// type Metadata = UuidVersion; +/// +/// // We just return a reference to the Uuid version. +/// fn metadata(&self) -> &Self::Metadata { +/// &self.0 +/// } +/// +/// // We use the `Display` implementation to serialize the Uuid +/// // version. +/// fn serialize_metadata(&self) -> Option { +/// Some(self.0.to_string()) +/// } +/// +/// // We use the `FromStr` implementation to deserialize the Uuid +/// // version. +/// fn deserialize_metadata(metadata: Option<&str>) -> Result { +/// metadata.map_or_else( +/// || { +/// Err(ArrowError::InvalidArgumentError( +/// "Uuid extension type metadata missing".to_owned(), +/// )) +/// }, +/// str::parse, +/// ) +/// } +/// +/// // The only supported data type is `FixedSizeBinary(16)`. +/// fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { +/// match data_type { +/// DataType::FixedSizeBinary(16) => Ok(()), +/// data_type => Err(ArrowError::InvalidArgumentError(format!( +/// "Uuid data type mismatch, expected FixedSizeBinary(16), found {data_type}" +/// ))), +/// } +/// } +/// +/// // We should always check if the data type is supported before +/// // constructing the extension type. +/// fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { +/// let uuid = Self(metadata); +/// uuid.supports_data_type(data_type)?; +/// Ok(uuid) +/// } +/// } +/// +/// // We can now construct the extension type. +/// let uuid_v1 = Uuid(UuidVersion::V1); +/// +/// // And add it to a field. +/// let mut field = +/// Field::new("", DataType::FixedSizeBinary(16), false).with_extension_type(uuid_v1); +/// +/// // And extract it from this field. +/// assert_eq!(field.try_extension_type::()?, uuid_v1); +/// +/// // When we try to add this to a field with an unsupported data type we +/// // get an error. +/// let result = Field::new("", DataType::Null, false).try_with_extension_type(uuid_v1); +/// assert!(result.is_err()); +/// # Ok(()) } +/// ``` +/// +/// +/// +/// [`Field`]: crate::Field +pub trait ExtensionType: Sized { + /// The name identifying this extension type. + /// + /// This is the string value that is used for the + /// [`EXTENSION_TYPE_NAME_KEY`] in the [`Field::metadata`] of a [`Field`] + /// to identify this extension type. + /// + /// We recommend that you use a “namespace”-style prefix for extension + /// type names to minimize the possibility of conflicts with multiple Arrow + /// readers and writers in the same application. For example, use + /// `myorg.name_of_type` instead of simply `name_of_type`. + /// + /// Extension names beginning with `arrow.` are reserved for canonical + /// extension types, they should not be used for third-party extension + /// types. + /// + /// [`Field`]: crate::Field + /// [`Field::metadata`]: crate::Field::metadata + const NAME: &str; + + /// The metadata type of this extension type. + /// + /// If an extension type defines no metadata it should use `()` to indicate + /// this. + type Metadata; + + /// Returns a reference to the metadata of this extension type, or `&()` if + /// if this extension type defines no metadata (`Self::Metadata=()`). + fn metadata(&self) -> &Self::Metadata; + + /// Returns the serialized representation of the metadata of this extension + /// type, or `None` if this extension type defines no metadata + /// (`Self::Metadata=()`). + /// + /// This is string value that is used for the + /// [`EXTENSION_TYPE_METADATA_KEY`] in the [`Field::metadata`] of a + /// [`Field`]. + /// + /// [`Field`]: crate::Field + /// [`Field::metadata`]: crate::Field::metadata + fn serialize_metadata(&self) -> Option; + + /// Deserialize the metadata of this extension type from the serialized + /// representation of the metadata. An extension type that defines no + /// metadata should expect `None` for the serialized metadata and return + /// `Ok(())`. + /// + /// This function should return an error when + /// - expected metadata is missing (for extensions types with non-optional + /// metadata) + /// - unexpected metadata is set (for extension types without metadata) + /// - deserialization of metadata fails + fn deserialize_metadata(metadata: Option<&str>) -> Result; + + /// Returns `OK())` iff the given data type is supported by this extension + /// type. + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError>; + + /// Construct this extension type for a field with the given data type and + /// metadata. + /// + /// This should return an error if the given data type is not supported by + /// this extension type. + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; +} diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index f16e2f9bbc05..5d373496fae7 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::canonical_extension_types::{CanonicalExtensionTypes, Json, Uuid}; use crate::error::ArrowError; use std::cmp::Ordering; use std::collections::HashMap; @@ -23,10 +22,12 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::datatype::DataType; +#[cfg(feature = "canonical-extension-types")] +use crate::extension::CanonicalExtensionType; use crate::schema::SchemaBuilder; use crate::{ - ExtensionType, ExtensionTypeExt, Fields, UnionFields, UnionMode, EXTENSION_TYPE_METADATA_KEY, - EXTENSION_TYPE_NAME_KEY, + extension::{ExtensionType, EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Fields, UnionFields, UnionMode, }; /// A reference counted [`Field`] @@ -341,48 +342,139 @@ impl Field { self } - /// Returns the given [`ExtensionType`] of this [`Field`], if set. - /// Returns `None` if this field does not have this extension type. - pub fn extension_type(&self) -> Option { - E::try_from_field(self) + /// Returns the extension type name of this [`Field`], if set. + /// + /// This returns the value of [`EXTENSION_TYPE_NAME_KEY`], if set in + /// [`Field::metadata`]. If the key is missing, there is no extension type + /// name and this returns `None`. + /// + /// # Example + /// + /// ``` + /// # use arrow_schema::{DataType, extension::EXTENSION_TYPE_NAME_KEY, Field}; + /// + /// let field = Field::new("", DataType::Null, false); + /// assert_eq!(field.extension_type_name(), None); + /// + /// let field = Field::new("", DataType::Null, false).with_metadata( + /// [(EXTENSION_TYPE_NAME_KEY.to_owned(), "example".to_owned())] + /// .into_iter() + /// .collect(), + /// ); + /// assert_eq!(field.extension_type_name(), Some("example")); + /// ``` + pub fn extension_type_name(&self) -> Option<&str> { + self.metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_ref) } - /// Returns the [`CanonicalExtensionTypes`] of this [`Field`], if set. - pub fn canonical_extension_type(&self) -> Option { - Json::try_from_field(self) - .map(Into::into) - .or(Uuid::try_from_field(self).map(Into::into)) + /// Returns the extension type metadata of this [`Field`], if set. + /// + /// This returns the value of [`EXTENSION_TYPE_METADATA_KEY`], if set in + /// [`Field::metadata`]. If the key is missing, there is no extension type + /// metadata and this returns `None`. + /// + /// # Example + /// + /// ``` + /// # use arrow_schema::{DataType, extension::EXTENSION_TYPE_METADATA_KEY, Field}; + /// + /// let field = Field::new("", DataType::Null, false); + /// assert_eq!(field.extension_type_metadata(), None); + /// + /// let field = Field::new("", DataType::Null, false).with_metadata( + /// [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "example".to_owned())] + /// .into_iter() + /// .collect(), + /// ); + /// assert_eq!(field.extension_type_metadata(), Some("example")); + /// ``` + pub fn extension_type_metadata(&self) -> Option<&str> { + self.metadata() + .get(EXTENSION_TYPE_METADATA_KEY) + .map(String::as_ref) + } + + /// Returns an instance of the given [`ExtensionType`] of this [`Field`], + /// if set in the [`Field::metadata`]. + /// + /// # Error + /// + /// Returns an error if + /// - this field does not have the name of this extension type + /// ([`ExtensionType::NAME`]) in the [`Field::metadata`] (mismatch or + /// missing) + /// - the deserialization of the metadata + /// ([`ExtensionType::deserialize_metadata`]) fails + /// - the construction of the extension type ([`ExtensionType::try_new`]) + /// fail (for example when the [`Field::data_type`] is not supported by + /// the extension type ([`ExtensionType::supports_data_type`])) + pub fn try_extension_type(&self) -> Result { + // Check the extension name in the metadata + match self.extension_type_name() { + // It should match the name of the given extension type + Some(name) if name == E::NAME => { + // Deserialize the metadata and try to construct the extension + // type + E::deserialize_metadata(self.extension_type_metadata()) + .and_then(|metadata| E::try_new(self.data_type(), metadata)) + } + // Name mismatch + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Field extension type name mismatch, expected {}, found {name}", + E::NAME + ))), + // Name missing + None => Err(ArrowError::InvalidArgumentError( + "Field extension type name missing".to_owned(), + )), + } + } + + /// Returns an instance of the given [`ExtensionType`] of this [`Field`], + /// panics if this [`Field`] does not have this extension type. + /// + /// # Panic + /// + /// This calls [`Field::try_extension_type`] and panics when it returns an + /// error. + pub fn extension_type(&self) -> E { + self.try_extension_type::() + .unwrap_or_else(|e| panic!("{e}")) } /// Updates the metadata of this [`Field`] with the [`ExtensionType::NAME`] - /// and [`ExtensionType::metadata`] of the given [`ExtensionType`]. + /// and [`ExtensionType::metadata`] of the given [`ExtensionType`], if the + /// given extension type supports the [`Field::data_type`] of this field + /// ([`ExtensionType::supports_data_type`]). + /// + /// If the given extension type defines no metadata, a previously set + /// value of [`EXTENSION_TYPE_METADATA_KEY`] is cleared. /// /// # Error /// - /// This functions returns an error if the datatype of this field does not - /// match the storage type of the given extension type. + /// This functions returns an error if the data type of this field does not + /// match any of the supported storage types of the given extension type. pub fn try_with_extension_type( &mut self, extension_type: E, ) -> Result<(), ArrowError> { - if extension_type.supports(&self.data_type) { - // Insert the name - self.metadata - .insert(EXTENSION_TYPE_NAME_KEY.to_owned(), E::NAME.to_owned()); - // Insert the metadata, if any - if let Some(metadata) = extension_type.serialized_metadata() { - self.metadata - .insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata); - } - Ok(()) - } else { - Err(ArrowError::InvalidArgumentError(format!( - "storage type of extension type {} does not match field data type, expected {}, found {}", - ::NAME, - extension_type.storage_types().iter().map(ToString::to_string).collect::>().join(" or "), - self.data_type - ))) - } + // Make sure the data type of this field is supported + extension_type.supports_data_type(&self.data_type)?; + + self.metadata + .insert(EXTENSION_TYPE_NAME_KEY.to_owned(), E::NAME.to_owned()); + match extension_type.serialize_metadata() { + Some(metadata) => self + .metadata + .insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata), + // If this extension type has no metadata, we make sure to + // clear previously set metadata. + None => self.metadata.remove(EXTENSION_TYPE_METADATA_KEY), + }; + + Ok(()) } /// Updates the metadata of this [`Field`] with the [`ExtensionType::NAME`] @@ -390,14 +482,44 @@ impl Field { /// /// # Panics /// - /// This functions panics if the datatype of this field does match the - /// storage type of the given extension type. + /// This calls [`Field::try_with_extension_type`] and panics when it + /// returns an error. pub fn with_extension_type(mut self, extension_type: E) -> Self { self.try_with_extension_type(extension_type) .unwrap_or_else(|e| panic!("{e}")); self } + /// Returns the [`CanonicalExtensionType`] of this [`Field`], if set. + /// + /// # Error + /// + /// Returns an error if + /// - this field does have a canonical extension type (mismatch or missing) + /// - the canonical extension is not supported + /// - the construction of the extension type fails + #[cfg(feature = "canonical-extension-types")] + pub fn try_canonical_extension_type(&self) -> Result { + use crate::extension::{FixedShapeTensor, Json, Uuid}; + + // Canonical extension type names start with `arrow.` + match self.extension_type_name() { + // An extension type name with an `arrow.` prefix + Some(name) if name.starts_with("arrow.") => match name { + FixedShapeTensor::NAME => self.try_extension_type::().map(Into::into), + Json::NAME => self.try_extension_type::().map(Into::into), + Uuid::NAME => self.try_extension_type::().map(Into::into), + _ => Err(ArrowError::InvalidArgumentError(format!("Unsupported canonical extension type: {name}"))), + }, + // Name missing the expected prefix + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Field extension type name mismatch, expected a name with an `arrow.` prefix, found {name}" + ))), + // Name missing + None => Err(ArrowError::InvalidArgumentError("Field extension type name missing".to_owned())), + } + } + /// Indicates whether this [`Field`] supports null values. #[inline] pub const fn is_nullable(&self) -> bool { diff --git a/arrow-schema/src/lib.rs b/arrow-schema/src/lib.rs index d06382fbcdf7..a83e23e27592 100644 --- a/arrow-schema/src/lib.rs +++ b/arrow-schema/src/lib.rs @@ -25,6 +25,7 @@ use std::fmt::Display; mod datatype_parse; mod error; pub use error::*; +pub mod extension; mod field; pub use field::*; mod fields; diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 1d38e67a0f02..00f42d598be5 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -102,6 +102,8 @@ default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] lz4 = ["lz4_flex"] # Enable arrow reader/writer APIs arrow = ["base64", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-data", "arrow-schema", "arrow-select", "arrow-ipc"] +# Enable support for arrow canonical extension types +arrow-canonical-extension-types = ["arrow-schema?/canonical-extension-types"] # Enable CLI tools cli = ["json", "base64", "clap", "arrow-csv", "serde"] # Enable JSON APIs diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 8a15037825d0..82bcc8db6a8e 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -23,13 +23,14 @@ //! //! The interfaces for converting arrow schema to parquet schema is coming. -use arrow_schema::canonical_extension_types::Uuid; use base64::prelude::BASE64_STANDARD; use base64::Engine; use std::collections::HashMap; use std::sync::Arc; use arrow_ipc::writer; +#[cfg(feature = "arrow-canonical-extension-types")] +use arrow_schema::extension::Uuid; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ @@ -472,8 +473,16 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .with_length(*length) - // If set, map arrow uuid extension type to parquet uuid logical type. - .with_logical_type(field.extension_type::().map(|_| LogicalType::Uuid)) + .with_logical_type( + #[cfg(feature = "arrow-canonical-extension-types")] + // If set, map arrow uuid extension type to parquet uuid logical type. + field + .try_extension_type::() + .ok() + .map(|_| LogicalType::Uuid), + #[cfg(not(feature = "arrow-canonical-extension-types"))] + None, + ) .build() } DataType::BinaryView => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) @@ -1937,6 +1946,7 @@ mod tests { } #[test] + #[cfg(feature = "arrow-canonical-extension-types")] fn arrow_uuid_to_parquet_uuid() -> Result<()> { let arrow_schema = Schema::new(vec![Field::new( "uuid", From 0966a0fdfcf496819bceeb2a27941282d95a0cb3 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Sat, 18 Jan 2025 00:17:47 +0100 Subject: [PATCH 07/21] Add `Json` support to parquet, schema roundtrip not working yet --- parquet/src/arrow/schema/mod.rs | 74 +++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 82bcc8db6a8e..b796271887cf 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow_ipc::writer; #[cfg(feature = "arrow-canonical-extension-types")] -use arrow_schema::extension::Uuid; +use arrow_schema::extension::{Json, Uuid}; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ @@ -276,12 +276,26 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result ret.try_with_extension_type(Uuid)?, + LogicalType::Json => ret.try_with_extension_type(Json::default())?, + _ => {} + } + } + if !meta.is_empty() { ret.set_metadata(meta); } @@ -516,13 +530,35 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } DataType::Utf8 | DataType::LargeUtf8 => { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::String)) + .with_logical_type({ + #[cfg(feature = "arrow-canonical-extension-types")] + { + // Use the Json logical type if the canonical Json + // extension type is set on this field. + field + .try_extension_type::() + .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) + } + #[cfg(not(feature = "arrow-canonical-extension-types"))] + Some(LogicalType::String) + }) .with_repetition(repetition) .with_id(id) .build() } DataType::Utf8View => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::String)) + .with_logical_type({ + #[cfg(feature = "arrow-canonical-extension-types")] + { + // Use the Json logical type if the canonical Json + // extension type is set on this field. + field + .try_extension_type::() + .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) + } + #[cfg(not(feature = "arrow-canonical-extension-types"))] + Some(LogicalType::String) + }) .with_repetition(repetition) .with_id(id) .build(), @@ -1962,6 +1998,36 @@ mod tests { Some(LogicalType::Uuid) ); + let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; + dbg!(&arrow_schema); + + assert_eq!(arrow_schema.field(0).try_extension_type::()?, Uuid); + + Ok(()) + } + + #[test] + #[cfg(feature = "arrow-canonical-extension-types")] + fn arrow_json_to_parquet_json() -> Result<()> { + let arrow_schema = Schema::new(vec![ + Field::new("json", DataType::Utf8, false).with_extension_type(Json::default()) + ]); + + let parquet_schema = arrow_to_parquet_schema(&arrow_schema)?; + + assert_eq!( + parquet_schema.column(0).logical_type(), + Some(LogicalType::Json) + ); + + let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; + dbg!(&arrow_schema); + + assert_eq!( + arrow_schema.field(0).try_extension_type::()?, + Json::default() + ); + Ok(()) } } From f5c06b11bb831187af65016721c0550a08aecb32 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Sat, 18 Jan 2025 00:28:15 +0100 Subject: [PATCH 08/21] Fix some clippy warnings --- arrow-schema/src/extension/canonical/opaque.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index 283ac1e486d6..69b2231eea5b 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -45,12 +45,12 @@ impl Opaque { /// Returns the name of the unknown type in the external system. pub fn type_name(&self) -> &str { - &self.0.type_name() + self.0.type_name() } /// Returns the name of the external system. pub fn vendor_name(&self) -> &str { - &self.0.vendor_name() + self.0.vendor_name() } } From b602412993f6c5078946405104c82dabf89b17c5 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Sat, 18 Jan 2025 00:30:32 +0100 Subject: [PATCH 09/21] Add explicit lifetime, resolving elided lifetime to static in assoc const was added in 1.81 --- arrow-schema/src/extension/canonical/bool8.rs | 2 +- arrow-schema/src/extension/canonical/fixed_shape_tensor.rs | 2 +- arrow-schema/src/extension/canonical/json.rs | 2 +- arrow-schema/src/extension/canonical/opaque.rs | 2 +- arrow-schema/src/extension/canonical/uuid.rs | 2 +- arrow-schema/src/extension/canonical/variable_shape_tensor.rs | 2 +- arrow-schema/src/extension/mod.rs | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index 0272752f35a8..d53a3bb875b1 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -34,7 +34,7 @@ use crate::{extension::ExtensionType, ArrowError, DataType}; pub struct Bool8; impl ExtensionType for Bool8 { - const NAME: &str = "arrow.bool8"; + const NAME: &'static str = "arrow.bool8"; type Metadata = &'static str; diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index abbfe3f6978f..faeb64258cbb 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -160,7 +160,7 @@ impl FixedShapeTensorMetadata { } impl ExtensionType for FixedShapeTensor { - const NAME: &str = "arrow.fixed_shape_tensor"; + const NAME: &'static str = "arrow.fixed_shape_tensor"; type Metadata = FixedShapeTensorMetadata; diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 8e89d611c5ff..c3131846d1e5 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -52,7 +52,7 @@ impl Default for JsonMetadata { } impl ExtensionType for Json { - const NAME: &str = "arrow.json"; + const NAME: &'static str = "arrow.json"; type Metadata = JsonMetadata; diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index 69b2231eea5b..38ba6d5ea7aa 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -91,7 +91,7 @@ impl OpaqueMetadata { } impl ExtensionType for Opaque { - const NAME: &str = "arrow.opaque"; + const NAME: &'static str = "arrow.opaque"; type Metadata = OpaqueMetadata; diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs index 206856265ae5..7190efeeb38b 100644 --- a/arrow-schema/src/extension/canonical/uuid.rs +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -38,7 +38,7 @@ use crate::{extension::ExtensionType, ArrowError, DataType}; pub struct Uuid; impl ExtensionType for Uuid { - const NAME: &str = "arrow.uuid"; + const NAME: &'static str = "arrow.uuid"; type Metadata = (); diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index 8730d6765715..9473893797a2 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -106,7 +106,7 @@ impl VariableShapeTensorMetadata { } impl ExtensionType for VariableShapeTensor { - const NAME: &str = "arrow.variable_shape_tensor"; + const NAME: &'static str = "arrow.variable_shape_tensor"; type Metadata = VariableShapeTensorMetadata; diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index 583334579229..76bdff33545a 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -115,7 +115,7 @@ pub const EXTENSION_TYPE_METADATA_KEY: &str = "ARROW:extension:metadata"; /// /// impl ExtensionType for Uuid { /// // We use a namespace as suggested by the specification. -/// const NAME: &str = "myorg.example.uuid"; +/// const NAME: &'static str = "myorg.example.uuid"; /// /// // The metadata type is the Uuid version. /// type Metadata = UuidVersion; @@ -201,7 +201,7 @@ pub trait ExtensionType: Sized { /// /// [`Field`]: crate::Field /// [`Field::metadata`]: crate::Field::metadata - const NAME: &str; + const NAME: &'static str; /// The metadata type of this extension type. /// From 81594d92f2dca09ce1b0f65c1f80fed21f7c1139 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Mon, 20 Jan 2025 12:52:09 +0100 Subject: [PATCH 10/21] Replace use of deprecated method, mark roundtrip as todo --- parquet/src/arrow/schema/mod.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 215936bc0b10..d7748537cb6f 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -2222,17 +2222,16 @@ mod tests { ) .with_extension_type(Uuid)]); - let parquet_schema = arrow_to_parquet_schema(&arrow_schema)?; + let parquet_schema = ArrowSchemaConverter::new().convert(&arrow_schema)?; assert_eq!( parquet_schema.column(0).logical_type(), Some(LogicalType::Uuid) ); - let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; - dbg!(&arrow_schema); - - assert_eq!(arrow_schema.field(0).try_extension_type::()?, Uuid); + // TODO: roundtrip + // let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; + // assert_eq!(arrow_schema.field(0).try_extension_type::()?, Uuid); Ok(()) } @@ -2244,20 +2243,19 @@ mod tests { Field::new("json", DataType::Utf8, false).with_extension_type(Json::default()) ]); - let parquet_schema = arrow_to_parquet_schema(&arrow_schema)?; + let parquet_schema = ArrowSchemaConverter::new().convert(&arrow_schema)?; assert_eq!( parquet_schema.column(0).logical_type(), Some(LogicalType::Json) ); - let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; - dbg!(&arrow_schema); - - assert_eq!( - arrow_schema.field(0).try_extension_type::()?, - Json::default() - ); + // TODO: roundtrip + // let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; + // assert_eq!( + // arrow_schema.field(0).try_extension_type::()?, + // Json::default() + // ); Ok(()) } From bb7c86a8fff798e6d7bbd0c7c4471eebfcc2d016 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Mon, 20 Jan 2025 17:44:28 +0100 Subject: [PATCH 11/21] Add more tests and missing impls --- arrow-schema/src/extension/canonical/bool8.rs | 71 ++++ .../extension/canonical/fixed_shape_tensor.rs | 268 ++++++++++-- arrow-schema/src/extension/canonical/json.rs | 48 ++- arrow-schema/src/extension/canonical/mod.rs | 41 ++ .../src/extension/canonical/opaque.rs | 70 ++++ arrow-schema/src/extension/canonical/uuid.rs | 42 +- .../canonical/variable_shape_tensor.rs | 396 +++++++++++++++++- arrow-schema/src/extension/mod.rs | 2 + arrow-schema/src/field.rs | 19 +- 9 files changed, 851 insertions(+), 106 deletions(-) diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index d53a3bb875b1..acad27d35030 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -70,3 +70,74 @@ impl ExtensionType for Bool8 { Self.supports_data_type(data_type).map(|_| Self) } } + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical-extension-types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::Int8, false); + field.try_with_extension_type(Bool8)?; + field.try_extension_type::()?; + #[cfg(feature = "canonical-extension-types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::Bool8(Bool8) + ); + + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "expected Int8, found Boolean")] + fn invalid_type() { + Field::new("", DataType::Boolean, false).with_extension_type(Bool8); + } + + #[test] + #[should_panic(expected = "Bool8 extension type expects an empty string as metadata")] + fn missing_metadata() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [(EXTENSION_TYPE_NAME_KEY.to_owned(), Bool8::NAME.to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "Bool8 extension type expects an empty string as metadata")] + fn invalid_metadata() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Bool8::NAME.to_owned()), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + "non-empty".to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } +} diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index faeb64258cbb..f1ce85b54a57 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -86,12 +86,18 @@ impl FixedShapeTensor { /// Return an error if the provided dimension names or permutations are /// invalid. pub fn try_new( - _value_type: DataType, - _shape: impl IntoIterator, - _dimension_names: Option>>, - _permutations: Option>>, + value_type: DataType, + shape: impl IntoIterator, + dimension_names: Option>, + permutations: Option>, ) -> Result { - todo!() + // TODO: are all data types are suitable as value type? + FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| { + Self { + value_type, + metadata, + } + }) } /// Returns the value type of the individual tensor elements. @@ -136,6 +142,58 @@ pub struct FixedShapeTensorMetadata { } impl FixedShapeTensorMetadata { + /// Returns metadata for a fixed shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names or permutations are + /// invalid. + pub fn try_new( + shape: impl IntoIterator, + dimension_names: Option>, + permutations: Option>, + ) -> Result { + let shape = shape.into_iter().collect::>(); + let dimensions = shape.len(); + + let dim_names = dimension_names.map(|dimension_names| { + if dimension_names.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len() + ))) + } else { + Ok(dimension_names) + } + }).transpose()?; + + let permutations = permutations + .map(|permutations| { + if permutations.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}", + permutations.len() + ))) + } else { + let mut sorted_permutations = permutations.clone(); + sorted_permutations.sort_unstable(); + if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" + ))) + } else { + Ok(permutations) + } + } + }) + .transpose()?; + + Ok(Self { + shape, + dim_names, + permutations, + }) + } + /// Returns the product of all the elements in tensor shape. pub fn list_size(&self) -> usize { self.shape.iter().product() @@ -208,46 +266,24 @@ impl ExtensionType for FixedShapeTensor { fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { match data_type { DataType::FixedSizeList(field, list_size) if !field.is_nullable() => { - // Make sure the shape matches + // Make sure the metadata is valid. + let metadata = FixedShapeTensorMetadata::try_new( + metadata.shape, + metadata.dim_names, + metadata.permutations, + )?; + // Make sure it is compatible with this data type. let expected_size = i32::try_from(metadata.list_size()).expect("overflow"); if *list_size != expected_size { - return Err(ArrowError::InvalidArgumentError(format!( + Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)" - ))); - } - // Make sure the dim names size is correct, if set. - if let Some(dim_names_size) = metadata.dimension_names().map(<[_]>::len) { - let expected_size = metadata.dimensions(); - if dim_names_size != expected_size { - return Err(ArrowError::InvalidArgumentError(format!( - "FixedShapeTensor dimension names size mismatch, expected {expected_size}, found {dim_names_size}" - ))); - } - } - // Make sure the permutations are correct, if set. - if let Some(permutations) = metadata.permutations() { - let expected_size = metadata.dimensions(); - if permutations.len() != expected_size { - return Err(ArrowError::InvalidArgumentError(format!( - "FixedShapeTensor permutations size mismatch, expected {expected_size}, found {}", - permutations.len() - ))); - } - // Check if the permutations are valid. - let mut permutations = permutations.to_vec(); - permutations.sort_unstable(); - let dimensions = metadata.dimensions(); - if (0..dimensions).zip(permutations).any(|(a, b)| a != b) { - return Err(ArrowError::InvalidArgumentError(format!( - "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" - ))); - } + ))) + } else { + Ok(Self { + value_type: field.data_type().clone(), + metadata, + }) } - - Ok(Self { - value_type: field.data_type().clone(), - metadata, - }) } data_type => Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}" @@ -255,3 +291,153 @@ impl ExtensionType for FixedShapeTensor { } } } + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical-extension-types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let fixed_shape_tensor = FixedShapeTensor::try_new( + DataType::Float32, + [100, 200, 500], + Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]), + Some(vec![2, 0, 1]), + )?; + let mut field = Field::new_fixed_size_list( + "", + Field::new("", DataType::Float32, false), + i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), + false, + ); + field.try_with_extension_type(fixed_shape_tensor.clone())?; + assert_eq!( + field.try_extension_type::()?, + fixed_shape_tensor + ); + #[cfg(feature = "canonical-extension-types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = + Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) + .with_metadata( + [( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "shape": [100, 200, 500], }"#.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")] + fn invalid_type() { + let fixed_shape_tensor = + FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap(); + let field = Field::new_fixed_size_list( + "", + Field::new("", DataType::Float32, false), + i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), + false, + ); + field.with_extension_type(fixed_shape_tensor); + } + + #[test] + #[should_panic(expected = "FixedShapeTensor extension types requires metadata")] + fn missing_metadata() { + let field = + Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) + .with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + FixedShapeTensor::NAME.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`" + )] + fn invalid_metadata() { + let fixed_shape_tensor = + FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap(); + let field = Field::new_fixed_size_list( + "", + Field::new("", DataType::Float32, false), + i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), + false, + ) + .with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + FixedShapeTensor::NAME.to_owned(), + ), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "not-shape": [] }"#.to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2" + )] + fn invalid_metadata_dimension_names() { + FixedShapeTensor::try_new( + DataType::Float32, + [100, 200, 500], + Some(vec!["a".to_owned(), "b".to_owned()]), + None, + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")] + fn invalid_metadata_permutations_len() { + FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0])) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3" + )] + fn invalid_metadata_permutations_values() { + FixedShapeTensor::try_new( + DataType::Float32, + [100, 200, 500], + None, + Some(vec![4, 3, 2]), + ) + .unwrap(); + } +} diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index c3131846d1e5..693afc2b48ef 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -99,24 +99,26 @@ impl ExtensionType for Json { #[cfg(test)] mod tests { - use serde_json::Map; - + #[cfg(feature = "canonical-extension-types")] + use crate::extension::CanonicalExtensionType; use crate::{ - extension::{CanonicalExtensionType, EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, Field, }; + use serde_json::Map; + use super::*; #[test] - fn json() -> Result<(), ArrowError> { + fn valid() -> Result<(), ArrowError> { let mut field = Field::new("", DataType::Utf8, false); field.try_with_extension_type(Json::default())?; assert_eq!( field.metadata().get(EXTENSION_TYPE_METADATA_KEY), Some(&r#""""#.to_owned()) ); - assert!(field.try_extension_type::().is_ok()); + field.try_extension_type::()?; let mut field = Field::new("", DataType::LargeUtf8, false); field.try_with_extension_type(Json(JsonMetadata(serde_json::Value::Object( @@ -126,26 +128,41 @@ mod tests { field.metadata().get(EXTENSION_TYPE_METADATA_KEY), Some(&"{}".to_owned()) ); - assert!(field.try_extension_type::().is_ok()); + field.try_extension_type::()?; let mut field = Field::new("", DataType::Utf8View, false); field.try_with_extension_type(Json::default())?; - assert!(field.try_extension_type::().is_ok()); + field.try_extension_type::()?; + #[cfg(feature = "canonical-extension-types")] assert_eq!( - field.try_canonical_extension_type().unwrap(), + field.try_canonical_extension_type()?, CanonicalExtensionType::Json(Json::default()) ); Ok(()) } + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + #[test] #[should_panic(expected = "expected one of Utf8, LargeUtf8, Utf8View, found Null")] - fn json_bad_type() { + fn invalid_type() { Field::new("", DataType::Null, false).with_extension_type(Json::default()); } #[test] - fn json_bad_metadata() { + #[should_panic( + expected = "Json extension type metadata is either an empty string or a JSON string with an empty object" + )] + fn invalid_metadata() { let field = Field::new("", DataType::Utf8, false).with_metadata( [ (EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned()), @@ -154,18 +171,19 @@ mod tests { .into_iter() .collect(), ); - // This returns `None` now because this metadata is invalid. - assert!(field.try_extension_type::().is_err()); + field.extension_type::(); } #[test] - fn json_missing_metadata() { + #[should_panic( + expected = "Json extension type metadata is either an empty string or a JSON string with an empty object" + )] + fn missing_metadata() { let field = Field::new("", DataType::LargeUtf8, false).with_metadata( [(EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned())] .into_iter() .collect(), ); - // This returns `None` now because the metadata is missing. - assert!(field.try_extension_type::().is_err()); + field.extension_type::(); } } diff --git a/arrow-schema/src/extension/canonical/mod.rs b/arrow-schema/src/extension/canonical/mod.rs index 778da9ef2c46..3d66299ca885 100644 --- a/arrow-schema/src/extension/canonical/mod.rs +++ b/arrow-schema/src/extension/canonical/mod.rs @@ -38,6 +38,10 @@ pub use uuid::Uuid; mod variable_shape_tensor; pub use variable_shape_tensor::{VariableShapeTensor, VariableShapeTensorMetadata}; +use crate::{ArrowError, Field}; + +use super::ExtensionType; + /// Canonical extension types. /// /// @@ -68,6 +72,37 @@ pub enum CanonicalExtensionType { /// /// Opaque(Opaque), + + /// The extension type for `Bool8`. + /// + /// + Bool8(Bool8), +} + +impl TryFrom<&Field> for CanonicalExtensionType { + type Error = ArrowError; + + fn try_from(value: &Field) -> Result { + // Canonical extension type names start with `arrow.` + match value.extension_type_name() { + // An extension type name with an `arrow.` prefix + Some(name) if name.starts_with("arrow.") => match name { + FixedShapeTensor::NAME => value.try_extension_type::().map(Into::into), + VariableShapeTensor::NAME => value.try_extension_type::().map(Into::into), + Json::NAME => value.try_extension_type::().map(Into::into), + Uuid::NAME => value.try_extension_type::().map(Into::into), + Opaque::NAME => value.try_extension_type::().map(Into::into), + Bool8::NAME => value.try_extension_type::().map(Into::into), + _ => Err(ArrowError::InvalidArgumentError(format!("Unsupported canonical extension type: {name}"))), + }, + // Name missing the expected prefix + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Field extension type name mismatch, expected a name with an `arrow.` prefix, found {name}" + ))), + // Name missing + None => Err(ArrowError::InvalidArgumentError("Field extension type name missing".to_owned())), + } + } } impl From for CanonicalExtensionType { @@ -99,3 +134,9 @@ impl From for CanonicalExtensionType { CanonicalExtensionType::Opaque(value) } } + +impl From for CanonicalExtensionType { + fn from(value: Bool8) -> Self { + CanonicalExtensionType::Bool8(value) + } +} diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index 38ba6d5ea7aa..0d1baeb395a0 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -129,3 +129,73 @@ impl ExtensionType for Opaque { Ok(Self::from(metadata)) } } + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical-extension-types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let opaque = Opaque::new("name", "vendor"); + let mut field = Field::new("", DataType::Null, false); + field.try_with_extension_type(opaque.clone())?; + assert_eq!(field.try_extension_type::()?, opaque); + #[cfg(feature = "canonical-extension-types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::Opaque(opaque) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::Null, false).with_metadata( + [( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "type_name": "type", "vendor_name": "vendor" }"#.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "Opaque extension types requires metadata")] + fn missing_metadata() { + let field = Field::new("", DataType::Null, false).with_metadata( + [(EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "Opaque metadata deserialization failed: missing field `vendor_name`" + )] + fn invalid_metadata() { + let field = Field::new("", DataType::Null, false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned()), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "type_name": "no-vendor" }"#.to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } +} diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs index 7190efeeb38b..c66581b14f74 100644 --- a/arrow-schema/src/extension/canonical/uuid.rs +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -77,42 +77,52 @@ impl ExtensionType for Uuid { #[cfg(test)] mod tests { + #[cfg(feature = "canonical-extension-types")] + use crate::extension::CanonicalExtensionType; use crate::{ - extension::{CanonicalExtensionType, EXTENSION_TYPE_METADATA_KEY}, + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, Field, }; use super::*; #[test] - fn uuid() -> Result<(), ArrowError> { + fn valid() -> Result<(), ArrowError> { let mut field = Field::new("", DataType::FixedSizeBinary(16), false); field.try_with_extension_type(Uuid)?; - assert!(field.try_extension_type::().is_ok()); + field.try_extension_type::()?; + #[cfg(feature = "canonical-extension-types")] assert_eq!( - field.try_canonical_extension_type().unwrap(), + field.try_canonical_extension_type()?, CanonicalExtensionType::Uuid(Uuid) ); Ok(()) } + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::FixedSizeBinary(16), false); + field.extension_type::(); + } + #[test] #[should_panic(expected = "expected FixedSizeBinary(16), found FixedSizeBinary(8)")] - fn uuid_bad_type() { + fn invalid_type() { Field::new("", DataType::FixedSizeBinary(8), false).with_extension_type(Uuid); } #[test] - fn uuid_with_metadata() { - // Add metadata that's not expected for uuid. - let field = Field::new("", DataType::FixedSizeBinary(16), false) - .with_extension_type(Uuid) - .with_metadata( - [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned())] - .into_iter() - .collect(), - ); - // This returns an error now because `Uuid` expects no metadata. - assert!(field.try_extension_type::().is_err()); + #[should_panic(expected = "Uuid extension type expects no metadata")] + fn with_metadata() { + let field = Field::new("", DataType::FixedSizeBinary(16), false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Uuid::NAME.to_owned()), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned()), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); } } diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index 9473893797a2..d71a568bea6f 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -19,15 +19,53 @@ //! //! -use std::sync::Arc; - use serde::{Deserialize, Serialize}; use crate::{extension::ExtensionType, ArrowError, DataType, Field}; /// The extension type for `VariableShapeTensor`. /// +/// Extension name: `arrow.variable_shape_tensor`. +/// +/// The storage type of the extension is: StructArray where struct is composed +/// of data and shape fields describing a single tensor per row: +/// - `data` is a List holding tensor elements (each list element is a single +/// tensor). The List’s value type is the value type of the tensor, such as +/// an integer or floating-point type. +/// - `shape` is a FixedSizeList[ndim] of the tensor shape where th +/// size of the list ndim is equal to the number of dimensions of the tensor. /// +/// Extension type parameters: +/// `value_type`: the Arrow data type of individual tensor elements. +/// +/// Optional parameters describing the logical layout: +/// - `dim_names`: explicit names to tensor dimensions as an array. The length +/// of it should be equal to the shape length and equal to the number of +/// dimensions. +/// `dim_names` can be used if the dimensions have well-known names and they +/// map to the physical layout (row-major). +/// - `permutation`: indices of the desired ordering of the original +/// dimensions, defined as an array. +/// The indices contain a permutation of the values `[0, 1, .., N-1]` where +/// `N` is the number of dimensions. The permutation indicates which +/// dimension of the logical layout corresponds to which dimension of the +/// physical tensor (the i-th dimension of the logical view corresponds to +/// the dimension with number `permutations[i]` of the physical tensor). +/// Permutation can be useful in case the logical order of the tensor is a +/// permutation of the physical order (row-major). +/// When logical and physical layout are equal, the permutation will always +/// be (`[0, 1, .., N-1]`) and can therefore be left out. +/// - `uniform_shape`: sizes of individual tensor’s dimensions which are +/// guaranteed to stay constant in uniform dimensions and can vary in non- +/// uniform dimensions. This holds over all tensors in the array. Sizes in +/// uniform dimensions are represented with int32 values, while sizes of the +/// non-uniform dimensions are not known in advance and are represented with +/// null. If `uniform_shape` is not provided it is assumed that all +/// dimensions are non-uniform. An array containing a tensor with shape (2, +/// 3, 4) and whose first and last dimensions are uniform would have +/// `uniform_shape` (2, null, 4). This allows for interpreting the tensor +/// correctly without accounting for uniform dimensions while still +/// permitting optional optimizations that take advantage of the uniformity. /// /// #[derive(Debug, Clone, PartialEq)] @@ -47,15 +85,27 @@ impl VariableShapeTensor { /// /// # Error /// - /// Return an error if the provided dimension names or permutations are - /// invalid. + /// Return an error if the provided dimension names, permutations or + /// uniform shapes are invalid. pub fn try_new( - _value_type: DataType, - _dimensions: usize, - _dimension_names: Option>>, - _permutations: Option>>, + value_type: DataType, + dimensions: usize, + dimension_names: Option>, + permutations: Option>, + uniform_shapes: Option>>, ) -> Result { - todo!() + // TODO: are all data types are suitable as value type? + VariableShapeTensorMetadata::try_new( + dimensions, + dimension_names, + permutations, + uniform_shapes, + ) + .map(|metadata| Self { + value_type, + dimensions, + metadata, + }) } /// Returns the value type of the individual tensor elements. @@ -79,6 +129,13 @@ impl VariableShapeTensor { pub fn permutations(&self) -> Option<&[usize]> { self.metadata.permutations() } + + /// Returns sizes of individual tensor’s dimensions which are guaranteed + /// to stay constant in uniform dimensions and can vary in non-uniform + /// dimensions. + pub fn uniform_shapes(&self) -> Option<&[Option]> { + self.metadata.uniform_shapes() + } } /// Extension type metadata for [`VariableShapeTensor`]. @@ -89,20 +146,94 @@ pub struct VariableShapeTensorMetadata { /// Indices of the desired ordering of the original dimensions. permutations: Option>, + + /// Sizes of individual tensor’s dimensions which are guaranteed to stay + /// constant in uniform dimensions and can vary in non-uniform dimensions. + uniform_shape: Option>>, } impl VariableShapeTensorMetadata { + /// Returns metadata for a variable shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names, permutations or + /// uniform shapes are invalid. + pub fn try_new( + dimensions: usize, + dimension_names: Option>, + permutations: Option>, + uniform_shapes: Option>>, + ) -> Result { + let dim_names = dimension_names.map(|dimension_names| { + if dimension_names.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len() + ))) + } else { + Ok(dimension_names) + } + }).transpose()?; + + let permutations = permutations + .map(|permutations| { + if permutations.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}", + permutations.len() + ))) + } else { + let mut sorted_permutations = permutations.clone(); + sorted_permutations.sort_unstable(); + if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" + ))) + } else { + Ok(permutations) + } + } + }) + .transpose()?; + + let uniform_shape = uniform_shapes + .map(|uniform_shapes| { + if uniform_shapes.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}", + uniform_shapes.len() + ))) + } else { + Ok(uniform_shapes) + } + }) + .transpose()?; + + Ok(Self { + dim_names, + permutations, + uniform_shape, + }) + } + /// Returns the names of the dimensions in this variable shape tensor, if /// set. pub fn dimension_names(&self) -> Option<&[String]> { self.dim_names.as_ref().map(AsRef::as_ref) } - /// Returns the indices of the desired ordering of the original - /// dimensions, if set. + /// Returns the indices of the desired ordering of the original dimensions, + /// if set. pub fn permutations(&self) -> Option<&[usize]> { self.permutations.as_ref().map(AsRef::as_ref) } + + /// Returns sizes of individual tensor’s dimensions which are guaranteed + /// to stay constant in uniform dimensions and can vary in non-uniform + /// dimensions. + pub fn uniform_shapes(&self) -> Option<&[Option]> { + self.uniform_shape.as_ref().map(AsRef::as_ref) + } } impl ExtensionType for VariableShapeTensor { @@ -154,7 +285,6 @@ impl ExtensionType for VariableShapeTensor { ), ] .into_iter() - .map(Arc::new) .collect(), ); data_type @@ -167,16 +297,37 @@ impl ExtensionType for VariableShapeTensor { }) } - fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { match data_type { DataType::Struct(fields) if fields.len() == 2 && matches!(fields.find("data"), Some((0, _))) && matches!(fields.find("shape"), Some((1, _))) => { - let _data_field = &fields[0]; - let _shape_field = &fields[1]; - todo!() + let shape_field = &fields[1]; + match shape_field.data_type() { + DataType::FixedSizeList(_, list_size) => { + let dimensions = usize::try_from(*list_size).expect("conversion failed"); + // Make sure the metadata is valid. + let metadata = VariableShapeTensorMetadata::try_new(dimensions, metadata.dim_names, metadata.permutations, metadata.uniform_shape)?; + let data_field = &fields[0]; + match data_field.data_type() { + DataType::List(field) => { + Ok(Self { + value_type: field.data_type().clone(), + dimensions, + metadata + }) + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}" + ))), + } + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}" + ))), + } } data_type => Err(ArrowError::InvalidArgumentError(format!( "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}" @@ -184,3 +335,216 @@ impl ExtensionType for VariableShapeTensor { } } } + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical-extension-types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let variable_shape_tensor = VariableShapeTensor::try_new( + DataType::Float32, + 3, + Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]), + Some(vec![2, 0, 1]), + Some(vec![Some(400), None, Some(3)]), + )?; + let mut field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ); + field.try_with_extension_type(variable_shape_tensor.clone())?; + assert_eq!( + field.try_extension_type::()?, + variable_shape_tensor + ); + #[cfg(feature = "canonical-extension-types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ) + .with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")] + fn invalid_type() { + let variable_shape_tensor = + VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap(); + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ); + field.with_extension_type(variable_shape_tensor); + } + + #[test] + #[should_panic(expected = "VariableShapeTensor extension types requires metadata")] + fn missing_metadata() { + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ) + .with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + VariableShapeTensor::NAME.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")] + fn invalid_metadata() { + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ) + .with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + VariableShapeTensor::NAME.to_owned(), + ), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2" + )] + fn invalid_metadata_dimension_names() { + VariableShapeTensor::try_new( + DataType::Float32, + 3, + Some(vec!["a".to_owned(), "b".to_owned()]), + None, + None, + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2" + )] + fn invalid_metadata_permutations_len() { + VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3" + )] + fn invalid_metadata_permutations_values() { + VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2" + )] + fn invalid_metadata_uniform_shapes() { + VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)])) + .unwrap(); + } +} diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index 76bdff33545a..8f22be4460f4 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -199,6 +199,8 @@ pub trait ExtensionType: Sized { /// extension types, they should not be used for third-party extension /// types. /// + /// Extension names are case-sensitive. + /// /// [`Field`]: crate::Field /// [`Field::metadata`]: crate::Field::metadata const NAME: &'static str; diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 85c75f93087f..888106f25777 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -513,24 +513,7 @@ impl Field { /// - the construction of the extension type fails #[cfg(feature = "canonical-extension-types")] pub fn try_canonical_extension_type(&self) -> Result { - use crate::extension::{FixedShapeTensor, Json, Uuid}; - - // Canonical extension type names start with `arrow.` - match self.extension_type_name() { - // An extension type name with an `arrow.` prefix - Some(name) if name.starts_with("arrow.") => match name { - FixedShapeTensor::NAME => self.try_extension_type::().map(Into::into), - Json::NAME => self.try_extension_type::().map(Into::into), - Uuid::NAME => self.try_extension_type::().map(Into::into), - _ => Err(ArrowError::InvalidArgumentError(format!("Unsupported canonical extension type: {name}"))), - }, - // Name missing the expected prefix - Some(name) => Err(ArrowError::InvalidArgumentError(format!( - "Field extension type name mismatch, expected a name with an `arrow.` prefix, found {name}" - ))), - // Name missing - None => Err(ArrowError::InvalidArgumentError("Field extension type name missing".to_owned())), - } + CanonicalExtensionType::try_from(self) } /// Indicates whether this [`Field`] supports null values. From 38c72557319f0b179e44482365436385d64d5f1e Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Mon, 20 Jan 2025 17:49:29 +0100 Subject: [PATCH 12/21] Add missing type annotations --- arrow-array/src/array/list_view_array.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arrow-array/src/array/list_view_array.rs b/arrow-array/src/array/list_view_array.rs index 7e52a6f3e457..d93ed6570be8 100644 --- a/arrow-array/src/array/list_view_array.rs +++ b/arrow-array/src/array/list_view_array.rs @@ -830,8 +830,8 @@ mod tests { .build() .unwrap(), ); - assert_eq!(string.value_offsets(), &[]); - assert_eq!(string.value_sizes(), &[]); + assert_eq!(string.value_offsets(), &[] as &[i32; 0]); + assert_eq!(string.value_sizes(), &[] as &[i32; 0]); let string = LargeListViewArray::from( ArrayData::builder(DataType::LargeListView(f)) @@ -841,8 +841,8 @@ mod tests { .unwrap(), ); assert_eq!(string.len(), 0); - assert_eq!(string.value_offsets(), &[]); - assert_eq!(string.value_sizes(), &[]); + assert_eq!(string.value_offsets(), &[] as &[i64; 0]); + assert_eq!(string.value_sizes(), &[] as &[i64; 0]); } #[test] From 1a21e962ec28a6e447b9d94ba97de2c74ecb7760 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Mon, 20 Jan 2025 19:40:19 +0100 Subject: [PATCH 13/21] Fix doc warning --- .../src/extension/canonical/variable_shape_tensor.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index d71a568bea6f..ee4867a1485b 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -32,8 +32,9 @@ use crate::{extension::ExtensionType, ArrowError, DataType, Field}; /// - `data` is a List holding tensor elements (each list element is a single /// tensor). The List’s value type is the value type of the tensor, such as /// an integer or floating-point type. -/// - `shape` is a FixedSizeList[ndim] of the tensor shape where th -/// size of the list ndim is equal to the number of dimensions of the tensor. +/// - `shape` is a `FixedSizeList[ndim]` of the tensor shape where the +/// size of the list `ndim` is equal to the number of dimensions of the +/// tensor. /// /// Extension type parameters: /// `value_type`: the Arrow data type of individual tensor elements. From 069642fde6af993dfb94d6082bddaf0933740f8b Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Wed, 22 Jan 2025 16:02:32 +0100 Subject: [PATCH 14/21] Add the feature to the `arrow` crate and use underscores --- arrow-schema/Cargo.toml | 4 ++-- arrow-schema/src/extension/canonical/bool8.rs | 4 ++-- .../extension/canonical/fixed_shape_tensor.rs | 4 ++-- arrow-schema/src/extension/canonical/json.rs | 4 ++-- .../src/extension/canonical/opaque.rs | 4 ++-- arrow-schema/src/extension/canonical/uuid.rs | 4 ++-- .../canonical/variable_shape_tensor.rs | 4 ++-- arrow-schema/src/extension/mod.rs | 6 ++--- arrow-schema/src/field.rs | 4 ++-- arrow/Cargo.toml | 1 + arrow/README.md | 1 + parquet/Cargo.toml | 2 +- parquet/src/arrow/schema/mod.rs | 22 +++++++++---------- 13 files changed, 33 insertions(+), 31 deletions(-) diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index a0481544a2d0..ffea42db6653 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -43,7 +43,7 @@ bitflags = { version = "2.0.0", default-features = false, optional = true } serde_json = { version = "1.0", optional = true } [features] -canonical-extension-types = ["dep:serde", "dep:serde_json"] +canonical_extension_types = ["dep:serde", "dep:serde_json"] # Enable ffi support ffi = ["bitflags"] serde = ["dep:serde"] @@ -57,4 +57,4 @@ criterion = { version = "0.5", default-features = false } [[bench]] name = "ffi" -harness = false \ No newline at end of file +harness = false diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index acad27d35030..211950bffb48 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -73,7 +73,7 @@ impl ExtensionType for Bool8 { #[cfg(test)] mod tests { - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, @@ -87,7 +87,7 @@ mod tests { let mut field = Field::new("", DataType::Int8, false); field.try_with_extension_type(Bool8)?; field.try_extension_type::()?; - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::Bool8(Bool8) diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index f1ce85b54a57..6fe94fba78aa 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -294,7 +294,7 @@ impl ExtensionType for FixedShapeTensor { #[cfg(test)] mod tests { - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, @@ -322,7 +322,7 @@ mod tests { field.try_extension_type::()?, fixed_shape_tensor ); - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor) diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 693afc2b48ef..baebe49e59c9 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -99,7 +99,7 @@ impl ExtensionType for Json { #[cfg(test)] mod tests { - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, @@ -133,7 +133,7 @@ mod tests { let mut field = Field::new("", DataType::Utf8View, false); field.try_with_extension_type(Json::default())?; field.try_extension_type::()?; - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::Json(Json::default()) diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index 0d1baeb395a0..1db7265cfde7 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -132,7 +132,7 @@ impl ExtensionType for Opaque { #[cfg(test)] mod tests { - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, @@ -147,7 +147,7 @@ mod tests { let mut field = Field::new("", DataType::Null, false); field.try_with_extension_type(opaque.clone())?; assert_eq!(field.try_extension_type::()?, opaque); - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::Opaque(opaque) diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs index c66581b14f74..8b2e71b7b5aa 100644 --- a/arrow-schema/src/extension/canonical/uuid.rs +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -77,7 +77,7 @@ impl ExtensionType for Uuid { #[cfg(test)] mod tests { - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, @@ -91,7 +91,7 @@ mod tests { let mut field = Field::new("", DataType::FixedSizeBinary(16), false); field.try_with_extension_type(Uuid)?; field.try_extension_type::()?; - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::Uuid(Uuid) diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index ee4867a1485b..804591776b2f 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -339,7 +339,7 @@ impl ExtensionType for VariableShapeTensor { #[cfg(test)] mod tests { - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, @@ -379,7 +379,7 @@ mod tests { field.try_extension_type::()?, variable_shape_tensor ); - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor) diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index 8f22be4460f4..c11b8cd1880a 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -17,9 +17,9 @@ //! Extension types. -#[cfg(feature = "canonical-extension-types")] +#[cfg(feature = "canonical_extension_types")] mod canonical; -#[cfg(feature = "canonical-extension-types")] +#[cfg(feature = "canonical_extension_types")] pub use canonical::*; use crate::{ArrowError, DataType}; @@ -39,7 +39,7 @@ pub const EXTENSION_TYPE_METADATA_KEY: &str = "ARROW:extension:metadata"; /// - [`EXTENSION_TYPE_METADATA_KEY`] /// /// Canonical extension types support in this crate requires the -/// `canonical-extension-types` feature. +/// `canonical_extension_types` feature. /// /// Extension types may or may not use the [`EXTENSION_TYPE_METADATA_KEY`] /// field. diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 888106f25777..dbd671a62a3a 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -22,7 +22,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::datatype::DataType; -#[cfg(feature = "canonical-extension-types")] +#[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::schema::SchemaBuilder; use crate::{ @@ -511,7 +511,7 @@ impl Field { /// - this field does have a canonical extension type (mismatch or missing) /// - the canonical extension is not supported /// - the construction of the extension type fails - #[cfg(feature = "canonical-extension-types")] + #[cfg(feature = "canonical_extension_types")] pub fn try_canonical_extension_type(&self) -> Result { CanonicalExtensionType::try_from(self) } diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 76119ec4abb4..e901dbbb18fd 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -80,6 +80,7 @@ force_validate = ["arrow-array/force_validate", "arrow-data/force_validate"] # Enable ffi support ffi = ["arrow-schema/ffi", "arrow-data/ffi", "arrow-array/ffi"] chrono-tz = ["arrow-array/chrono-tz"] +canonical_extension_types = ["arrow-schema/canonical_extension_types"] [dev-dependencies] chrono = { workspace = true } diff --git a/arrow/README.md b/arrow/README.md index 79aefaae9053..64d9eb980e60 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -61,6 +61,7 @@ The `arrow` crate provides the following features which may be enabled in your ` - `chrono-tz` - support of parsing timezone using [chrono-tz](https://docs.rs/chrono-tz/0.6.0/chrono_tz/) - `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) - `pyarrow` - bindings for pyo3 to call arrow-rs from python +- `canonical_extension_types` - definitions for [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html#format-canonical-extensions) ## Arrow Feature Status diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index fd43f369acad..64a8a1f6ab6d 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -104,7 +104,7 @@ lz4 = ["lz4_flex"] # Enable arrow reader/writer APIs arrow = ["base64", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-data", "arrow-schema", "arrow-select", "arrow-ipc"] # Enable support for arrow canonical extension types -arrow-canonical-extension-types = ["arrow-schema?/canonical-extension-types"] +arrow_canonical_extension_types = ["arrow-schema?/canonical-extension-types"] # Enable CLI tools cli = ["json", "base64", "clap", "arrow-csv", "serde"] # Enable JSON APIs diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index d7748537cb6f..8b3e92251bd1 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -23,7 +23,7 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_ipc::writer; -#[cfg(feature = "arrow-canonical-extension-types")] +#[cfg(feature = "arrow_canonical_extension_types")] use arrow_schema::extension::{Json, Uuid}; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; @@ -382,7 +382,7 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result Result ret.try_with_extension_type(Uuid)?, @@ -607,13 +607,13 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { .with_id(id) .with_length(*length) .with_logical_type( - #[cfg(feature = "arrow-canonical-extension-types")] + #[cfg(feature = "arrow_canonical_extension_types")] // If set, map arrow uuid extension type to parquet uuid logical type. field .try_extension_type::() .ok() .map(|_| LogicalType::Uuid), - #[cfg(not(feature = "arrow-canonical-extension-types"))] + #[cfg(not(feature = "arrow_canonical_extension_types"))] None, ) .build() @@ -650,7 +650,7 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { DataType::Utf8 | DataType::LargeUtf8 => { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) .with_logical_type({ - #[cfg(feature = "arrow-canonical-extension-types")] + #[cfg(feature = "arrow_canonical_extension_types")] { // Use the Json logical type if the canonical Json // extension type is set on this field. @@ -658,7 +658,7 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { .try_extension_type::() .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) } - #[cfg(not(feature = "arrow-canonical-extension-types"))] + #[cfg(not(feature = "arrow_canonical_extension_types"))] Some(LogicalType::String) }) .with_repetition(repetition) @@ -667,7 +667,7 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { } DataType::Utf8View => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) .with_logical_type({ - #[cfg(feature = "arrow-canonical-extension-types")] + #[cfg(feature = "arrow_canonical_extension_types")] { // Use the Json logical type if the canonical Json // extension type is set on this field. @@ -675,7 +675,7 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { .try_extension_type::() .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) } - #[cfg(not(feature = "arrow-canonical-extension-types"))] + #[cfg(not(feature = "arrow_canonical_extension_types"))] Some(LogicalType::String) }) .with_repetition(repetition) @@ -2213,7 +2213,7 @@ mod tests { } #[test] - #[cfg(feature = "arrow-canonical-extension-types")] + #[cfg(feature = "arrow_canonical_extension_types")] fn arrow_uuid_to_parquet_uuid() -> Result<()> { let arrow_schema = Schema::new(vec![Field::new( "uuid", @@ -2237,7 +2237,7 @@ mod tests { } #[test] - #[cfg(feature = "arrow-canonical-extension-types")] + #[cfg(feature = "arrow_canonical_extension_types")] fn arrow_json_to_parquet_json() -> Result<()> { let arrow_schema = Schema::new(vec![ Field::new("json", DataType::Utf8, false).with_extension_type(Json::default()) From 85193447e61316a893b24e4adfe4b640d73abb72 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Wed, 22 Jan 2025 16:05:18 +0100 Subject: [PATCH 15/21] Update feature name in `parquet` crate --- parquet/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 64a8a1f6ab6d..00d4c5b750f8 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -104,7 +104,7 @@ lz4 = ["lz4_flex"] # Enable arrow reader/writer APIs arrow = ["base64", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-data", "arrow-schema", "arrow-select", "arrow-ipc"] # Enable support for arrow canonical extension types -arrow_canonical_extension_types = ["arrow-schema?/canonical-extension-types"] +arrow_canonical_extension_types = ["arrow-schema?/canonical_extension_types"] # Enable CLI tools cli = ["json", "base64", "clap", "arrow-csv", "serde"] # Enable JSON APIs From 5fec56d8ff81d14902274bd45e9b41e634a9379e Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Wed, 22 Jan 2025 21:52:45 +0100 Subject: [PATCH 16/21] Add experimental warning to `extensions` module docs --- arrow-schema/src/extension/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index c11b8cd1880a..f63986f7326a 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -16,6 +16,8 @@ // under the License. //! Extension types. +//! +//!
This module is experimental. There might be breaking changes between minor releases.
#[cfg(feature = "canonical_extension_types")] mod canonical; From b78e6925d76b35a739341f9a86cde91e58074ecc Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Wed, 22 Jan 2025 22:35:37 +0100 Subject: [PATCH 17/21] Add a note about the associated metadata type --- arrow-schema/src/extension/mod.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index f63986f7326a..c5119873af0c 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -209,6 +209,12 @@ pub trait ExtensionType: Sized { /// The metadata type of this extension type. /// + /// Implementations can use strongly or loosly typed data structures here + /// depending on the complexity of the metadata. + /// + /// Implementations can also use `Self` here if the extension type can be + /// constructed directly from its metadata. + /// /// If an extension type defines no metadata it should use `()` to indicate /// this. type Metadata; From 29a94cbed185515f4eb8cad362f1d2c242fd8eac Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 24 Jan 2025 12:07:48 +0100 Subject: [PATCH 18/21] Fix `Json` canonical extension type empty string metadata --- arrow-schema/src/extension/canonical/json.rs | 44 ++++++++++---------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index baebe49e59c9..8fdf34fae90e 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -19,7 +19,7 @@ //! //! -use serde_json::Value; +use serde_json::{Map, Value}; use crate::{extension::ExtensionType, ArrowError, DataType}; @@ -42,14 +42,8 @@ use crate::{extension::ExtensionType, ArrowError, DataType}; pub struct Json(JsonMetadata); /// Extension type metadata for [`Json`]. -#[derive(Debug, Clone, PartialEq)] -pub struct JsonMetadata(Value); - -impl Default for JsonMetadata { - fn default() -> Self { - Self(Value::String(Default::default())) - } -} +#[derive(Debug, Default, Clone, PartialEq)] +pub struct JsonMetadata(Option>); impl ExtensionType for Json { const NAME: &'static str = "arrow.json"; @@ -61,7 +55,14 @@ impl ExtensionType for Json { } fn serialize_metadata(&self) -> Option { - Some(self.metadata().0.to_string()) + Some( + self.metadata() + .0 + .as_ref() + .map(serde_json::to_string) + .map(Result::unwrap) + .unwrap_or_else(|| "".to_owned()), + ) } fn deserialize_metadata(metadata: Option<&str>) -> Result { @@ -69,13 +70,16 @@ impl ExtensionType for Json { metadata .map_or_else( || Err(ArrowError::InvalidArgumentError(ERR.to_owned())), - |metadata| match metadata { - r#""""# => Ok(Value::String(Default::default())), - value => value - .parse::() - .ok() - .filter(|value| matches!(value.as_object(), Some(map) if map.is_empty())) - .ok_or_else(|| ArrowError::InvalidArgumentError(ERR.to_owned())), + |metadata| { + match metadata { + // Empty string + "" => Ok(None), + value => match serde_json::from_str::>(value) { + // JSON string with an empty object + Ok(map) if map.is_empty() => Ok(Some(map)), + _ => Err(ArrowError::InvalidArgumentError(ERR.to_owned())), + }, + } }, ) .map(JsonMetadata) @@ -116,14 +120,12 @@ mod tests { field.try_with_extension_type(Json::default())?; assert_eq!( field.metadata().get(EXTENSION_TYPE_METADATA_KEY), - Some(&r#""""#.to_owned()) + Some(&"".to_owned()) ); field.try_extension_type::()?; let mut field = Field::new("", DataType::LargeUtf8, false); - field.try_with_extension_type(Json(JsonMetadata(serde_json::Value::Object( - Map::default(), - ))))?; + field.try_with_extension_type(Json(JsonMetadata(Some(Map::default()))))?; assert_eq!( field.metadata().get(EXTENSION_TYPE_METADATA_KEY), Some(&"{}".to_owned()) From 757b041ee1a80735d1314dc905e299b6034535da Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 24 Jan 2025 12:15:16 +0100 Subject: [PATCH 19/21] Simplify `Bool8::deserialize_metadata` --- arrow-schema/src/extension/canonical/bool8.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index 211950bffb48..fdd25677ed0e 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -47,14 +47,13 @@ impl ExtensionType for Bool8 { } fn deserialize_metadata(metadata: Option<&str>) -> Result { - const ERR: &str = "Bool8 extension type expects an empty string as metadata"; - metadata.map_or_else( - || Err(ArrowError::InvalidArgumentError(ERR.to_owned())), - |value| match value { - "" => Ok(""), - _ => Err(ArrowError::InvalidArgumentError(ERR.to_owned())), - }, - ) + if metadata.is_some_and(str::is_empty) { + Ok("") + } else { + Err(ArrowError::InvalidArgumentError( + "Bool8 extension type expects an empty string as metadata".to_owned(), + )) + } } fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { From c6f0443c7cc6eff4d0765366166545d34ef91d69 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 24 Jan 2025 12:44:50 +0100 Subject: [PATCH 20/21] Use `Empty` instead of `serde_json::Map` in `JsonMetadata` --- arrow-schema/src/extension/canonical/json.rs | 31 ++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 8fdf34fae90e..0a8a1ae7e020 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -19,7 +19,7 @@ //! //! -use serde_json::{Map, Value}; +use serde::{Deserialize, Serialize}; use crate::{extension::ExtensionType, ArrowError, DataType}; @@ -41,9 +41,14 @@ use crate::{extension::ExtensionType, ArrowError, DataType}; #[derive(Debug, Clone, Default, PartialEq)] pub struct Json(JsonMetadata); +/// Empty object +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +struct Empty {} + /// Extension type metadata for [`Json`]. #[derive(Debug, Default, Clone, PartialEq)] -pub struct JsonMetadata(Option>); +pub struct JsonMetadata(Option); impl ExtensionType for Json { const NAME: &'static str = "arrow.json"; @@ -74,11 +79,9 @@ impl ExtensionType for Json { match metadata { // Empty string "" => Ok(None), - value => match serde_json::from_str::>(value) { - // JSON string with an empty object - Ok(map) if map.is_empty() => Ok(Some(map)), - _ => Err(ArrowError::InvalidArgumentError(ERR.to_owned())), - }, + value => serde_json::from_str::(value) + .map(Option::Some) + .map_err(|_| ArrowError::InvalidArgumentError(ERR.to_owned())), } }, ) @@ -110,8 +113,6 @@ mod tests { Field, }; - use serde_json::Map; - use super::*; #[test] @@ -122,15 +123,21 @@ mod tests { field.metadata().get(EXTENSION_TYPE_METADATA_KEY), Some(&"".to_owned()) ); - field.try_extension_type::()?; + assert_eq!( + field.try_extension_type::()?, + Json(JsonMetadata(None)) + ); let mut field = Field::new("", DataType::LargeUtf8, false); - field.try_with_extension_type(Json(JsonMetadata(Some(Map::default()))))?; + field.try_with_extension_type(Json(JsonMetadata(Some(Empty {}))))?; assert_eq!( field.metadata().get(EXTENSION_TYPE_METADATA_KEY), Some(&"{}".to_owned()) ); - field.try_extension_type::()?; + assert_eq!( + field.try_extension_type::()?, + Json(JsonMetadata(Some(Empty {}))) + ); let mut field = Field::new("", DataType::Utf8View, false); field.try_with_extension_type(Json::default())?; From 75f56a497df774c6c2f3bc2c1315d847c452e3dd Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Fri, 24 Jan 2025 12:47:27 +0100 Subject: [PATCH 21/21] Use `map_or` instead of `is_some_and` (msrv) --- arrow-schema/src/extension/canonical/bool8.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index fdd25677ed0e..3f6c50cb3e5e 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -47,7 +47,7 @@ impl ExtensionType for Bool8 { } fn deserialize_metadata(metadata: Option<&str>) -> Result { - if metadata.is_some_and(str::is_empty) { + if metadata.map_or(false, str::is_empty) { Ok("") } else { Err(ArrowError::InvalidArgumentError(