diff --git a/distsql/distsql.go b/distsql/distsql.go index cb2bbfaef5ff4..3219098b646e9 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -16,6 +16,7 @@ package distsql import ( "context" + "strconv" "unsafe" "github.com/opentracing/opentracing-go" @@ -25,6 +26,7 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/logutil" @@ -32,17 +34,18 @@ import ( "github.com/pingcap/tipb/go-tipb" "github.com/tikv/client-go/v2/tikvrpc/interceptor" "go.uber.org/zap" + "google.golang.org/grpc/metadata" ) // DispatchMPPTasks dispatches all tasks and returns an iterator. func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int, startTs uint64) (SelectResult, error) { ctx = WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx) _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + ctx = SetTiFlashMaxThreadsInContext(ctx, sctx) resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback, startTs) if resp == nil { return nil, errors.New("client returns nil response") } - encodeType := tipb.EncodeType_TypeDefault if canUseChunkRPC(sctx) { encodeType = tipb.EncodeType_TypeChunk @@ -97,6 +100,11 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie EventCb: eventCb, EnableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo, } + + if kvReq.StoreType == kv.TiFlash { + ctx = SetTiFlashMaxThreadsInContext(ctx, sctx) + } + resp := sctx.GetClient().Send(ctx, kvReq, sctx.GetSessionVars().KVVars, option) if resp == nil { return nil, errors.New("client returns nil response") @@ -141,6 +149,14 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie }, nil } +// SetTiFlashMaxThreadsInContext set the config TiFlash max threads in context. +func SetTiFlashMaxThreadsInContext(ctx context.Context, sctx sessionctx.Context) context.Context { + if sctx.GetSessionVars().TiFlashMaxThreads != -1 { + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxTiFlashThreads, strconv.FormatInt(sctx.GetSessionVars().TiFlashMaxThreads, 10)) + } + return ctx +} + // SelectWithRuntimeStats sends a DAG request, returns SelectResult. // The difference from Select is that SelectWithRuntimeStats will set copPlanIDs into selectResult, // which can help selectResult to collect runtime stats. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 431a51a221922..63afe1476e710 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -626,6 +626,11 @@ type SessionVars struct { // Note if you want to set `enforceMPPExecution` to `true`, you must set `allowMPPExecution` to `true` first. enforceMPPExecution bool + // TiFlashMaxThreads is the maximum number of threads to execute the request which is pushed down to tiflash. + // Default value is -1, means it will not be pushed down to tiflash. + // If the value is bigger than -1, it will be pushed down to tiflash and used to create db context in tiflash. + TiFlashMaxThreads int64 + // TiDBAllowAutoRandExplicitInsert indicates whether explicit insertion on auto_random column is allowed. AllowAutoRandExplicitInsert bool @@ -1305,6 +1310,7 @@ func NewSessionVars() *SessionVars { vars.allowMPPExecution = DefTiDBAllowMPPExecution vars.HashExchangeWithNewCollation = DefTiDBHashExchangeWithNewCollation vars.enforceMPPExecution = DefTiDBEnforceMPPExecution + vars.TiFlashMaxThreads = DefTiFlashMaxThreads vars.MPPStoreFailTTL = DefTiDBMPPStoreFailTTL enableChunkRPC := "0" diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index ca9b52493d53e..839775e71ba50 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -127,6 +127,10 @@ var defaultSysVars = []*SysVar{ s.enforceMPPExecution = TiDBOptOn(val) return nil }}, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBMaxTiFlashThreads, Type: TypeInt, Value: strconv.Itoa(DefTiFlashMaxThreads), MinValue: -1, MaxValue: MaxConfigurableConcurrency, SetSession: func(s *SessionVars, val string) error { + s.TiFlashMaxThreads = TidbOptInt64(val, DefTiFlashMaxThreads) + return nil + }}, {Scope: ScopeSession, Name: TiDBSnapshot, Value: "", skipInit: true, SetSession: func(s *SessionVars, val string) error { err := setSnapshotTS(s, val) if err != nil { diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 0010b6f876c71..d852c47ea6870 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -323,6 +323,10 @@ const ( // Note if you want to set `tidb_enforce_mpp` to `true`, you must set `tidb_allow_mpp` to `true` first. TiDBEnforceMPPExecution = "tidb_enforce_mpp" + // TiDBMaxTiFlashThreads is the maximum number of threads to execute the request which is pushed down to tiflash. + // Default value is -1, means it will not be pushed down to tiflash. + // If the value is bigger than -1, it will be pushed down to tiflash and used to create db context in tiflash. + TiDBMaxTiFlashThreads = "tidb_max_tiflash_threads" // TiDBMPPStoreFailTTL is the unavailable time when a store is detected failed. During that time, tidb will not send any task to // TiFlash even though the failed TiFlash node has been recovered. TiDBMPPStoreFailTTL = "tidb_mpp_store_fail_ttl" @@ -756,6 +760,7 @@ const ( DefTiDBAllowMPPExecution = true DefTiDBHashExchangeWithNewCollation = true DefTiDBEnforceMPPExecution = false + DefTiFlashMaxThreads = -1 DefTiDBMPPStoreFailTTL = "60s" DefTiDBTxnMode = "" DefTiDBRowFormatV1 = 1