diff --git a/session/session.go b/session/session.go index 63bc1c970fe08..8c449dc14faf2 100644 --- a/session/session.go +++ b/session/session.go @@ -2091,7 +2091,7 @@ func (s *session) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFu metrics.SessionRestrictedSQLCounter.Inc() ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) - rs, err := se.ExecuteStmt(ctx, stmt) + rs, err := se.ExecuteInternalStmt(ctx, stmt) if err != nil { se.sessionVars.StmtCtx.AppendError(err) } @@ -2113,6 +2113,18 @@ func (s *session) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFu }) } +// ExecuteInternalStmt execute internal stmt +func (s *session) ExecuteInternalStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + // Restore the goroutine label by using the original ctx after execution is finished. + pprof.SetGoroutineLabels(ctx) + }() + return s.ExecuteStmt(ctx, stmtNode) +} + func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context()))