Skip to content

Commit

Permalink
Add basic join support (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Aug 15, 2022
1 parent 2644b81 commit 3349464
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 34 deletions.
106 changes: 93 additions & 13 deletions src/consumer.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use std::sync::Arc;

use async_recursion::async_recursion;

use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::LogicalPlan;
use datafusion::logical_plan::build_join_schema;
use datafusion::prelude::JoinType;
use datafusion::{
error::{DataFusionError, Result},
logical_plan::{Expr, Operator},
optimizer::utils::split_conjunction,
prelude::{Column, DataFrame, SessionContext},
scalar::ScalarValue,
};

use std::collections::HashMap;
use std::sync::Arc;
use substrait::protobuf::{
expression::{field_reference::ReferenceType::MaskedReference, literal::LiteralType, RexType},
expression::{
field_reference::ReferenceType::MaskedReference, literal::LiteralType, MaskExpression,
RexType,
},
function_argument::ArgType,
read_rel::ReadType,
rel::RelType,
Expand Down Expand Up @@ -57,7 +63,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel) -> Result<A
let input = from_substrait_rel(ctx, input).await?;
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
let x = from_substrait_rex(e, input.as_ref()).await?;
let x = from_substrait_rex(e, &input.schema()).await?;
exprs.push(x.as_ref().clone());
}
input.select(exprs)
Expand All @@ -71,7 +77,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel) -> Result<A
if let Some(input) = filter.input.as_ref() {
let input = from_substrait_rel(ctx, input).await?;
if let Some(condition) = filter.condition.as_ref() {
let expr = from_substrait_rex(condition, input.as_ref()).await?;
let expr = from_substrait_rex(condition, &input.schema()).await?;
input.filter(expr.as_ref().clone())
} else {
Err(DataFusionError::NotImplemented(
Expand All @@ -84,10 +90,83 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel) -> Result<A
))
}
}
Some(RelType::Join(join)) => {
let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap()).await?;
let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap()).await?;
let join_type = match join.r#type {
1 => JoinType::Inner,
2 => JoinType::Left,
3 => JoinType::Right,
4 => JoinType::Full,
5 => JoinType::Anti,
6 => JoinType::Semi,
_ => return Err(DataFusionError::Internal("invalid join type".to_string())),
};
let mut predicates = vec![];
let schema = build_join_schema(&left.schema(), &right.schema(), &JoinType::Inner)?;
let on = from_substrait_rex(&join.expression.as_ref().unwrap(), &schema).await?;
split_conjunction(&on, &mut predicates);
let pairs = predicates
.iter()
.map(|p| match p {
Expr::BinaryExpr {
left,
op: Operator::Eq,
right,
} => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => Ok((l.flat_name(), r.flat_name())),
_ => {
return Err(DataFusionError::Internal(
"invalid join condition".to_string(),
))
}
},
_ => {
return Err(DataFusionError::Internal(
"invalid join condition".to_string(),
))
}
})
.collect::<Result<Vec<_>>>()?;
let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| l.as_str()).collect();
let right_cols: Vec<&str> = pairs.iter().map(|(_, r)| r.as_str()).collect();
left.join(right, join_type, &left_cols, &right_cols, None)
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
let table_name: String = nt.names[0].clone();
ctx.table(&*table_name)
let t = ctx.table(&*table_name)?;
match &read.projection {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
let column_indices: Vec<usize> = projection
.struct_items
.iter()
.map(|item| item.field as usize)
.collect();
match t.to_logical_plan()? {
LogicalPlan::TableScan(scan) => {
let mut scan = scan.clone();
let fields: Vec<DFField> = column_indices
.iter()
.map(|i| scan.projected_schema.field(*i).clone())
.collect();
scan.projection = Some(column_indices);
scan.projected_schema = DFSchemaRef::new(
DFSchema::new_with_metadata(fields, HashMap::new())?,
);
let plan = LogicalPlan::TableScan(scan);
Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan)))
}
_ => Err(DataFusionError::Internal(
"unexpected plan for table".to_string(),
)),
}
}
_ => Ok(t),
},
_ => Ok(t),
}
}
_ => Err(DataFusionError::NotImplemented(
"Only NamedTable reads are supported".to_string(),
Expand All @@ -102,14 +181,13 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel) -> Result<A

/// Convert Substrait Rex to DataFusion Expr
#[async_recursion]
pub async fn from_substrait_rex(e: &Expression, input: &DataFrame) -> Result<Arc<Expr>> {
pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema) -> Result<Arc<Expr>> {
match &e.rex_type {
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
Some(MaskedReference(mask)) => match &mask.select.as_ref() {
Some(x) if x.struct_items.len() == 1 => Ok(Arc::new(Expr::Column(Column {
relation: None,
name: input
.schema()
name: input_schema
.field(x.struct_items[0].field as usize)
.name()
.to_string(),
Expand All @@ -128,9 +206,11 @@ pub async fn from_substrait_rex(e: &Expression, input: &DataFrame) -> Result<Arc
match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
Ok(Arc::new(Expr::BinaryExpr {
left: Box::new(from_substrait_rex(l, input).await?.as_ref().clone()),
left: Box::new(from_substrait_rex(l, input_schema).await?.as_ref().clone()),
op,
right: Box::new(from_substrait_rex(r, input).await?.as_ref().clone()),
right: Box::new(
from_substrait_rex(r, input_schema).await?.as_ref().clone(),
),
}))
}
(l, r) => Err(DataFusionError::NotImplemented(format!(
Expand Down
48 changes: 44 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,59 @@ mod tests {
roundtrip("SELECT * FROM data WHERE d AND a > 1").await
}

#[tokio::test]
async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
}

#[tokio::test]
async fn inner_join() -> Result<()> {
assert_expected_plan(
"SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
"Projection: #data.a\
\n Inner Join: #data.a = #data2.a\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
)
.await
}

async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = to_substrait_rel(&plan)?;
let df = from_substrait_rel(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;
let plan2str = format!("{:?}", plan2);
assert_eq!(expected_plan_str, &plan2str);
Ok(())
}
async fn roundtrip(sql: &str) -> Result<()> {
let mut ctx = SessionContext::new();
ctx.register_csv("data", "testdata/data.csv", CsvReadOptions::new())
.await?;
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = to_substrait_rel(&plan)?;

// pretty print the protobuf struct
//println!("{:#?}", proto);

let df = from_substrait_rel(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;
let plan2 = ctx.optimize(&plan2)?;
//println!("Roundtrip Plan:\n{:?}", plan2);

let plan1str = format!("{:?}", plan);
let plan2str = format!("{:?}", plan2);
assert_eq!(plan1str, plan2str);
Ok(())
}

async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("data", "testdata/data.csv", CsvReadOptions::new())
.await?;
ctx.register_csv("data2", "testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}
98 changes: 81 additions & 17 deletions src/producer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use datafusion::{
error::{DataFusionError, Result},
logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator},
logical_plan::{DFSchemaRef, Expr, JoinConstraint, LogicalPlan, Operator},
prelude::JoinType,
scalar::ScalarValue,
};

use substrait::protobuf::{
expression::{
field_reference::ReferenceType,
Expand All @@ -14,8 +14,9 @@ use substrait::protobuf::{
function_argument::ArgType,
read_rel::{NamedTable, ReadType},
rel::RelType,
Expression, FilterRel, FunctionArgument, NamedStruct, ProjectRel, ReadRel, Rel,
Expression, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
};

/// Convert DataFusion LogicalPlan to Substrait Rel
pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
match plan {
Expand Down Expand Up @@ -87,6 +88,62 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
}))),
}))
}
LogicalPlan::Join(join) => {
let left = to_substrait_rel(join.left.as_ref())?;
let right = to_substrait_rel(join.right.as_ref())?;
let join_type = match join.join_type {
JoinType::Inner => 1,
JoinType::Left => 2,
JoinType::Right => 3,
JoinType::Full => 4,
JoinType::Anti => 5,
JoinType::Semi => 6,
};
// we only support basic joins so return an error for anything not yet supported
if join.null_equals_null {
return Err(DataFusionError::NotImplemented(
"join null_equals_null".to_string(),
));
}
if join.filter.is_some() {
return Err(DataFusionError::NotImplemented("join filter".to_string()));
}
match join.join_constraint {
JoinConstraint::On => {}
_ => {
return Err(DataFusionError::NotImplemented(
"join constraint".to_string(),
))
}
}
// map the left and right columns to binary expressions in the form `l = r`
let join_expression: Vec<Expr> = join
.on
.iter()
.map(|(l, r)| Expr::Column(l.clone()).eq(Expr::Column(r.clone())))
.collect();
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
let join_expression = join_expression
.into_iter()
.reduce(|acc: Expr, expr: Expr| acc.and(expr));
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type,
expression: Some(Box::new(to_substrait_rex(&e, &join.schema)?)),
post_join_filter: None,
advanced_extension: None,
}))),
}))
} else {
Err(DataFusionError::NotImplemented(
"Empty join condition".to_string(),
))
}
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported operator: {:?}",
plan
Expand Down Expand Up @@ -126,20 +183,10 @@ pub fn operator_to_reference(op: Operator) -> u32 {
/// Convert DataFusion Expr to Substrait Rex
pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> {
match expr {
Expr::Column(col) => Ok(Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
reference_type: Some(ReferenceType::MaskedReference(MaskExpression {
select: Some(StructSelect {
struct_items: vec![StructItem {
field: schema.index_of_column_by_name(None, &col.name)? as i32,
child: None,
}],
}),
maintain_singular_struct: false,
})),
root_type: None,
}))),
}),
Expr::Column(col) => {
let index = schema.index_of_column(&col)?;
substrait_field_ref(index)
}
Expr::BinaryExpr { left, op, right } => {
let l = to_substrait_rex(left, schema)?;
let r = to_substrait_rex(right, schema)?;
Expand Down Expand Up @@ -195,3 +242,20 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression>
))),
}
}

fn substrait_field_ref(index: usize) -> Result<Expression> {
Ok(Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
reference_type: Some(ReferenceType::MaskedReference(MaskExpression {
select: Some(StructSelect {
struct_items: vec![StructItem {
field: index as i32,
child: None,
}],
}),
maintain_singular_struct: false,
})),
root_type: None,
}))),
})
}

0 comments on commit 3349464

Please sign in to comment.