Skip to content

Commit

Permalink
[opt](Nereids) polish aggregate function signature matching (apache#3…
Browse files Browse the repository at this point in the history
…9352)

pick from master apache#39352

use double to match string
- stddev
- stddev_samp

use largeint to match string
- group_bit_and
- group_bit_or
- group_git_xor

use double to match decimalv3
- topn_weighted

optimize error message
- multi_distinct_sum
  • Loading branch information
morrySnow committed Aug 16, 2024
1 parent f0da2ff commit 2c5ee78
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ public class AvgWeighted extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
public class BitmapAgg extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(BigIntType.INSTANCE)
);
FunctionSignature.ret(BitmapType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(TinyIntType.INSTANCE)
);

public BitmapAgg(Expression arg0) {
super("bitmap_agg", arg0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ public CollectList(boolean distinct, Expression arg) {
super("collect_list", distinct, arg);
}

@Override
public FunctionSignature computeSignature(FunctionSignature signature) {
signature = signature.withReturnType(ArrayType.of(getArgumentType(0)));
return super.computeSignature(signature);
}

/**
* withDistinctAndChildren.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ public class GroupBitAnd extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class GroupBitOr extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class GroupBitXor extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,16 @@
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/** MultiDistinctSum */
public class MultiDistinctSum extends NullableAggregateFunction implements UnaryExpression,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
);

public MultiDistinctSum(Expression arg0) {
super("multi_distinct_sum", true, false, arg0);
}
Expand All @@ -57,8 +48,10 @@ public MultiDistinctSum(boolean distinct, boolean alwaysNullable, Expression arg

@Override
public void checkLegalityBeforeTypeCoercion() {
if (child().getDataType().isDateLikeType()) {
throw new AnalysisException("Sum in multi distinct functions do not support Date/Datetime type");
DataType argType = child().getDataType();
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum requires a numeric or boolean parameter: " + this.toSql());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public class Stddev extends NullableAggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public class StddevSamp extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,25 @@ public class TopNWeighted extends AggregateFunction
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
// three arguments
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(DoubleType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DecimalV2Type.CATALOG_DEFAULT))
.args(DecimalV2Type.CATALOG_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateType.INSTANCE))
.args(DateType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateTimeType.INSTANCE))
Expand All @@ -78,31 +79,35 @@ public class TopNWeighted extends AggregateFunction
.args(DateV2Type.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateTimeV2Type.SYSTEM_DEFAULT))
.args(DateTimeV2Type.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(StringType.INSTANCE))
.args(StringType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT))
.args(VarcharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),

// four arguments
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(DoubleType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(DecimalV2Type.CATALOG_DEFAULT,
BigIntType.INSTANCE,
IntegerType.INSTANCE,
IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateType.INSTANCE))
.args(DateType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateTimeType.INSTANCE))
Expand All @@ -114,10 +119,12 @@ public class TopNWeighted extends AggregateFunction
BigIntType.INSTANCE,
IntegerType.INSTANCE,
IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(StringType.INSTANCE))
.args(StringType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE)
.args(StringType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT))
.args(VarcharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public class Variance extends NullableAggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public class VarianceSamp extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Loading

0 comments on commit 2c5ee78

Please sign in to comment.