Skip to content

Commit

Permalink
Add sort consumer and producer (#24)
Browse files Browse the repository at this point in the history
Add consumer

Add producer and test

Modified error string
  • Loading branch information
nseekhao authored Oct 14, 2022
1 parent 161b774 commit c8c8732
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 2 deletions.
52 changes: 52 additions & 0 deletions src/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use datafusion::{
prelude::{Column, DataFrame, SessionContext},
scalar::ScalarValue,
};
use substrait::protobuf::sort_field::{SortKind::*, SortDirection};
use std::collections::HashMap;
use std::sync::Arc;
use substrait::protobuf::{
Expand Down Expand Up @@ -107,6 +108,57 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel) -> Result<A
))
}
}
Some(RelType::Sort(sort)) => {
if let Some(input) = sort.input.as_ref() {
let input = from_substrait_rel(ctx, input).await?;
let mut sorts: Vec<Expr> = vec![];
for s in &sort.sorts {
let expr = from_substrait_rex(&s.expr.as_ref().unwrap(), &input.schema()).await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
let direction : SortDirection = unsafe {
::std::mem::transmute(*d)
};
match direction {
SortDirection::AscNullsFirst => Ok((true, true)),
SortDirection::AscNullsLast => Ok((true, false)),
SortDirection::DescNullsFirst => Ok((false, true)),
SortDirection::DescNullsLast => Ok((false, false)),
SortDirection::Clustered => {
Err(DataFusionError::NotImplemented(
"Sort with direction clustered is not yet supported".to_string(),
))
},
SortDirection::Unspecified => {
Err(DataFusionError::NotImplemented(
"Unspecified sort direction is invalid".to_string(),
))
}
}
}
ComparisonFunctionReference(_) => {
Err(DataFusionError::NotImplemented(
"Sort using comparison function reference is not supported".to_string(),
))
},
},
None => {
Err(DataFusionError::NotImplemented(
"Sort without sort kind is invalid".to_string(),
))
},
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
sorts.push(Expr::Sort { expr: Box::new(expr.as_ref().clone()), asc: asc, nulls_first: nulls_first });
}
input.sort(sorts)
} else {
Err(DataFusionError::NotImplemented(
"Sort without an input is not valid".to_string(),
))
}
}
Some(RelType::Join(join)) => {
let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap()).await?;
let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap()).await?;
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ mod tests {
roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await
}

#[tokio::test]
async fn select_with_sort() -> Result<()> {
roundtrip("SELECT a, b FROM data ORDER BY a").await
}

#[tokio::test]
async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
Expand Down
45 changes: 43 additions & 2 deletions src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ use substrait::protobuf::{
mask_expression::{StructItem, StructSelect},
FieldReference, Literal, MaskExpression, RexType, ScalarFunction,
},
sort_field::{
SortDirection,
SortKind,
},
function_argument::ArgType,
read_rel::{NamedTable, ReadType},
rel::RelType,
Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel,
Rel,
Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel, SortField, SortRel
};

/// Convert DataFusion LogicalPlan to Substrait Rel
Expand Down Expand Up @@ -105,6 +108,22 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
}))),
}))
}
LogicalPlan::Sort(sort) => {
let input = to_substrait_rel(sort.input.as_ref())?;
let sort_fields = sort
.expr
.iter()
.map(|e| substrait_sort_field(e, sort.input.schema()))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Sort(Box::new(SortRel {
common: None,
input: Some(input),
sorts: sort_fields,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Join(join) => {
let left = to_substrait_rel(join.left.as_ref())?;
let right = to_substrait_rel(join.right.as_ref())?;
Expand Down Expand Up @@ -263,6 +282,28 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression>
}
}

fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef) -> Result<SortField> {
match expr {
Expr::Sort { expr, asc, nulls_first } => {
let e = to_substrait_rex(expr, schema)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
(false, true) => SortDirection::DescNullsFirst,
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
expr: Some(e),
sort_kind: Some(SortKind::Direction(d as i32)),
})
},
_ => Err(DataFusionError::NotImplemented(format!(
"Expecting sort expression but got {:?}",
expr
))),
}
}

fn substrait_field_ref(index: usize) -> Result<Expression> {
Ok(Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
Expand Down

0 comments on commit c8c8732

Please sign in to comment.