From c05a2a0c11634b1b98aef2cec7a653dc2d5227ec Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 4 Nov 2024 17:40:14 +0400 Subject: [PATCH] feat: Add SQL support for `RIGHT JOIN`, fix an issue with wildcard aliasing --- crates/polars-sql/src/context.rs | 29 +++--- crates/polars-sql/src/sql_expr.rs | 33 ++++--- py-polars/tests/unit/sql/test_joins.py | 118 ++++++++++++++++++++++++- 3 files changed, 151 insertions(+), 29 deletions(-) diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 1a060545439c..c85e24608786 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -561,6 +561,7 @@ impl SQLContext { lf = match &join.join_operator { op @ (JoinOperator::FullOuter(constraint) | JoinOperator::LeftOuter(constraint) + | JoinOperator::RightOuter(constraint) | JoinOperator::Inner(constraint) | JoinOperator::LeftAnti(constraint) | JoinOperator::LeftSemi(constraint) @@ -585,6 +586,7 @@ impl SQLContext { match op { JoinOperator::FullOuter(_) => JoinType::Full, JoinOperator::LeftOuter(_) => JoinType::Left, + JoinOperator::RightOuter(_) => JoinType::Right, JoinOperator::Inner(_) => JoinType::Inner, #[cfg(feature = "semi_anti_join")] JoinOperator::LeftAnti(_) | JoinOperator::RightAnti(_) => JoinType::Anti, @@ -1414,14 +1416,14 @@ fn collect_compound_identifiers( right_name: &str, ) -> PolarsResult<(Vec, Vec)> { if left.len() == 2 && right.len() == 2 { - let (tbl_a, col_a) = (left[0].value.as_str(), left[1].value.as_str()); - let (tbl_b, col_b) = (right[0].value.as_str(), right[1].value.as_str()); + let (tbl_a, col_name_a) = (left[0].value.as_str(), left[1].value.as_str()); + let (tbl_b, col_name_b) = (right[0].value.as_str(), right[1].value.as_str()); // switch left/right operands if the caller has them in reverse if left_name == tbl_b || right_name == tbl_a { - Ok((vec![col(col_b)], vec![col(col_a)])) + Ok((vec![col(col_name_b)], vec![col(col_name_a)])) } else { - Ok((vec![col(col_a)], vec![col(col_b)])) + Ok((vec![col(col_name_a)], vec![col(col_name_b)])) } } else { polars_bail!(SQLInterface: "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len()); @@ -1461,14 +1463,13 @@ fn process_join_on( ) -> PolarsResult<(Vec, Vec)> { if let SQLExpr::BinaryOp { left, op, right } = expression { match *op { - BinaryOperator::Eq => { - if let (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) = - (left.as_ref(), right.as_ref()) - { + BinaryOperator::Eq => match (left.as_ref(), right.as_ref()) { + (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => { collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name) - } else { - polars_bail!(SQLInterface: "JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); - } + }, + _ => { + polars_bail!(SQLInterface: "only equi-join constraints (on identifiers) are currently supported; found lhs={:?}, rhs={:?}", left, right); + }, }, BinaryOperator::And => { let (mut left_i, mut right_i) = process_join_on(left, tbl_left, tbl_right)?; @@ -1479,13 +1480,13 @@ fn process_join_on( Ok((left_i, right_i)) }, _ => { - polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); + polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found op = '{:?}'", op); }, } } else if let SQLExpr::Nested(expr) = expression { process_join_on(expr, tbl_left, tbl_right) } else { - polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression); + polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found expression = {:?}", expression); } } @@ -1504,7 +1505,7 @@ fn process_join_constraint( } if op != &BinaryOperator::Eq { polars_bail!(SQLInterface: - "only equi-join constraints are supported; found '{:?}' op in\n{:?}", op, constraint) + "only equi-join constraints are currently supported; found '{:?}' op in\n{:?}", op, constraint) } match (left.as_ref(), right.as_ref()) { (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => { diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 5eb2bdd843b4..9e068efb6064 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -1156,6 +1156,24 @@ pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr { } } +fn resolve_column<'a>( + ctx: &'a mut SQLContext, + ident_root: &'a Ident, + name: &'a str, + dtype: &'a DataType, +) -> PolarsResult<(Expr, Option<&'a DataType>)> { + let resolved = ctx.resolve_name(&ident_root.value, name); + let resolved = resolved.as_str(); + Ok(( + if name != resolved { + col(resolved).alias(name) + } else { + col(name) + }, + Some(dtype), + )) +} + pub(crate) fn resolve_compound_identifier( ctx: &mut SQLContext, idents: &[Ident], @@ -1182,20 +1200,11 @@ pub(crate) fn resolve_compound_identifier( let name = &remaining_idents.next().unwrap().value; if lf.is_some() && name == "*" { return Ok(schema - .iter_names() - .map(|name| col(name.clone())) + .iter_names_and_dtypes() + .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0) .collect::>()); } else if let Some((_, name, dtype)) = schema.get_full(name) { - let resolved = ctx.resolve_name(&ident_root.value, name); - let resolved = resolved.as_str(); - Ok(( - if name != resolved { - col(resolved).alias(name.clone()) - } else { - col(name.clone()) - }, - Some(dtype), - )) + resolve_column(ctx, ident_root, name, dtype) } else if lf.is_none() { remaining_idents = idents.iter().skip(1); Ok(( diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index 43c00ed8b3d5..e4d9b37d83a5 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -2,6 +2,7 @@ from io import BytesIO from pathlib import Path +from typing import Any import pytest @@ -295,10 +296,11 @@ def test_join_misc_16255() -> None: ) def test_non_equi_joins(constraint: str) -> None: # no support (yet) for non equi-joins in polars joins + # TODO: integrate awareness of new IEJoin with ( pytest.raises( SQLInterfaceError, - match=r"only equi-join constraints are supported", + match=r"only equi-join constraints are currently supported", ), pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx, ): @@ -335,6 +337,109 @@ def test_implicit_joins() -> None: ) +@pytest.mark.parametrize( + ("query", "expected"), + [ + # INNER joins + ( + "SELECT df1.* FROM df1 INNER JOIN df2 USING (a)", + {"a": [1, 3], "b": ["x", "z"], "c": [100, 300]}, + ), + ( + "SELECT df2.* FROM df1 INNER JOIN df2 USING (a)", + {"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]}, + ), + ( + "SELECT df1.* FROM df2 INNER JOIN df1 USING (a)", + {"a": [1, 3], "b": ["x", "z"], "c": [100, 300]}, + ), + ( + "SELECT df2.* FROM df2 INNER JOIN df1 USING (a)", + {"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]}, + ), + # LEFT joins + ( + "SELECT df1.* FROM df1 LEFT JOIN df2 USING (a)", + {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}, + ), + ( + "SELECT df2.* FROM df1 LEFT JOIN df2 USING (a)", + {"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]}, + ), + ( + "SELECT df1.* FROM df2 LEFT JOIN df1 USING (a)", + {"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]}, + ), + ( + "SELECT df2.* FROM df2 LEFT JOIN df1 USING (a)", + {"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}, + ), + # RIGHT joins + ( + "SELECT df1.* FROM df1 RIGHT JOIN df2 USING (a)", + {"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]}, + ), + ( + "SELECT df2.* FROM df1 RIGHT JOIN df2 USING (a)", + {"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}, + ), + ( + "SELECT df1.* FROM df2 RIGHT JOIN df1 USING (a)", + {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}, + ), + ( + "SELECT df2.* FROM df2 RIGHT JOIN df1 USING (a)", + {"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]}, + ), + # FULL joins + ( + "SELECT df1.* FROM df1 FULL JOIN df2 USING (a)", + { + "a": [1, 2, 3, None], + "b": ["x", "y", "z", None], + "c": [100, 200, 300, None], + }, + ), + ( + "SELECT df2.* FROM df1 FULL JOIN df2 USING (a)", + { + "a": [1, 3, 4, None], + "b": ["qq", "pp", "oo", None], + "c": [400, 500, 600, None], + }, + ), + ( + "SELECT df1.* FROM df2 FULL JOIN df1 USING (a)", + { + "a": [1, 2, 3, None], + "b": ["x", "y", "z", None], + "c": [100, 200, 300, None], + }, + ), + ( + "SELECT df2.* FROM df2 FULL JOIN df1 USING (a)", + { + "a": [1, 3, 4, None], + "b": ["qq", "pp", "oo", None], + "c": [400, 500, 600, None], + }, + ), + ], +) +def test_wildcard_resolution_and_join_order( + query: str, expected: dict[str, Any] +) -> None: + df1 = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}) # noqa: F841 + df2 = pl.DataFrame({"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}) # noqa: F841 + + res = pl.sql(query).collect() + assert_frame_equal( + res, + pl.DataFrame(expected), + check_row_order=False, + ) + + def test_natural_joins_01() -> None: df1 = pl.DataFrame( { @@ -481,8 +586,15 @@ def test_natural_joins_02(cols_constraint: str, expect_data: list[tuple[int]]) - @pytest.mark.parametrize( "join_clause", [ - "df2 INNER JOIN df3 ON df2.CharacterID=df3.CharacterID", - "df2 INNER JOIN (df3 INNDER JOIN df4 ON df3.CharacterID=df4.CharacterID) ON df2.CharacterID=df3.CharacterID", + """ + df2 JOIN df3 ON + df2.CharacterID = df3.CharacterID + """, + """ + df2 INNER JOIN ( + df3 JOIN df4 ON df3.CharacterID = df4.CharacterID + ) ON df2.CharacterID = df3.CharacterID + """, ], ) def test_nested_join(join_clause: str) -> None: