Skip to content

Commit

Permalink
Support Group By and Order By Column Positions (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
shehabgamin authored Sep 18, 2024
1 parent 9b4a524 commit d618551
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 25 deletions.
53 changes: 45 additions & 8 deletions crates/sail-plan/src/resolver/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ impl PlanResolver<'_> {
pub(super) async fn resolve_sort_order(
&self,
sort: spec::SortOrder,
resolve_literals: bool,
schema: &DFSchemaRef,
state: &mut PlanResolverState,
) -> PlanResult<expr::Expr> {
Expand All @@ -128,22 +129,55 @@ impl PlanResolver<'_> {
NullOrdering::NullsLast => false,
NullOrdering::Unspecified => asc,
};
Ok(expr::Expr::Sort(expr::Sort {
expr: Box::new(self.resolve_expression(*child, schema, state).await?),
asc,
nulls_first,
}))

match child.as_ref() {
spec::Expr::Literal(literal) if resolve_literals => {
let num_fields = schema.fields().len();
let position = match literal {
spec::Literal::Integer(value) => *value as usize,
spec::Literal::Long(value) => *value as usize,
_ => {
return Ok(expr::Expr::Sort(expr::Sort {
expr: Box::new(self.resolve_expression(*child, schema, state).await?),
asc,
nulls_first,
}))
}
};
if position > 0 && position <= num_fields {
Ok(expr::Expr::Sort(expr::Sort {
expr: Box::new(expr::Expr::Column(Column::from(
schema.qualified_field(position - 1),
))),
asc,
nulls_first,
}))
} else {
Err(PlanError::invalid(format!(
"Cannot resolve column position {position}. Valid positions are 1 to {num_fields}."
)))
}
}
_ => Ok(expr::Expr::Sort(expr::Sort {
expr: Box::new(self.resolve_expression(*child, schema, state).await?),
asc,
nulls_first,
})),
}
}

pub(super) async fn resolve_sort_orders(
&self,
sort: Vec<spec::SortOrder>,
resolve_literals: bool,
schema: &DFSchemaRef,
state: &mut PlanResolverState,
) -> PlanResult<Vec<expr::Expr>> {
let mut results: Vec<expr::Expr> = Vec::with_capacity(sort.len());
for s in sort {
let expr = self.resolve_sort_order(s, schema, state).await?;
let expr = self
.resolve_sort_order(s, resolve_literals, schema, state)
.await?;
results.push(expr);
}
Ok(results)
Expand Down Expand Up @@ -800,7 +834,7 @@ impl PlanResolver<'_> {
schema: &DFSchemaRef,
state: &mut PlanResolverState,
) -> PlanResult<NamedExpr> {
let sort = self.resolve_sort_order(sort, schema, state).await?;
let sort = self.resolve_sort_order(sort, true, schema, state).await?;
Ok(NamedExpr::new(vec![], sort))
}

Expand Down Expand Up @@ -866,7 +900,10 @@ impl PlanResolver<'_> {
let partition_by = self
.resolve_expressions(partition_spec, schema, state)
.await?;
let order_by = self.resolve_sort_orders(order_spec, schema, state).await?;
// Spark treats literals as constants in ORDER BY window definition
let order_by = self
.resolve_sort_orders(order_spec, false, schema, state)
.await?;
let window_frame = if let Some(frame) = frame_spec {
self.resolve_window_frame(frame, &order_by, schema)?
} else {
Expand Down
48 changes: 41 additions & 7 deletions crates/sail-plan/src/resolver/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ impl PlanResolver<'_> {
) -> PlanResult<LogicalPlan> {
let input = self.resolve_query_plan(input, state).await?;
let schema = input.schema();
let expr = self.resolve_sort_orders(order, schema, state).await?;
let expr = self.resolve_sort_orders(order, true, schema, state).await?;
if is_global {
Ok(LogicalPlan::Sort(plan::Sort {
expr,
Expand Down Expand Up @@ -884,18 +884,20 @@ impl PlanResolver<'_> {
having,
with_grouping_expressions,
} = aggregate;

let input = self.resolve_query_plan(*input, state).await?;
let schema = input.schema();
let grouping = self
.resolve_named_expressions(grouping, schema, state)
.await?;
let projections = self
.resolve_named_expressions(projections, schema, state)
.await?;
let grouping = self
.resolve_named_expressions(grouping, schema, state)
.await?;
let having = match having {
Some(having) => Some(self.resolve_expression(having, schema, state).await?),
None => None,
};

self.rewrite_aggregate(
input,
projections,
Expand Down Expand Up @@ -2163,7 +2165,7 @@ impl PlanResolver<'_> {
let table_reference = self.resolve_table_reference(&table)?;
let table_provider = self.ctx.table_provider(table_reference.clone()).await?;
let schema = self
.resolve_table_schema(&table_reference, &table_provider, columns.iter().collect())
.resolve_table_schema(&table_reference, &table_provider, &columns)
.await?;
let df_schema = Arc::new(DFSchema::try_from_qualified_schema(
table_reference.clone(),
Expand Down Expand Up @@ -2216,7 +2218,7 @@ impl PlanResolver<'_> {
let table_reference = self.resolve_table_reference(&table)?;
let table_provider = self.ctx.table_provider(table_reference.clone()).await?;
let table_schema = self
.resolve_table_schema(&table_reference, &table_provider, vec![])
.resolve_table_schema(&table_reference, &table_provider, &[])
.await?;
let fields = table_schema
.fields
Expand Down Expand Up @@ -2277,7 +2279,7 @@ impl PlanResolver<'_> {
let table_reference = self.resolve_table_reference(&table)?;
let table_provider = self.ctx.table_provider(table_reference.clone()).await?;
let table_schema = self
.resolve_table_schema(&table_reference, &table_provider, vec![])
.resolve_table_schema(&table_reference, &table_provider, &[])
.await?;
let table_schema = Arc::new(DFSchema::try_from_qualified_schema(
table_reference.clone(),
Expand Down Expand Up @@ -2501,6 +2503,37 @@ impl PlanResolver<'_> {
Ok(())
}

fn resolve_expressions_positions(
&self,
exprs: Vec<NamedExpr>,
projections: &[NamedExpr],
) -> PlanResult<Vec<NamedExpr>> {
let num_projections = projections.len() as i64;
exprs
.into_iter()
.map(|named_expr| {
let NamedExpr { expr, .. } = &named_expr;
match expr {
Expr::Literal(scalar_value) => {
let position = match scalar_value {
ScalarValue::Int32(Some(position)) => *position as i64,
ScalarValue::Int64(Some(position)) => *position,
_ => return Ok(named_expr),
};
if position > 0_i64 && position <= num_projections {
Ok(projections[(position - 1) as usize].clone())
} else {
Err(PlanError::invalid(format!(
"Cannot resolve column position {position}. Valid positions are 1 to {num_projections}."
)))
}
}
_ => Ok(named_expr),
}
})
.collect()
}

fn rewrite_aggregate(
&self,
input: LogicalPlan,
Expand All @@ -2510,6 +2543,7 @@ impl PlanResolver<'_> {
with_grouping_expressions: bool,
state: &mut PlanResolverState,
) -> PlanResult<LogicalPlan> {
let grouping = self.resolve_expressions_positions(grouping, &projections)?;
let mut aggregate_candidates = projections
.iter()
.map(|x| x.expr.clone())
Expand Down
4 changes: 2 additions & 2 deletions crates/sail-plan/src/resolver/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ impl PlanResolver<'_> {
&self,
table_reference: &TableReference,
table_provider: &Arc<dyn TableProvider>,
columns: Vec<&spec::Identifier>,
columns: &[spec::Identifier],
) -> PlanResult<SchemaRef> {
let columns: Vec<&str> = columns.into_iter().map(|c| c.into()).collect();
let columns: Vec<&str> = columns.iter().map(|c| c.into()).collect();
let schema = table_provider.schema();
if columns.is_empty() {
Ok(schema)
Expand Down
10 changes: 5 additions & 5 deletions crates/sail-spark-connect/tests/gold_data/function/datetime.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
}
},
"output": {
"failure": "not implemented: function: session_window"
"failure": "error in DataFusion: Error during planning: cannot resolve attribute: ObjectName([Identifier(\"session_window\"), Identifier(\"start\")])"
}
},
{
Expand Down Expand Up @@ -82,7 +82,7 @@
}
},
"output": {
"failure": "not implemented: function: session_window"
"failure": "error in DataFusion: Error during planning: cannot resolve attribute: ObjectName([Identifier(\"session_window\"), Identifier(\"start\")])"
}
},
{
Expand Down Expand Up @@ -130,7 +130,7 @@
}
},
"output": {
"failure": "not implemented: function: window"
"failure": "error in DataFusion: Error during planning: cannot resolve attribute: ObjectName([Identifier(\"window\")])"
}
},
{
Expand Down Expand Up @@ -174,7 +174,7 @@
}
},
"output": {
"failure": "not implemented: function: window"
"failure": "error in DataFusion: Error during planning: cannot resolve attribute: ObjectName([Identifier(\"window\"), Identifier(\"start\")])"
}
},
{
Expand Down Expand Up @@ -216,7 +216,7 @@
}
},
"output": {
"failure": "not implemented: function: window"
"failure": "error in DataFusion: Error during planning: cannot resolve attribute: ObjectName([Identifier(\"window\"), Identifier(\"start\")])"
}
},
{
Expand Down
4 changes: 2 additions & 2 deletions crates/sail-sql/src/statement/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ pub(crate) fn delete_statement_to_plan(delete: ast::Delete) -> SqlResult<spec::P
}

let mut from = match from {
ast::FromTable::WithFromKeyword(v) => v,
ast::FromTable::WithoutKeyword(v) => v,
ast::FromTable::WithFromKeyword(tables) => tables,
ast::FromTable::WithoutKeyword(tables) => tables,
};
if from.len() != 1 {
return Err(SqlError::invalid(format!(
Expand Down
1 change: 0 additions & 1 deletion python/pysail/tests/spark/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_group_by(spark):
assert_frame_equal(actual, expected)


@pytest.mark.skip(reason="not implemented")
def test_group_by_column_position(spark):
actual = spark.sql("SELECT id, sum(quantity) FROM dealer GROUP BY 1 ORDER BY 1").toPandas()
expected = pd.DataFrame(
Expand Down

0 comments on commit d618551

Please sign in to comment.