Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Feb 5, 2025
1 parent d536e96 commit 3693a97
Show file tree
Hide file tree
Showing 9 changed files with 389 additions and 133 deletions.
13 changes: 9 additions & 4 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,19 @@ impl Operator {
| Self::LtEq
| Self::Gt
| Self::GtEq
| Self::And
| Self::Or
| Self::Xor
| Self::EqValidity
| Self::NotEqValidity
)
}

pub fn is_bitwise(&self) -> bool {
matches!(self, Self::And | Self::Or | Self::Xor)
}

pub fn is_comparison_or_bitwise(&self) -> bool {
self.is_comparison() || self.is_bitwise()
}

pub fn swap_operands(self) -> Self {
match self {
Operator::Eq => Operator::Eq,
Expand All @@ -465,6 +470,6 @@ impl Operator {
}

pub fn is_arithmetic(&self) -> bool {
!(self.is_comparison())
!(self.is_comparison_or_bitwise())
}
}
202 changes: 165 additions & 37 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,6 @@ pub fn resolve_join(
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
ctxt.expr_arena
.get($expr.node())
.get_type($schema, Context::Default, ctxt.expr_arena)
};
}
// # Resolve scalars
//
// Scalars need to be expanded. We translate them to temporary columns added with
Expand Down Expand Up @@ -234,6 +226,15 @@ pub fn resolve_join(
(schema_left, schema_right)
};

// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
ctxt.expr_arena
.get($expr.node())
.get_type($schema, Context::Default, ctxt.expr_arena)
};
}

// # Cast lossless
//
// If we do a full join and keys are coalesced, the cast keys must be added up front.
Expand All @@ -249,15 +250,20 @@ pub fn resolve_join(
let rtype = get_dtype!(rnode, &schema_right)?;

if let Some(dtype) = get_numeric_upcast_supertype_lossless(&ltype, &rtype) {
// We use overflowing cast to allow better optimization as we are casting to a known
// lossless supertype.
//
// We have unique references to these nodes (they are created by this function),
// so we can mutate in-place without causing side effects somewhere else.
let casted_l = ctxt.expr_arena.add(AExpr::Cast {
expr: lnode.node(),
dtype: dtype.clone(),
options: CastOptions::Strict,
options: CastOptions::Overflowing,
});
let casted_r = ctxt.expr_arena.add(AExpr::Cast {
expr: rnode.node(),
dtype,
options: CastOptions::Strict,
options: CastOptions::Overflowing,
});

if key_cols_coalesced {
Expand Down Expand Up @@ -400,37 +406,12 @@ fn resolve_join_where(
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
.map_err(|e| e.context(failed_here!(join left)))?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt
let schema_left = ctxt
.lp_arena
.get(input_right)
.get(input_left)
.schema(ctxt.lp_arena)
.into_owned();

for expr in &predicates {
fn all_in_schema(
schema: &Schema,
other: Option<&Schema>,
left: &Expr,
right: &Expr,
) -> bool {
let mut iter =
expr_to_leaf_column_names_iter(left).chain(expr_to_leaf_column_names_iter(right));
iter.all(|name| {
schema.contains(name.as_str()) && other.is_none_or(|s| !s.contains(name.as_str()))
})
}

let valid = expr.into_iter().all(|e| match e {
Expr::BinaryExpr { left, op, right } if op.is_comparison() => {
!(all_in_schema(&schema_left, None, left, right)
|| all_in_schema(&schema_right, Some(&schema_left), left, right))
},
_ => true,
});
polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table")
}

let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

Expand All @@ -444,9 +425,31 @@ fn resolve_join_where(
ctxt,
)?;

let mut ae_nodes_stack = Vec::new();

let schema_merged = ctxt
.lp_arena
.get(last_node)
.schema(ctxt.lp_arena)
.into_owned();
let schema_merged = schema_merged.as_ref();

for e in predicates {
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;

debug_assert!(ae_nodes_stack.is_empty());
ae_nodes_stack.clear();
ae_nodes_stack.push(predicate.node());

process_join_where_predicate(
&mut ae_nodes_stack,
0,
schema_left.as_ref(),
schema_merged,
ctxt.expr_arena,
&mut ExprOrigin::None,
)?;

ctxt.conversion_optimizer
.push_scratch(predicate.node(), ctxt.expr_arena);

Expand All @@ -464,3 +467,128 @@ fn resolve_join_where(

Ok((last_node, join_node))
}

/// Performs validation and type-coercion on join_where predicates.
///
/// Validates for all comparison expressions / subexpressions, that:
/// 1. They reference columns from both sides.
/// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless
/// supertype (and inserts the necessary casting).
fn process_join_where_predicate(
stack: &mut Vec<Node>,
binary_expr_stack_offset: usize,
schema_left: &Schema,
schema_merged: &Schema,
expr_arena: &mut Arena<AExpr>,
column_origins: &mut ExprOrigin,
) -> PolarsResult<()> {
while stack.len() > binary_expr_stack_offset {
let ae_node = stack.pop().unwrap();
let ae = expr_arena.get(ae_node).clone();

match ae {
AExpr::Column(ref name) => {
let origin = if schema_left.contains(name) {
ExprOrigin::Left
} else if schema_merged.contains(name) {
ExprOrigin::Right
} else {
polars_bail!(ColumnNotFound: "{}", name);
};

*column_origins |= origin;
},
AExpr::BinaryExpr {
left: left_node,
op,
right: right_node,
} if op.is_comparison_or_bitwise() => {
{
let new_stack_offset = stack.len();
stack.extend([right_node, left_node]);

// Reset `column_origins` to a `None` state. We will only have 2 possible return states from
// this point:
// * Ok(()), with column_origins @ ExprOrigin::Both
// * Err(_), in which case the value of column_origins doesn't matter.
*column_origins = ExprOrigin::None;

process_join_where_predicate(
stack,
new_stack_offset,
schema_left,
schema_merged,
expr_arena,
column_origins,
)?;

if *column_origins != ExprOrigin::Both {
polars_bail!(
InvalidOperation:
"'join_where' predicate only refers to columns from a single table: {}",
node_to_expr(ae_node, expr_arena),
)
}
}

// Fetch them again in case they were rewritten.
let left = expr_arena.get(left_node).clone();
let right = expr_arena.get(right_node).clone();

let resolve_dtype = |ae: &AExpr, node: Node| -> PolarsResult<DataType> {
ae.to_dtype(schema_merged, Context::Default, expr_arena)
.map_err(|e| {
e.context(
format!(
"could not resolve dtype of join_where predicate (expr: {})",
node_to_expr(node, expr_arena),
)
.into(),
)
})
};

let dtype_left = resolve_dtype(&left, left_node)?;
let dtype_right = resolve_dtype(&right, right_node)?;

if let Some(dtype) =
get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)
.filter(|_| op.is_comparison())
{
// We have unique references to these nodes (they are created by this function),
// so we can mutate in-place without causing side effects somewhere else.
let expr = expr_arena.add(expr_arena.get(left_node).clone());
expr_arena.replace(
left_node,
AExpr::Cast {
expr,
dtype: dtype.clone(),
options: CastOptions::Overflowing,
},
);

let expr = expr_arena.add(expr_arena.get(right_node).clone());
expr_arena.replace(
right_node,
AExpr::Cast {
expr,
dtype,
options: CastOptions::Overflowing,
},
);
} else {
polars_ensure!(
dtype_left == dtype_right,
SchemaMismatch:
"datatypes of join_where comparison don't match - {} on left does not match {} on right \
(expr: {})",
dtype_left, dtype_right, node_to_expr(ae_node, expr_arena),
)
}
},
ae => ae.inputs_rev(stack),
}
}

Ok(())
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod ir_to_dsl;
feature = "json"
))]
mod scans;
mod stack_opt;
pub(crate) mod stack_opt;

use std::borrow::Cow;
use std::sync::{Arc, Mutex};
Expand Down
14 changes: 7 additions & 7 deletions crates/polars-plan/src/plans/conversion/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ macro_rules! unpack {
fn compares_cat_to_string(type_left: &DataType, type_right: &DataType, op: Operator) -> bool {
#[cfg(feature = "dtype-categorical")]
{
op.is_comparison()
op.is_comparison_or_bitwise()
&& matches_any_order!(
type_left,
type_right,
Expand Down Expand Up @@ -167,40 +167,40 @@ pub(super) fn process_binary(
match (&type_left, &type_right, op) {
#[cfg(not(feature = "dtype-categorical"))]
(DataType::String, dt, op) | (dt, DataType::String, op)
if op.is_comparison() && dt.is_primitive_numeric() =>
if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() =>
{
return Ok(None)
},
#[cfg(feature = "dtype-categorical")]
(String | Unknown(UnknownKind::Str) | Categorical(_, _), dt, op)
| (dt, Unknown(UnknownKind::Str) | String | Categorical(_, _), op)
if op.is_comparison() && dt.is_primitive_numeric() =>
if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() =>
{
return Ok(None)
},
#[cfg(feature = "dtype-categorical")]
(Unknown(UnknownKind::Str) | String | Enum(_, _), dt, op)
| (dt, Unknown(UnknownKind::Str) | String | Enum(_, _), op)
if op.is_comparison() && dt.is_primitive_numeric() =>
if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() =>
{
return Ok(None)
},
#[cfg(feature = "dtype-date")]
(Date, String | Unknown(UnknownKind::Str), op)
| (String | Unknown(UnknownKind::Str), Date, op)
if op.is_comparison() =>
if op.is_comparison_or_bitwise() =>
{
err_date_str_compare()?
},
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), String | Unknown(UnknownKind::Str), op)
| (String | Unknown(UnknownKind::Str), Datetime(_, _), op)
if op.is_comparison() =>
if op.is_comparison_or_bitwise() =>
{
err_date_str_compare()?
},
#[cfg(feature = "dtype-time")]
(Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison() => {
(Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison_or_bitwise() => {
err_date_str_compare()?
},
// structs can be arbitrarily nested, leave the complexity to the caller for now.
Expand Down
Loading

0 comments on commit 3693a97

Please sign in to comment.