Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: Optimize code if else branch logic #28969

Merged
merged 14 commits into from
Nov 8, 2021
3 changes: 2 additions & 1 deletion ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ func (d *ddl) doDDLJob(ctx sessionctx.Context, job *model.Job) error {
if err != nil {
logutil.BgLogger().Error("[ddl] get history DDL job failed, check again", zap.Error(err))
continue
} else if historyJob == nil {
}
if historyJob == nil {
logutil.BgLogger().Debug("[ddl] DDL job is not in history, maybe not run", zap.Int64("jobID", jobID))
continue
}
Expand Down
19 changes: 9 additions & 10 deletions distsql/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.
_, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash]
resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback)
if resp == nil {
err := errors.New("client returns nil response")
return nil, err
return nil, errors.New("client returns nil response")
}

encodeType := tipb.EncodeType_TypeDefault
Expand Down Expand Up @@ -91,8 +90,7 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
}
resp := sctx.GetClient().Send(ctx, kvReq, sctx.GetSessionVars().KVVars, sctx.GetSessionVars().StmtCtx.MemTracker, enabledRateLimitAction, eventCb)
if resp == nil {
err := errors.New("client returns nil response")
return nil, err
return nil, errors.New("client returns nil response")
}

label := metrics.LblGeneral
Expand Down Expand Up @@ -139,13 +137,14 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
func SelectWithRuntimeStats(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request,
fieldTypes []*types.FieldType, fb *statistics.QueryFeedback, copPlanIDs []int, rootPlanID int) (SelectResult, error) {
sr, err := Select(ctx, sctx, kvReq, fieldTypes, fb)
if err == nil {
if selectResult, ok := sr.(*selectResult); ok {
selectResult.copPlanIDs = copPlanIDs
selectResult.rootPlanID = rootPlanID
}
if err != nil {
return nil, err
}
if selectResult, ok := sr.(*selectResult); ok {
selectResult.copPlanIDs = copPlanIDs
selectResult.rootPlanID = rootPlanID
}
return sr, err
return sr, nil
}

// Analyze do a analyze request.
Expand Down
85 changes: 36 additions & 49 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {
data = append(data, 0)
// auth-plugin name
if cc.ctx == nil {
err := cc.openSession()
if err != nil {
if err := cc.openSession(); err != nil {
return err
}
}
Expand All @@ -374,15 +373,13 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {

// Close the session to force this to be re-opened after we parse the response. This is needed
// to ensure we use the collation and client flags from the response for the session.
err = cc.ctx.Close()
if err != nil {
if err = cc.ctx.Close(); err != nil {
return err
}
cc.ctx = nil

data = append(data, 0)
err = cc.writePacket(data)
if err != nil {
if err = cc.writePacket(data); err != nil {
return err
}
return cc.flush(ctx)
Expand Down Expand Up @@ -1292,7 +1289,6 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
// ComProcessInfo, ComConnect, ComProcessKill, ComDebug
case mysql.ComPing:
return cc.writeOK(ctx)
// ComTime, ComDelayedInsert
case mysql.ComChangeUser:
return cc.handleChangeUser(ctx, data)
// ComBinlogDump, ComTableDump, ComConnectOut, ComRegisterSlave
Expand Down Expand Up @@ -1548,11 +1544,11 @@ func processStream(ctx context.Context, cc *clientConn, loadDataInfo *executor.L
}
if err != nil {
logutil.Logger(ctx).Error("load data process stream error", zap.Error(err))
} else {
err = loadDataInfo.EnqOneTask(ctx)
if err != nil {
logutil.Logger(ctx).Error("load data process stream error", zap.Error(err))
}
return
}
if err = loadDataInfo.EnqOneTask(ctx); err != nil {
logutil.Logger(ctx).Error("load data process stream error", zap.Error(err))
return
}
}

Expand Down Expand Up @@ -1769,7 +1765,7 @@ func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) {
var retryable bool
for i, stmt := range stmts {
if len(pointPlans) > 0 {
// Save the point plan in Session so we don't need to build the point plan again.
// Save the point plan in Session, so we don't need to build the point plan again.
cc.ctx.SetValue(plannercore.PointPlanKey, plannercore.PointPlanVal{Plan: pointPlans[i]})
}
retryable, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1)
Expand Down Expand Up @@ -1897,7 +1893,7 @@ func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.Stm
}

// The first return value indicates whether the call of handleStmt has no side effect and can be retried.
// Currently the first return value is used to fallback to TiKV when TiFlash is down.
// Currently, the first return value is used to fall back to TiKV when TiFlash is down.
func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns []stmtctx.SQLWarn, lastStmt bool) (bool, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
Expand All @@ -1914,37 +1910,33 @@ func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns [
return true, err
}

status := cc.ctx.Status()
if lastStmt {
cc.ctx.GetSessionVars().StmtCtx.AppendWarnings(warns)
}

status := cc.ctx.Status()
if !lastStmt {
} else {
status |= mysql.ServerMoreResultsExists
}

if rs != nil {
connStatus := atomic.LoadInt32(&cc.status)
if connStatus == connStatusShutdown {
if connStatus := atomic.LoadInt32(&cc.status); connStatus == connStatusShutdown {
return false, executor.ErrQueryInterrupted
}

retryable, err := cc.writeResultset(ctx, rs, false, status, 0)
if err != nil {
if retryable, err := cc.writeResultset(ctx, rs, false, status, 0); err != nil {
return retryable, err
}
} else {
handled, err := cc.handleQuerySpecial(ctx, status)
if handled {
execStmt := cc.ctx.Value(session.ExecStmtVarKey)
if execStmt != nil {
execStmt.(*executor.ExecStmt).FinishExecuteStmt(0, err, false)
}
}
if err != nil {
return false, err
return false, nil
}

handled, err := cc.handleQuerySpecial(ctx, status)
if handled {
if execStmt := cc.ctx.Value(session.ExecStmtVarKey); execStmt != nil {
execStmt.(*executor.ExecStmt).FinishExecuteStmt(0, err, false)
}
}
if err != nil {
return false, err
}

return false, nil
}

Expand Down Expand Up @@ -2045,13 +2037,13 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b
}()
cc.initResultEncoder(ctx)
defer cc.rsEncoder.clean()
var err error
if mysql.HasCursorExistsFlag(serverStatus) {
err = cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize)
} else {
retryable, err = cc.writeChunks(ctx, rs, binary, serverStatus)
if err := cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize); err != nil {
return false, err
}
return false, cc.flush(ctx)
}
if err != nil {
if retryable, err := cc.writeChunks(ctx, rs, binary, serverStatus); err != nil {
return retryable, err
}

Expand Down Expand Up @@ -2109,8 +2101,7 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool
// We need to call Next before we get columns.
// Otherwise, we will get incorrect columns info.
columns := rs.Columns()
err = cc.writeColumnInfo(columns, serverStatus)
if err != nil {
if err = cc.writeColumnInfo(columns, serverStatus); err != nil {
return false, err
}
gotColumnInfo = true
Expand Down Expand Up @@ -2151,13 +2142,11 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error {
fetchedRows := rs.GetFetchedRows()

req := rs.NewChunk(cc.chunkAlloc)
// if fetchedRows is not enough, getting data from recordSet.
for len(fetchedRows) < fetchSize {
// if fetchedRows is not enough, getting data from recordSet.
req := rs.NewChunk(cc.chunkAlloc)
// Here server.tidbResultSet implements Next method.
err := rs.Next(ctx, req)
if err != nil {
if err := rs.Next(ctx, req); err != nil {
return err
}
rowCount := req.NumRows()
Expand Down Expand Up @@ -2255,12 +2244,10 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
dbName, _ := parseNullTermString(data)
cc.dbname = string(hack.String(dbName))

err := cc.ctx.Close()
if err != nil {
if err := cc.ctx.Close(); err != nil {
logutil.Logger(ctx).Debug("close old context failed", zap.Error(err))
}
err = cc.openSessionAndDoAuth(pass, "")
if err != nil {
if err := cc.openSessionAndDoAuth(pass, ""); err != nil {
return err
}
return cc.handleCommonConnectionReset(ctx)
Expand Down
83 changes: 41 additions & 42 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
if tlsConfig != nil {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.BgLogger().Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.GetSysVar("ssl_ca").Value) > 0))
logutil.BgLogger().Info("mysql protocol server secure connection is enabled",
zap.Bool("client verification enabled", len(variable.GetSysVar("ssl_ca").Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
}
Expand All @@ -250,9 +251,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
}

if s.cfg.Socket != "" {

err := cleanupStaleSocket(s.cfg.Socket)
if err != nil {
if err := cleanupStaleSocket(s.cfg.Socket); err != nil {
return nil, errors.Trace(err)
}

Expand All @@ -272,24 +271,23 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
if proxyTarget == nil {
proxyTarget = s.socket
}
pplistener, err := proxyprotocol.NewListener(proxyTarget, s.cfg.ProxyProtocol.Networks,
ppListener, err := proxyprotocol.NewListener(proxyTarget, s.cfg.ProxyProtocol.Networks,
int(s.cfg.ProxyProtocol.HeaderTimeout))
if err != nil {
logutil.BgLogger().Error("ProxyProtocol networks parameter invalid")
return nil, errors.Trace(err)
}
if s.listener != nil {
s.listener = pplistener
s.listener = ppListener
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("host", s.cfg.Host))
} else {
s.socket = pplistener
s.socket = ppListener
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket))
}
}

if s.cfg.Status.ReportStatus {
err = s.listenStatusHTTPServer()
if err != nil {
if err = s.listenStatusHTTPServer(); err != nil {
return nil, errors.Trace(err)
}
}
Expand All @@ -304,25 +302,26 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {

func cleanupStaleSocket(socket string) error {
sockStat, err := os.Stat(socket)
if err == nil {
if sockStat.Mode().Type() != os.ModeSocket {
return fmt.Errorf(
"the specified socket file %s is a %s instead of a socket file",
socket, sockStat.Mode().String())
}
if err != nil {
return nil
}

_, err = net.Dial("unix", socket)
if err != nil {
logutil.BgLogger().Warn("Unix socket exists and is nonfunctional, removing it",
zap.String("socket", socket), zap.Error(err))
err = os.Remove(socket)
if err != nil {
return fmt.Errorf("failed to remove socket file %s", socket)
}
} else {
return fmt.Errorf("unix socket %s exists and is functional, not removing it", socket)
}
if sockStat.Mode().Type() != os.ModeSocket {
return fmt.Errorf(
"the specified socket file %s is a %s instead of a socket file",
socket, sockStat.Mode().String())
}

if _, err = net.Dial("unix", socket); err == nil {
return fmt.Errorf("unix socket %s exists and is functional, not removing it", socket)
}

logutil.BgLogger().Warn("Unix socket exists and is nonfunctional, removing it",
zap.String("socket", socket), zap.Error(err))
if err = os.Remove(socket); err != nil {
return fmt.Errorf("failed to remove socket file %s", socket)
}

return nil
}

Expand Down Expand Up @@ -363,7 +362,7 @@ func (s *Server) Run() error {
s.startStatusHTTP()
}
// If error should be reported and exit the server it can be sent on this
// channel. Otherwise end with sending a nil error to signal "done"
// channel. Otherwise, end with sending a nil error to signal "done"
errChan := make(chan error)
go s.startNetworkListener(s.listener, false, errChan)
go s.startNetworkListener(s.socket, true, errChan)
Expand Down Expand Up @@ -393,7 +392,7 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool,
}
}

// If we got PROXY protocol error, we should continue accept.
// If we got PROXY protocol error, we should continue to accept.
if proxyprotocol.IsProxyProtocolError(err) {
logutil.BgLogger().Error("PROXY protocol failed", zap.Error(err))
continue
Expand All @@ -406,7 +405,6 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool,

clientConn := s.newConn(conn)
if isUnixSocket {

uc, ok := conn.(*net.UnixConn)
if !ok {
logutil.BgLogger().Error("Expected UNIX socket, but got something else")
Expand All @@ -424,19 +422,20 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool,

err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
host, _, err := clientConn.PeerHost("")
if err != nil {
logutil.BgLogger().Error("get peer host failed", zap.Error(err))
terror.Log(clientConn.Close())
return errors.Trace(err)
}
err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth, &variable.ConnectionInfo{Host: host})
if err != nil {
logutil.BgLogger().Info("do connection event failed", zap.Error(err))
terror.Log(clientConn.Close())
return errors.Trace(err)
}
if authPlugin.OnConnectionEvent == nil {
return nil
}
host, _, err := clientConn.PeerHost("")
if err != nil {
logutil.BgLogger().Error("get peer host failed", zap.Error(err))
terror.Log(clientConn.Close())
return errors.Trace(err)
}
if err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth,
&variable.ConnectionInfo{Host: host}); err != nil {
logutil.BgLogger().Info("do connection event failed", zap.Error(err))
terror.Log(clientConn.Close())
return errors.Trace(err)
}
return nil
})
Expand Down
3 changes: 1 addition & 2 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1522,8 +1522,7 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex
}

s.PrepareTxnCtx(ctx)
err := s.loadCommonGlobalVariablesIfNeeded()
if err != nil {
if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil {
return nil, err
}

Expand Down