From 48937312fdaabc1a10e11a248f84d5cfbb98f67a Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 12 Aug 2022 07:13:59 -0400 Subject: [PATCH] Add failing exact_median test --- datafusion/core/tests/dataframe.rs | 52 +++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index b25e83cb7eba..55f404c2e709 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::Float64Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ array::{Int32Array, StringArray}, @@ -31,7 +32,7 @@ use datafusion::logical_plan::{col, Expr}; use datafusion::prelude::CsvReadOptions; use datafusion::{datasource::MemTable, prelude::JoinType}; use datafusion_expr::expr::GroupingSet; -use datafusion_expr::{avg, count, lit, sum}; +use datafusion_expr::{avg, count, lit, sum, AggregateFunction}; #[tokio::test] async fn join() -> Result<()> { @@ -424,3 +425,52 @@ async fn aggregates_table(ctx: &SessionContext) -> Result> { ) .await } + +#[tokio::test] +async fn exact_median() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from_slice(&[1, 1, 1, 1])), + Arc::new(Float64Array::from_slice(&[10.0, 0.0, 20.0, 100.0])), + ], + ) + .unwrap(); + + let ctx = SessionContext::new(); + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); + ctx.register_table("t", Arc::new(provider)).unwrap(); + + let df = ctx + .table("t") + .unwrap() + .aggregate( + vec![col("a")], + vec![Expr::AggregateFunction { + fun: AggregateFunction::Median, + args: vec![col("b")], + distinct: false, + } + .alias("agg")], + ) + .unwrap(); + + let results = df.collect().await.unwrap(); + + #[rustfmt::skip] + let expected = vec![ + "+---+------+", + "| a | agg |", + "+---+------+", + "| 1 | 15.0 |", + "+---+------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +}