Skip to content

Commit

Permalink
infer right side nullability for LEFT join (#5748)
Browse files Browse the repository at this point in the history
* weaken schema check on mem table

* weaken schema check on mem table

* fix LEFT join column null inference

* fix LEFT join column null inference. fix test

* fix LEFT join column null inference. fix test

* fix LEFT join column null inference. rollback some code
  • Loading branch information
comphead authored Mar 29, 2023
1 parent 5c65924 commit f210cac
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
7 changes: 7 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ FROM t1
----
11 a 55

# test create table from query with LEFT join
statement ok
create table temp as
with t1 as (select 1 as col1, 'asd' as col2),
t2 as (select 1 as col3, 'sdf' as col4)
select col2, col4 from t1 left join t2 on col1 = col3

statement ok
drop table IF EXISTS t1;

Expand Down
36 changes: 30 additions & 6 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1041,20 +1041,44 @@ pub fn build_join_schema(
right: &DFSchema,
join_type: &JoinType,
) -> Result<DFSchema> {
let right_fields = right.fields();
let left_fields = left.fields();

let fields: Vec<DFField> = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let right_fields = right.fields().iter();
let left_fields = left.fields().iter();
JoinType::Inner | JoinType::Full | JoinType::Right => {
// left then right
left_fields.chain(right_fields).cloned().collect()
left_fields
.iter()
.chain(right_fields.iter())
.cloned()
.collect()
}
JoinType::Left => {
// left then right, right set to nullable in case of not matched scenario
let right_fields_nullable: Vec<DFField> = right_fields
.iter()
.map(|f| {
let field = f.field().clone().with_nullable(true);
if let Some(q) = f.qualifier() {
DFField::from_qualified(q, field)
} else {
DFField::from(field)
}
})
.collect();
left_fields
.iter()
.chain(&right_fields_nullable)
.cloned()
.collect()
}
JoinType::LeftSemi | JoinType::LeftAnti => {
// Only use the left side for the schema
left.fields().clone()
left_fields.clone()
}
JoinType::RightSemi | JoinType::RightAnti => {
// Only use the right side for the schema
right.fields().clone()
right_fields.clone()
}
};

Expand Down
20 changes: 10 additions & 10 deletions datafusion/optimizer/src/extract_equijoin_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ mod tests {
Some(col("t1.a").eq(col("t2.a"))),
)?
.build()?;
let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -199,7 +199,7 @@ mod tests {
Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
)?
.build()?;
let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -223,7 +223,7 @@ mod tests {
),
)?
.build()?;
let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -250,7 +250,7 @@ mod tests {
),
)?
.build()?;
let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -277,7 +277,7 @@ mod tests {
),
)?
.build()?;
let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

Expand Down Expand Up @@ -314,9 +314,9 @@ mod tests {
),
)?
.build()?;
let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
\n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";

Expand Down Expand Up @@ -349,9 +349,9 @@ mod tests {
Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
)?
.build()?;
let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
\n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";

Expand Down Expand Up @@ -380,7 +380,7 @@ mod tests {
Some(filter),
)?
.build()?;
let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/push_down_projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ mod tests {
vec![
DFField::new(Some("test"), "a", DataType::UInt32, false),
DFField::new(Some("test"), "b", DataType::UInt32, false),
DFField::new(Some("test2"), "c1", DataType::UInt32, false),
DFField::new(Some("test2"), "c1", DataType::UInt32, true),
],
HashMap::new(),
)?,
Expand Down Expand Up @@ -776,7 +776,7 @@ mod tests {
vec![
DFField::new(Some("test"), "a", DataType::UInt32, false),
DFField::new(Some("test"), "b", DataType::UInt32, false),
DFField::new(Some("test2"), "c1", DataType::UInt32, false),
DFField::new(Some("test2"), "c1", DataType::UInt32, true),
],
HashMap::new(),
)?,
Expand Down Expand Up @@ -817,7 +817,7 @@ mod tests {
vec![
DFField::new(Some("test"), "a", DataType::UInt32, false),
DFField::new(Some("test"), "b", DataType::UInt32, false),
DFField::new(Some("test2"), "a", DataType::UInt32, false),
DFField::new(Some("test2"), "a", DataType::UInt32, true),
],
HashMap::new(),
)?,
Expand Down

0 comments on commit f210cac

Please sign in to comment.