Skip to content

Commit

Permalink
expression: introduce EvalContext to evaluate expression (pingcap#49416)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Dec 13, 2023
1 parent 3fb6b98 commit e14f66f
Show file tree
Hide file tree
Showing 78 changed files with 1,556 additions and 1,544 deletions.
1 change: 1 addition & 0 deletions pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ go_library(
"constant.go",
"constant_fold.go",
"constant_propagation.go",
"context.go",
"distsql_builtin.go",
"errors.go",
"evaluator.go",
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (desc *baseFuncDesc) GetTiPBExpr(tryWindowDesc bool) (tp tipb.ExprType) {
}

// AggFuncToPBExpr converts aggregate function to pb.
func AggFuncToPBExpr(sctx sessionctx.Context, client kv.Client, aggFunc *AggFuncDesc, storeType kv.StoreType) (*tipb.Expr, error) {
func AggFuncToPBExpr(sctx expression.EvalContext, client kv.Client, aggFunc *AggFuncDesc, storeType kv.StoreType) (*tipb.Expr, error) {
pc := expression.NewPBConverter(client, sctx)
tp := aggFunc.GetTiPBExpr(false)
if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) {
Expand Down
10 changes: 5 additions & 5 deletions pkg/expression/aggregation/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ type Aggregation interface {
GetResult(evalCtx *AggEvaluateContext) types.Datum

// CreateContext creates a new AggEvaluateContext for the aggregation function.
CreateContext(ctx sessionctx.Context) *AggEvaluateContext
CreateContext(ctx expression.EvalContext) *AggEvaluateContext

// ResetContext resets the content of the evaluate context.
ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext)
ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext)
}

// NewDistAggFunc creates new Aggregate function for mock tikv.
Expand Down Expand Up @@ -87,7 +87,7 @@ func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx

// AggEvaluateContext is used to store intermediate result when calculating aggregate functions.
type AggEvaluateContext struct {
Ctx sessionctx.Context
Ctx expression.EvalContext
DistinctChecker *distinctChecker
Count int64
Value types.Datum
Expand Down Expand Up @@ -127,15 +127,15 @@ func newAggFunc(funcName string, args []expression.Expression, hasDistinct bool)
}

// CreateContext implements Aggregation interface.
func (af *aggFunction) CreateContext(ctx sessionctx.Context) *AggEvaluateContext {
func (af *aggFunction) CreateContext(ctx expression.EvalContext) *AggEvaluateContext {
evalCtx := &AggEvaluateContext{Ctx: ctx}
if af.HasDistinct {
evalCtx.DistinctChecker = createDistinctChecker(ctx.GetSessionVars().StmtCtx)
}
return evalCtx
}

func (af *aggFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (af *aggFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
if af.HasDistinct {
evalCtx.DistinctChecker = createDistinctChecker(ctx.GetSessionVars().StmtCtx)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
package aggregation

import (
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -49,7 +49,7 @@ func (af *avgFunction) updateAvg(ctx types.Context, evalCtx *AggEvaluateContext,
return nil
}

func (af *avgFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (af *avgFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
if af.HasDistinct {
evalCtx.DistinctChecker = createDistinctChecker(ctx.GetSessionVars().StmtCtx)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/aggregation/bit_and.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package aggregation
import (
"math"

"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -27,13 +27,13 @@ type bitAndFunction struct {
aggFunction
}

func (bf *bitAndFunction) CreateContext(ctx sessionctx.Context) *AggEvaluateContext {
func (bf *bitAndFunction) CreateContext(ctx expression.EvalContext) *AggEvaluateContext {
evalCtx := bf.aggFunction.CreateContext(ctx)
evalCtx.Value.SetUint64(math.MaxUint64)
return evalCtx
}

func (*bitAndFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (*bitAndFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
evalCtx.Ctx = ctx
evalCtx.Value.SetUint64(math.MaxUint64)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/aggregation/bit_or.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package aggregation

import (
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -25,13 +25,13 @@ type bitOrFunction struct {
aggFunction
}

func (bf *bitOrFunction) CreateContext(ctx sessionctx.Context) *AggEvaluateContext {
func (bf *bitOrFunction) CreateContext(ctx expression.EvalContext) *AggEvaluateContext {
evalCtx := bf.aggFunction.CreateContext(ctx)
evalCtx.Value.SetUint64(0)
return evalCtx
}

func (*bitOrFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (*bitOrFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
evalCtx.Ctx = ctx
evalCtx.Value.SetUint64(0)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/aggregation/bit_xor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package aggregation

import (
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -25,13 +25,13 @@ type bitXorFunction struct {
aggFunction
}

func (bf *bitXorFunction) CreateContext(ctx sessionctx.Context) *AggEvaluateContext {
func (bf *bitXorFunction) CreateContext(ctx expression.EvalContext) *AggEvaluateContext {
evalCtx := bf.aggFunction.CreateContext(ctx)
evalCtx.Value.SetUint64(0)
return evalCtx
}

func (*bitXorFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (*bitXorFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
evalCtx.Ctx = ctx
evalCtx.Value.SetUint64(0)
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/expression/aggregation/concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -46,7 +45,7 @@ func (*concatFunction) writeValue(evalCtx *AggEvaluateContext, val types.Datum)
}
}

func (cf *concatFunction) initSeparator(ctx sessionctx.Context, row chunk.Row) error {
func (cf *concatFunction) initSeparator(ctx expression.EvalContext, row chunk.Row) error {
sepArg := cf.Args[len(cf.Args)-1]
sepDatum, err := sepArg.Eval(ctx, row)
if err != nil {
Expand Down Expand Up @@ -122,7 +121,7 @@ func (cf *concatFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum)
return d
}

func (cf *concatFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (cf *concatFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
if cf.HasDistinct {
evalCtx.DistinctChecker = createDistinctChecker(ctx.GetSessionVars().StmtCtx)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package aggregation

import (
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -61,7 +61,7 @@ func (cf *countFunction) Update(evalCtx *AggEvaluateContext, _ *stmtctx.Statemen
return nil
}

func (cf *countFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (cf *countFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
if cf.HasDistinct {
evalCtx.DistinctChecker = createDistinctChecker(ctx.GetSessionVars().StmtCtx)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package aggregation

import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -48,7 +48,7 @@ func (*firstRowFunction) GetResult(evalCtx *AggEvaluateContext) types.Datum {
return evalCtx.Value
}

func (*firstRowFunction) ResetContext(ctx sessionctx.Context, evalCtx *AggEvaluateContext) {
func (*firstRowFunction) ResetContext(ctx expression.EvalContext, evalCtx *AggEvaluateContext) {
evalCtx.Ctx = ctx
evalCtx.GotFirstRow = false
}
Expand Down
60 changes: 30 additions & 30 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,59 +307,59 @@ func (b *baseBuiltinFunc) getArgs() []Expression {
return b.args
}

func (*baseBuiltinFunc) vecEvalInt(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalInt(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalInt() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) vecEvalReal(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalReal(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalReal() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) vecEvalString(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalString(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalString() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) vecEvalDecimal(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalDecimal(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalDecimal() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) vecEvalTime(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalTime(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalTime() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) vecEvalDuration(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalDuration(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalDuration() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) vecEvalJSON(sessionctx.Context, *chunk.Chunk, *chunk.Column) error {
func (*baseBuiltinFunc) vecEvalJSON(EvalContext, *chunk.Chunk, *chunk.Column) error {
return errors.Errorf("baseBuiltinFunc.vecEvalJSON() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalInt(sessionctx.Context, chunk.Row) (int64, bool, error) {
func (*baseBuiltinFunc) evalInt(EvalContext, chunk.Row) (int64, bool, error) {
return 0, false, errors.Errorf("baseBuiltinFunc.evalInt() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalReal(sessionctx.Context, chunk.Row) (float64, bool, error) {
func (*baseBuiltinFunc) evalReal(EvalContext, chunk.Row) (float64, bool, error) {
return 0, false, errors.Errorf("baseBuiltinFunc.evalReal() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalString(sessionctx.Context, chunk.Row) (string, bool, error) {
func (*baseBuiltinFunc) evalString(EvalContext, chunk.Row) (string, bool, error) {
return "", false, errors.Errorf("baseBuiltinFunc.evalString() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalDecimal(sessionctx.Context, chunk.Row) (*types.MyDecimal, bool, error) {
func (*baseBuiltinFunc) evalDecimal(EvalContext, chunk.Row) (*types.MyDecimal, bool, error) {
return nil, false, errors.Errorf("baseBuiltinFunc.evalDecimal() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalTime(sessionctx.Context, chunk.Row) (types.Time, bool, error) {
func (*baseBuiltinFunc) evalTime(EvalContext, chunk.Row) (types.Time, bool, error) {
return types.ZeroTime, false, errors.Errorf("baseBuiltinFunc.evalTime() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalDuration(sessionctx.Context, chunk.Row) (types.Duration, bool, error) {
func (*baseBuiltinFunc) evalDuration(EvalContext, chunk.Row) (types.Duration, bool, error) {
return types.Duration{}, false, errors.Errorf("baseBuiltinFunc.evalDuration() should never be called, please contact the TiDB team for help")
}

func (*baseBuiltinFunc) evalJSON(sessionctx.Context, chunk.Row) (types.BinaryJSON, bool, error) {
func (*baseBuiltinFunc) evalJSON(EvalContext, chunk.Row) (types.BinaryJSON, bool, error) {
return types.BinaryJSON{}, false, errors.Errorf("baseBuiltinFunc.evalJSON() should never be called, please contact the TiDB team for help")
}

Expand Down Expand Up @@ -417,7 +417,7 @@ func (b *baseBuiltinFunc) getRetTp() *types.FieldType {
return b.tp
}

func (b *baseBuiltinFunc) equal(ctx sessionctx.Context, fun builtinFunc) bool {
func (b *baseBuiltinFunc) equal(ctx EvalContext, fun builtinFunc) bool {
funArgs := fun.getArgs()
if len(funArgs) != len(b.args) {
return false
Expand Down Expand Up @@ -484,25 +484,25 @@ type vecBuiltinFunc interface {
isChildrenVectorized() bool

// vecEvalInt evaluates this builtin function in a vectorized manner.
vecEvalInt(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalInt(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error

// vecEvalReal evaluates this builtin function in a vectorized manner.
vecEvalReal(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalReal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error

// vecEvalString evaluates this builtin function in a vectorized manner.
vecEvalString(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalString(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error

// vecEvalDecimal evaluates this builtin function in a vectorized manner.
vecEvalDecimal(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalDecimal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error

// vecEvalTime evaluates this builtin function in a vectorized manner.
vecEvalTime(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalTime(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error

// vecEvalDuration evaluates this builtin function in a vectorized manner.
vecEvalDuration(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalDuration(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error

// vecEvalJSON evaluates this builtin function in a vectorized manner.
vecEvalJSON(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
vecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error
}

// reverseBuiltinFunc evaluates the exactly one column value in the function when given a result for expression.
Expand All @@ -523,23 +523,23 @@ type builtinFunc interface {
reverseBuiltinFunc

// evalInt evaluates int result of builtinFunc by given row.
evalInt(ctx sessionctx.Context, row chunk.Row) (val int64, isNull bool, err error)
evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error)
// evalReal evaluates real representation of builtinFunc by given row.
evalReal(ctx sessionctx.Context, row chunk.Row) (val float64, isNull bool, err error)
evalReal(ctx EvalContext, row chunk.Row) (val float64, isNull bool, err error)
// evalString evaluates string representation of builtinFunc by given row.
evalString(ctx sessionctx.Context, row chunk.Row) (val string, isNull bool, err error)
evalString(ctx EvalContext, row chunk.Row) (val string, isNull bool, err error)
// evalDecimal evaluates decimal representation of builtinFunc by given row.
evalDecimal(ctx sessionctx.Context, row chunk.Row) (val *types.MyDecimal, isNull bool, err error)
evalDecimal(ctx EvalContext, row chunk.Row) (val *types.MyDecimal, isNull bool, err error)
// evalTime evaluates DATE/DATETIME/TIMESTAMP representation of builtinFunc by given row.
evalTime(ctx sessionctx.Context, row chunk.Row) (val types.Time, isNull bool, err error)
evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error)
// evalDuration evaluates duration representation of builtinFunc by given row.
evalDuration(ctx sessionctx.Context, row chunk.Row) (val types.Duration, isNull bool, err error)
evalDuration(ctx EvalContext, row chunk.Row) (val types.Duration, isNull bool, err error)
// evalJSON evaluates JSON representation of builtinFunc by given row.
evalJSON(ctx sessionctx.Context, row chunk.Row) (val types.BinaryJSON, isNull bool, err error)
evalJSON(ctx EvalContext, row chunk.Row) (val types.BinaryJSON, isNull bool, err error)
// getArgs returns the arguments expressions.
getArgs() []Expression
// equal check if this function equals to another function.
equal(sessionctx.Context, builtinFunc) bool
equal(EvalContext, builtinFunc) bool
// getRetTp returns the return type of the built-in function.
getRetTp() *types.FieldType
// setPbCode sets pbCode for signature.
Expand Down
Loading

0 comments on commit e14f66f

Please sign in to comment.