Skip to content

Commit

Permalink
refine implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME committed Feb 27, 2023
1 parent 7986488 commit 4b46b98
Showing 1 changed file with 105 additions and 183 deletions.
288 changes: 105 additions & 183 deletions src/frontend/src/utils/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::expr::{
try_get_bool_constant, ExprDisplay, ExprImpl, ExprMutator, ExprRewriter, ExprType, ExprVisitor,
FunctionCall, InputRef,
};
use crate::utils::condition::execpetion_cast::{ResultForCmp, ResultForEq};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Condition {
Expand Down Expand Up @@ -398,16 +399,35 @@ impl Condition {
for expr in group.clone() {
if let Some((input_ref, const_expr)) = expr.as_eq_const() {
assert_eq!(input_ref.index, order_column_ids[i]);
let always_false = Self::analyze_eq_const_expr(
expr,
&mut other_conds,
input_ref,
const_expr,
&mut eq_conds,
)?;
if always_false {
let new_expr = if let Ok(expr) = const_expr
.clone()
.cast_implicit(input_ref.data_type.clone())
{
expr
} else {
match self::execpetion_cast::cast_except_for_eq(
const_expr,
input_ref.data_type,
) {
Ok(ResultForEq::Success(expr)) => expr,
Ok(ResultForEq::NeverEqual) => {
return Ok(false_cond());
}
Err(_) => {
other_conds.push(expr);
continue;
}
}
};

let Some(new_cond) = new_expr.eval_row_const()? else {
// column = NULL, PK column never be NULL
return Ok(false_cond());
};
if Self::mutual_exclusive_with_eq_conds(&new_cond, &eq_conds) {
return Ok(false_cond());
}
eq_conds = vec![Some(new_cond)];
} else if let Some(input_ref) = expr.as_is_null() {
assert_eq!(input_ref.index, order_column_ids[i]);
if !eq_conds.is_empty() && eq_conds.into_iter().all(|l| l.is_some()) {
Expand Down Expand Up @@ -446,17 +466,54 @@ impl Condition {
eq_conds = scalars.into_iter().sorted().collect();
} else if let Some((input_ref, op, const_expr)) = expr.as_comparison_const() {
assert_eq!(input_ref.index, order_column_ids[i]);
let always_false = Self::analyze_cmp_const_expr(
expr,
&mut other_conds,
input_ref,
op,
const_expr,
&mut lb,
&mut ub,
)?;
if always_false {
let new_expr = if let Ok(expr) = const_expr
.clone()
.cast_implicit(input_ref.data_type.clone())
{
expr
} else {
match self::execpetion_cast::cast_except_for_cmp(
const_expr,
input_ref.data_type,
op,
) {
Ok(ResultForCmp::Success(expr)) => expr,
Ok(ResultForCmp::OutUpperBound) => {
if op == Type::GreaterThan || op == Type::GreaterThanOrEqual {
return Ok(false_cond());
}
continue;
}
Ok(ResultForCmp::OutLowerBound) => {
if op == Type::LessThan || op == Type::LessThanOrEqual {
return Ok(false_cond());
}
continue;
}
Err(_) => {
other_conds.push(expr);
continue;
}
}
};
let Some(value) = new_expr.eval_row_const()? else {
// column compare with NULL, PK column never be NULL
return Ok(false_cond());
};
match op {
ExprType::LessThan => {
ub.push((Bound::Excluded(value), expr));
}
ExprType::LessThanOrEqual => {
ub.push((Bound::Included(value), expr));
}
ExprType::GreaterThan => {
lb.push((Bound::Excluded(value), expr));
}
ExprType::GreaterThanOrEqual => {
lb.push((Bound::Included(value), expr));
}
_ => unreachable!(),
}
} else {
other_conds.push(expr);
Expand Down Expand Up @@ -533,46 +590,9 @@ impl Condition {
))
}

/// Analyze the expr like 'column1 = const'
///
/// return:
/// true indicate that this eq expr always be false
fn analyze_eq_const_expr(
ori_expr: ExprImpl,
other_conds: &mut Vec<ExprImpl>,
input_ref: InputRef,
const_expr: ExprImpl,
eq_conds: &mut Vec<Option<ScalarImpl>>,
) -> Result<bool> {
let expr = match const_expr
.clone()
.cast_implicit(input_ref.data_type.clone())
{
Ok(expr) => expr,
Err(_) => match self::mismatch::cast_mismatch_eq(const_expr, input_ref.data_type) {
Ok(Some(expr)) => expr,
Ok(None) => return Ok(true),
Err(_) => {
other_conds.push(ori_expr);
return Ok(false);
}
},
};

let Some(new_cond) = expr.eval_row_const()? else {
// column = NULL, PK column never be NULL
return Ok(true);
};
if Self::mutual_exclusive_with_eq_conds(&new_cond, eq_conds) {
return Ok(true);
}
*eq_conds = vec![Some(new_cond)];
Ok(false)
}

fn mutual_exclusive_with_eq_conds(
new_conds: &ScalarImpl,
eq_conds: &Vec<Option<ScalarImpl>>,
eq_conds: &[Option<ScalarImpl>],
) -> bool {
return !eq_conds.is_empty()
&& eq_conds.iter().all(|l| {
Expand All @@ -584,64 +604,6 @@ impl Condition {
});
}

/// Analyze the expr like 'column1 {<,>,>=,<=} const'
///
/// return:
/// true indicate that this expr always be false
fn analyze_cmp_const_expr(
ori_expr: ExprImpl,
other_conds: &mut Vec<ExprImpl>,
input_ref: InputRef,
op: Type,
const_expr: ExprImpl,
lb: &mut Vec<(Bound<ScalarImpl>, ExprImpl)>,
ub: &mut Vec<(Bound<ScalarImpl>, ExprImpl)>,
) -> Result<bool> {
match const_expr
.clone()
.cast_implicit(input_ref.data_type.clone())
{
Ok(expr) => {
let Some(value) = expr.eval_row_const()? else {
// column compare with NULL
return Ok(true);
};
match op {
ExprType::LessThan => {
ub.push((Bound::Excluded(value), ori_expr));
}
ExprType::LessThanOrEqual => {
ub.push((Bound::Included(value), ori_expr));
}
ExprType::GreaterThan => {
lb.push((Bound::Excluded(value), ori_expr));
}
ExprType::GreaterThanOrEqual => {
lb.push((Bound::Included(value), ori_expr));
}
_ => unreachable!(),
}
Ok(false)
}
Err(_) => {
match self::mismatch::analyze_cmp_mismatch_expr(
ori_expr.clone(),
input_ref.data_type,
op,
const_expr,
lb,
ub,
) {
Ok(res) => Ok(res),
Err(_) => {
other_conds.push(ori_expr);
Ok(false)
}
}
}
}
}

/// Split the condition expressions into `N` groups.
/// An expression `expr` is in the `i`-th group if `f(expr)==i`.
///
Expand Down Expand Up @@ -780,10 +742,8 @@ impl fmt::Debug for ConditionDisplay<'_> {
}
}

mod mismatch {
use std::ops::Bound;

use risingwave_common::types::{DataType, ScalarImpl};
mod execpetion_cast {
use risingwave_common::types::DataType;
use risingwave_pb::expr::expr_node::Type;

use crate::expr::{Expr, ExprImpl};
Expand All @@ -794,87 +754,49 @@ mod mismatch {
InRange(ExprImpl),
}

/// return None indicates that the expression is out of range
pub fn cast_mismatch_eq(
const_expr: ExprImpl,
target: DataType,
) -> Result<Option<ExprImpl>, ()> {
pub enum ResultForEq {
Success(ExprImpl),
NeverEqual,
}

pub enum ResultForCmp {
Success(ExprImpl),
OutUpperBound,
OutLowerBound,
}

pub fn cast_except_for_eq(const_expr: ExprImpl, target: DataType) -> Result<ResultForEq, ()> {
match (const_expr.return_type(), &target) {
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int32, DataType::Int16) => match shrink_int(const_expr, target)? {
ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => Ok(None),
ShrinkResult::InRange(expr) => Ok(Some(expr)),
| (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
ShrinkResult::InRange(expr) => Ok(ResultForEq::Success(expr)),
ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => {
Ok(ResultForEq::NeverEqual)
}
},
_ => Err(()),
}
}

/// return true indicates that the expression is always false.
pub fn analyze_cmp_mismatch_expr(
ori_expr: ExprImpl,
target: DataType,
op: Type,
pub fn cast_except_for_cmp(
const_expr: ExprImpl,
lb: &mut Vec<(Bound<ScalarImpl>, ExprImpl)>,
ub: &mut Vec<(Bound<ScalarImpl>, ExprImpl)>,
) -> Result<bool, ()> {
target: DataType,
_op: Type,
) -> Result<ResultForCmp, ()> {
match (const_expr.return_type(), &target) {
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int32, DataType::Int16) => {
analyze_cmp_mismatch_int(ori_expr, target, op, const_expr, lb, ub)
}
_ => Err(()),
}
}

/// return true indicates that the expression is always false.
fn analyze_cmp_mismatch_int(
ori_expr: ExprImpl,
target: DataType,
op: Type,
const_expr: ExprImpl,
lb: &mut Vec<(Bound<ScalarImpl>, ExprImpl)>,
ub: &mut Vec<(Bound<ScalarImpl>, ExprImpl)>,
) -> Result<bool, ()> {
match shrink_int(const_expr, target)? {
ShrinkResult::OutUpperBound => match op {
Type::LessThan | Type::LessThanOrEqual => Ok(false),
Type::GreaterThanOrEqual | Type::GreaterThan => Ok(true),
_ => unreachable!(),
| (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
ShrinkResult::InRange(expr) => Ok(ResultForCmp::Success(expr)),
ShrinkResult::OutUpperBound => Ok(ResultForCmp::OutUpperBound),
ShrinkResult::OutLowerBound => Ok(ResultForCmp::OutLowerBound),
},
ShrinkResult::OutLowerBound => match op {
Type::LessThan | Type::LessThanOrEqual => Ok(true),
Type::GreaterThanOrEqual | Type::GreaterThan => Ok(false),
_ => unreachable!(),
},
ShrinkResult::InRange(expr) => {
let Some(value) = expr.eval_row_const().map_err(|_|())? else {
// column compare with NULL
return Ok(true);
};
match op {
Type::LessThan => {
ub.push((Bound::Excluded(value), ori_expr));
}
Type::LessThanOrEqual => {
ub.push((Bound::Included(value), ori_expr));
}
Type::GreaterThan => {
lb.push((Bound::Excluded(value), ori_expr));
}
Type::GreaterThanOrEqual => {
lb.push((Bound::Included(value), ori_expr));
}
_ => unreachable!(),
}
Ok(false)
}
_ => Err(()),
}
}

fn shrink_int(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
fn shrink_integral(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
let (upper_bound, lowwer_bound) = match (const_expr.return_type(), &target) {
(DataType::Int64, DataType::Int32) => (i32::MAX as i64, i32::MIN as i64),
(DataType::Int64, DataType::Int16) | (DataType::Int32, DataType::Int16) => {
Expand Down

0 comments on commit 4b46b98

Please sign in to comment.