From 7f8d3944f598ff3f4028b455efa22fb78abfb409 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Mar 2024 17:15:34 +0800 Subject: [PATCH] server: start to listen after init stats complete (#51472) close pingcap/tidb#51473 --- br/pkg/mock/mock_cluster.go | 5 +- cmd/tidb-server/main.go | 6 -- pkg/domain/main_test.go | 1 + pkg/server/BUILD.bazel | 1 + .../handler/extractorhandler/extract_test.go | 6 +- .../handler/extractorhandler/main_test.go | 1 + pkg/server/handler/optimizor/main_test.go | 1 + .../handler/optimizor/optimize_trace_test.go | 5 +- .../handler/optimizor/plan_replayer_test.go | 2 +- .../optimizor/statistics_handler_test.go | 6 +- pkg/server/handler/tests/http_handler_test.go | 7 ++- pkg/server/handler/tests/main_test.go | 1 + pkg/server/http_status.go | 7 ++- pkg/server/server.go | 58 ++++++++++++++---- pkg/server/tests/commontest/main_test.go | 1 + pkg/server/tests/commontest/tidb_test.go | 38 +++++++----- pkg/server/tests/cursor/main_test.go | 1 + pkg/server/tests/main_test.go | 1 + pkg/server/tests/servertestkit/testkit.go | 16 ++--- pkg/server/tests/tls/main_test.go | 1 + pkg/server/tests/tls/tls_test.go | 61 +++++++++++-------- 21 files changed, 145 insertions(+), 81 deletions(-) diff --git a/br/pkg/mock/mock_cluster.go b/br/pkg/mock/mock_cluster.go index 4135abaa19ad0..6aa97498bee4f 100644 --- a/br/pkg/mock/mock_cluster.go +++ b/br/pkg/mock/mock_cluster.go @@ -88,6 +88,7 @@ func NewCluster() (*Cluster, error) { // Start runs a mock cluster. func (mock *Cluster) Start() error { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) mock.TiDBDriver = server.NewTiDBDriver(mock.Storage) cfg := config.NewConfig() // let tidb random select a port @@ -107,6 +108,7 @@ func (mock *Cluster) Start() error { panic(err1) } }() + <-server.RunInGoTestChan mock.DSN = waitUntilServerOnline("127.0.0.1", cfg.Status.StatusPort) return nil } @@ -181,7 +183,8 @@ func waitUntilServerOnline(addr string, statusPort uint) string { } if retry == retryTime { log.Panic("failed to connect HTTP status in every 10 ms", - zap.Int("retryTime", retryTime)) + zap.Int("retryTime", retryTime), + zap.String("url", statusURL)) } return strings.SplitAfter(dsn, "/")[0] } diff --git a/cmd/tidb-server/main.go b/cmd/tidb-server/main.go index fcadd42eba393..3092546bc7544 100644 --- a/cmd/tidb-server/main.go +++ b/cmd/tidb-server/main.go @@ -302,11 +302,6 @@ func main() { storage, dom := createStoreAndDomain(keyspaceName) svr := createServer(storage, dom) - // Register error API is not thread-safe, the caller MUST NOT register errors after initialization. - // To prevent misuse, set a flag to indicate that register new error will panic immediately. - // For regression of issue like https://github.com/pingcap/tidb/issues/28190 - terror.RegisterFinish() - exited := make(chan struct{}) signal.SetupSignalHandler(func() { svr.Close() @@ -317,7 +312,6 @@ func main() { close(exited) }) topsql.SetupTopSQL() - terror.MustNil(svr.Run(dom)) <-exited syncLog() diff --git a/pkg/domain/main_test.go b/pkg/domain/main_test.go index 9d83ca803a26d..db8c12f4f8028 100644 --- a/pkg/domain/main_test.go +++ b/pkg/domain/main_test.go @@ -24,6 +24,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() opts := []goleak.Option{ goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"), diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 674fc95232c16..75c5ae8a8355e 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -113,6 +113,7 @@ go_library( "@com_github_pingcap_kvproto//pkg/diagnosticspb", "@com_github_pingcap_kvproto//pkg/mpp", "@com_github_pingcap_kvproto//pkg/tikvpb", + "@com_github_pingcap_log//:log", "@com_github_pingcap_sysutil//:sysutil", "@com_github_prometheus_client_golang//prometheus", "@com_github_prometheus_client_golang//prometheus/promhttp", diff --git a/pkg/server/handler/extractorhandler/extract_test.go b/pkg/server/handler/extractorhandler/extract_test.go index 2ba74a48843d4..4951b8cc0301f 100644 --- a/pkg/server/handler/extractorhandler/extract_test.go +++ b/pkg/server/handler/extractorhandler/extract_test.go @@ -59,13 +59,13 @@ func TestExtractHandler(t *testing.T) { dom, err := session.GetDomain(store) require.NoError(t, err) server.SetDomain(dom) - - client.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - client.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-server2.RunInGoTestChan + client.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + client.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) client.WaitUntilServerOnline() startTime := time.Now() time.Sleep(time.Second) diff --git a/pkg/server/handler/extractorhandler/main_test.go b/pkg/server/handler/extractorhandler/main_test.go index de1e0c3c2d2cb..a9a19564bcfd9 100644 --- a/pkg/server/handler/extractorhandler/main_test.go +++ b/pkg/server/handler/extractorhandler/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/handler/optimizor/main_test.go b/pkg/server/handler/optimizor/main_test.go index 116fc9dba5b30..75598c3c3802a 100644 --- a/pkg/server/handler/optimizor/main_test.go +++ b/pkg/server/handler/optimizor/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/handler/optimizor/optimize_trace_test.go b/pkg/server/handler/optimizor/optimize_trace_test.go index 6aac71902bf2d..b10cf19a68118 100644 --- a/pkg/server/handler/optimizor/optimize_trace_test.go +++ b/pkg/server/handler/optimizor/optimize_trace_test.go @@ -49,12 +49,13 @@ func TestDumpOptimizeTraceAPI(t *testing.T) { require.NoError(t, err) server.SetDomain(dom) - client.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - client.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-server2.RunInGoTestChan + client.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + client.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) client.WaitUntilServerOnline() statsHandler := optimizor.NewStatsHandler(dom) diff --git a/pkg/server/handler/optimizor/plan_replayer_test.go b/pkg/server/handler/optimizor/plan_replayer_test.go index 3547bf80b7db7..f1afd703bfb16 100644 --- a/pkg/server/handler/optimizor/plan_replayer_test.go +++ b/pkg/server/handler/optimizor/plan_replayer_test.go @@ -97,7 +97,7 @@ func prepareServerAndClientForTest(t *testing.T, store kv.Storage, dom *domain.D err := srv.Run(nil) require.NoError(t, err) }() - + <-server.RunInGoTestChan client.Port = testutil.GetPortFromTCPAddr(srv.ListenAddr()) client.StatusPort = testutil.GetPortFromTCPAddr(srv.StatusListenerAddr()) client.WaitUntilServerOnline() diff --git a/pkg/server/handler/optimizor/statistics_handler_test.go b/pkg/server/handler/optimizor/statistics_handler_test.go index c2e2e936a8820..bf977a75fd2f8 100644 --- a/pkg/server/handler/optimizor/statistics_handler_test.go +++ b/pkg/server/handler/optimizor/statistics_handler_test.go @@ -55,13 +55,13 @@ func TestDumpStatsAPI(t *testing.T) { dom, err := session.GetDomain(store) require.NoError(t, err) server.SetDomain(dom) - - client.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - client.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-server2.RunInGoTestChan + client.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + client.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) client.WaitUntilServerOnline() statsHandler := optimizor.NewStatsHandler(dom) diff --git a/pkg/server/handler/tests/http_handler_test.go b/pkg/server/handler/tests/http_handler_test.go index 217a988f410cd..b2726b7f5a848 100644 --- a/pkg/server/handler/tests/http_handler_test.go +++ b/pkg/server/handler/tests/http_handler_test.go @@ -463,17 +463,18 @@ func (ts *basicHTTPHandlerTestSuite) startServer(t *testing.T) { cfg.Port = 0 cfg.Status.StatusPort = 0 cfg.Status.ReportStatus = true - + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) ts.server = server ts.server.SetDomain(ts.domain) go func() { err := server.Run(ts.domain) require.NoError(t, err) }() + <-server2.RunInGoTestChan + ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) ts.WaitUntilServerOnline() do, err := session.GetDomain(ts.store) diff --git a/pkg/server/handler/tests/main_test.go b/pkg/server/handler/tests/main_test.go index c74b04527702a..1ee8ef712db61 100644 --- a/pkg/server/handler/tests/main_test.go +++ b/pkg/server/handler/tests/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/http_status.go b/pkg/server/http_status.go index 72b527a7d27cf..4f6808061ab90 100644 --- a/pkg/server/http_status.go +++ b/pkg/server/http_status.go @@ -67,8 +67,13 @@ import ( const defaultStatusPort = 10080 -func (s *Server) startStatusHTTP() { +func (s *Server) startStatusHTTP() error { + err := s.initHTTPListener() + if err != nil { + return err + } go s.startHTTPServer() + return nil } func serveError(w http.ResponseWriter, status int, txt string) { diff --git a/pkg/server/server.go b/pkg/server/server.go index c7ca2dcf60384..d78db9c20faf9 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -49,6 +49,7 @@ import ( "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" + "github.com/pingcap/log" autoid "github.com/pingcap/tidb/pkg/autoid_service" "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/domain" @@ -83,6 +84,8 @@ var ( osVersion string // RunInGoTest represents whether we are run code in test. RunInGoTest bool + // RunInGoTestChan is used to control the RunInGoTest. + RunInGoTestChan chan struct{} ) func init() { @@ -289,7 +292,11 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { if s.tlsConfig != nil { s.capability |= mysql.ClientSSL } + variable.RegisterStatistics(s) + return s, nil +} +func (s *Server) initTiDBListener() (err error) { if s.cfg.Host != "" && (s.cfg.Port != 0 || RunInGoTest) { addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(int(s.cfg.Port))) tcpProto := "tcp" @@ -297,7 +304,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { tcpProto = "tcp4" } if s.listener, err = net.Listen(tcpProto, addr); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr)) if RunInGoTest && s.cfg.Port == 0 { @@ -307,18 +314,18 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { if s.cfg.Socket != "" { if err := cleanupStaleSocket(s.cfg.Socket); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", s.cfg.Socket)) } if s.socket == nil && s.listener == nil { err = errors.New("Server not configured to listen on either -socket or -host and -port") - return nil, errors.Trace(err) + return errors.Trace(err) } if s.cfg.ProxyProtocol.Networks != "" { @@ -330,7 +337,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { int(s.cfg.ProxyProtocol.HeaderTimeout), s.cfg.ProxyProtocol.Fallbackable) if err != nil { logutil.BgLogger().Error("ProxyProtocol networks parameter invalid") - return nil, errors.Trace(err) + return errors.Trace(err) } if s.listener != nil { s.listener = ppListener @@ -340,10 +347,13 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket)) } } + return nil +} +func (s *Server) initHTTPListener() (err error) { if s.cfg.Status.ReportStatus { if err = s.listenStatusHTTPServer(); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } } @@ -364,10 +374,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { logutil.BgLogger().Error("Fail to load JWKS from the path", zap.String("jwks", s.cfg.Security.AuthTokenJWKS)) } } - - variable.RegisterStatistics(s) - - return s, nil + return } func cleanupStaleSocket(socket string) error { @@ -426,7 +433,11 @@ func (s *Server) Run(dom *domain.Domain) error { // Start HTTP API to report tidb info such as TPS. if s.cfg.Status.ReportStatus { - s.startStatusHTTP() + err := s.startStatusHTTP() + if err != nil { + log.Error("failed to create the server", zap.Error(err), zap.Stack("stack")) + return err + } } if config.GetGlobalConfig().Performance.ForceInitStats && dom != nil { <-dom.StatsHandle().InitStatsDone @@ -434,15 +445,38 @@ func (s *Server) Run(dom *domain.Domain) error { // If error should be reported and exit the server it can be sent on this // channel. Otherwise, end with sending a nil error to signal "done" errChan := make(chan error, 2) + err := s.initTiDBListener() + if err != nil { + log.Error("failed to create the server", zap.Error(err), zap.Stack("stack")) + return err + } + // Register error API is not thread-safe, the caller MUST NOT register errors after initialization. + // To prevent misuse, set a flag to indicate that register new error will panic immediately. + // For regression of issue like https://github.com/pingcap/tidb/issues/28190 + terror.RegisterFinish() go s.startNetworkListener(s.listener, false, errChan) go s.startNetworkListener(s.socket, true, errChan) - err := <-errChan + if RunInGoTest && !isClosed(RunInGoTestChan) { + close(RunInGoTestChan) + } + err = <-errChan if err != nil { return err } return <-errChan } +// isClosed is to check if the channel is closed +func isClosed(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + } + + return false +} + func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, errChan chan error) { if listener == nil { errChan <- nil diff --git a/pkg/server/tests/commontest/main_test.go b/pkg/server/tests/commontest/main_test.go index b2732e844bc05..1e7d0484adf68 100644 --- a/pkg/server/tests/commontest/main_test.go +++ b/pkg/server/tests/commontest/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/tests/commontest/tidb_test.go b/pkg/server/tests/commontest/tidb_test.go index bb34da81e464b..3243728ff11cb 100644 --- a/pkg/server/tests/commontest/tidb_test.go +++ b/pkg/server/tests/commontest/tidb_test.go @@ -141,8 +141,9 @@ func TestStatusPort(t *testing.T) { cfg.Performance.TCPKeepAlive = true server, err := server2.NewServer(cfg, ts.Tidbdrv) + require.NoError(t, err) + err = server.Run(ts.Domain) require.Error(t, err) - require.Nil(t, server) } func TestMultiStatements(t *testing.T) { @@ -164,16 +165,16 @@ func TestSocketForwarding(t *testing.T) { cfg.Port = cli.Port os.Remove(cfg.Socket) cfg.Status.ReportStatus = false - + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-server2.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) defer server.Close() cli.RunTestRegression(t, func(config *mysql.Config) { @@ -197,7 +198,7 @@ func TestSocket(t *testing.T) { cfg.Status.ReportStatus = false ts := servertestkit.CreateTidbTestSuite(t) - + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) @@ -205,7 +206,7 @@ func TestSocket(t *testing.T) { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-server2.RunInGoTestChan defer server.Close() confFunc := func(config *mysql.Config) { @@ -232,15 +233,17 @@ func TestSocketAndIp(t *testing.T) { cfg.Status.ReportStatus = false ts := servertestkit.CreateTidbTestSuite(t) - + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + go func() { err := server.Run(nil) require.NoError(t, err) }() + <-server2.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) cli.WaitUntilServerCanConnect() defer server.Close() @@ -397,7 +400,7 @@ func TestOnlySocket(t *testing.T) { cfg.Status.ReportStatus = false ts := servertestkit.CreateTidbTestSuite(t) - + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) @@ -405,7 +408,7 @@ func TestOnlySocket(t *testing.T) { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-server2.RunInGoTestChan defer server.Close() require.Nil(t, server.Listener()) require.NotNil(t, server.Socket()) @@ -898,17 +901,18 @@ func TestGracefulShutdown(t *testing.T) { cfg.Status.StatusPort = 0 cfg.Status.ReportStatus = true cfg.Performance.TCPKeepAlive = true + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) require.NotNil(t, server) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) - + <-server2.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) resp, err := cli.FetchStatus("/status") // server is up require.NoError(t, err) require.Nil(t, resp.Body.Close()) @@ -2141,16 +2145,18 @@ func TestLocalhostClientMapping(t *testing.T) { cfg.Status.ReportStatus = false ts := servertestkit.CreateTidbTestSuite(t) - + server2.RunInGoTestChan = make(chan struct{}) server, err := server2.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + go func() { err := server.Run(nil) require.NoError(t, err) }() defer server.Close() + <-server2.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) cli.WaitUntilServerCanConnect() cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) diff --git a/pkg/server/tests/cursor/main_test.go b/pkg/server/tests/cursor/main_test.go index 139c2df93ac66..f16318debb3d3 100644 --- a/pkg/server/tests/cursor/main_test.go +++ b/pkg/server/tests/cursor/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/tests/main_test.go b/pkg/server/tests/main_test.go index 43deac385aea2..32ab551e50f7d 100644 --- a/pkg/server/tests/main_test.go +++ b/pkg/server/tests/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/tests/servertestkit/testkit.go b/pkg/server/tests/servertestkit/testkit.go index 9b880b0567235..725a16f49d0a0 100644 --- a/pkg/server/tests/servertestkit/testkit.go +++ b/pkg/server/tests/servertestkit/testkit.go @@ -23,7 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/server" + srv "github.com/pingcap/tidb/pkg/server" "github.com/pingcap/tidb/pkg/server/internal/testserverclient" "github.com/pingcap/tidb/pkg/server/internal/testutil" "github.com/pingcap/tidb/pkg/server/internal/util" @@ -40,8 +40,8 @@ import ( // TidbTestSuite is a test suite for tidb type TidbTestSuite struct { *testserverclient.TestServerClient - Tidbdrv *server.TiDBDriver - Server *server.Server + Tidbdrv *srv.TiDBDriver + Server *srv.Server Domain *domain.Domain Store kv.Storage } @@ -68,12 +68,11 @@ func CreateTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *TidbTestSuite require.NoError(t, err) ts.Domain, err = session.BootstrapSession(ts.Store) require.NoError(t, err) - ts.Tidbdrv = server.NewTiDBDriver(ts.Store) + ts.Tidbdrv = srv.NewTiDBDriver(ts.Store) - server, err := server.NewServer(cfg, ts.Tidbdrv) + server, err := srv.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) - ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + ts.Server = server ts.Server.SetDomain(ts.Domain) ts.Domain.InfoSyncer().SetSessionManager(ts.Server) @@ -81,6 +80,9 @@ func CreateTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *TidbTestSuite err := ts.Server.Run(nil) require.NoError(t, err) }() + <-srv.RunInGoTestChan + ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) ts.WaitUntilServerOnline() t.Cleanup(func() { diff --git a/pkg/server/tests/tls/main_test.go b/pkg/server/tests/tls/main_test.go index f440b12f122fa..b0e4571ee3062 100644 --- a/pkg/server/tests/tls/main_test.go +++ b/pkg/server/tests/tls/main_test.go @@ -33,6 +33,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() topsqlstate.EnableTopSQL() unistore.CheckResourceTagForTopSQLInGoTest = true diff --git a/pkg/server/tests/tls/tls_test.go b/pkg/server/tests/tls/tls_test.go index 8cc24f8fa0592..7093610cecefb 100644 --- a/pkg/server/tests/tls/tls_test.go +++ b/pkg/server/tests/tls/tls_test.go @@ -31,7 +31,7 @@ import ( "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/server" + tidbserver "github.com/pingcap/tidb/pkg/server" "github.com/pingcap/tidb/pkg/server/internal/testserverclient" "github.com/pingcap/tidb/pkg/server/internal/testutil" util2 "github.com/pingcap/tidb/pkg/server/internal/util" @@ -210,16 +210,18 @@ func TestTLSVerify(t *testing.T) { SSLCert: fileName("server-cert.pem"), SSLKey: fileName("server-key.pem"), } - server, err := server.NewServer(cfg, ts.Tidbdrv) + tidbserver.RunInGoTestChan = make(chan struct{}) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) defer server.Close() - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) // The client does not provide a certificate, the connection should succeed. err = cli.RunTestTLSConnection(t, nil) require.NoError(t, err) @@ -303,15 +305,16 @@ func TestTLSBasic(t *testing.T) { SSLCert: fileName("server-cert.pem"), SSLKey: fileName("server-key.pem"), } - server, err := server.NewServer(cfg, ts.Tidbdrv) + tidbserver.RunInGoTestChan = make(chan struct{}) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) err = cli.RunTestTLSConnection(t, connOverrider) // We should establish connection successfully. require.NoError(t, err) cli.RunTestRegression(t, connOverrider, "TLSRegression") @@ -370,7 +373,7 @@ func TestErrorNoRollback(t *testing.T) { SSLCert: "wrong path", SSLKey: "wrong path", } - _, err = server.NewServer(cfg, ts.Tidbdrv) + _, err = tidbserver.NewServer(cfg, ts.Tidbdrv) require.Error(t, err) // test reload tls fail with/without "error no rollback option" @@ -379,16 +382,17 @@ func TestErrorNoRollback(t *testing.T) { SSLCert: "/tmp/server-cert-rollback.pem", SSLKey: "/tmp/server-key-rollback.pem", } - server, err := server.NewServer(cfg, ts.Tidbdrv) + tidbserver.RunInGoTestChan = make(chan struct{}) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() defer server.Close() - time.Sleep(time.Millisecond * 100) + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) connOverrider := func(config *mysql.Config) { config.TLSConfig = "client-cert-rollback-test" } @@ -438,15 +442,16 @@ func TestReloadTLS(t *testing.T) { SSLCert: "/tmp/server-cert-reload.pem", SSLKey: "/tmp/server-key-reload.pem", } - server, err := server.NewServer(cfg, ts.Tidbdrv) + tidbserver.RunInGoTestChan = make(chan struct{}) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) // The client provides a valid certificate. connOverrider := func(config *mysql.Config) { config.TLSConfig = "client-certificate-reload" @@ -531,16 +536,17 @@ func TestStatusAPIWithTLS(t *testing.T) { cfg.Security.ClusterSSLCA = fileName("ca-cert-2.pem") cfg.Security.ClusterSSLCert = fileName("server-cert-2.pem") cfg.Security.ClusterSSLKey = fileName("server-key-2.pem") - server, err := server.NewServer(cfg, ts.Tidbdrv) + tidbserver.RunInGoTestChan = make(chan struct{}) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) - + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) // https connection should work. ts.RunTestStatusAPI(t) @@ -588,15 +594,17 @@ func TestStatusAPIWithTLSCNCheck(t *testing.T) { cfg.Security.ClusterSSLCert = serverCertPath cfg.Security.ClusterSSLKey = serverKeyPath cfg.Security.ClusterVerifyCN = []string{"tidb-client-2"} - server, err := server.NewServer(cfg, ts.Tidbdrv) + tidbserver.RunInGoTestChan = make(chan struct{}) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) defer server.Close() time.Sleep(time.Millisecond * 100) @@ -628,7 +636,7 @@ func TestInvalidTLS(t *testing.T) { SSLCert: "bogus-server-cert.pem", SSLKey: "bogus-server-key.pem", } - _, err := server.NewServer(cfg, ts.Tidbdrv) + _, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.Error(t, err) } @@ -645,17 +653,18 @@ func TestTLSAuto(t *testing.T) { cfg.Status.ReportStatus = false cfg.Security.AutoTLS = true cfg.Security.RSAKeySize = 528 // Reduces unittest runtime + tidbserver.RunInGoTestChan = make(chan struct{}) err := os.MkdirAll(cfg.TempStoragePath, 0700) require.NoError(t, err) - server, err := server.NewServer(cfg, ts.Tidbdrv) + server, err := tidbserver.NewServer(cfg, ts.Tidbdrv) require.NoError(t, err) server.SetDomain(ts.Domain) - cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-tidbserver.RunInGoTestChan + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) err = cli.RunTestTLSConnection(t, connOverrider) // Relying on automatically created TLS certificates require.NoError(t, err)