Skip to content

Commit

Permalink
simplify comparison from cast smaller int type to bigger int type
Browse files Browse the repository at this point in the history
  • Loading branch information
yujun777 committed Feb 24, 2025
1 parent 3f85bc4 commit 40e9006
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@
import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
Expand Down Expand Up @@ -103,7 +107,9 @@ public static Expression simplify(ComparisonPredicate cp) {
Expression result;

// process type coercion
if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) {
if (left.getDataType().isIntegerLikeType() && right.getDataType().isIntegerLikeType()) {
result = processIntegerLikeTypeCoercion(cp, left, right);
} else if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) {
result = processFloatLikeTypeCoercion(cp, left, right);
} else if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) {
result = processDecimalV3TypeCoercion(cp, left, right);
Expand All @@ -121,6 +127,56 @@ public static Expression simplify(ComparisonPredicate cp) {
return result;
}

private static Expression processIntegerLikeTypeCoercion(ComparisonPredicate cp,
Expression left, Expression right) {
// Suppose a is integer type, for expression `a > 488120 + 10000`,
// since right type is big int (int plus int type is big int),
// then will have cast(a as bigint) > cast(488120 + 10000).
// After fold constant, will have cast(a as bigint) > big int(498120),
// since 498120 can represent as an int type, will rewrite as a > int(498120).
if (left instanceof Cast && left.getDataType().isIntegerLikeType()
&& ((Cast) left).child().getDataType().isIntegerLikeType()
&& right instanceof IntegerLikeLiteral) {
DataType castDataType = left.getDataType();
DataType childDataType = ((Cast) left).child().getDataType();
boolean castDataTypeWider = false;
for (DataType type : TypeCoercionUtils.NUMERIC_PRECEDENCE) {
if (type.equals(childDataType)) {
break;
}
if (type.equals(castDataType)) {
castDataTypeWider = true;
break;
}
}
if (castDataTypeWider) {
Optional<Pair<BigDecimal, BigDecimal>> minMaxOpt =
TypeCoercionUtils.getDataTypeMinMaxValue(childDataType);
if (minMaxOpt.isPresent()) {
BigDecimal childTypeMinValue = minMaxOpt.get().first;
BigDecimal childTypeMaxValue = minMaxOpt.get().second;
BigDecimal rightValue = ((IntegerLikeLiteral) right).getBigDecimalValue();
if (rightValue.compareTo(childTypeMinValue) >= 0 && rightValue.compareTo(childTypeMaxValue) <= 0) {
Expression newRight = null;
if (childDataType.equals(BigIntType.INSTANCE)) {
newRight = new BigIntLiteral(rightValue.longValue());
} else if (childDataType.equals(IntegerType.INSTANCE)) {
newRight = new IntegerLiteral(rightValue.intValue());
} else if (childDataType.equals(SmallIntType.INSTANCE)) {
newRight = new SmallIntLiteral(rightValue.shortValue());
} else if (childDataType.equals(TinyIntType.INSTANCE)) {
newRight = new TinyIntLiteral(rightValue.byteValue());
}
if (newRight != null) {
return cp.withChildren(((Cast) left).child(), newRight);
}
}
}
}
}
return cp;
}

private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
if (left instanceof Cast && right instanceof DateLiteral
&& ((Cast) left).getDataType().equals(right.getDataType())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public DecimalLiteral(DecimalV2Type dataType, BigDecimal value) {
}

@Override
protected BigDecimal getBigDecimalValue() {
public BigDecimal getBigDecimalValue() {
return value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public double getDouble() {
}

@Override
protected BigDecimal getBigDecimalValue() {
public BigDecimal getBigDecimalValue() {
return value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public DoubleLiteral(double value) {
}

@Override
protected BigDecimal getBigDecimalValue() {
public BigDecimal getBigDecimalValue() {
return new BigDecimal(String.valueOf(value));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public FloatLiteral(float value) {
}

@Override
protected BigDecimal getBigDecimalValue() {
public BigDecimal getBigDecimalValue() {
return new BigDecimal(String.valueOf(value));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public long getLongValue() {
}

@Override
protected BigDecimal getBigDecimalValue() {
public BigDecimal getBigDecimalValue() {
return new BigDecimal(getLongValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public double getDouble() {
}

@Override
protected BigDecimal getBigDecimalValue() {
public BigDecimal getBigDecimalValue() {
return new BigDecimal(value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ public int compareTo(ComparableLiteral other) {
+ this + " (" + dataType + ") vs " + other + " (" + ((Literal) other).dataType + ")");
}

protected abstract BigDecimal getBigDecimalValue();
public abstract BigDecimal getBigDecimalValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ void testSimplifyComparisonPredicateRule() {

}

@Test
void testIntCompIntLiteral() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
SimplifyCastRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE
)
));

Expression intSlot = new SlotReference("a", IntegerType.INSTANCE);
Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE);
Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE);

assertRewrite(new LessThan(new Cast(intSlot, BigIntType.INSTANCE), new BigIntLiteral(10L)),
new LessThan(intSlot, new IntegerLiteral(10)));
assertRewrite(new LessThan(new Cast(smallIntSlot, BigIntType.INSTANCE), new BigIntLiteral(10L)),
new LessThan(smallIntSlot, new SmallIntLiteral((short) 10)));
assertRewrite(new LessThan(new Cast(tinyIntSlot, BigIntType.INSTANCE), new BigIntLiteral(10L)),
new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 10)));
}

@Test
void testDateTimeV2CmpDateTimeV2() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
Expand Down Expand Up @@ -861,8 +882,8 @@ void testTypeRangeLimit() {
Pair.of(new DoubleLiteral(-128.1), new DecimalV3Literal(new BigDecimal("-128.1")))),
ImmutableList.of(
Pair.of(new TinyIntLiteral((byte) -128), null),
Pair.of(new SmallIntLiteral((short) -128), null),
Pair.of(new IntegerLiteral(-128), null),
Pair.of(new SmallIntLiteral((short) -128), new TinyIntLiteral((byte) -128)),
Pair.of(new IntegerLiteral(-128), new TinyIntLiteral((byte) -128)),
Pair.of(new DecimalV3Literal(new BigDecimal("-128")), new TinyIntLiteral((byte) -128)),
Pair.of(new DoubleLiteral(-128.0), new TinyIntLiteral((byte) -128))),
ImmutableList.of(
Expand Down

0 comments on commit 40e9006

Please sign in to comment.