Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ArrayAgg schema mismatch issue #8055

Merged
merged 10 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,92 @@ mod tests {

use super::*;

async fn assert_logical_expr_schema_eq_physical_expr_schema(
df: DataFrame,
) -> Result<()> {
let logical_expr_dfschema = df.schema();
let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned());
let batches = df.collect().await?;
let physical_expr_schema = batches[0].schema();
assert_eq!(logical_expr_schema, physical_expr_schema);
Ok(())
}

#[tokio::test]
async fn test_array_agg_ord_schema() -> Result<()> {
let ctx = SessionContext::new();

let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(3.0, 'c')
"#;
ctx.sql(create_table_query).await?;

let query = r#"SELECT
array_agg("double_field" ORDER BY "string_field") as "double_field",
array_agg("string_field" ORDER BY "string_field") as "string_field"
FROM test_table"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_array_agg_schema() -> Result<()> {
let ctx = SessionContext::new();

let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(3.0, 'c')
"#;
ctx.sql(create_table_query).await?;

let query = r#"SELECT
array_agg("double_field") as "double_field",
array_agg("string_field") as "string_field"
FROM test_table"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_array_agg_distinct_schema() -> Result<()> {
let ctx = SessionContext::new();

let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(2.0, 'a')
"#;
ctx.sql(create_table_query).await?;

let query = r#"SELECT
array_agg(distinct "double_field") as "double_field",
array_agg(distinct "string_field") as "string_field"
FROM test_table"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn select_columns() -> Result<()> {
// build plan using Table API
Expand Down
43 changes: 37 additions & 6 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ use std::sync::Arc;
/// ARRAY_AGG aggregate expression
#[derive(Debug)]
pub struct ArrayAgg {
/// Column name
name: String,
/// The DataType for the input expression
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
}

impl ArrayAgg {
Expand All @@ -45,11 +50,13 @@ impl ArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
) -> Self {
Self {
name: name.into(),
expr,
input_data_type: data_type,
expr,
nullable,
}
}
}
Expand All @@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
false,
self.nullable,
))
}

Expand All @@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
false,
self.nullable,
)])
}

Expand Down Expand Up @@ -184,7 +192,6 @@ mod tests {
use super::*;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
use arrow::array::ArrayRef;
use arrow::array::Int32Array;
use arrow::datatypes::*;
Expand All @@ -195,6 +202,30 @@ mod tests {
use datafusion_common::DataFusionError;
use datafusion_common::Result;

macro_rules! test_op {
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => {
test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type())
};
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);

let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;

let agg = Arc::new(<$OP>::new(
col("a", &schema)?,
"bla".to_string(),
$EXPECTED_DATATYPE,
true,
));
let actual = aggregate(&batch, agg)?;
let expected = ScalarValue::from($EXPECTED);

assert_eq!(expected, actual);

Ok(()) as Result<(), DataFusionError>
}};
}

#[test]
fn array_agg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand All @@ -208,7 +239,7 @@ mod tests {
])]);
let list = ScalarValue::List(Arc::new(list));

generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
}

#[test]
Expand Down Expand Up @@ -264,7 +295,7 @@ mod tests {

let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();

generic_test_op!(
test_op!(
array,
DataType::List(Arc::new(Field::new_list(
"item",
Expand Down
13 changes: 10 additions & 3 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
}

impl DistinctArrayAgg {
Expand All @@ -48,12 +50,14 @@ impl DistinctArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
) -> Self {
let name = name.into();
Self {
name,
expr,
input_data_type,
expr,
nullable,
}
}
}
Expand All @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
false,
self.nullable,
))
}

Expand All @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), true),
false,
self.nullable,
)])
}

Expand Down Expand Up @@ -238,6 +243,7 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));
let actual = aggregate(&batch, agg)?;

Expand All @@ -255,6 +261,7 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));

let mut accum1 = agg.create_accumulator()?;
Expand Down
20 changes: 15 additions & 5 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,17 @@ use itertools::izip;
/// and that can merge aggregations from multiple partitions.
#[derive(Debug)]
pub struct OrderSensitiveArrayAgg {
/// Column name
name: String,
/// The DataType for the input expression
input_data_type: DataType,
order_by_data_types: Vec<DataType>,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
/// Ordering data types
order_by_data_types: Vec<DataType>,
/// Ordering requirement
ordering_req: LexOrdering,
}

Expand All @@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
order_by_data_types: Vec<DataType>,
ordering_req: LexOrdering,
) -> Self {
Self {
name: name.into(),
expr,
input_data_type,
expr,
nullable,
order_by_data_types,
ordering_req,
}
Expand All @@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
false,
self.nullable,
))
}

Expand All @@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
false,
self.nullable, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
fields.push(Field::new_list(
format_state_name(&self.name, "array_agg_orderings"),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
false,
self.nullable,
));
Ok(fields)
}
Expand Down
18 changes: 12 additions & 6 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,16 @@ pub fn create_aggregate_expr(
),
(AggregateFunction::ArrayAgg, false) => {
let expr = input_phy_exprs[0].clone();
let nullable = expr.nullable(input_schema)?;

if ordering_req.is_empty() {
Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable))
} else {
Arc::new(expressions::OrderSensitiveArrayAgg::new(
expr,
name,
data_type,
nullable,
ordering_types,
ordering_req.to_vec(),
))
Expand All @@ -132,10 +135,13 @@ pub fn create_aggregate_expr(
"ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available"
);
}
let expr = input_phy_exprs[0].clone();
let is_expr_nullable = expr.nullable(input_schema)?;
Arc::new(expressions::DistinctArrayAgg::new(
input_phy_exprs[0].clone(),
expr,
name,
data_type,
is_expr_nullable,
))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Expand Down Expand Up @@ -432,8 +438,8 @@ mod tests {
assert_eq!(
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true,),
false,
Field::new("item", data_type.clone(), true),
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand Down Expand Up @@ -471,8 +477,8 @@ mod tests {
assert_eq!(
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true,),
false,
Field::new("item", data_type.clone(), true),
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand Down