Skip to content

Commit

Permalink
feat: add support for fixed list wildcard in type signature
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Feb 21, 2024
1 parent cf11a70 commit 333ba83
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 6 deletions.
6 changes: 3 additions & 3 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,9 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
match data_type {
DataType::List(field) | DataType::LargeList(field) => {
1 + list_ndims(field.data_type())
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()),
_ => 0,
}
}
Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ use arrow::datatypes::DataType;
/// return results with this timezone.
pub const TIMEZONE_WILDCARD: &str = "+TZ";

/// Constant that is used as a placeholder for any valid fixed size list.
/// This is used where a function can accept a fixed size list type with any
/// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths.
pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN;

///A function's volatility, which defines the functions eligibility for certain optimizations
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Volatility {
Expand Down
86 changes: 83 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use crate::signature::{ArrayFunctionSignature, TIMEZONE_WILDCARD};
use crate::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
};
use crate::{Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
Expand Down Expand Up @@ -372,13 +374,19 @@ fn coerced_from<'a>(
List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),

// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
List(_) | LargeList(_)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}
FixedSizeList(_, size)
if *size == FIXED_SIZE_LIST_WILDCARD
&& list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}

Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
match type_from {
Expand Down Expand Up @@ -408,8 +416,10 @@ fn coerced_from<'a>(

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;
use arrow::datatypes::{DataType, TimeUnit};
use arrow::datatypes::{DataType, Field, TimeUnit};

#[test]
fn test_maybe_data_types() {
Expand Down Expand Up @@ -485,4 +495,74 @@ mod tests {

Ok(())
}

#[test]
fn test_fixed_list_wildcard_coerce() -> Result<()> {
let inner = Arc::new(Field::new("item", DataType::Int32, false));
let type_into = DataType::FixedSizeList(inner.clone(), FIXED_SIZE_LIST_WILDCARD);
let cases = vec![
DataType::FixedSizeList(inner.clone(), 2),
DataType::FixedSizeList(inner.clone(), 3),
DataType::FixedSizeList(inner.clone(), FIXED_SIZE_LIST_WILDCARD),
DataType::List(inner.clone()),
];
for case in cases {
let out = coerced_from(&type_into, &case);
assert_eq!(out, Some(type_into.clone()));
}

let nested_inner = Arc::new(Field::new(
"item",
DataType::FixedSizeList(inner.clone(), FIXED_SIZE_LIST_WILDCARD),
false,
));

let invalid_cases = vec![
DataType::Int32,
DataType::Boolean,
DataType::FixedSizeList(nested_inner.clone(), 1),
DataType::List(nested_inner.clone()),
];

for case in invalid_cases {
let out = coerced_from(&type_into, &case);
assert_eq!(out, None);
}
let type_into_nested = DataType::FixedSizeList(nested_inner.clone(), FIXED_SIZE_LIST_WILDCARD);
let type_from_nested = DataType::List(nested_inner.clone());
let out = coerced_from(&type_into_nested, &type_from_nested);
assert_eq!(out, Some(type_into_nested));

Ok(())
}
#[test]
fn test_fixed_list_no_wildcard_coerce() -> Result<()> {
let inner = Arc::new(Field::new("item", DataType::Int32, false));
let type_into = DataType::FixedSizeList(inner.clone(), 1);
let invalid_cases = vec![
DataType::FixedSizeList(inner.clone(), 2),
DataType::FixedSizeList(inner.clone(), 3),
DataType::FixedSizeList(inner.clone(), 4),
];
for case in invalid_cases {
let out = coerced_from(&type_into, &case);
assert_eq!(out, None);
}

let cases = vec![
DataType::FixedSizeList(inner.clone(), 1),
DataType::FixedSizeList(inner.clone(), FIXED_SIZE_LIST_WILDCARD),
DataType::List(inner.clone()),
];

for case in cases {
let out = coerced_from(&type_into, &case);

assert_eq!(out, Some(type_into.clone()));
}

let invalid_cases = vec![DataType::Int32, DataType::Boolean];

Ok(())
}
}

0 comments on commit 333ba83

Please sign in to comment.