From 1cc8cb4e76734f208083d37cb1d10e392c406a4b Mon Sep 17 00:00:00 2001 From: Darin Peshev Date: Thu, 13 May 2021 13:28:58 -0700 Subject: [PATCH] ccl/sqlproxyccl: CC code migration to DB Previsouly the sql proxy code was in the CC repo. This was making the testing of the proxy against a live SQL server hard and was also requiring a frequent cockroach repo bumps in case of changes. This moves all the code from the CC report to the DB repo so now the proxy is part of the cockroach executable. More detailed list of changed: * The old, sample star-proxy code has been retired in favor of the code moving over from the CC repo. * The code that handles individual connections to the backend has been separated into a new ProxyHandler. Added tests for the proxy handler. * BackendConfig has been retired. * Using stop.Stopper to control the shutdown of the proxy. * Added a command under mt that can be used to run the test directory server. * Added proxy options to control idle timeout, rate limits, config options, use of directory server etc. * Added code to monitor and handle os signals (HUP, TERM, INT). * Intergated the cert manager so the certificates can be reloaded on external signal. * Fixed the SQL tenant process so now the idle timeout causes the stopper to quiesce and the process to terminate successfuly. * Set up the logging for the new proxy. * Added a self-signed cert type to the cert manager to be used when testing secure connections witout generating explicit key/cert files. * Moved the HookGlobal code from CC that can be used for temporary hooks during testing. Here is how to test end to end the proxy, SQL tenant and host server, using the test directory: ``` # Start a host server in insecure mode. Tenants should already have been configured. ./cockroach start-single-node --insecure --log="{sinks: {stderr: {filter: info}}}" # Start a test directory server ./cockroach mt test-directory --port 36257 --log="{sinks: {stderr: {filter: info}}}" # Start a proxy server that uses the directory server ./cockroach mt start-proxy --directory=:36257 --listen-metrics=:8081 --log="{sinks: {stderr: {filter: info}}}" --insecure # Start a SQL client for one of the tenants. ./cockroach sql --url="postgresql://sqladmin@127.0.0.1:46257/dim-dog-2.defaultdb" --insecure ``` Release note: None --- .github/CODEOWNERS | 3 + pkg/BUILD.bazel | 4 +- pkg/ccl/cliccl/BUILD.bazel | 5 - pkg/ccl/cliccl/mtproxy.go | 188 ----- pkg/ccl/sqlproxyccl/BUILD.bazel | 22 +- pkg/ccl/sqlproxyccl/authentication.go | 27 +- pkg/ccl/sqlproxyccl/authentication_test.go | 8 +- pkg/ccl/sqlproxyccl/backend_dialer.go | 37 +- pkg/ccl/sqlproxyccl/error.go | 86 +- pkg/ccl/sqlproxyccl/errorcode_string.go | 46 +- pkg/ccl/sqlproxyccl/frontend_admitter.go | 21 +- pkg/ccl/sqlproxyccl/frontend_admitter_test.go | 16 +- .../{ => idle}/idle_disconnect_connection.go | 26 +- .../idle_disconnect_connection_test.go | 4 +- pkg/ccl/sqlproxyccl/metrics.go | 50 +- pkg/ccl/sqlproxyccl/proxy.go | 257 +----- pkg/ccl/sqlproxyccl/proxy_handler.go | 613 +++++++++++++++ pkg/ccl/sqlproxyccl/proxy_handler_test.go | 735 ++++++++++++++++++ pkg/ccl/sqlproxyccl/proxy_test.go | 671 ---------------- pkg/ccl/sqlproxyccl/server.go | 59 +- pkg/ccl/sqlproxyccl/server_test.go | 16 +- pkg/ccl/sqlproxyccl/tenant/BUILD.bazel | 30 +- .../sqlproxyccl/tenant/test_directory_svr.go | 6 +- pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel | 30 + .../directory_test.go | 42 +- .../{tenant => tenantdirsvr}/main_test.go | 2 +- .../{admitter => throttler}/BUILD.bazel | 19 + .../{admitter => throttler}/local.go | 6 +- .../{admitter => throttler}/local_test.go | 2 +- .../mocks_generated.go | 4 +- .../{admitter => throttler}/service.go | 8 +- pkg/cli/BUILD.bazel | 48 +- pkg/cli/cliflags/flags_mt.go | 70 ++ pkg/cli/context.go | 30 + pkg/cli/flags.go | 22 + pkg/cli/mt.go | 7 +- pkg/cli/mt_proxy.go | 172 ++++ pkg/cli/mt_start_sql.go | 8 +- pkg/cli/mt_test_directory.go | 53 ++ pkg/cli/start.go | 4 +- pkg/security/certmgr/BUILD.bazel | 3 + pkg/security/certmgr/cert.go | 7 +- pkg/security/certmgr/file_cert_test.go | 1 - pkg/security/certmgr/mocks_generated.go | 15 + pkg/security/certmgr/self_signed_cert.go | 103 +++ pkg/security/certmgr/self_signed_cert_test.go | 44 ++ pkg/server/server.go | 10 +- pkg/server/status.go | 2 +- pkg/testutils/BUILD.bazel | 1 + pkg/testutils/hook.go | 24 + .../lint/passes/fmtsafe/functions.go | 2 +- pkg/util/grpcutil/grpc_log.go | 2 +- pkg/util/log/fluent_client_test.go | 2 +- pkg/util/netutil/net.go | 8 +- 54 files changed, 2309 insertions(+), 1372 deletions(-) delete mode 100644 pkg/ccl/cliccl/mtproxy.go rename pkg/ccl/sqlproxyccl/{ => idle}/idle_disconnect_connection.go (74%) rename pkg/ccl/sqlproxyccl/{ => idle}/idle_disconnect_connection_test.go (97%) create mode 100644 pkg/ccl/sqlproxyccl/proxy_handler.go create mode 100644 pkg/ccl/sqlproxyccl/proxy_handler_test.go delete mode 100644 pkg/ccl/sqlproxyccl/proxy_test.go create mode 100644 pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel rename pkg/ccl/sqlproxyccl/{tenant => tenantdirsvr}/directory_test.go (90%) rename pkg/ccl/sqlproxyccl/{tenant => tenantdirsvr}/main_test.go (98%) rename pkg/ccl/sqlproxyccl/{admitter => throttler}/BUILD.bazel (69%) rename pkg/ccl/sqlproxyccl/{admitter => throttler}/local.go (96%) rename pkg/ccl/sqlproxyccl/{admitter => throttler}/local_test.go (99%) rename pkg/ccl/sqlproxyccl/{admitter => throttler}/mocks_generated.go (94%) rename pkg/ccl/sqlproxyccl/{admitter => throttler}/service.go (70%) create mode 100644 pkg/cli/mt_proxy.go create mode 100644 pkg/cli/mt_test_directory.go create mode 100644 pkg/security/certmgr/self_signed_cert.go create mode 100644 pkg/security/certmgr/self_signed_cert_test.go create mode 100644 pkg/testutils/hook.go diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index aa150a844f99..8e3b58d7874d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -45,6 +45,9 @@ /pkg/cli/import_test.go @cockroachdb/cli-prs @cockroachdb/bulk-prs /pkg/cli/sql*.go @cockroachdb/cli-prs @cockroachdb/sql-experience /pkg/cli/start*.go @cockroachdb/cli-prs @cockroachdb/server-prs +/pkg/cli/mt_proxy.go @cockroachdb/sqlproxy-prs @cockroachdb/server-prs +/pkg/cli/mt_start_sql.go @cockroachdb/sqlproxy-prs @cockroachdb/server-prs +/pkg/cli/mt_test_directory.go @cockroachdb/sqlproxy-prs @cockroachdb/server-prs /pkg/cli/connect*.go @cockroachdb/cli-prs @cockroachdb/server-prs /pkg/cli/init.go @cockroachdb/cli-prs @cockroachdb/server-prs diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 3b0d5fc60377..4f12e44b81d7 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -23,10 +23,10 @@ ALL_TESTS = [ "//pkg/ccl/oidcccl:oidcccl_test", "//pkg/ccl/partitionccl:partitionccl_test", "//pkg/ccl/serverccl:serverccl_test", - "//pkg/ccl/sqlproxyccl/admitter:admitter_test", "//pkg/ccl/sqlproxyccl/cache:cache_test", "//pkg/ccl/sqlproxyccl/denylist:denylist_test", - "//pkg/ccl/sqlproxyccl/tenant:tenant_test", + "//pkg/ccl/sqlproxyccl/tenantdirsvr:tenantdirsvr_test", + "//pkg/ccl/sqlproxyccl/throttler:admitter_test", "//pkg/ccl/sqlproxyccl:sqlproxyccl_test", "//pkg/ccl/storageccl/engineccl:engineccl_test", "//pkg/ccl/storageccl:storageccl_test", diff --git a/pkg/ccl/cliccl/BUILD.bazel b/pkg/ccl/cliccl/BUILD.bazel index ac7b4850565f..840b02e1b104 100644 --- a/pkg/ccl/cliccl/BUILD.bazel +++ b/pkg/ccl/cliccl/BUILD.bazel @@ -7,7 +7,6 @@ go_library( "debug.go", "debug_backup.go", "demo.go", - "mtproxy.go", "start.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/cliccl", @@ -19,7 +18,6 @@ go_library( "//pkg/ccl/backupccl", "//pkg/ccl/baseccl", "//pkg/ccl/cliccl/cliflagsccl", - "//pkg/ccl/sqlproxyccl", "//pkg/ccl/storageccl", "//pkg/ccl/storageccl/engineccl/enginepbccl:enginepbccl_go_proto", "//pkg/ccl/workloadccl/cliccl", @@ -53,12 +51,9 @@ go_library( "//pkg/util/timeutil/pgdate", "//pkg/util/uuid", "@com_github_cockroachdb_apd_v2//:apd", - "@com_github_cockroachdb_cmux//:cmux", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//oserror", - "@com_github_jackc_pgproto3_v2//:pgproto3", "@com_github_spf13_cobra//:cobra", - "@org_golang_x_sync//errgroup", ], ) diff --git a/pkg/ccl/cliccl/mtproxy.go b/pkg/ccl/cliccl/mtproxy.go deleted file mode 100644 index e297f3188035..000000000000 --- a/pkg/ccl/cliccl/mtproxy.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2020 The Cockroach Authors. -// -// Licensed as a CockroachDB Enterprise file under the Cockroach Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt - -package cliccl - -import ( - "context" - "crypto/tls" - "io/ioutil" - "net" - "strings" - - "github.com/cockroachdb/cmux" - "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" - "github.com/cockroachdb/cockroach/pkg/cli" - "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" - "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" -) - -var sqlProxyListenAddr, sqlProxyTargetAddr string -var sqlProxyListenCert, sqlProxyListenKey string - -func init() { - startSQLProxyCmd := &cobra.Command{ - Use: "start-proxy ", - Short: "start-proxy host:port", - Long: `Starts a SQL proxy for testing purposes. - -This proxy provides very limited functionality. It accepts incoming connections -and relays them to the specified backend server after verifying that at least -one of the following holds: - -1. the supplied database name is prefixed with 'prancing-pony.'; this prefix - will then be removed for the connection to the backend server, and/or -2. the options parameter is 'prancing-pony'. - -Connections to the target address use TLS but do not identify the identity of -the peer, making them susceptible to MITM attacks. -`, - RunE: cli.MaybeDecorateGRPCError(runStartSQLProxy), - Args: cobra.NoArgs, - } - f := startSQLProxyCmd.Flags() - f.StringVarP(&sqlProxyListenCert, "listen-cert", "", "", "Certificate file to use for listener (auto-generate if empty)") - f.StringVarP(&sqlProxyListenKey, "listen-key", "", "", "Private key file to use for listener(auto-generate if empty)") - f.StringVarP(&sqlProxyListenAddr, "listen-addr", "", "127.0.0.1:46257", "Address for incoming connections") - f.StringVarP(&sqlProxyTargetAddr, "target-addr", "", "127.0.0.1:26257", "Address for outgoing connections") - cli.AddMTCommand(startSQLProxyCmd) -} - -func runStartSQLProxy(*cobra.Command, []string) error { - // openssl genrsa -out testserver.key 2048 - // openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt -days 3650 - // ^-- Enter * as Common Name below, rest can be empty. - certBytes := []byte(`-----BEGIN CERTIFICATE----- -MIICpDCCAYwCCQDWdkou+YTT/DANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls -b2NhbGhvc3QwHhcNMjAwNTIwMTQxMjIyWhcNMzAwNTE4MTQxMjIyWjAUMRIwEAYD -VQQDDAlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDG -H6V5TZjppRR61azexRtKLOnftVpO0+CPslHynbWwrJ6sxZIdglUWoCT3a/93tq0h -SaMWIxH+29wDXiICCTbr485h3Sov5Rq7kV/AwcLMOpdqjbN2PRBW95aq8rV3h/Ui -K3hu8OZjeh4DzhxsDWYwLG+1aUHnzpwDVIvXqiiKHVtT3WLDHRNUAuph9o/4Fao0 -m1KAzXvfbnNyMUWTAOUCIX2tlq79rEIAKOySCKDr07TuVrzCKcF5sbXkFXlmyFNl -KbmXRuD3UxghxLMmUZar7eZR84x6R/Rj5Dqyrs3nfl+/30Zk0pNe6naaKO39zqlR -rWQIqwSZrY1HwGGeVJFjAgMBAAEwDQYJKoZIhvcNAQELBQADggEBACluo7vP0kXd -uXD3joPKiMJ0FgZqeDtuSvBPfl0okqPN+bk/Huqu+FgxfChCs+2EcreGFxshjzuv -J58ogFq1YMB4pS4GlqarHE+UdliOobD+OyvX40w9lTJ2wI+v7kI79udFE+tyLIs6 -YkuzFd1nB0Zcf8QFzyPRTVXVpsWid3ZvARDakp4z7klPLnkfVrXo/ivlKqGF+Ymy -vJ/riLR01omTVi6W40cml/H4DAtG/XVsQeFXWpjUv97MWGRVYycmpCleVkK+uC2x -XAEi/UMoPhhJd6HEWG+56IkFFoN4lNtPuyal0vzOJCn70pgQx3yKh61RQcPrJMlD -m9qz1xbrzj8= ------END CERTIFICATE----- -`) - keyBytes := []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAxh+leU2Y6aUUetWs3sUbSizp37VaTtPgj7JR8p21sKyerMWS -HYJVFqAk92v/d7atIUmjFiMR/tvcA14iAgk26+POYd0qL+Uau5FfwMHCzDqXao2z -dj0QVveWqvK1d4f1Iit4bvDmY3oeA84cbA1mMCxvtWlB586cA1SL16ooih1bU91i -wx0TVALqYfaP+BWqNJtSgM17325zcjFFkwDlAiF9rZau/axCACjskgig69O07la8 -winBebG15BV5ZshTZSm5l0bg91MYIcSzJlGWq+3mUfOMekf0Y+Q6sq7N535fv99G -ZNKTXup2mijt/c6pUa1kCKsEma2NR8BhnlSRYwIDAQABAoIBADFpiSKUyNNU2aO9 -EO1KaYD5bKbfmxNX4oTUK33/+WWD19stN0Dm1YPcEvwmUkOwKsPHksYdnwpaGSg5 -3O93DtyMJ1ffCfuB/0XSfvgbGxNGdacciiquFhoqi8g82idioC+SeenpaPxcY4n9 -aLdGLDtNidrL0qUWsXBfMLVr+cpgENPMiri31CGLNpfO1b4icdQjiltEn70To2Al -68Ptar60m/lJzf8QMFSf499/W3b7fLjGFK+Gzump94xAVMd7HhACf42ZpWRPe1Ih -lyHP6D0091cIRhGxIZrhLToSuySpf1A+C/rQYTqzEPv/a3b4Ja6poulpBppwJyDa -roC4KtkCgYEA/h0HRzekvNAg2cOV/Ag1RyE4AmdyCDubBHcPNJi9kI/RsA0RgurO -pr2oET0HWTENgE4e4hYQnlqUvTXYisvtvhigiCkcynpGoMJa5Y4St7J1QKdQtuwY -vcRqOGGSKl73biK79+BIV/6swWCkB+VzoGzKP8dY/XZHsI0FDdnia8UCgYEAx5gz -9qfzfiqOQP/GN6vGIzoldGxCDCHyyEZmvy0FiBlMJK36Qkjtz48eqYEXOUCX8Z5X -gB583iMv72Oz/wmefoIjnUd9uXyMqvhnYxG4vQhU4a83K4q4TPkNd7+sLiNqxIq2 -o2jT6BktOHE5OiICeFGMFOfHtsyV78JMsuzUEwcCgYAXU5LXdsQokPJzCwE5oYdC -gEoj7lsJZm9UeZlrupmsK4eUIZ755ZQSulYzPubtyRL0NDehiWT9JFODCu5Vz2KD -kL8rwJpj+9V/7Fdrux78veUFilZedE3RHbaidlJ0kUMlWQroNi5t5XL2TWjBUM7M -azAlqqcAnVr3WfqcyuN+AQKBgQCsz+xV6I7bMy9dudc+hlyUTZj2V3FMHeyeWM5H -QkzizLxvma7vy0MUDd/HdTzNVk74ZVdvV3ZXwvGS/Klw7TwsXrNFTwvdGKiWs2KY -lVR1XwxXJyTGb2IpSw3NG8iRXhroNw3xKCcpcvsDPo0E90NaN4jo5NG3RSWgpINR -+9mW6wKBgCze3gZB6AU0W/Fluql88ANx/+nqZhrsJyfrv87AkgDtt0o1tOj+emKR -Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2 -/ur8gv24YZJvV7OvPhw1SAuYL7MKMsfTW4TEKWTfkZWvm4YfZNmR ------END RSA PRIVATE KEY----- -`) - - if (sqlProxyListenKey == "") != (sqlProxyListenCert == "") { - return errors.New("must specify either both or neither of cert and key") - } - - if sqlProxyListenCert != "" { - var err error - certBytes, err = ioutil.ReadFile(sqlProxyListenCert) - if err != nil { - return err - } - } - if sqlProxyListenKey != "" { - var err error - keyBytes, err = ioutil.ReadFile(sqlProxyListenKey) - if err != nil { - return err - } - } - - cer, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - return err - } - - ln, err := net.Listen("tcp", sqlProxyListenAddr) - if err != nil { - return err - } - defer func() { _ = ln.Close() }() - - // Multiplex the listen address to give easy access to metrics from this - // command. - mux := cmux.New(ln) - httpLn := mux.Match(cmux.HTTP1Fast()) - proxyLn := mux.Match(cmux.Any()) - - outgoingConf := &tls.Config{ - InsecureSkipVerify: true, - } - server := sqlproxyccl.NewServer(sqlproxyccl.Options{ - FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return sqlproxyccl.FrontendAdmit( - incoming, - &tls.Config{ - Certificates: []tls.Certificate{cer}, - }, - ) - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - params := msg.Parameters - const magic = "prancing-pony" - if strings.HasPrefix(params["database"], magic+".") { - params["database"] = params["database"][len(magic)+1:] - } else if params["options"] != "--cluster="+magic { - return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) - } - conn, err := sqlproxyccl.BackendDial(msg, sqlProxyTargetAddr, outgoingConf) - if err != nil { - return nil, err - } - return conn, nil - }, - }) - - group, ctx := errgroup.WithContext(context.Background()) - - group.Go(func() error { - return server.ServeHTTP(ctx, httpLn) - }) - - group.Go(func() error { - return server.Serve(proxyLn) - }) - - group.Go(func() error { - return mux.Serve() - }) - - return group.Wait() -} diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 4c229fbc6687..03af6c4ea7f9 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -11,20 +11,32 @@ go_library( "idle_disconnect_connection.go", "metrics.go", "proxy.go", + "proxy_handler.go", "server.go", ":gen-errorcode-stringer", # keep ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl", visibility = ["//visibility:public"], deps = [ + "//pkg/ccl/sqlproxyccl/cache", + "//pkg/ccl/sqlproxyccl/denylist", + "//pkg/ccl/sqlproxyccl/tenant", + "//pkg/ccl/sqlproxyccl/throttler", + "//pkg/roachpb", + "//pkg/security/certmgr", + "//pkg/util", "//pkg/util/contextutil", "//pkg/util/httputil", "//pkg/util/log", "//pkg/util/metric", + "//pkg/util/retry", + "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", "@com_github_jackc_pgproto3_v2//:pgproto3", + "@org_golang_google_grpc//:go_default_library", ], ) @@ -36,7 +48,7 @@ go_test( "frontend_admitter_test.go", "idle_disconnect_connection_test.go", "main_test.go", - "proxy_test.go", + "proxy_handler_test.go", "server_test.go", ], data = [ @@ -48,16 +60,24 @@ go_test( tags = ["broken_in_bazel"], deps = [ "//pkg/base", + "//pkg/ccl/kvccl/kvtenantccl", + "//pkg/ccl/sqlproxyccl/tenant", "//pkg/ccl/utilccl", + "//pkg/roachpb", "//pkg/security", "//pkg/security/securitytest", "//pkg/server", + "//pkg/sql", + "//pkg/sql/pgwire", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", "//pkg/util/leaktest", + "//pkg/util/log", "//pkg/util/randutil", + "//pkg/util/stop", + "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index 4f09a93e5c69..bee246503b47 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -14,7 +14,10 @@ import ( "github.com/jackc/pgproto3/v2" ) -func authenticate(clientConn, crdbConn net.Conn) error { +// authenticate handles the startup of the pgwire protocol to the point where +// the connections is considered authenticated. If that doesn't happen, it +// returns an error. +var authenticate = func(clientConn, crdbConn net.Conn) error { fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn) @@ -26,13 +29,13 @@ func authenticate(clientConn, crdbConn net.Conn) error { // TODO(spaskob): in verbose mode, log these messages. backendMsg, err := be.Receive() if err != nil { - return NewErrorf(CodeBackendReadFailed, "unable to receive message from backend: %v", err) + return newErrorf(codeBackendReadFailed, "unable to receive message from backend: %v", err) } err = fe.Send(backendMsg) if err != nil { - return NewErrorf( - CodeClientWriteFailed, "unable to send message %v to client: %v", backendMsg, err, + return newErrorf( + codeClientWriteFailed, "unable to send message %v to client: %v", backendMsg, err, ) } @@ -40,12 +43,12 @@ func authenticate(clientConn, crdbConn net.Conn) error { switch tp := backendMsg.(type) { case *pgproto3.ReadyForQuery: // Server has authenticated the connection successfully and is ready to - // serve queries. + // Serve queries. return nil case *pgproto3.AuthenticationOk: // Server has authenticated the connection; keep reading messages until // `pgproto3.ReadyForQuery` is encountered which signifies that server - // is ready to serve queries. + // is ready to Serve queries. case *pgproto3.ParameterStatus: // Server sent status message; keep reading messages until // `pgproto3.ReadyForQuery` is encountered. @@ -55,7 +58,7 @@ func authenticate(clientConn, crdbConn net.Conn) error { case *pgproto3.ErrorResponse: // Server has rejected the authentication response from the client and // has closed the connection. - return NewErrorf(CodeAuthFailed, "authentication failed: %v", backendMsg) + return newErrorf(codeAuthFailed, "authentication failed: %v", backendMsg) case *pgproto3.AuthenticationCleartextPassword, *pgproto3.AuthenticationMD5Password, @@ -64,17 +67,17 @@ func authenticate(clientConn, crdbConn net.Conn) error { // Read the client response and forward it to server. fntMsg, err := fe.Receive() if err != nil { - return NewErrorf(CodeClientReadFailed, "unable to receive message from client: %v", err) + return newErrorf(codeClientReadFailed, "unable to receive message from client: %v", err) } err = be.Send(fntMsg) if err != nil { - return NewErrorf( - CodeBackendWriteFailed, "unable to send message %v to backend: %v", fntMsg, err, + return newErrorf( + codeBackendWriteFailed, "unable to send message %v to backend: %v", fntMsg, err, ) } default: - return NewErrorf(CodeBackendDisconnected, "received unexpected backend message type: %v", tp) + return newErrorf(codeBackendDisconnected, "received unexpected backend message type: %v", tp) } } - return NewErrorf(CodeBackendDisconnected, "authentication took more than %d iterations", i) + return newErrorf(codeBackendDisconnected, "authentication took more than %d iterations", i) } diff --git a/pkg/ccl/sqlproxyccl/authentication_test.go b/pkg/ccl/sqlproxyccl/authentication_test.go index 1719f5ec9297..12a1250b66a1 100644 --- a/pkg/ccl/sqlproxyccl/authentication_test.go +++ b/pkg/ccl/sqlproxyccl/authentication_test.go @@ -95,9 +95,9 @@ func TestAuthenticateError(t *testing.T) { err := authenticate(srv, cli) require.Error(t, err) - codeErr := (*CodeError)(nil) + codeErr := (*codeError)(nil) require.True(t, errors.As(err, &codeErr)) - require.Equal(t, CodeAuthFailed, codeErr.code) + require.Equal(t, codeAuthFailed, codeErr.code) } func TestAuthenticateUnexpectedMessage(t *testing.T) { @@ -117,7 +117,7 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) { err := authenticate(srv, cli) require.Error(t, err) - codeErr := (*CodeError)(nil) + codeErr := (*codeError)(nil) require.True(t, errors.As(err, &codeErr)) - require.Equal(t, CodeBackendDisconnected, codeErr.code) + require.Equal(t, codeBackendDisconnected, codeErr.code) } diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index 0e3aff5e4cbb..70b37b760a47 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -17,33 +17,34 @@ import ( "github.com/jackc/pgproto3/v2" ) -// BackendDial is an example backend dialer that does a TCP/IP connection -// to a backend, SSL and forwards the start message. -func BackendDial( +// backendDial is an example backend dialer that does a TCP/IP connection +// to a backend, SSL and forwards the start message. It is defined as a variable +// so it can be redirected for testing. +var backendDial = func( msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { conn, err := net.Dial("tcp", outgoingAddress) if err != nil { - return nil, NewErrorf( - CodeBackendDown, "unable to reach backend SQL server: %v", err, + return nil, newErrorf( + codeBackendDown, "unable to reach backend SQL server: %v", err, ) } - conn, err = SSLOverlay(conn, tlsConfig) + conn, err = sslOverlay(conn, tlsConfig) if err != nil { return nil, err } - err = RelayStartupMsg(conn, msg) + err = relayStartupMsg(conn, msg) if err != nil { - return nil, NewErrorf( - CodeBackendDown, "relaying StartupMessage to target server %v: %v", + return nil, newErrorf( + codeBackendDown, "relaying StartupMessage to target server %v: %v", outgoingAddress, err) } return conn, nil } -// SSLOverlay attempts to upgrade the PG connection to use SSL +// sslOverlay attempts to upgrade the PG connection to use SSL // if a tls.Config is specified.. -func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { +func sslOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { if tlsConfig == nil { return conn, nil } @@ -51,20 +52,20 @@ func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { var err error // Send SSLRequest. if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { - return nil, NewErrorf( - CodeBackendDown, "sending SSLRequest to target server: %v", err, + return nil, newErrorf( + codeBackendDown, "sending SSLRequest to target server: %v", err, ) } response := make([]byte, 1) if _, err = io.ReadFull(conn, response); err != nil { return nil, - NewErrorf(CodeBackendDown, "reading response to SSLRequest") + newErrorf(codeBackendDown, "reading response to SSLRequest") } if response[0] != pgAcceptSSLRequest { - return nil, NewErrorf( - CodeBackendRefusedTLS, "target server refused TLS connection", + return nil, newErrorf( + codeBackendRefusedTLS, "target server refused TLS connection", ) } @@ -72,8 +73,8 @@ func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { return tls.Client(conn, outCfg), nil } -// RelayStartupMsg forwards the start message on the backend connection. -func RelayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { +// relayStartupMsg forwards the start message on the backend connection. +func relayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { _, err = conn.Write(msg.Encode(nil)) return } diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index e6dcb3dd6e23..7df35e255324 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -14,88 +14,88 @@ import ( "github.com/cockroachdb/errors" ) -// ErrorCode classifies errors emitted by Proxy(). -//go:generate stringer -type=ErrorCode -type ErrorCode int +// errorCode classifies errors emitted by Proxy(). +//go:generate stringer -type=errorCode +type errorCode int const ( - _ ErrorCode = iota + _ errorCode = iota - // CodeAuthFailed indicates that client authentication attempt has failed and + // codeAuthFailed indicates that client authentication attempt has failed and // backend has closed the connection. - CodeAuthFailed + codeAuthFailed - // CodeBackendReadFailed indicates an error reading from backend connection. - CodeBackendReadFailed - // CodeBackendWriteFailed indicates an error writing to backend connection. - CodeBackendWriteFailed + // codeBackendReadFailed indicates an error reading from backend connection. + codeBackendReadFailed + // codeBackendWriteFailed indicates an error writing to backend connection. + codeBackendWriteFailed - // CodeClientReadFailed indicates an error reading from the client connection - CodeClientReadFailed - // CodeClientWriteFailed indicates an error writing to the client connection. - CodeClientWriteFailed + // codeClientReadFailed indicates an error reading from the client connection + codeClientReadFailed + // codeClientWriteFailed indicates an error writing to the client connection. + codeClientWriteFailed - // CodeUnexpectedInsecureStartupMessage indicates that the client sent a + // codeUnexpectedInsecureStartupMessage indicates that the client sent a // StartupMessage which was unexpected. Typically this means that an // SSLRequest was expected but the client attempted to go ahead without TLS, // or vice versa. - CodeUnexpectedInsecureStartupMessage + codeUnexpectedInsecureStartupMessage - // CodeSNIRoutingFailed indicates an error choosing a backend address based on + // codeSNIRoutingFailed indicates an error choosing a backend address based on // the client's SNI header. - CodeSNIRoutingFailed + codeSNIRoutingFailed - // CodeUnexpectedStartupMessage indicates an unexpected startup message + // codeUnexpectedStartupMessage indicates an unexpected startup message // received from the client after TLS negotiation. - CodeUnexpectedStartupMessage + codeUnexpectedStartupMessage - // CodeParamsRoutingFailed indicates an error choosing a backend address based + // codeParamsRoutingFailed indicates an error choosing a backend address based // on the client's session parameters. - CodeParamsRoutingFailed + codeParamsRoutingFailed - // CodeBackendDown indicates an error establishing or maintaining a connection + // codeBackendDown indicates an error establishing or maintaining a connection // to the backend SQL server. - CodeBackendDown + codeBackendDown - // CodeBackendRefusedTLS indicates that the backend SQL server refused a TLS- + // codeBackendRefusedTLS indicates that the backend SQL server refused a TLS- // enabled SQL connection. - CodeBackendRefusedTLS + codeBackendRefusedTLS - // CodeBackendDisconnected indicates that the backend disconnected (with a + // codeBackendDisconnected indicates that the backend disconnected (with a // connection error) while serving client traffic. - CodeBackendDisconnected + codeBackendDisconnected - // CodeClientDisconnected indicates that the client disconnected unexpectedly + // codeClientDisconnected indicates that the client disconnected unexpectedly // (with a connection error) while in a session with backend SQL server. - CodeClientDisconnected + codeClientDisconnected - // CodeProxyRefusedConnection indicates that the proxy refused the connection + // codeProxyRefusedConnection indicates that the proxy refused the connection // request due to high load or too many connection attempts. - CodeProxyRefusedConnection + codeProxyRefusedConnection - // CodeExpiredClientConnection indicates that proxy connection to the client + // codeExpiredClientConnection indicates that proxy connection to the client // has expired and should be closed. - CodeExpiredClientConnection + codeExpiredClientConnection - // CodeIdleDisconnect indicates that the connection was disconnected for + // codeIdleDisconnect indicates that the connection was disconnected for // being idle for longer than the specified timeout. - CodeIdleDisconnect + codeIdleDisconnect ) -// CodeError is combines an error with one of the above codes to ease +// codeError is combines an error with one of the above codes to ease // the processing of the errors. -type CodeError struct { - code ErrorCode +type codeError struct { + code errorCode err error } -func (e *CodeError) Error() string { +func (e *codeError) Error() string { return fmt.Sprintf("%s: %s", e.code, e.err) } -// NewErrorf returns a new CodeError out of the supplied args. -func NewErrorf(code ErrorCode, format string, args ...interface{}) error { - return &CodeError{ +// newErrorf returns a new codeError out of the supplied args. +func newErrorf(code errorCode, format string, args ...interface{}) error { + return &codeError{ code: code, err: errors.Errorf(format, args...), } diff --git a/pkg/ccl/sqlproxyccl/errorcode_string.go b/pkg/ccl/sqlproxyccl/errorcode_string.go index de1432ced7e4..04ecb4a68bf4 100644 --- a/pkg/ccl/sqlproxyccl/errorcode_string.go +++ b/pkg/ccl/sqlproxyccl/errorcode_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT. +// Code generated by "stringer -type=errorCode"; DO NOT EDIT. package sqlproxyccl @@ -8,32 +8,32 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[CodeAuthFailed-1] - _ = x[CodeBackendReadFailed-2] - _ = x[CodeBackendWriteFailed-3] - _ = x[CodeClientReadFailed-4] - _ = x[CodeClientWriteFailed-5] - _ = x[CodeUnexpectedInsecureStartupMessage-6] - _ = x[CodeSNIRoutingFailed-7] - _ = x[CodeUnexpectedStartupMessage-8] - _ = x[CodeParamsRoutingFailed-9] - _ = x[CodeBackendDown-10] - _ = x[CodeBackendRefusedTLS-11] - _ = x[CodeBackendDisconnected-12] - _ = x[CodeClientDisconnected-13] - _ = x[CodeProxyRefusedConnection-14] - _ = x[CodeExpiredClientConnection-15] - _ = x[CodeIdleDisconnect-16] + _ = x[codeAuthFailed-1] + _ = x[codeBackendReadFailed-2] + _ = x[codeBackendWriteFailed-3] + _ = x[codeClientReadFailed-4] + _ = x[codeClientWriteFailed-5] + _ = x[codeUnexpectedInsecureStartupMessage-6] + _ = x[codeSNIRoutingFailed-7] + _ = x[codeUnexpectedStartupMessage-8] + _ = x[codeParamsRoutingFailed-9] + _ = x[codeBackendDown-10] + _ = x[codeBackendRefusedTLS-11] + _ = x[codeBackendDisconnected-12] + _ = x[codeClientDisconnected-13] + _ = x[codeProxyRefusedConnection-14] + _ = x[codeExpiredClientConnection-15] + _ = x[codeIdleDisconnect-16] } -const _ErrorCode_name = "CodeAuthFailedCodeBackendReadFailedCodeBackendWriteFailedCodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnectionCodeExpiredClientConnectionCodeIdleDisconnect" +const _errorCode_name = "codeAuthFailedcodeBackendReadFailedcodeBackendWriteFailedcodeClientReadFailedcodeClientWriteFailedcodeUnexpectedInsecureStartupMessagecodeSNIRoutingFailedcodeUnexpectedStartupMessagecodeParamsRoutingFailedcodeBackendDowncodeBackendRefusedTLScodeBackendDisconnectedcodeClientDisconnectedcodeProxyRefusedConnectioncodeExpiredClientConnectioncodeIdleDisconnect" -var _ErrorCode_index = [...]uint16{0, 14, 35, 57, 77, 98, 134, 154, 182, 205, 220, 241, 264, 286, 312, 339, 357} +var _errorCode_index = [...]uint16{0, 14, 35, 57, 77, 98, 134, 154, 182, 205, 220, 241, 264, 286, 312, 339, 357} -func (i ErrorCode) String() string { +func (i errorCode) String() string { i -= 1 - if i < 0 || i >= ErrorCode(len(_ErrorCode_index)-1) { - return "ErrorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" + if i < 0 || i >= errorCode(len(_errorCode_index)-1) { + return "errorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" } - return _ErrorCode_name[_ErrorCode_index[i]:_ErrorCode_index[i+1]] + return _errorCode_name[_errorCode_index[i]:_errorCode_index[i+1]] } diff --git a/pkg/ccl/sqlproxyccl/frontend_admitter.go b/pkg/ccl/sqlproxyccl/frontend_admitter.go index 2f523f8231fa..7811b1d5f323 100644 --- a/pkg/ccl/sqlproxyccl/frontend_admitter.go +++ b/pkg/ccl/sqlproxyccl/frontend_admitter.go @@ -15,10 +15,13 @@ import ( "github.com/jackc/pgproto3/v2" ) -// FrontendAdmit is the default implementation of a frontend admitter. It can +// frontendAdmit is the default implementation of a frontend admitter. It can // upgrade to an optional SSL connection, and will handle and verify // the startup message received from the PG SQL client. -func FrontendAdmit( +// The connection returned should never be nil in case of error. Depending +// on whether the error happened before the connection was upgraded to TLS or not +// it will either be the original or the TLS connection. +var frontendAdmit = func( conn net.Conn, incomingTLSConfig *tls.Config, ) (net.Conn, *pgproto3.StartupMessage, error) { // `conn` could be replaced by `conn` embedded in a `tls.Conn` connection, @@ -30,7 +33,7 @@ func FrontendAdmit( if incomingTLSConfig != nil { m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage() if err != nil { - return nil, nil, NewErrorf(CodeClientReadFailed, "while receiving startup message") + return conn, nil, newErrorf(codeClientReadFailed, "while receiving startup message") } switch m.(type) { case *pgproto3.SSLRequest: @@ -38,15 +41,15 @@ func FrontendAdmit( // Ignore CancelRequest explicitly. We don't need to do this but it makes // testing easier by avoiding a call to sendErrToClient on this path // (which would confuse assertCtx). - return nil, nil, nil + return conn, nil, nil default: - code := CodeUnexpectedInsecureStartupMessage - return nil, nil, NewErrorf(code, "unsupported startup message: %T", m) + code := codeUnexpectedInsecureStartupMessage + return conn, nil, newErrorf(code, "unsupported startup message: %T", m) } _, err = conn.Write([]byte{pgAcceptSSLRequest}) if err != nil { - return nil, nil, NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err) + return conn, nil, newErrorf(codeClientWriteFailed, "acking SSLRequest: %v", err) } cfg := incomingTLSConfig.Clone() @@ -60,11 +63,11 @@ func FrontendAdmit( m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage() if err != nil { - return nil, nil, NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err) + return conn, nil, newErrorf(codeClientReadFailed, "receiving post-TLS startup message: %v", err) } msg, ok := m.(*pgproto3.StartupMessage) if !ok { - return nil, nil, NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) + return conn, nil, newErrorf(codeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) } // Add the sniServerName (if used) as parameter diff --git a/pkg/ccl/sqlproxyccl/frontend_admitter_test.go b/pkg/ccl/sqlproxyccl/frontend_admitter_test.go index e4f3b50e6c8b..5d6fc04f2418 100644 --- a/pkg/ccl/sqlproxyccl/frontend_admitter_test.go +++ b/pkg/ccl/sqlproxyccl/frontend_admitter_test.go @@ -58,7 +58,7 @@ func TestFrontendAdmitWithClientSSLDisableAndCustomParam(t *testing.T) { fmt.Printf("Done\n") }() - frontendCon, msg, err := FrontendAdmit(srv, nil) + frontendCon, msg, err := frontendAdmit(srv, nil) require.NoError(t, err) require.Equal(t, srv, frontendCon) require.NotNil(t, msg) @@ -88,7 +88,7 @@ func TestFrontendAdmitWithClientSSLRequire(t *testing.T) { tlsConfig, err := tlsConfig() require.NoError(t, err) - frontendCon, msg, err := FrontendAdmit(srv, tlsConfig) + frontendCon, msg, err := frontendAdmit(srv, tlsConfig) require.NoError(t, err) require.NotEqual(t, srv, frontendCon) // The connection was replaced by SSL require.NotNil(t, msg) @@ -107,12 +107,12 @@ func TestFrontendAdmitWithCancel(t *testing.T) { require.NoError(t, err) }() - frontendCon, msg, err := FrontendAdmit(srv, nil) + frontendCon, msg, err := frontendAdmit(srv, nil) require.EqualError(t, err, - "CodeUnexpectedStartupMessage: "+ + "codeUnexpectedStartupMessage: "+ "unsupported post-TLS startup message: *pgproto3.CancelRequest", ) - require.Nil(t, frontendCon) + require.NotNil(t, frontendCon) require.Nil(t, msg) } @@ -139,11 +139,11 @@ func TestFrontendAdmitWithSSLAndCancel(t *testing.T) { tlsConfig, err := tlsConfig() require.NoError(t, err) - frontendCon, msg, err := FrontendAdmit(srv, tlsConfig) + frontendCon, msg, err := frontendAdmit(srv, tlsConfig) require.EqualError(t, err, - "CodeUnexpectedStartupMessage: "+ + "codeUnexpectedStartupMessage: "+ "unsupported post-TLS startup message: *pgproto3.CancelRequest", ) - require.Nil(t, frontendCon) + require.NotNil(t, frontendCon) require.Nil(t, msg) } diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection.go similarity index 74% rename from pkg/ccl/sqlproxyccl/idle_disconnect_connection.go rename to pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection.go index bc5fd1be9e41..711bad21c585 100644 --- a/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go +++ b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package sqlproxyccl +package idle import ( "net" @@ -17,10 +17,10 @@ import ( "github.com/cockroachdb/errors" ) -// IdleDisconnectConnection is a wrapper around net.Conn that disconnects if +// DisconnectConnection is a wrapper around net.Conn that disconnects if // connection is idle. The idle time is only counted while the client is // waiting, blocked on Read. -type IdleDisconnectConnection struct { +type DisconnectConnection struct { net.Conn timeout time.Duration mu struct { @@ -30,10 +30,10 @@ type IdleDisconnectConnection struct { } var errNotSupported = errors.Errorf( - "Not supported for IdleDisconnectConnection", + "Not supported for DisconnectConnection", ) -func (c *IdleDisconnectConnection) updateDeadline() error { +func (c *DisconnectConnection) updateDeadline() error { now := timeutil.Now() // If it has been more than 1% of the timeout duration - advance the deadline. c.mu.Lock() @@ -49,7 +49,7 @@ func (c *IdleDisconnectConnection) updateDeadline() error { } // Read reads data from the connection with timeout. -func (c *IdleDisconnectConnection) Read(b []byte) (n int, err error) { +func (c *DisconnectConnection) Read(b []byte) (n int, err error) { if err := c.updateDeadline(); err != nil { return 0, err } @@ -57,7 +57,7 @@ func (c *IdleDisconnectConnection) Read(b []byte) (n int, err error) { } // Write writes data to the connection and sets the read timeout. -func (c *IdleDisconnectConnection) Write(b []byte) (n int, err error) { +func (c *DisconnectConnection) Write(b []byte) (n int, err error) { // The Write for the connection is not blocking (or can block only temporary // in case of flow control). For idle connections, the Read will be the call // that will block and stay blocked until the backend doesn't send something. @@ -72,26 +72,26 @@ func (c *IdleDisconnectConnection) Write(b []byte) (n int, err error) { } // SetDeadline is unsupported as it will interfere with the reads. -func (c *IdleDisconnectConnection) SetDeadline(t time.Time) error { +func (c *DisconnectConnection) SetDeadline(t time.Time) error { return errNotSupported } // SetReadDeadline is unsupported as it will interfere with the reads. -func (c *IdleDisconnectConnection) SetReadDeadline(t time.Time) error { +func (c *DisconnectConnection) SetReadDeadline(t time.Time) error { return errNotSupported } // SetWriteDeadline is unsupported as it will interfere with the reads. -func (c *IdleDisconnectConnection) SetWriteDeadline(t time.Time) error { +func (c *DisconnectConnection) SetWriteDeadline(t time.Time) error { return errNotSupported } -// IdleDisconnectOverlay upgrades the connection to one that closes when +// DisconnectOverlay upgrades the connection to one that closes when // idle for more than timeout duration. Timeout of zero will turn off // the idle disconnect code. -func IdleDisconnectOverlay(conn net.Conn, timeout time.Duration) net.Conn { +func DisconnectOverlay(conn net.Conn, timeout time.Duration) net.Conn { if timeout != 0 { - return &IdleDisconnectConnection{Conn: conn, timeout: timeout} + return &DisconnectConnection{Conn: conn, timeout: timeout} } return conn } diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection_test.go similarity index 97% rename from pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go rename to pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection_test.go index 002901450914..edb074dbd7cb 100644 --- a/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go +++ b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package sqlproxyccl +package idle import ( "fmt" @@ -31,7 +31,7 @@ func setupServerWithIdleDisconnect(t testing.TB, timeout time.Duration) net.Addr t.Errorf("Error during accept: %v", err) } defer cServ.Close() - cServ = IdleDisconnectOverlay(cServ, timeout) + cServ = DisconnectOverlay(cServ, timeout) _, _ = io.Copy(cServ, cServ) }() return server.Addr() diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index c643a7b4d81a..bea2ff15d3de 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -8,11 +8,14 @@ package sqlproxyccl -import "github.com/cockroachdb/cockroach/pkg/util/metric" +import ( + "github.com/cockroachdb/cockroach/pkg/util/metric" + "github.com/cockroachdb/errors" +) -// Metrics contains pointers to the metrics for monitoring proxy +// metrics contains pointers to the metrics for monitoring proxy // operations. -type Metrics struct { +type metrics struct { BackendDisconnectCount *metric.Counter IdleDisconnectCount *metric.Counter BackendDownCount *metric.Counter @@ -26,9 +29,9 @@ type Metrics struct { } // MetricStruct implements the metrics.Struct interface. -func (Metrics) MetricStruct() {} +func (metrics) MetricStruct() {} -var _ metric.Struct = Metrics{} +var _ metric.Struct = metrics{} var ( metaCurConnCount = metric.Metadata{ @@ -93,9 +96,9 @@ var ( } ) -// MakeProxyMetrics instantiates the metrics holder for proxy monitoring. -func MakeProxyMetrics() Metrics { - return Metrics{ +// makeProxyMetrics instantiates the metrics holder for proxy monitoring. +func makeProxyMetrics() metrics { + return metrics{ BackendDisconnectCount: metric.NewCounter(metaBackendDisconnectCount), IdleDisconnectCount: metric.NewCounter(metaIdleDisconnectCount), BackendDownCount: metric.NewCounter(metaBackendDownCount), @@ -108,3 +111,34 @@ func MakeProxyMetrics() Metrics { ExpiredClientConnCount: metric.NewCounter(metaExpiredClientConnCount), } } + +// updateForError updates the metrics relevant for the type of the +// error message. +func (metrics *metrics) updateForError(err error) { + if err == nil { + return + } + codeErr := (*codeError)(nil) + if errors.As(err, &codeErr) { + switch codeErr.code { + case codeExpiredClientConnection: + metrics.ExpiredClientConnCount.Inc(1) + case codeBackendDisconnected: + metrics.BackendDisconnectCount.Inc(1) + case codeClientDisconnected: + metrics.ClientDisconnectCount.Inc(1) + case codeIdleDisconnect: + metrics.IdleDisconnectCount.Inc(1) + case codeProxyRefusedConnection: + metrics.RefusedConnCount.Inc(1) + metrics.BackendDownCount.Inc(1) + case codeParamsRoutingFailed: + metrics.RoutingErrCount.Inc(1) + metrics.BackendDownCount.Inc(1) + case codeBackendDown: + metrics.BackendDownCount.Inc(1) + case codeAuthFailed: + metrics.AuthFailedCount.Inc(1) + } + } +} diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 2ffc4dff17b2..cdd275b73e61 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -9,11 +9,8 @@ package sqlproxyccl import ( - "context" - "crypto/tls" "io" "net" - "os" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" @@ -24,76 +21,42 @@ const pgAcceptSSLRequest = 'S' // See https://www.postgresql.org/docs/9.1/protocol-message-formats.html. var pgSSLRequest = []int32{8, 80877103} -// BackendConfig contains the configuration of a backend connection that is -// being proxied. -// To be removed once all clients are migrated to use backend dialer. -type BackendConfig struct { - // The address to which the connection is forwarded. - OutgoingAddress string - // TLS settings to use when connecting to OutgoingAddress. - TLSConf *tls.Config - // Called after successfully connecting to OutgoingAddr. - OnConnectionSuccess func() - // KeepAliveLoop if provided controls the lifetime of the proxy connection. - // It will be run in its own goroutine when the connection is successfully - // opened. Returning from `KeepAliveLoop` will close the proxy connection. - // Note that non-nil error return values will be forwarded to the user and - // hence should explain the reason for terminating the connection. - // Most common use of KeepAliveLoop will be as an infinite loop that - // periodically checks if the connection should still be kept alive. Hence it - // may block indefinitely so it's prudent to use the provided context and - // return on context cancellation. - // See `TestProxyKeepAlive` for example usage. - KeepAliveLoop func(context.Context) error +// sendErrToClientAndUpdateMetrics simply combines the update of the metrics and +// the transmission of the err back to the client. +func updateMetricsAndSendErrToClient(err error, conn net.Conn, metrics *metrics) { + metrics.updateForError(err) + sendErrToClient(conn, err) } -// Options are the options to the Proxy method. -type Options struct { - // Deprecated: construct FrontendAdmitter, passing this information in case - // that SSL is desired. - IncomingTLSConfig *tls.Config // config used for client -> proxy connection - - // BackendFromParams returns the config to use for the proxy -> backend - // connection. The TLS config is in it and it must have an appropriate - // ServerName for the remote backend. - // Deprecated: processing of the params now happens in the BackendDialer. - // This is only here to support OnSuccess and KeepAlive. - BackendConfigFromParams func( - params map[string]string, incomingConn *Conn, - ) (config *BackendConfig, clientErr error) - - // If set, consulted to modify the parameters set by the frontend before - // forwarding them to the backend during startup. - // Deprecated: include the code that modifies the request params - // in the backend dialer. - ModifyRequestParams func(map[string]string) - - // If set, consulted to decorate an error message to be sent to the client. - // The error passed to this method will contain no internal information. - OnSendErrToClient func(code ErrorCode, msg string) string - - // If set, will be called immediately after a new incoming connection - // is accepted. It can optionally negotiate SSL, provide admittance control or - // other types of frontend connection filtering. - FrontendAdmitter func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) - - // If set, will be used to establish and return connection to the backend. - // If not set, the old logic will be used. - // The argument is the startup message received from the frontend. It - // contains the protocol version and params sent by the client. - BackendDialer func(msg *pgproto3.StartupMessage) (net.Conn, error) -} - -// Proxy takes an incoming client connection and relays it to a backend SQL -// server. -func (s *Server) Proxy(proxyConn *Conn) error { - sendErrToClient := func(conn net.Conn, code ErrorCode, msg string) { - if s.opts.OnSendErrToClient != nil { - msg = s.opts.OnSendErrToClient(code, msg) +// sendErrToClient will encode and pass back to the SQL client an error message. +// It can be called by the implementors of proxyHandler to give more +// information to the end user in case of a problem. +var sendErrToClient = func(conn net.Conn, err error) { + if err == nil || conn == nil { + return + } + codeErr := (*codeError)(nil) + if errors.As(err, &codeErr) { + var msg string + switch codeErr.code { + // These are send as is. + case codeExpiredClientConnection, + codeBackendDown, + codeParamsRoutingFailed, + codeClientDisconnected, + codeBackendDisconnected, + codeAuthFailed, + codeProxyRefusedConnection: + msg = codeErr.Error() + // The rest - the message send back is sanitized. + case codeIdleDisconnect: + msg = "terminating connection due to idle timeout" + case codeUnexpectedInsecureStartupMessage: + msg = "server requires encryption" } var pgCode string - if code == CodeIdleDisconnect { + if codeErr.code == codeIdleDisconnect { pgCode = "57P01" // admin shutdown } else { pgCode = "08004" // rejected connection @@ -104,135 +67,13 @@ func (s *Server) Proxy(proxyConn *Conn) error { Message: msg, }).Encode(nil)) } +} - frontendAdmitter := s.opts.FrontendAdmitter - if frontendAdmitter == nil { - // Keep this until all clients are switched to provide FrontendAdmitter - // at what point we can also drop IncomingTLSConfig - frontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return FrontendAdmit(incoming, s.opts.IncomingTLSConfig) - } - } - - conn, msg, err := frontendAdmitter(proxyConn) - if err != nil { - var codeErr *CodeError - if ok := errors.As(err, &codeErr); ok && codeErr.code == CodeUnexpectedInsecureStartupMessage { - sendErrToClient( - proxyConn, // Do this on the TCP connection as it means denying SSL - CodeUnexpectedInsecureStartupMessage, - "server requires encryption", - ) - } else if ok { - sendErrToClient(proxyConn, codeErr.code, codeErr.Error()) - } else { - sendErrToClient(proxyConn, CodeClientDisconnected, err.Error()) - } - return err - } - - // This currently only happens for CancelRequest type of startup messages - // that we don't support - if conn == nil { - return nil - - } - defer func() { _ = conn.Close() }() - - backendDialer := s.opts.BackendDialer - var backendConfig *BackendConfig - if s.opts.BackendConfigFromParams != nil { - var clientErr error - backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) - if clientErr != nil { - var codeErr *CodeError - if !errors.As(clientErr, &codeErr) { - codeErr = &CodeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), - } - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - } - if backendDialer == nil { - // This we need to keep until all the clients are switched to provide BackendDialer. - // It constructs a backend dialer from the information provided via - // BackendConfigFromParams function. - backendDialer = func(msg *pgproto3.StartupMessage) (net.Conn, error) { - // We should be able to remove this when the all clients switch to - // backend dialer. - if s.opts.ModifyRequestParams != nil { - s.opts.ModifyRequestParams(msg.Parameters) - } - - crdbConn, err := BackendDial(msg, backendConfig.OutgoingAddress, backendConfig.TLSConf) - if err != nil { - if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) { - sendErrToClient(conn, codeErr.code, codeErr.Error()) - } else { - sendErrToClient(conn, CodeBackendDisconnected, err.Error()) - } - return nil, err - } - - return crdbConn, nil - } - } - - crdbConn, err := backendDialer(msg) - if err != nil { - s.metrics.BackendDownCount.Inc(1) - var codeErr *CodeError - if !errors.As(err, &codeErr) { - codeErr = &CodeError{ - code: CodeBackendDown, - err: errors.Errorf("unable to reach backend SQL server: %v", err), - } - } - if codeErr.code == CodeProxyRefusedConnection { - s.metrics.RefusedConnCount.Inc(1) - } else if codeErr.code == CodeParamsRoutingFailed { - s.metrics.RoutingErrCount.Inc(1) - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - defer func() { _ = crdbConn.Close() }() - - if err := authenticate(conn, crdbConn); err != nil { - s.metrics.AuthFailedCount.Inc(1) - var codeErr *CodeError - if !errors.As(err, &codeErr) { - codeErr = &CodeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("unrecognized auth failure"), - } - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - - s.metrics.SuccessfulConnCount.Inc(1) - - // These channels are buffered because we'll only consume one of them. +// connectionCopy does a bi-directional copy between the backend and frontend +// connections. It terminates when one of connections terminate. +func connectionCopy(crdbConn, conn net.Conn) error { errOutgoing := make(chan error, 1) errIncoming := make(chan error, 1) - errExpired := make(chan error, 1) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if backendConfig != nil { - if backendConfig.OnConnectionSuccess != nil { - backendConfig.OnConnectionSuccess() - } - if backendConfig.KeepAliveLoop != nil { - go func() { - errExpired <- backendConfig.KeepAliveLoop(ctx) - }() - } - } go func() { _, err := io.Copy(crdbConn, conn) @@ -253,34 +94,18 @@ func (s *Server) Proxy(proxyConn *Conn) error { case err := <-errIncoming: if err == nil { return nil - } else if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) && - codeErr.code == CodeExpiredClientConnection { - s.metrics.ExpiredClientConnCount.Inc(1) - sendErrToClient(conn, codeErr.code, codeErr.Error()) + } else if codeErr := (*codeError)(nil); errors.As(err, &codeErr) && + codeErr.code == codeExpiredClientConnection { return codeErr - } else if errors.Is(err, os.ErrDeadlineExceeded) { - s.metrics.IdleDisconnectCount.Inc(1) - sendErrToClient(conn, CodeIdleDisconnect, "terminating connection due to idle timeout") - return NewErrorf(CodeIdleDisconnect, "terminating connection due to idle timeout: %v", err) + } else if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { + return newErrorf(codeIdleDisconnect, "terminating connection due to idle timeout: %v", err) } else { - s.metrics.BackendDisconnectCount.Inc(1) - sendErrToClient(conn, CodeBackendDisconnected, "copying from target server to client") - return NewErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err) + return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) } case err := <-errOutgoing: // The incoming connection got closed. if err != nil { - s.metrics.ClientDisconnectCount.Inc(1) - return NewErrorf(CodeClientDisconnected, "copying from target server to client: %v", err) - } - return nil - case err := <-errExpired: - if err != nil { - // The client connection expired. - s.metrics.ExpiredClientConnCount.Inc(1) - code := CodeExpiredClientConnection - sendErrToClient(conn, code, err.Error()) - return NewErrorf(code, "expired client conn: %v", err) + return newErrorf(codeClientDisconnected, "copying from target server to client: %v", err) } return nil } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go new file mode 100644 index 000000000000..b60e27f9015d --- /dev/null +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -0,0 +1,613 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/cache" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/idle" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/security/certmgr" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/retry" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" + "github.com/jackc/pgproto3/v2" + "google.golang.org/grpc" +) + +var ( + // This assumes that whitespaces are used to separate command line args. + // Unlike the original spec, this does not handle escaping rules. + // + // See "options" in https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS. + clusterNameLongOptionRE = regexp.MustCompile(`(?:-c\s*|--)cluster=([\S]*)`) + + // clusterNameRegex restricts cluster names to have between 6 and 20 + // alphanumeric characters, with dashes allowed within the name (but not as a + // starting or ending character). + clusterNameRegex = regexp.MustCompile("^[a-z0-9][a-z0-9-]{4,18}[a-z0-9]$") +) + +const ( + // Cluster identifier is in the form "clustername-. Tenant id is + // always in the end but the cluster name can also contain '-' or digits. + // For example: + // "foo-7-10" -> cluster name is "foo-7" and tenant id is 10. + clusterTenantSep = "-" + // TODO(spaskob): add ballpark estimate. + maxKnownConnCacheSize = 5e6 // 5 million. +) + +// ProxyOptions is the information needed to construct a new proxyHandler. +type ProxyOptions struct { + // Denylist file to limit access to IP addresses and tenant ids. + Denylist string + // ListenAddr is the listen address for incoming connections. + ListenAddr string + // ListenCert is the file containing PEM-encoded x509 certificate for listen address. + // Set to "*" to auto-generate self-signed cert. + ListenCert string + // ListenKey is the file containing PEM-encoded x509 key for listen address. + // Set to "*" to auto-generate self-signed cert. + ListenKey string + // MetricsAddress is the listen address for incoming connections for metrics retrieval. + MetricsAddress string + // SkipVerify if set will skip the identity verification of the + // backend. This is for testing only. + SkipVerify bool + // Insecure if set, will not use TLS for the backend connection. For testing. + Insecure bool + // RoutingRule for constructing the backend address for each incoming + // connection. Optionally use '{{clusterName}}' + // which will be substituted with the cluster name. + RoutingRule string + // DirectoryAddr specified optional {HOSTNAME}:{PORT} for service that does the resolution + // from backend id to IP address. If specified - it will be used instead of the + // routing rule above. + DirectoryAddr string + // RatelimitBaseDelay is the initial backoff after a failed login attempt. + // Set to 0 to disable rate limiting. + RatelimitBaseDelay time.Duration + // ValidateAccessInterval is the time interval between validations, confirming + // that current connections are still valid. + ValidateAccessInterval time.Duration + // PollConfigInterval defines polling interval for pickup up changes in config file. + PollConfigInterval time.Duration + // IdleTimeout if set, will close connections that have been idle for that duration. + IdleTimeout time.Duration +} + +// proxyHandler is the default implementation of a proxy handler. +type proxyHandler struct { + ProxyOptions + + // metrics contains various counters reflecting the proxy operations. + metrics *metrics + + // stopper is used to do an orderly shutdown. + stopper *stop.Stopper + + // incomingCert is the managed cert of the proxy endpoint to + // which clients connect. + incomingCert certmgr.Cert + + // denyListService provides access control. + denyListService denylist.Service + + // throttleService will do throttling of incoming connection requests. + throttleService throttler.Service + + // directory is optional and if set, will be used to resolve + // backend id to IP addresses. + directory *tenant.Directory + + // CertManger keeps up to date the certificates used. + certManager *certmgr.CertManager + + //connCache is used to keep track of all current connections. + connCache cache.ConnCache +} + +// newProxyHandler will create a new proxy handler with configuration based on +// the provided options. +func newProxyHandler( + ctx context.Context, stopper *stop.Stopper, proxyMetrics *metrics, options ProxyOptions, +) (*proxyHandler, error) { + handler := proxyHandler{ + stopper: stopper, + metrics: proxyMetrics, + ProxyOptions: options, + certManager: certmgr.NewCertManager(ctx), + } + + var err error + err = handler.setupIncomingCert() + if err != nil { + return nil, err + } + + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + handler.denyListService, err = denylist.NewViperDenyListFromFile(ctx, options.Denylist, options.PollConfigInterval) + if err != nil { + return nil, err + } + + handler.throttleService = throttler.NewLocalService(throttler.WithBaseDelay(options.RatelimitBaseDelay)) + handler.connCache = cache.NewCappedConnCache(maxKnownConnCacheSize) + + if handler.DirectoryAddr != "" { + conn, err := grpc.Dial(handler.DirectoryAddr, grpc.WithInsecure()) + if err != nil { + return nil, err + } + // nolint:grpcconnclose + stopper.AddCloser(stop.CloserFn(func() { _ = conn.Close() /* nolint:grpcconnclose */ })) + client := tenant.NewDirectoryClient(conn) + handler.directory, err = tenant.NewDirectory(ctx, stopper, client) + if err != nil { + return nil, err + } + } + + return &handler, nil +} + +// handle is called by the proxy server to handle a single incoming client +// connection. +func (handler *proxyHandler) handle(ctx context.Context, proxyConn *conn) error { + conn, msg, err := frontendAdmit(proxyConn, handler.IncomingTLSConfig()) + defer func() { _ = conn.Close() }() + if err != nil { + sendErrToClient(conn, err) + return err + } + + // This currently only happens for CancelRequest type of startup messages + // that we don't support + if msg == nil { + return nil + } + + // Note that the errors returned from this function are user-facing errors so + // we should be careful with the details that we want to expose. + backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(msg) + if err != nil { + clientErr := &codeError{codeParamsRoutingFailed, err} + log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", clientErr.Error()) + updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) + return clientErr + } + // This forwards the remote addr to the backend. + backendStartupMsg.Parameters["crdb:remote_addr"] = conn.RemoteAddr().String() + + ctx = logtags.AddTag(ctx, "cluster", clusterName) + ctx = logtags.AddTag(ctx, "tenant", tenID) + + ipAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + clientErr := &codeError{codeParamsRoutingFailed, err} + log.Errorf(ctx, "could not parse address: %v", clientErr.Error()) + updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) + return clientErr + } + + if err = handler.validateAccessAndThrottle(ctx, tenID, ipAddr); err != nil { + updateMetricsAndSendErrToClient(err, conn, handler.metrics) + return err + } + + var TLSConf *tls.Config + if !handler.Insecure { + TLSConf = &tls.Config{InsecureSkipVerify: handler.SkipVerify} + } + + var crdbConn net.Conn + var outgoingAddress string + retryOpts := retry.Options{InitialBackoff: 10 * time.Millisecond, MaxBackoff: time.Second} + everyMinute := log.Every(time.Second) + var outgoingAddressErrs, codeBackendDownErrs, reportFailureErrs int + for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { + shouldLog := everyMinute.ShouldLog() + outgoingAddress, err = handler.outgoingAddress(ctx, clusterName, tenID) + if err != nil { + outgoingAddressErrs++ + if shouldLog { + log.Ops.Errorf(ctx, + "outgoing address: %v (%d errors in the past minute)", + err, + outgoingAddressErrs, + ) + outgoingAddressErrs = 0 + } + continue + } + + crdbConn, err = backendDial(backendStartupMsg, outgoingAddress, TLSConf) + // If we get a backend down error and are using the directory - report the + // error to the directory and retry the connection. + codeErr := (*codeError)(nil) + if err != nil && + errors.As(err, &codeErr) && + codeErr.code == codeBackendDown && + handler.directory != nil { + codeBackendDownErrs++ + if shouldLog { + log.Ops.Errorf(ctx, + "backend dial: %v (%d errors in the past minute)", + err, + codeBackendDownErrs, + ) + codeBackendDownErrs = 0 + } + err = handler.directory.ReportFailure(ctx, roachpb.MakeTenantID(tenID), outgoingAddress) + if err != nil { + reportFailureErrs++ + if shouldLog { + log.Ops.Errorf(ctx, + "report failure: %v (%d errors in the past minute)", + err, + reportFailureErrs, + ) + reportFailureErrs = 0 + } + } + continue + } + break + } + + if err != nil { + updateMetricsAndSendErrToClient(err, conn, handler.metrics) + return err + } + + crdbConn = idle.DisconnectOverlay(crdbConn, handler.IdleTimeout) + + defer func() { _ = crdbConn.Close() }() + + if err := authenticate(conn, crdbConn); err != nil { + handler.metrics.updateForError(err) + return errors.AssertionFailedf("unrecognized auth failure") + } + + handler.metrics.SuccessfulConnCount.Inc(1) + + handler.connCache.Insert( + &cache.ConnKey{IPAddress: ipAddr, TenantID: roachpb.MakeTenantID(tenID)}, + ) + + errConnectionCopy := make(chan error, 1) + errExpired := make(chan error, 1) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // TODO(darinpp): starting a new go routine for every connection here is inefficient. + // Change to maintain a map of connections by IP/tenant and have single + // go routine that closes connections that are blocklisted. + go func() { + errExpired <- func(ctx context.Context) error { + t := timeutil.NewTimer() + defer t.Stop() + t.Reset(0) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + t.Read = true + if err := handler.validateAccess(ctx, tenID, ipAddr); err != nil { + return err + } + } + t.Reset(util.Jitter(handler.ValidateAccessInterval, 0.15)) + } + }(ctx) + }() + + log.Infof(ctx, "new connection") + connBegin := timeutil.Now() + defer func() { + log.Infof(ctx, "closing after %.2fs", timeutil.Since(connBegin).Seconds()) + }() + + go func() { + err := connectionCopy(crdbConn, conn) + errConnectionCopy <- err + }() + + select { + case err := <-errConnectionCopy: + updateMetricsAndSendErrToClient(err, conn, handler.metrics) + return err + case err := <-errExpired: + if err != nil { + // The client connection expired. + codeErr := newErrorf( + codeExpiredClientConnection, "expired client conn: %v", err, + ) + updateMetricsAndSendErrToClient(codeErr, conn, handler.metrics) + return codeErr + } + return nil + case <-handler.stopper.ShouldQuiesce(): + return nil + } + +} + +// outgoingAddress resolves a tenant ID and a tenant cluster name to an IP of +// the backend. +func (handler *proxyHandler) outgoingAddress( + ctx context.Context, name string, tenID uint64, +) (string, error) { + if handler.directory == nil { + addr := strings.ReplaceAll(handler.RoutingRule, "{{clusterName}}", name) + log.Infof(ctx, "backend %s resolved to '%s'", name, addr) + return addr, nil + } + + // This doesn't verify the name part of the tenant and relies just on the int id. + addr, err := handler.directory.EnsureTenantIP(ctx, roachpb.MakeTenantID(tenID), "") + if err != nil { + return "", err + } + return addr, nil +} + +func (handler *proxyHandler) validateAccessAndThrottle( + ctx context.Context, tenID uint64, ipAddr string, +) error { + if err := handler.validateAccess(ctx, tenID, ipAddr); err != nil { + return err + } + + // Admit the connection + connKey := cache.ConnKey{IPAddress: ipAddr, TenantID: roachpb.MakeTenantID(tenID)} + if !handler.connCache.Exists(&connKey) { + // Unknown previous successful connections from this IP and tenant. + // Hence we need to rate limit. + if err := handler.throttleService.LoginCheck(ipAddr, timeutil.Now()); err != nil { + log.Errorf(ctx, "throttler refused connection: %v", err.Error()) + return newErrorf(codeProxyRefusedConnection, "Connection attempt throttled") + } + } + + return nil +} + +func (handler *proxyHandler) validateAccess( + ctx context.Context, tenID uint64, ipAddr string, +) error { + // First validate against the deny list service + list := handler.denyListService + if entry, err := list.Denied(denylist.DenyEntity{Item: fmt.Sprint(tenID), Type: denylist.ClusterType}); err != nil { + // Log error but don't return since this could be transient. + log.Errorf(ctx, "could not consult denied list for tenant: %d", tenID) + } else if entry != nil { + log.Errorf(ctx, "access denied for tenant: %d, reason: %s", tenID, entry.Reason) + return newErrorf(codeProxyRefusedConnection, "tenant %d %s", tenID, entry.Reason) + } + + if entry, err := list.Denied(denylist.DenyEntity{Item: ipAddr, Type: denylist.IPAddrType}); err != nil { + // Log error but don't return since this could be transient. + log.Errorf(ctx, "could not consult denied list for IP address: %s", ipAddr) + } else if entry != nil { + log.Errorf(ctx, "access denied for IP address: %s, reason: %s", ipAddr, entry.Reason) + return newErrorf(codeProxyRefusedConnection, "IP address %s %s", ipAddr, entry.Reason) + } + + return nil +} + +// clusterNameAndTenantFromParams extracts the cluster name from the connection +// parameters, and rewrites the database param, if necessary. We currently +// support embedding the cluster name in two ways: +// - Within the database param (e.g. "happy-koala.defaultdb") +// +// - Within the options param (e.g. "... --cluster=happy-koala ..."). +// PostgreSQL supports three different ways to set a run-time parameter +// through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and +// "--NAME=VALUE". +func clusterNameAndTenantFromParams( + msg *pgproto3.StartupMessage, +) (*pgproto3.StartupMessage, string, uint64, error) { + clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"]) + if err != nil { + return msg, "", 0, err + } + + clusterNameFromOpt, err := parseOptionsParam(msg.Parameters["options"]) + if err != nil { + return msg, "", 0, err + } + + if clusterNameFromDB == "" && clusterNameFromOpt == "" { + return msg, "", 0, errors.New("missing cluster name in connection string") + } + + if clusterNameFromDB != "" && clusterNameFromOpt != "" { + return msg, "", 0, errors.New("multiple cluster names provided") + } + + if clusterNameFromDB == "" { + clusterNameFromDB = clusterNameFromOpt + } + + sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep) + // Cluster name provided without a tenant ID in the end. + if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 { + return msg, "", 0, errors.Errorf("invalid cluster name %s", clusterNameFromDB) + } + clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:] + + if !clusterNameRegex.MatchString(clusterNameSansTenant) { + return msg, "", 0, errors.Errorf("invalid cluster name '%s'", clusterNameSansTenant) + } + + tenID, err := strconv.ParseUint(tenantIDStr, 10, 64) + if err != nil { + return msg, "", 0, errors.Wrapf(err, "cannot parse %s as uint64", tenantIDStr) + } + + // Make and return a copy of the startup msg so the original is not modified. + paramsOut := map[string]string{} + for key, value := range msg.Parameters { + if key == "database" { + paramsOut[key] = databaseName + } else if key != "options" { + paramsOut[key] = value + } + } + outMsg := &pgproto3.StartupMessage{ + ProtocolVersion: msg.ProtocolVersion, + Parameters: paramsOut, + } + + return outMsg, clusterNameFromDB, tenID, nil +} + +// parseDatabaseParam parses the database parameter from the PG connection +// string, and tries to extract the cluster name if present. The cluster +// name should be embedded in the database parameter using the dot (".") +// delimiter in the form of ".". This approach +// is safe because dots are not allowed in the database names themselves. +func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, err error) { + // Database param is not provided. + if databaseParam == "" { + return "", "", nil + } + + parts := strings.SplitN(databaseParam, ".", 2) + + // Database param provided without cluster name. + if len(parts) <= 1 { + return "", databaseParam, nil + } + + clusterName, databaseName = parts[0], parts[1] + + // Ensure that the param is in the right format if the delimiter is provided. + if len(parts) > 2 || clusterName == "" || databaseName == "" { + return "", "", errors.New("invalid database param") + } + + return clusterName, databaseName, nil +} + +// parseOptionsParam parses the options parameter from the PG connection string, +// and tries to return the cluster name if present. Just like PostgreSQL, the +// sqlproxy supports three different ways to set a run-time parameter through +// its command-line options: +// -c NAME=VALUE (commonly used throughout documentation around PGOPTIONS) +// -cNAME=VALUE +// --NAME=VALUE +// +// CockroachDB currently does not support the options parameter, so the parsing +// logic is built on that assumption. If we do start supporting options in +// CockroachDB itself, then we should revisit this. +// +// Note that this parsing approach is not perfect as it allows a negative case +// like options="-c --cluster=happy-koala -c -c -c" to go through. To properly +// parse this, we need to traverse the string from left to right, and look at +// every single argument, but that involves quite a bit of work, so we'll punt +// for now. +func parseOptionsParam(optionsParam string) (string, error) { + // Only search up to 2 in case of large inputs. + matches := clusterNameLongOptionRE.FindAllStringSubmatch(optionsParam, 2 /* n */) + if len(matches) == 0 { + return "", nil + } + + if len(matches) > 1 { + // Technically we could still allow requests to go through if all + // cluster names match, but we don't want to parse the entire string, so + // we will just error out if at least two cluster flags are provided. + return "", errors.New("multiple cluster flags provided") + } + + // Length of each match should always be 2 with the given regex, one for + // the full string, and the other for the cluster name. + if len(matches[0]) != 2 { + // We don't want to panic here. + return "", errors.New("internal server error") + } + + // Flag was provided, but value is NULL. + if len(matches[0][1]) == 0 { + return "", errors.New("invalid cluster flag") + } + + return matches[0][1], nil +} + +// IncomingTLSConfig gets back the current TLS config for the incoiming client +// connection endpoint. +func (handler *proxyHandler) IncomingTLSConfig() *tls.Config { + if handler.incomingCert == nil { + return nil + } + + cert := handler.incomingCert.TLSCert() + if cert == nil { + return nil + } + + return &tls.Config{Certificates: []tls.Certificate{*cert}} +} + +// setupIncomingCert will setup a manged cert for the incoming connections. +// They can either be unencrypted (in case a cert and key names are empty), +// using self-signed, runtime generated cert (if cert is set to *) or +// using file based cert where the cert/key values refer to file names +// containing the information. +func (handler *proxyHandler) setupIncomingCert() error { + if (handler.ListenKey == "") != (handler.ListenCert == "") { + return errors.New("must specify either both or neither of cert and key") + } + + if handler.ListenCert == "" { + return nil + } + + // TODO(darin): change the cert manager so it uses the stopper. + ctx, _ := handler.stopper.WithCancelOnQuiesce(context.Background()) + certMgr := certmgr.NewCertManager(ctx) + var cert certmgr.Cert + if handler.ListenCert == "*" { + cert = certmgr.NewSelfSignedCert(0, 3, 0, 0) + } else if handler.ListenCert != "" { + cert = certmgr.NewFileCert(handler.ListenCert, handler.ListenKey) + } + cert.Reload(ctx) + err := cert.Err() + if err != nil { + return err + } + certMgr.ManageCert("client", cert) + handler.certManager = certMgr + handler.incomingCert = cert + + return nil +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go new file mode 100644 index 000000000000..bdcbb90020dd --- /dev/null +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -0,0 +1,735 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/require" +) + +// To ensure tenant startup code is included. +var _ = kvtenantccl.Connector{} + +const frontendError = "Frontend error!" +const backendError = "Backend error!" + +func hookBackendDial( + f func( + msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + ) (net.Conn, error), +) func() { + return testutils.HookGlobal(&backendDial, f) +} +func hookFrontendAdmit( + f func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error), +) func() { + return testutils.HookGlobal(&frontendAdmit, f) +} +func hookSendErrToClient(f func(conn net.Conn, err error)) func() { + return testutils.HookGlobal(&sendErrToClient, f) +} +func hookAuthenticate(f func(clientConn, crdbConn net.Conn) error) func() { + return testutils.HookGlobal(&authenticate, f) +} + +func newSecureProxyServer( + ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, +) (server *Server, addr string) { + // Created via: + const _ = ` +openssl genrsa -out testserver.key 2048 +openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ + -days 3650 -config testserver_config.cnf +` + opts.ListenKey = "testserver.key" + opts.ListenCert = "testserver.crt" + + return newProxyServer(ctx, t, stopper, opts) +} + +func newProxyServer( + ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, +) (server *Server, addr string) { + const listenAddress = "127.0.0.1:0" + ln, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + require.NoError(t, + stopper.RunAsyncTask(ctx, "proxy-ln-close", func(ctx context.Context) { + <-stopper.ShouldQuiesce() + require.NoError(t, ln.Close()) + })) + + server, err = NewServer(ctx, stopper, *opts) + require.NoError(t, err) + + err = server.Stopper.RunAsyncTask(ctx, "proxy-server-serve", func(ctx context.Context) { + _ = server.Serve(ctx, ln) + }) + require.NoError(t, err) + + return server, ln.Addr().String() +} + +func runTestQuery(ctx context.Context, conn *pgx.Conn) error { + var n int + if err := conn.QueryRow(ctx, "SELECT $1::int", 1).Scan(&n); err != nil { + return err + } + if n != 1 { + return errors.Errorf("expected 1 got %d", n) + } + return nil +} + +type assertCtx struct { + emittedCode *errorCode +} + +func makeAssertCtx() assertCtx { + var emittedCode errorCode = -1 + return assertCtx{ + emittedCode: &emittedCode, + } +} + +func (ac *assertCtx) onSendErrToClient(code errorCode) { + *ac.emittedCode = code +} + +func (ac *assertCtx) assertConnectErr( + ctx context.Context, t *testing.T, prefix, suffix string, expCode errorCode, expErr string, +) { + t.Helper() + *ac.emittedCode = -1 + t.Run(suffix, func(t *testing.T) { + conn, err := pgx.Connect(ctx, prefix+suffix) + if err == nil { + _ = conn.Close(ctx) + } + require.Contains(t, err.Error(), expErr) + require.Equal(t, expCode, *ac.emittedCode) + }) +} + +func TestLongDBName(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookBackendDial(func(_ *pgproto3.StartupMessage, _ string, _ *tls.Config) (net.Conn, error) { + return nil, newErrorf(codeParamsRoutingFailed, "boom") + })() + + ac := makeAssertCtx() + originalSendErrToClient := sendErrToClient + defer hookSendErrToClient(func(conn net.Conn, err error) { + if codeErr, ok := err.(*codeError); ok { + ac.onSendErrToClient(codeErr.code) + } + originalSendErrToClient(conn, err) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + + longDB := strings.Repeat("x", 70) // 63 is limit + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=dim-dog-28", addr, longDB) + ac.assertConnectErr(ctx, t, pgurl, "" /* suffix */, codeParamsRoutingFailed, "boom") + require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) +} + +func TestFailedConnection(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + // TODO(asubiotto): consider using datadriven for these, especially if the + // proxy becomes more complex. + + var originalSendErrToClient = sendErrToClient + ac := makeAssertCtx() + defer hookSendErrToClient(func(conn net.Conn, err error) { + if codeErr, ok := err.(*codeError); ok { + ac.onSendErrToClient(codeErr.code) + } + originalSendErrToClient(conn, err) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) + + _, p, err := net.SplitHostPort(addr) + require.NoError(t, err) + u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) + // Valid connections, but no backend server running. + for _, sslmode := range []string{"require", "prefer"} { + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode="+sslmode, + codeBackendDown, "unable to reach backend SQL server", + ) + } + + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode=verify-ca&sslrootcert=testserver.crt", + codeBackendDown, "unable to reach backend SQL server", + ) + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode=verify-full&sslrootcert=testserver.crt", + codeBackendDown, "unable to reach backend SQL server", + ) + require.Equal(t, int64(4), s.metrics.BackendDownCount.Count()) + + // Unencrypted connections bounce. + for _, sslmode := range []string{"disable", "allow"} { + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode="+sslmode, + codeUnexpectedInsecureStartupMessage, "server requires encryption", + ) + } + require.Equal(t, int64(0), s.metrics.RoutingErrCount.Count()) + // TenantID rejected as malformed. + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim&sslmode=require", + codeParamsRoutingFailed, "invalid cluster name", + ) + require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) + // No TenantID. + ac.assertConnectErr( + ctx, t, u, "?sslmode=require", + codeParamsRoutingFailed, "missing cluster name", + ) + require.Equal(t, int64(2), s.metrics.RoutingErrCount.Count()) +} + +func TestUnexpectedError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + // Set up a Server whose FrontendAdmitter function always errors with a + // non-codeError error. + defer hookFrontendAdmit( + func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + log.Infof(context.Background(), "frontendAdmit returning unexpected error") + return conn, nil, errors.New("unexpected error") + })() + + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + + u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addr) + + // Time how long it takes for pgx.Connect to return. If the proxy handles + // errors appropriately, pgx.Connect should return near immediately + // because the server should close the connection. If not, it may take up + // to the 5s connect_timeout for pgx.Connect to give up. + start := timeutil.Now() + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + t.Log(err) + elapsed := timeutil.Since(start) + if elapsed >= 5*time.Second { + t.Errorf("pgx.Connect took %s to error out", elapsed) + } +} + +func TestProxyAgainstSecureCRDB(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + var connSuccess bool + originalAuthenticate := authenticate + defer hookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := originalAuthenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + + url := fmt.Sprintf("postgres://bob:wrong@%s/dim-dog-28.defaultdb?sslmode=require", addr) + _, err := pgx.Connect(ctx, url) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + url = fmt.Sprintf("postgres://bob@%s/dim-dog-28.defaultdb?sslmode=require", addr) + _, err = pgx.Connect(ctx, url) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + url = fmt.Sprintf("postgres://bob:builder@%s/dim-dog-28.defaultdb?sslmode=require", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(2), s.metrics.AuthFailedCount.Count()) + }() + + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestProxyTLSClose(t *testing.T) { + defer leaktest.AfterTest(t)() + // NB: The leaktest call is an important part of this test. We're + // verifying that no goroutines are leaked, despite calling Close an + // underlying TCP connection (rather than the TLSConn that wraps it). + + ctx := context.Background() + + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + var proxyIncomingConn atomic.Value // *conn + var connSuccess bool + frontendAdmit := frontendAdmit + defer hookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + proxyIncomingConn.Store(conn) + return frontendAdmit(conn, incomingTLSConfig) + })() + originalAuthenticate := authenticate + defer hookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := originalAuthenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + + url := fmt.Sprintf("postgres://bob:builder@%s/dim-dog-28.defaultdb?sslmode=require", addr) + c, err := pgx.Connect(ctx, url) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + incomingConn, ok := proxyIncomingConn.Load().(*conn) + require.True(t, ok) + require.NoError(t, incomingConn.Close()) + <-incomingConn.done() // should immediately proceed + + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) + }() + + require.NoError(t, runTestQuery(ctx, c)) +} + +func TestProxyModifyRequestParams(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + proxyOutgoingTLSConfig := outgoingTLSConfig.Clone() + proxyOutgoingTLSConfig.InsecureSkipVerify = true + + backendDial := backendDial + defer hookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + params := msg.Parameters + authToken, ok := params["authToken"] + require.True(t, ok) + require.Equal(t, "abc123", authToken) + user, ok := params["user"] + require.True(t, ok) + require.Equal(t, "bogususer", user) + require.Contains(t, params, "user") + + // NB: This test will fail unless the user used between the proxy + // and the backend is changed to a user that actually exists. + delete(params, "authToken") + params["user"] = "root" + + return backendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + })() + + defer sql.Stopper().Stop(ctx) + + s, proxyAddr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) + + u := fmt.Sprintf("postgres://bogususer@%s/?sslmode=require&authToken=abc123&options=--cluster=dim-dog-28", proxyAddr) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestInsecureProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + s, addr := newProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + + u := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + u = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + + defer func() { + require.NoError(t, conn.Close(ctx)) + require.Equal(t, int64(1), s.metrics.AuthFailedCount.Count()) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestInsecureDoubleProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) + defer sql.Stopper().Stop(ctx) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + // Test multiple proxies: proxyB -> proxyA -> tc + _, proxyA := newProxyServer(ctx, t, sql.Stopper(), + &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), Insecure: true}, + ) + _, proxyB := newProxyServer(ctx, t, sql.Stopper(), + &ProxyOptions{RoutingRule: proxyA, Insecure: true}, + ) + + u := fmt.Sprintf("postgres://root:admin@%s/dim-dog-28.dim-dog-29.defaultdb?sslmode=disable", proxyB) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + }() + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestErroneousFrontend(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return conn, nil, errors.New(frontendError) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + + u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + // Generic message here as the Frontend's error is not codeError and + // by default we don't pass back error's text. The startup message doesn't get + // processed in this case. + require.Regexp(t, "connection reset by peer|failed to receive message", err) +} + +func TestErroneousBackend(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookBackendDial( + func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + return nil, errors.New(backendError) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + + u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + // Generic message here as the Backend's error is not codeError and + // by default we don't pass back error's text. The startup message has already + // been processed. + require.Regexp(t, "failed to receive message", err) +} + +func TestProxyRefuseConn(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookBackendDial(func(_ *pgproto3.StartupMessage, _ string, _ *tls.Config) (net.Conn, error) { + return nil, newErrorf(codeProxyRefusedConnection, "too many attempts") + })() + + ac := makeAssertCtx() + originalSendErrToClient := sendErrToClient + defer hookSendErrToClient(func(conn net.Conn, err error) { + if codeErr, ok := err.(*codeError); ok { + ac.onSendErrToClient(codeErr.code) + } + originalSendErrToClient(conn, err) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + + ac.assertConnectErr( + ctx, t, fmt.Sprintf("postgres://root:admin@%s/", addr), + "?sslmode=require&options=--cluster=dim-dog-28", + codeProxyRefusedConnection, "too many attempts", + ) + require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) + require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) +} + +func TestDenylistUpdate(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + denyList, err := ioutil.TempFile("", "*_denylist.yml") + require.NoError(t, err) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + proxyOutgoingTLSConfig := outgoingTLSConfig.Clone() + proxyOutgoingTLSConfig.InsecureSkipVerify = true + + backendDial := backendDial + defer hookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + time.AfterFunc(100*time.Millisecond, func() { + _, err := denyList.WriteString("127.0.0.1: test-denied") + require.NoError(t, err) + }) + return backendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{Denylist: denyList.Name()}) + defer func() { _ = os.Remove(denyList.Name()) }() + + url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(context.Background(), url) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.Equal(t, int64(1), s.metrics.ExpiredClientConnCount.Count()) + }() + + require.Eventuallyf( + t, + func() bool { + _, err = conn.Exec(context.Background(), "SELECT 1") + return err != nil && strings.Contains(err.Error(), "expired") + }, + time.Second, 5*time.Millisecond, + "unexpected error received: %v", err, + ) +} + +func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + proxyOutgoingTLSConfig := outgoingTLSConfig.Clone() + proxyOutgoingTLSConfig.InsecureSkipVerify = true + + idleTimeout, _ := time.ParseDuration("0.5s") + var connSuccess bool + frontendAdmit := frontendAdmit + defer hookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return frontendAdmit(conn, incomingTLSConfig) + })() + backendDial := backendDial + defer hookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + return backendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + })() + originalAuthenticate := authenticate + defer hookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := originalAuthenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{IdleTimeout: idleTimeout}) + + url := fmt.Sprintf("postgres://root:admin@%s/?sslmode=require&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + }() + + var n int + err = conn.QueryRow(ctx, "SELECT $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + time.Sleep(idleTimeout * 2) + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") +} + +func newDirectoryServer( + ctx context.Context, t *testing.T, srv serverutils.TestServerInterface, addr *net.TCPAddr, +) (*stop.Stopper, *net.TCPAddr) { + tdsStopper := stop.NewStopper() + listener, err := net.ListenTCP("tcp", addr) + require.NoError(t, err) + tds := tenant.NewTestDirectoryServer(tdsStopper) + tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*tenant.Process, error) { + log.TestingClearServerIdentifiers() + tenantStopper := tenant.NewSubStopper(tdsStopper) + ten, err := srv.StartTenant(ctx, base.TestTenantArgs{ + Existing: true, + TenantID: roachpb.MakeTenantID(tenantID), + ForceInsecure: true, + Stopper: tenantStopper, + }) + require.NoError(t, err) + sqlAddr, err := net.ResolveTCPAddr("tcp", ten.SQLAddr()) + require.NoError(t, err) + ten.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + return &tenant.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil + } + go func() { require.NoError(t, tds.Serve(listener)) }() + return tdsStopper, listener.Addr().(*net.TCPAddr) +} + +func TestDirectoryReconnect(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + // New test cluster + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) + defer srv.Stopper().Stop(ctx) + + // Create tenant 28 + sqlConn := srv.InternalExecutor().(*sql.InternalExecutor) + _, err := sqlConn.Exec(ctx, "", nil, "SELECT crdb_internal.create_tenant(28)") + require.NoError(t, err) + + // New test directory server + stopper1, tdsAddr := newDirectoryServer(ctx, t, srv, &net.TCPAddr{}) + + // New proxy server using the directory + _, addr := newProxyServer( + ctx, t, srv.Stopper(), &ProxyOptions{DirectoryAddr: tdsAddr.String(), Insecure: true}, + ) + + // try to connect - should be successful. + url := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err = pgx.Connect(ctx, url) + require.NoError(t, err) + + // Stop the directory server and the tenant + stopper1.Stop(ctx) + + var succeeded syncutil.AtomicBool + + stopper2, _ := newDirectoryServer(ctx, t, srv, tdsAddr) + defer stopper2.Stop(ctx) + + require.Eventually(t, func() bool { + // try to connect through the proxy again - should be successful. + url = fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err = pgx.Connect(ctx, url) + require.NoError(t, err) + succeeded.Set(true) + return true + }, 1000*time.Second, 100*time.Millisecond) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_test.go b/pkg/ccl/sqlproxyccl/proxy_test.go deleted file mode 100644 index 1b81ea2cf61b..000000000000 --- a/pkg/ccl/sqlproxyccl/proxy_test.go +++ /dev/null @@ -1,671 +0,0 @@ -// Copyright 2020 The Cockroach Authors. -// -// Licensed as a CockroachDB Enterprise file under the Cockroach Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt - -package sqlproxyccl - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/require" -) - -const FrontendError = "Frontend error!" -const BackendError = "Backend error!" - -func setupTestProxyWithCerts( - t *testing.T, opts *Options, -) (server *Server, addr string, done func()) { - // Created via: - const _ = ` -openssl genrsa -out testserver.key 2048 -openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ - -days 3650 -config testserver_config.cnf -` - cer, err := tls.LoadX509KeyPair("testserver.crt", "testserver.key") - require.NoError(t, err) - opts.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return FrontendAdmit( - incoming, - &tls.Config{ - Certificates: []tls.Certificate{cer}, - ServerName: "localhost", - }, - ) - } - - const listenAddress = "127.0.0.1:0" - // NB: ln closes before wg.Wait or we deadlock. - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - - var wg sync.WaitGroup - wg.Add(1) - - done = func() { - _ = ln.Close() - wg.Wait() - } - - server = NewServer(*opts) - - go func() { - defer wg.Done() - _ = server.Serve(ln) - }() - - return server, ln.Addr().String(), done -} - -func testingTenantIDFromDatabaseForAddr( - addr string, validTenant string, -) func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return func(msg *pgproto3.StartupMessage) (net.Conn, error) { - const dbKey = "database" - p := msg.Parameters - db, ok := p[dbKey] - if !ok { - return nil, NewErrorf( - CodeParamsRoutingFailed, "need to specify database", - ) - } - sl := strings.SplitN(db, "_", 2) - if len(sl) != 2 { - return nil, NewErrorf( - CodeParamsRoutingFailed, "malformed database name", - ) - } - db, tenantID := sl[0], sl[1] - - if tenantID != validTenant { - return nil, NewErrorf(CodeParamsRoutingFailed, "invalid tenantID") - } - - p[dbKey] = db - return BackendDial(msg, addr, &tls.Config{ - // NB: this would be false in production. - InsecureSkipVerify: true, - }) - } -} - -func runTestQuery(conn *pgx.Conn) error { - var n int - if err := conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n); err != nil { - return err - } - if n != 1 { - return errors.Errorf("expected 1 got %d", n) - } - return nil -} - -type assertCtx struct { - emittedCode *ErrorCode -} - -func makeAssertCtx() assertCtx { - var emittedCode ErrorCode = -1 - return assertCtx{ - emittedCode: &emittedCode, - } -} - -func (ac *assertCtx) onSendErrToClient(code ErrorCode, msg string) string { - *ac.emittedCode = code - return msg -} - -func (ac *assertCtx) assertConnectErr( - t *testing.T, prefix, suffix string, expCode ErrorCode, expErr string, -) { - t.Helper() - *ac.emittedCode = -1 - t.Run(suffix, func(t *testing.T) { - ctx := context.Background() - conn, err := pgx.Connect(ctx, prefix+suffix) - if err == nil { - _ = conn.Close(ctx) - } - require.Contains(t, err.Error(), expErr) - require.Equal(t, expCode, *ac.emittedCode) - - }) -} - -func TestLongDBName(t *testing.T) { - defer leaktest.AfterTest(t)() - - ac := makeAssertCtx() - - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return nil, NewErrorf(CodeParamsRoutingFailed, "boom") - }, - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - longDB := strings.Repeat("x", 70) // 63 is limit - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s", addr, longDB) - ac.assertConnectErr(t, pgurl, "" /* suffix */, CodeParamsRoutingFailed, "boom") - require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) -} - -func TestFailedConnection(t *testing.T) { - defer leaktest.AfterTest(t)() - - // TODO(asubiotto): consider using datadriven for these, especially if the - // proxy becomes more complex. - - ac := makeAssertCtx() - opts := Options{ - BackendDialer: testingTenantIDFromDatabaseForAddr( - "undialable%$!@$", "29", - ), - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - _, p, err := net.SplitHostPort(addr) - require.NoError(t, err) - u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) - // Valid connections, but no backend server running. - for _, sslmode := range []string{"require", "prefer"} { - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode="+sslmode, - CodeBackendDown, "unable to reach backend SQL server", - ) - } - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode=verify-ca&sslrootcert=testserver.crt", - CodeBackendDown, "unable to reach backend SQL server", - ) - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode=verify-full&sslrootcert=testserver.crt", - CodeBackendDown, "unable to reach backend SQL server", - ) - require.Equal(t, int64(4), s.metrics.BackendDownCount.Count()) - - // Unencrypted connections bounce. - for _, sslmode := range []string{"disable", "allow"} { - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode="+sslmode, - CodeUnexpectedInsecureStartupMessage, "server requires encryption", - ) - } - - // TenantID rejected by test hook. - ac.assertConnectErr( - t, u, "defaultdb_28?sslmode=require", - CodeParamsRoutingFailed, "invalid tenantID", - ) - - // No TenantID. - ac.assertConnectErr( - t, u, "defaultdb?sslmode=require", - CodeParamsRoutingFailed, "malformed database name", - ) - require.Equal(t, int64(2), s.metrics.RoutingErrCount.Count()) -} - -func TestUnexpectedError(t *testing.T) { - defer leaktest.AfterTest(t)() - - // Set up a Server whose FrontendAdmitter function always errors with a - // non-CodeError error. - ctx := context.Background() - s := NewServer(Options{ - FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return nil, nil, errors.New("unexpected error") - }, - }) - const listenAddress = "127.0.0.1:0" - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = s.Serve(ln) - }() - defer func() { - _ = ln.Close() - wg.Wait() - }() - - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", ln.Addr().String()) - - // Time how long it takes for pgx.Connect to return. If the proxy handles - // errors appropriately, pgx.Connect should return near immediately - // because the server should close the connection. If not, it may take up - // to the 5s connect_timeout for pgx.Connect to give up. - start := timeutil.Now() - _, err = pgx.Connect(ctx, u) - require.Error(t, err) - t.Log(err) - elapsed := timeutil.Since(start) - if elapsed >= 5*time.Second { - t.Errorf("pgx.Connect took %s to error out", elapsed) - } -} - -func TestProxyAgainstSecureCRDB(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{OnConnectionSuccess: func() { connSuccess = true }}, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=require", addr) - _, err := pgx.Connect(context.Background(), url) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - url = fmt.Sprintf("postgres://bob@%s?sslmode=require", addr) - _, err = pgx.Connect(context.Background(), url) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - url = fmt.Sprintf("postgres://bob:builder@%s?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(2), s.metrics.AuthFailedCount.Count()) - }() - - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - require.NoError(t, runTestQuery(conn)) -} - -func TestProxyTLSClose(t *testing.T) { - defer leaktest.AfterTest(t)() - // NB: The leaktest call is an important part of this test. We're - // verifying that no goroutines are leaked, despite calling Close an - // underlying TCP connection (rather than the TLSConn that wraps it). - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - var proxyIncomingConn atomic.Value // *Conn - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, conn *Conn) (*BackendConfig, error) { - proxyIncomingConn.Store(conn) - return &BackendConfig{ - OnConnectionSuccess: func() { connSuccess = true }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://bob:builder@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - incomingConn := proxyIncomingConn.Load().(*Conn) - require.NoError(t, incomingConn.Close()) - <-incomingConn.Done() // should immediately proceed - - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func TestProxyModifyRequestParams(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - params := msg.Parameters - require.EqualValues(t, map[string]string{ - "authToken": "abc123", - "user": "bogususer", - }, params) - - // NB: This test will fail unless the user used between the proxy - // and the backend is changed to a user that actually exists. - delete(params, "authToken") - params["user"] = "root" - - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - }, - } - s, proxyAddr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - u := fmt.Sprintf("postgres://bogususer@%s/?sslmode=require&authToken=abc123", proxyAddr) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - require.NoError(t, conn.Close(ctx)) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func newInsecureProxyServer( - t *testing.T, outgoingAddr string, outgoingTLSConfig *tls.Config, customOptions ...func(*Options), -) (server *Server, addr string, cleanup func()) { - op := Options{ - BackendDialer: func(message *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(message, outgoingAddr, outgoingTLSConfig) - }} - for _, opt := range customOptions { - opt(&op) - } - s := NewServer(op) - const listenAddress = "127.0.0.1:0" - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = s.Serve(ln) - }() - return s, ln.Addr().String(), func() { - _ = ln.Close() - wg.Wait() - } -} - -func TestInsecureProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - s, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}, func(op *Options) {} /* custom options */) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable", addr) - _, err := pgx.Connect(context.Background(), u) - require.Error(t, err) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - u = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - - defer func() { - require.NoError(t, conn.Close(ctx)) - require.Equal(t, int64(1), s.metrics.AuthFailedCount.Count()) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func TestInsecureDoubleProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ - ServerArgs: base.TestServerArgs{Insecure: true}, - }) - defer tc.Stopper().Stop(ctx) - - // Test multiple proxies: proxyB -> proxyA -> tc - _, proxyA, cleanupA := newInsecureProxyServer(t, tc.Server(0).ServingSQLAddr(), - nil /* tls config */, func(op *Options) {} /* custom server options */) - defer cleanupA() - _, proxyB, cleanupB := newInsecureProxyServer(t, proxyA, nil /* tls config */, func(op *Options) {} /* custom server options */) - defer cleanupB() - - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable", proxyB) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - }() - require.NoError(t, runTestQuery(conn)) -} - -func TestErroneousFrontend(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - _, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), nil, /* tls config */ - func(op *Options) { - op.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return nil, nil, errors.New(FrontendError) - } - }) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - - _, err := pgx.Connect(ctx, u) - require.Error(t, err) - require.True(t, testutils.IsError(err, FrontendError)) -} - -func TestErroneousBackend(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - _, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), nil, /* tls config */ - func(op *Options) { - op.BackendDialer = func(message *pgproto3.StartupMessage) (net.Conn, error) { - return nil, errors.New(BackendError) - } - }) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - - _, err := pgx.Connect(ctx, u) - require.Error(t, err) - require.True(t, testutils.IsError(err, BackendError)) -} - -func TestProxyRefuseConn(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - ac := makeAssertCtx() - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return nil, NewErrorf(CodeProxyRefusedConnection, "too many attempts") - }, - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - ac.assertConnectErr( - t, fmt.Sprintf("postgres://root:admin@%s/", addr), "defaultdb_29?sslmode=require", - CodeProxyRefusedConnection, "too many attempts", - ) - require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) - require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) -} - -func TestProxyKeepAlive(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - // Don't let connections last more than 100ms. - KeepAliveLoop: func(ctx context.Context) error { - t := timeutil.NewTimer() - t.Reset(100 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - t.Read = true - return errors.New("expired") - } - } - }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.Equal(t, int64(1), s.metrics.ExpiredClientConnCount.Count()) - }() - - require.Eventuallyf( - t, - func() bool { - _, err = conn.Exec(context.Background(), "SELECT 1") - return err != nil && strings.Contains(err.Error(), "expired") - }, - time.Second, 5*time.Millisecond, - "unexpected error received: %v", err, - ) -} - -func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - idleTimeout, _ := time.ParseDuration("0.5s") - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OnConnectionSuccess: func() { connSuccess = true }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - conn, err := BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - if err != nil { - return nil, err - } - return IdleDisconnectOverlay(conn, idleTimeout), nil - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - }() - - var n int - err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) - require.NoError(t, err) - require.EqualValues(t, 1, n) - time.Sleep(idleTimeout * 2) - err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) - require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") -} diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index 2dedd3b3bfbf..fb5779d559b2 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -18,18 +18,24 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/httputil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" ) +// proxyConnHandler defines the signature of the function that handles each +// individual new incoming connection. +type proxyConnHandler func(ctx context.Context, proxyConn *conn) error + // Server is a TCP server that proxies SQL connections to a // configurable backend. It may also run an HTTP server to expose a // health check and prometheus metrics. type Server struct { - opts *Options + Stopper *stop.Stopper + connHandler proxyConnHandler mux *http.ServeMux - metrics *Metrics + metrics *metrics metricsRegistry *metric.Registry promMu syncutil.Mutex @@ -38,17 +44,22 @@ type Server struct { // NewServer constructs a new proxy server and provisions metrics and health // checks as well. -func NewServer(opts Options) *Server { +func NewServer(ctx context.Context, stopper *stop.Stopper, options ProxyOptions) (*Server, error) { + proxyMetrics := makeProxyMetrics() + handler, err := newProxyHandler(ctx, stopper, &proxyMetrics, options) + if err != nil { + return nil, err + } + mux := http.NewServeMux() registry := metric.NewRegistry() - proxyMetrics := MakeProxyMetrics() - - registry.AddMetricStruct(proxyMetrics) + registry.AddMetricStruct(&proxyMetrics) s := &Server{ - opts: &opts, + Stopper: stopper, + connHandler: handler.handle, mux: mux, metrics: &proxyMetrics, metricsRegistry: registry, @@ -60,7 +71,7 @@ func NewServer(opts Options) *Server { mux.HandleFunc("/_status/vars/", s.handleVars) mux.HandleFunc("/_status/healthz/", s.handleHealth) - return s + return s, nil } func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { @@ -134,32 +145,34 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error { // Serve serves a listener according to the Options given in NewServer(). // Incoming client connections are taken through the Postgres handshake and // relayed to the configured backend server. -func (s *Server) Serve(ln net.Listener) error { +func (s *Server) Serve(ctx context.Context, ln net.Listener) error { for { origConn, err := ln.Accept() if err != nil { return err } - conn := &Conn{ + conn := &conn{ Conn: origConn, } - go func() { + err = s.Stopper.RunAsyncTask(ctx, "proxy-con-handler", func(ctx context.Context) { defer func() { _ = conn.Close() }() s.metrics.CurConnCount.Inc(1) defer s.metrics.CurConnCount.Dec(1) - tBegin := timeutil.Now() remoteAddr := conn.RemoteAddr() - log.Infof(context.Background(), "handling client %s", remoteAddr) - err := s.Proxy(conn) - log.Infof(context.Background(), "client %s disconnected after %.2fs: %v", - remoteAddr, timeutil.Since(tBegin).Seconds(), err) - }() + ctxWithTag := logtags.AddTag(ctx, "client", remoteAddr) + if err := s.connHandler(ctxWithTag, conn); err != nil { + log.Infof(ctxWithTag, "connection error: %v", err) + } + }) + if err != nil { + return err + } } } -// Conn is a SQL connection into the proxy. -type Conn struct { +// conn is a SQL connection into the proxy. +type conn struct { net.Conn mu struct { @@ -170,7 +183,7 @@ type Conn struct { } // Done returns a channel that's closed when the connection is closed. -func (c *Conn) Done() <-chan struct{} { +func (c *conn) done() <-chan struct{} { c.mu.Lock() defer c.mu.Unlock() if c.mu.closedCh == nil { @@ -184,8 +197,8 @@ func (c *Conn) Done() <-chan struct{} { // Close closes the connection. // Any blocked Read or Write operations will be unblocked and return errors. -// The connection's Done channel will be closed. -func (c *Conn) Close() error { +// The connection's Done channel will be closed. This overrides net.Conn.Close. +func (c *conn) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.mu.closed { diff --git a/pkg/ccl/sqlproxyccl/server_test.go b/pkg/ccl/sqlproxyccl/server_test.go index f6b411695d2f..c20491b07780 100644 --- a/pkg/ccl/sqlproxyccl/server_test.go +++ b/pkg/ccl/sqlproxyccl/server_test.go @@ -9,18 +9,25 @@ package sqlproxyccl import ( + "context" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/stretchr/testify/require" ) func TestHandleHealth(t *testing.T) { defer leaktest.AfterTest(t)() - proxyServer := NewServer(Options{}) + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + proxyServer, err := NewServer(ctx, stopper, ProxyOptions{}) + require.NoError(t, err) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/_status/healthz/", nil) @@ -36,7 +43,12 @@ func TestHandleHealth(t *testing.T) { func TestHandleVars(t *testing.T) { defer leaktest.AfterTest(t)() - proxyServer := NewServer(Options{}) + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + proxyServer, err := NewServer(ctx, stopper, ProxyOptions{}) + require.NoError(t, err) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/_status/vars/", nil) diff --git a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel index 509fb42f6adb..243234607fa9 100644 --- a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel @@ -1,5 +1,5 @@ load("@bazel_gomock//:gomock.bzl", "gomock") -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") @@ -47,34 +47,6 @@ go_library( ], ) -go_test( - name = "tenant_test", - srcs = [ - "directory_test.go", - "main_test.go", - ], - embed = [":tenant"], - deps = [ - "//pkg/base", - "//pkg/ccl", - "//pkg/ccl/utilccl", - "//pkg/roachpb", - "//pkg/security", - "//pkg/security/securitytest", - "//pkg/server", - "//pkg/sql", - "//pkg/testutils/serverutils", - "//pkg/testutils/testcluster", - "//pkg/util/leaktest", - "//pkg/util/log", - "//pkg/util/randutil", - "//pkg/util/stop", - "@com_github_stretchr_testify//assert", - "@com_github_stretchr_testify//require", - "@org_golang_google_grpc//:go_default_library", - ], -) - gomock( name = "mocks_tenant", out = "mocks_generated.go", diff --git a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go index d2e4dd7ec624..315f86b675ee 100644 --- a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go @@ -193,7 +193,7 @@ func (s *TestDirectoryServer) EnsureEndpoint( s.proc.Lock() defer s.proc.Unlock() - lst, err := s.listLocked(ctx, &ListEndpointsRequest{req.TenantID}) + lst, err := s.listLocked(ctx, &ListEndpointsRequest{TenantID: req.TenantID}) if err != nil { return nil, err } @@ -225,6 +225,7 @@ func NewTestDirectoryServer(stopper *stop.Stopper) *TestDirectoryServer { dir.TenantStarterFunc = dir.startTenantLocked dir.proc.processByAddrByTenantID = map[uint64]map[net.Addr]*Process{} dir.listen.eventListeners = list.New() + stopper.AddCloser(stop.CloserFn(func() { dir.grpcServer.GracefulStop() })) RegisterDirectoryServer(dir.grpcServer, dir) return dir } @@ -301,6 +302,7 @@ func (s *TestDirectoryServer) startTenantLocked( fmt.Sprintf("--sql-addr=%s", sql.Addr().String()), fmt.Sprintf("--http-addr=%s", http.Addr().String()), fmt.Sprintf("--tenant-id=%d", tenantID), + "--insecure", } if err = sql.Close(); err != nil { return nil, err @@ -331,7 +333,7 @@ func (s *TestDirectoryServer) startTenantLocked( _ = c.Process.Kill() s.deregisterInstance(tenantID, process.SQL) })) - err = process.Stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { + err = s.stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { if err := c.Wait(); err != nil { log.Infof(ctx, "finished %s with err %s", process.Cmd.Args, err) log.Infof(ctx, "output %s", b.Bytes()) diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel new file mode 100644 index 000000000000..475c9d728605 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel @@ -0,0 +1,30 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "tenantdirsvr_test", + srcs = [ + "directory_test.go", + "main_test.go", + ], + deps = [ + "//pkg/base", + "//pkg/ccl", + "//pkg/ccl/kvccl/kvtenantccl", + "//pkg/ccl/sqlproxyccl/tenant", + "//pkg/ccl/utilccl", + "//pkg/roachpb", + "//pkg/security", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql", + "//pkg/testutils/serverutils", + "//pkg/testutils/testcluster", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/randutil", + "//pkg/util/stop", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_test.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go similarity index 90% rename from pkg/ccl/sqlproxyccl/tenant/directory_test.go rename to pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go index 5027281400af..b172e31530e2 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_test.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package tenant +package tenantdirsvr import ( "context" @@ -16,6 +16,8 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -27,6 +29,9 @@ import ( "google.golang.org/grpc" ) +// To ensure tenant startup code is included. +var _ = kvtenantccl.Connector{} + func TestDirectoryErrors(t *testing.T) { defer leaktest.AfterTest(t)() defer log.ScopeWithoutShowLogs(t).Close(t) @@ -89,7 +94,7 @@ func TestEndpointWatcher(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // Resume tenant again by a direct call to the directory server - _, err = tds.EnsureEndpoint(ctx, &EnsureEndpointRequest{tenantID.ToUint64()}) + _, err = tds.EnsureEndpoint(ctx, &tenant.EnsureEndpointRequest{tenantID.ToUint64()}) require.NoError(t, err) // Wait for background watcher to populate the initial endpoint. @@ -188,7 +193,7 @@ func TestResume(t *testing.T) { }(i) } - var processes map[net.Addr]*Process + var processes map[net.Addr]*tenant.Process // Eventually the tenant process will be resumed. require.Eventually(t, func() bool { processes = tds.Get(tenantID) @@ -212,7 +217,7 @@ func TestDeleteTenant(t *testing.T) { // Create the directory. ctx := context.Background() // Disable throttling for this test - tc, dir, tds := newTestDirectory(t, RefreshDelay(-1)) + tc, dir, tds := newTestDirectory(t, tenant.RefreshDelay(-1)) defer tc.Stopper().Stop(ctx) tenantID := roachpb.MakeTenantID(50) @@ -266,7 +271,7 @@ func TestRefreshThrottling(t *testing.T) { // Create the directory, but with extreme rate limiting so that directory // will never refresh. ctx := context.Background() - tc, dir, _ := newTestDirectory(t, RefreshDelay(60*time.Minute)) + tc, dir, _ := newTestDirectory(t, tenant.RefreshDelay(60*time.Minute)) defer tc.Stopper().Stop(ctx) // Create test tenant. @@ -325,10 +330,10 @@ func destroyTenant(tc serverutils.TestClusterInterface, id roachpb.TenantID) err func startTenant( ctx context.Context, srv serverutils.TestServerInterface, id uint64, -) (*Process, error) { +) (*tenant.Process, error) { log.TestingClearServerIdentifiers() - tenantStopper := NewSubStopper(srv.Stopper()) - tenant, err := srv.StartTenant( + tenantStopper := tenant.NewSubStopper(srv.Stopper()) + t, err := srv.StartTenant( ctx, base.TestTenantArgs{ Existing: true, @@ -339,18 +344,22 @@ func startTenant( if err != nil { return nil, err } - sqlAddr, err := net.ResolveTCPAddr("tcp", tenant.SQLAddr()) + sqlAddr, err := net.ResolveTCPAddr("tcp", t.SQLAddr()) if err != nil { return nil, err } - return &Process{SQL: sqlAddr, Stopper: tenantStopper}, nil + return &tenant.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil } // Setup directory that uses a client connected to a test directory server // that manages tenants connected to a backing KV server. func newTestDirectory( - t *testing.T, opts ...DirOption, -) (tc serverutils.TestClusterInterface, directory *Directory, tds *TestDirectoryServer) { + t *testing.T, opts ...tenant.DirOption, +) ( + tc serverutils.TestClusterInterface, + directory *tenant.Directory, + tds *tenant.TestDirectoryServer, +) { tc = serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ // We need to start the cluster insecure in order to not // care about TLS settings for the RPC client connection. @@ -359,8 +368,8 @@ func newTestDirectory( }, }) clusterStopper := tc.Stopper() - tds = NewTestDirectoryServer(clusterStopper) - tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*Process, error) { + tds = tenant.NewTestDirectoryServer(clusterStopper) + tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*tenant.Process, error) { t.Logf("starting tenant %d", tenantID) process, err := startTenant(ctx, tc.Server(0), tenantID) if err != nil { @@ -372,7 +381,6 @@ func newTestDirectory( listenPort, err := net.Listen("tcp", ":0") require.NoError(t, err) - clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, listenPort.Close()) })) go func() { _ = tds.Serve(listenPort) }() // Setup directory @@ -381,8 +389,8 @@ func newTestDirectory( require.NoError(t, err) // nolint:grpcconnclose clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, conn.Close() /* nolint:grpcconnclose */) })) - client := NewDirectoryClient(conn) - directory, err = NewDirectory(context.Background(), clusterStopper, client, opts...) + client := tenant.NewDirectoryClient(conn) + directory, err = tenant.NewDirectory(context.Background(), clusterStopper, client, opts...) require.NoError(t, err) return diff --git a/pkg/ccl/sqlproxyccl/tenant/main_test.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go similarity index 98% rename from pkg/ccl/sqlproxyccl/tenant/main_test.go rename to pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go index 3a25676ae8dd..90199cefb604 100644 --- a/pkg/ccl/sqlproxyccl/tenant/main_test.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package tenant +package tenantdirsvr import ( "os" diff --git a/pkg/ccl/sqlproxyccl/admitter/BUILD.bazel b/pkg/ccl/sqlproxyccl/throttler/BUILD.bazel similarity index 69% rename from pkg/ccl/sqlproxyccl/admitter/BUILD.bazel rename to pkg/ccl/sqlproxyccl/throttler/BUILD.bazel index 2840f7d49664..1647e1188fbe 100644 --- a/pkg/ccl/sqlproxyccl/admitter/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/throttler/BUILD.bazel @@ -41,3 +41,22 @@ gomock( library = ":service", package = "admitter", ) + +go_library( + name = "throttler", + srcs = ["local.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler", + visibility = ["//visibility:public"], + deps = [ + "//pkg/util/log", + "//pkg/util/syncutil", + "@com_github_cockroachdb_errors//:errors", + ], +) + +go_test( + name = "throttler_test", + srcs = ["local_test.go"], + embed = [":throttler"], + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/pkg/ccl/sqlproxyccl/admitter/local.go b/pkg/ccl/sqlproxyccl/throttler/local.go similarity index 96% rename from pkg/ccl/sqlproxyccl/admitter/local.go rename to pkg/ccl/sqlproxyccl/throttler/local.go index 24fb50d381b4..0c57b801a3a0 100644 --- a/pkg/ccl/sqlproxyccl/admitter/local.go +++ b/pkg/ccl/sqlproxyccl/throttler/local.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package admitter +package throttler import ( "context" @@ -37,7 +37,7 @@ type limiter struct { index int } -// localService is an admitter Service that manages state purely in local +// localService is an throttler Service that manages state purely in local // memory. Internally, it maintains a map from IP address to rate limiting info // for that address. In order to put a cap on memory usage, the map is capped // at a maximum size, at which point a random IP address will be evicted. @@ -81,7 +81,7 @@ func WithBaseDelay(d time.Duration) LocalOption { } } -// NewLocalService returns an admitter Service that manages state purely in +// NewLocalService returns an throttler Service that manages state purely in // local memory. func NewLocalService(opts ...LocalOption) Service { s := &localService{ diff --git a/pkg/ccl/sqlproxyccl/admitter/local_test.go b/pkg/ccl/sqlproxyccl/throttler/local_test.go similarity index 99% rename from pkg/ccl/sqlproxyccl/admitter/local_test.go rename to pkg/ccl/sqlproxyccl/throttler/local_test.go index 3caeccd414b9..22bde2acace0 100644 --- a/pkg/ccl/sqlproxyccl/admitter/local_test.go +++ b/pkg/ccl/sqlproxyccl/throttler/local_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package admitter +package throttler import ( "fmt" diff --git a/pkg/ccl/sqlproxyccl/admitter/mocks_generated.go b/pkg/ccl/sqlproxyccl/throttler/mocks_generated.go similarity index 94% rename from pkg/ccl/sqlproxyccl/admitter/mocks_generated.go rename to pkg/ccl/sqlproxyccl/throttler/mocks_generated.go index 65bcc3cff0a8..5499db346db9 100644 --- a/pkg/ccl/sqlproxyccl/admitter/mocks_generated.go +++ b/pkg/ccl/sqlproxyccl/throttler/mocks_generated.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: service.go -// Package admitter is a generated GoMock package. -package admitter +// Package throttler is a generated GoMock package. +package throttler import ( reflect "reflect" diff --git a/pkg/ccl/sqlproxyccl/admitter/service.go b/pkg/ccl/sqlproxyccl/throttler/service.go similarity index 70% rename from pkg/ccl/sqlproxyccl/admitter/service.go rename to pkg/ccl/sqlproxyccl/throttler/service.go index 766ecca28f27..bac2218ac7e9 100644 --- a/pkg/ccl/sqlproxyccl/admitter/service.go +++ b/pkg/ccl/sqlproxyccl/throttler/service.go @@ -6,14 +6,14 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -// Package admitter provides admission checks functionality. Rate limiting currently. -package admitter +// Package throttler provides admission checks functionality. Rate limiting currently. +package throttler import "time" -//go:generate mockgen -package=admitter -destination=mocks_generated.go -source=service.go . Service +//go:generate mockgen -package=throttler -destination=mocks_generated.go -source=service.go . Service -// Service provides the interface for performing admission checks before +// Service provides the interface for performing throttle checks before // allowing requests into the managed service system. type Service interface { // LoginCheck determines whether a login request should be allowed to diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 20ff9ed85026..e770f7a64a4a 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -38,7 +38,9 @@ go_library( "log_flags.go", "mt.go", "mt_cert.go", + "mt_proxy.go", "mt_start_sql.go", + "mt_test_directory.go", "node.go", "nodelocal.go", "quit.go", @@ -91,6 +93,8 @@ go_library( deps = [ "//pkg/base", "//pkg/build", + "//pkg/ccl/sqlproxyccl", + "//pkg/ccl/sqlproxyccl/tenant", "//pkg/cli/cliflags", "//pkg/cli/exit", "//pkg/cli/syncbench", @@ -219,49 +223,9 @@ go_library( "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", "@org_golang_x_sync//errgroup", + "@org_golang_x_sys//unix", "@org_golang_x_time//rate", - ] + select({ - "@io_bazel_rules_go//go/platform:aix": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:android": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:darwin": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:dragonfly": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:freebsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:illumos": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:ios": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:js": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:linux": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:netbsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:openbsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:plan9": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:solaris": [ - "@org_golang_x_sys//unix", - ], - "//conditions:default": [], - }), + ], ) go_test( diff --git a/pkg/cli/cliflags/flags_mt.go b/pkg/cli/cliflags/flags_mt.go index b3bcc9845d54..c8dada8a8578 100644 --- a/pkg/cli/cliflags/flags_mt.go +++ b/pkg/cli/cliflags/flags_mt.go @@ -23,4 +23,74 @@ var ( EnvVar: "COCKROACH_KV_ADDRS", Description: `A comma-separated list of KV endpoints (load balancers allowed).`, } + + DenyList = FlagInfo{ + Name: "denylist-file", + Description: "Denylist file to limit access to IP addresses and tenant ids.", + } + + ProxyListenAddr = FlagInfo{ + Name: "listen-addr", + Description: "Listen address for incoming connections.", + } + + ListenCert = FlagInfo{ + Name: "listen-cert", + Description: "File containing PEM-encoded x509 certificate for listen address.", + } + + ListenKey = FlagInfo{ + Name: "listen-key", + Description: "File containing PEM-encoded x509 key for listen address.", + } + + ListenMetrics = FlagInfo{ + Name: "listen-metrics", + Description: "Listen address for incoming connections for metrics retrieval.", + } + + RoutingRule = FlagInfo{ + Name: "routing-rule", + Description: "Routing rule for incoming connections. Use '{{clusterName}}' for substitution.", + } + + DirectoryAddr = FlagInfo{ + Name: "directory", + Description: "Directory address of the service doing resolution from backend id to IP.", + } + + SkipVerify = FlagInfo{ + Name: "skip-verify", + Description: "If true, skip identity verification of backend. For testing only.", + } + + InsecureBackend = FlagInfo{ + Name: "insecure", + Description: "If true, use insecure connection to the backend.", + } + + RatelimitBaseDelay = FlagInfo{ + Name: "ratelimit-base-delay", + Description: "Initial backoff after a failed login attempt. Set to 0 to disable rate limiting.", + } + + ValidateAccessInterval = FlagInfo{ + Name: "validate-access-interval", + Description: "Time interval between validation that current connections are still valid.", + } + + PollConfigInterval = FlagInfo{ + Name: "poll-config-interval", + Description: "Polling interval changes in config file.", + } + + IdleTimeout = FlagInfo{ + Name: "idle-timeout", + Description: "Close connections idle for this duration.", + } + + TestDirectoryListenPort = FlagInfo{ + Name: "port", + Description: "Test directory server binds and listens on this port.", + } ) diff --git a/pkg/cli/context.go b/pkg/cli/context.go index 33983521d495..aa8f622dbee0 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -18,6 +18,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" "github.com/cockroachdb/cockroach/pkg/config/zonepb" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/settings" @@ -52,6 +53,8 @@ func initCLIDefaults() { setStmtDiagContextDefaults() setAuthContextDefaults() setImportContextDefaults() + setProxyContextDefaults() + setTestDirectorySvrContextDefaults() initPreFlagsDefaults() @@ -615,6 +618,33 @@ func setImportContextDefaults() { importCtx.rowLimit = 0 } +// proxyContext captures the command-line parameters of the `mt start-proxy` command. +var proxyContext sqlproxyccl.ProxyOptions + +func setProxyContextDefaults() { + proxyContext.Denylist = "" + proxyContext.ListenAddr = "127.0.0.1:46257" + proxyContext.ListenCert = "" + proxyContext.ListenKey = "" + proxyContext.MetricsAddress = "0.0.0.0:8080" + proxyContext.RoutingRule = "" + proxyContext.DirectoryAddr = "" + proxyContext.SkipVerify = false + proxyContext.Insecure = false + proxyContext.RatelimitBaseDelay = 50 * time.Millisecond + proxyContext.ValidateAccessInterval = 30 * time.Second + proxyContext.PollConfigInterval = 30 * time.Second + proxyContext.IdleTimeout = 0 +} + +var testDirectorySvrContext struct { + port int +} + +func setTestDirectorySvrContextDefaults() { + testDirectorySvrContext.port = 36257 +} + // GetServerCfgStores provides direct public access to the StoreSpecList inside // serverCfg. This is used by CCL code to populate some fields. // diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index c4d5c47b534d..d10e5bfb97c4 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -945,6 +945,28 @@ func init() { boolFlag(f, &serverCfg.ExternalIODirConfig.DisableImplicitCredentials, cliflags.ExternalIODisableImplicitCredentials) } + // Multi-tenancy proxy command flags. + { + f := mtStartSQLProxyCmd.Flags() + stringFlag(f, &proxyContext.Denylist, cliflags.DenyList) + stringFlag(f, &proxyContext.ListenAddr, cliflags.ProxyListenAddr) + stringFlag(f, &proxyContext.ListenCert, cliflags.ListenCert) + stringFlag(f, &proxyContext.ListenKey, cliflags.ListenKey) + stringFlag(f, &proxyContext.MetricsAddress, cliflags.ListenMetrics) + stringFlag(f, &proxyContext.RoutingRule, cliflags.RoutingRule) + stringFlag(f, &proxyContext.DirectoryAddr, cliflags.DirectoryAddr) + boolFlag(f, &proxyContext.SkipVerify, cliflags.SkipVerify) + boolFlag(f, &proxyContext.Insecure, cliflags.InsecureBackend) + durationFlag(f, &proxyContext.RatelimitBaseDelay, cliflags.RatelimitBaseDelay) + durationFlag(f, &proxyContext.ValidateAccessInterval, cliflags.ValidateAccessInterval) + durationFlag(f, &proxyContext.PollConfigInterval, cliflags.PollConfigInterval) + durationFlag(f, &proxyContext.IdleTimeout, cliflags.IdleTimeout) + } + // Multi-tenancy test directory command flags. + { + f := mtTestDirectorySvr.Flags() + intFlag(f, &testDirectorySvrContext.port, cliflags.TestDirectoryListenPort) + } } type tenantIDWrapper struct { diff --git a/pkg/cli/mt.go b/pkg/cli/mt.go index 6144e7a96b0b..2d0cd67c8532 100644 --- a/pkg/cli/mt.go +++ b/pkg/cli/mt.go @@ -12,14 +12,11 @@ package cli import "github.com/spf13/cobra" -// AddMTCommand adds a subcommand to `./cockroach mt`. -func AddMTCommand(cmd *cobra.Command) { - mtCmd.AddCommand(cmd) -} - func init() { cockroachCmd.AddCommand(mtCmd) mtCmd.AddCommand(mtStartSQLCmd) + mtCmd.AddCommand(mtStartSQLProxyCmd) + mtCmd.AddCommand(mtTestDirectorySvr) mtCertsCmd.AddCommand( mtCreateTenantClientCACertCmd, diff --git a/pkg/cli/mt_proxy.go b/pkg/cli/mt_proxy.go new file mode 100644 index 000000000000..4be4b500c72c --- /dev/null +++ b/pkg/cli/mt_proxy.go @@ -0,0 +1,172 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import ( + "context" + "fmt" + "net" + "os" + "os/signal" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/log/severity" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/spf13/cobra" + "golang.org/x/sys/unix" +) + +var mtStartSQLProxyCmd = &cobra.Command{ + Use: "start-proxy", + Short: "start a sql proxy", + Long: `Starts a SQL proxy. + +This proxy accepts incoming connections and relays them to a backend server +determined by the arguments used. +`, + RunE: MaybeDecorateGRPCError(runStartSQLProxy), + Args: cobra.NoArgs, +} + +func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { + // Initialize logging, stopper and context that can be canceled + ctx, stopper, err := initLogging(cmd) + if err != nil { + return err + } + defer stopper.Stop(ctx) + + log.Infof(ctx, "New proxy with opts: %+v", proxyContext) + + proxyLn, err := net.Listen("tcp", proxyContext.ListenAddr) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = proxyLn.Close() })) + + metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = metricsLn.Close() })) + + server, err := sqlproxyccl.NewServer(ctx, stopper, proxyContext) + if err != nil { + return err + } + + errChan := make(chan error, 1) + + if err := stopper.RunAsyncTask(ctx, "serve-http", func(ctx context.Context) { + log.Infof(ctx, "HTTP metrics server listening at %s", metricsLn.Addr()) + if err := server.ServeHTTP(ctx, metricsLn); err != nil { + errChan <- err + } + }); err != nil { + return err + } + + if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) { + log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr()) + if err := server.Serve(ctx, proxyLn); err != nil { + errChan <- err + } + }); err != nil { + return err + } + + return waitForSignals(ctx, stopper, errChan) +} + +func initLogging(cmd *cobra.Command) (ctx context.Context, stopper *stop.Stopper, err error) { + // Remove the default store, which avoids using it to set up logging. + // Instead, we'll default to logging to stderr unless --log-dir is + // specified. This makes sense since the standalone SQL server is + // at the time of writing stateless and may not be provisioned with + // suitable storage. + serverCfg.Stores.Specs = nil + serverCfg.ClusterName = "" + + ctx = context.Background() + stopper, err = setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) + if err != nil { + return + } + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + return ctx, stopper, err +} + +func waitForSignals( + ctx context.Context, stopper *stop.Stopper, errChan chan error, +) (returnErr error) { + // Need to alias the signals if this has to run on non-unix OSes too. + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, unix.SIGINT, unix.SIGTERM) + + // Dump the stacks when QUIT is received. + quitSignalCh := make(chan os.Signal, 1) + signal.Notify(quitSignalCh, unix.SIGQUIT) + go func() { + for { + <-quitSignalCh + log.DumpStacks(context.Background()) + } + }() + + select { + case err := <-errChan: + log.StartAlwaysFlush() + return err + case <-stopper.ShouldQuiesce(): + // Stop has been requested through the stopper's Stop + <-stopper.IsStopped() + // StartAlwaysFlush both flushes and ensures that subsequent log + // writes are flushed too. + log.StartAlwaysFlush() + case sig := <-signalCh: // INT or TERM + log.StartAlwaysFlush() // In case the caller follows up with KILL + log.Ops.Infof(ctx, "received signal '%s'", sig) + if sig == os.Interrupt { + returnErr = errors.New("interrupted") + } + go func() { + log.Infof(ctx, "server stopping") + stopper.Stop(ctx) + }() + case <-log.FatalChan(): + stopper.Stop(ctx) + select {} // Block and wait for logging go routine to shut down the process + } + + for { + select { + case sig := <-signalCh: + switch sig { + case os.Interrupt: // SIGTERM after SIGTERM + log.Ops.Infof(ctx, "received additional signal '%s'; continuing graceful shutdown", sig) + continue + } + + log.Ops.Shoutf(ctx, severity.ERROR, + "received signal '%s' during shutdown, initiating hard shutdown", log.Safe(sig)) + panic("terminate") + case <-stopper.IsStopped(): + const msgDone = "server shutdown completed" + log.Ops.Infof(ctx, msgDone) + fmt.Fprintln(os.Stdout, msgDone) + } + break + } + + return returnErr +} diff --git a/pkg/cli/mt_start_sql.go b/pkg/cli/mt_start_sql.go index d7dd8c83e39b..9bc1e57b6729 100644 --- a/pkg/cli/mt_start_sql.go +++ b/pkg/cli/mt_start_sql.go @@ -136,6 +136,10 @@ func runStartSQL(cmd *cobra.Command, args []string) error { ch := make(chan os.Signal, 1) signal.Notify(ch, drainSignals...) - sig := <-ch - return errors.Newf("received signal %v", sig) + select { + case sig := <-ch: + return errors.Newf("received signal %v", sig) + case <-stopper.ShouldQuiesce(): + return nil + } } diff --git a/pkg/cli/mt_test_directory.go b/pkg/cli/mt_test_directory.go new file mode 100644 index 000000000000..09e5f8541340 --- /dev/null +++ b/pkg/cli/mt_test_directory.go @@ -0,0 +1,53 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import ( + "context" + "fmt" + "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/spf13/cobra" +) + +var mtTestDirectorySvr = &cobra.Command{ + Use: "test-directory", + Short: "Run a test directory service.", + Long: ` +Run a test directory service. +`, + Args: cobra.NoArgs, + RunE: MaybeDecorateGRPCError(runDirectorySvr), +} + +func runDirectorySvr(cmd *cobra.Command, args []string) (returnErr error) { + ctx := context.Background() + serverCfg.Stores.Specs = nil + + stopper, err := setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) + if err != nil { + return err + } + defer stopper.Stop(ctx) + + tds := tenant.NewTestDirectoryServer(stopper) + + listenPort, err := net.Listen( + "tcp", fmt.Sprintf(":%d", testDirectorySvrContext.port), + ) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = listenPort.Close() })) + return tds.Serve(listenPort) +} diff --git a/pkg/cli/start.go b/pkg/cli/start.go index af46bc511aa5..a669d4064e80 100644 --- a/pkg/cli/start.go +++ b/pkg/cli/start.go @@ -118,7 +118,9 @@ var serverCmds = append(StartCmds, mtStartSQLCmd) // customLoggingSetupCmds lists the commands that call setupLogging() // after other types of configuration. -var customLoggingSetupCmds = append(serverCmds, debugCheckLogConfigCmd, demoCmd) +var customLoggingSetupCmds = append( + serverCmds, debugCheckLogConfigCmd, demoCmd, mtStartSQLProxyCmd, mtTestDirectorySvr, +) func initBlockProfile() { // Enable the block profile for a sample of mutex and channel operations. diff --git a/pkg/security/certmgr/BUILD.bazel b/pkg/security/certmgr/BUILD.bazel index a236e350159c..207dfba183b3 100644 --- a/pkg/security/certmgr/BUILD.bazel +++ b/pkg/security/certmgr/BUILD.bazel @@ -15,6 +15,7 @@ go_library( srcs = [ "cert_manager.go", "file_cert.go", + "self_signed_cert.go", ":mocks_certmgr", # keep ], embed = [":certlib"], # keep @@ -25,6 +26,7 @@ go_library( "//pkg/util/log/eventpb", "//pkg/util/syncutil", "//pkg/util/sysutil", + "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_golang_mock//gomock", # keep ], @@ -35,6 +37,7 @@ go_test( srcs = [ "cert_manager_test.go", "file_cert_test.go", + "self_signed_cert_test.go", ], embed = [":certmgr"], deps = [ diff --git a/pkg/security/certmgr/cert.go b/pkg/security/certmgr/cert.go index f39d04fb4ec7..8baa3e394e0a 100644 --- a/pkg/security/certmgr/cert.go +++ b/pkg/security/certmgr/cert.go @@ -10,7 +10,10 @@ package certmgr -import "context" +import ( + "context" + "crypto/tls" +) //go:generate mockgen -package=certmgr -destination=mocks_generated.go -source=cert.go . Cert @@ -43,4 +46,6 @@ type Cert interface { Err() error // ClearErr will clear the last reported err and allow reloads to resume. ClearErr() + // TLSCert will return the last loaded tls.Certificate + TLSCert() *tls.Certificate } diff --git a/pkg/security/certmgr/file_cert_test.go b/pkg/security/certmgr/file_cert_test.go index 2ee58858d4a3..24f4ac66f43c 100644 --- a/pkg/security/certmgr/file_cert_test.go +++ b/pkg/security/certmgr/file_cert_test.go @@ -58,7 +58,6 @@ func TestFileCert_Err(t *testing.T) { require.Nil(t, fc.Err()) fc.Reload(context.Background()) require.NotNil(t, fc.Err()) - require.NotNil(t, fc.Err()) fc.ClearErr() require.Nil(t, fc.Err()) } diff --git a/pkg/security/certmgr/mocks_generated.go b/pkg/security/certmgr/mocks_generated.go index 013d7ef63c24..097b2ba77bca 100644 --- a/pkg/security/certmgr/mocks_generated.go +++ b/pkg/security/certmgr/mocks_generated.go @@ -6,6 +6,7 @@ package certmgr import ( context "context" + tls "crypto/tls" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -71,3 +72,17 @@ func (mr *MockCertMockRecorder) Reload(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reload", reflect.TypeOf((*MockCert)(nil).Reload), ctx) } + +// TLSCert mocks base method. +func (m *MockCert) TLSCert() *tls.Certificate { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TLSCert") + ret0, _ := ret[0].(*tls.Certificate) + return ret0 +} + +// TLSCert indicates an expected call of TLSCert. +func (mr *MockCertMockRecorder) TLSCert() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TLSCert", reflect.TypeOf((*MockCert)(nil).TLSCert)) +} diff --git a/pkg/security/certmgr/self_signed_cert.go b/pkg/security/certmgr/self_signed_cert.go new file mode 100644 index 000000000000..0ba941f8f45e --- /dev/null +++ b/pkg/security/certmgr/self_signed_cert.go @@ -0,0 +1,103 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package certmgr + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" +) + +// Ensure that SelfSignedCert implements Cert. +var _ Cert = (*SelfSignedCert)(nil) + +// SelfSignedCert represents a single, self-signed certificate, generated at runtime. +type SelfSignedCert struct { + syncutil.Mutex + years, months, days int + secs time.Duration + err error + cert *tls.Certificate +} + +// NewSelfSignedCert will generate a new self-signed cert. +// A follow up Reload will regenerate the cert. +func NewSelfSignedCert(years, months, days int, secs time.Duration) *SelfSignedCert { + return &SelfSignedCert{years: years, months: months, days: days, secs: secs} +} + +// Reload will regenerate the self-signed cert. +func (ssc *SelfSignedCert) Reload(ctx context.Context) { + ssc.Lock() + defer ssc.Unlock() + + // There was a previous error that is not yet retrieved. + if ssc.err != nil { + return + } + + // Generate self signed cert for testing. + // Use "openssl s_client -showcerts -starttls postgres -connect {HOSTNAME}:{PORT}" to + // inspect the certificate or save the certificate portion to a file and + // use sslmode=verify-full + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + ssc.err = errors.Wrapf(err, "could not generate key") + return + } + + from := timeutil.Now() + until := from.AddDate(ssc.years, ssc.months, ssc.days).Add(ssc.secs) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: from, + NotAfter: until, + DNSNames: []string{"localhost"}, + IsCA: true, + } + cer, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) + if err != nil { + ssc.err = errors.Wrapf(err, "could not create certificate") + } + + ssc.cert = &tls.Certificate{ + Certificate: [][]byte{cer}, + PrivateKey: priv, + } +} + +// Err will return the last error that occurred during reload or nil if the +// last reload was successful. +func (ssc *SelfSignedCert) Err() error { + ssc.Lock() + defer ssc.Unlock() + return ssc.err +} + +// ClearErr will clear the last err so the follow up Reload can execute. +func (ssc *SelfSignedCert) ClearErr() { + ssc.Lock() + defer ssc.Unlock() + ssc.err = nil +} + +// TLSCert returns the tls certificate if the cert generation was successful. +func (ssc *SelfSignedCert) TLSCert() *tls.Certificate { + return ssc.cert +} diff --git a/pkg/security/certmgr/self_signed_cert_test.go b/pkg/security/certmgr/self_signed_cert_test.go new file mode 100644 index 000000000000..9a75b58a80be --- /dev/null +++ b/pkg/security/certmgr/self_signed_cert_test.go @@ -0,0 +1,44 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package certmgr + +import ( + "context" + "crypto/x509" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSelfSignedCert_Err(t *testing.T) { + ssc := NewSelfSignedCert(-9999, 0, 0, 0) + require.NotNil(t, ssc) + require.Nil(t, ssc.Err()) + ssc.Reload(context.Background()) + require.Regexp(t, "cannot represent time as GeneralizedTime", ssc.Err()) + ssc.ClearErr() + require.Nil(t, ssc.Err()) +} + +func TestSelfSignedCert_TLSCert(t *testing.T) { + ssc := NewSelfSignedCert(1, 6, 3, 5*time.Hour) + require.NotNil(t, ssc) + require.Nil(t, ssc.Err()) + ssc.Reload(context.Background()) + require.Nil(t, ssc.Err()) + require.NotNil(t, ssc.TLSCert()) + require.Len(t, ssc.TLSCert().Certificate, 1) + cert, err := x509.ParseCertificate(ssc.TLSCert().Certificate[0]) + require.NoError(t, err) + expectedUntil := cert.NotBefore.AddDate(1, 6, 3).Add(5 * time.Hour) + require.Equal(t, expectedUntil, cert.NotAfter) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 2ba695f17f0c..aa3a975c25bf 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1189,7 +1189,7 @@ func (s *Server) PreStart(ctx context.Context) error { filtered := s.cfg.FilterGossipBootstrapResolvers(ctx) // Set up the init server. We have to do this relatively early because we - // can't call RegisterInitServer() after `grpc.Serve`, which is called in + // can't call RegisterInitServer() after `grpc.serve`, which is called in // startRPCServer (and for the loopback grpc-gw connection). var initServer *initServer { @@ -1735,7 +1735,7 @@ func (s *Server) PreStart(ctx context.Context) error { } s.oidc = oidc - // Serve UI assets. + // serve UI assets. // // The authentication mux used here is created in "allow anonymous" mode so that the UI // assets are served up whether or not there is a session. If there is a session, the mux @@ -2121,7 +2121,7 @@ func (s *Server) startListenRPCAndSQL( // initialize) before we accept RPC requests. The caller // (Server.Start) will call this at the right moment. startRPCServer = func(ctx context.Context) { - // Serve the gRPC endpoint. + // serve the gRPC endpoint. _ = s.stopper.RunAsyncTask(workersCtx, "serve-grpc", func(context.Context) { netutil.FatalIfUnexpected(s.grpc.Serve(anyL)) }) @@ -2174,7 +2174,7 @@ func (s *Server) startServeUI( return err } - // Serve the plain HTTP (non-TLS) connection over clearL. + // serve the plain HTTP (non-TLS) connection over clearL. // This produces a HTTP redirect to the `https` URL for the path /, // handles the request normally (via s.ServeHTTP) for the path /health, // and produces 404 for anything else. @@ -2195,7 +2195,7 @@ func (s *Server) startServeUI( httpLn = tls.NewListener(tlsL, uiTLSConfig) } - // Serve the HTTP endpoint. This will be the original httpLn + // serve the HTTP endpoint. This will be the original httpLn // listening on --http-addr without TLS if uiTLSConfig was // nil, or overridden above if uiTLSConfig was not nil to come from // the TLS negotiation over the HTTP port. diff --git a/pkg/server/status.go b/pkg/server/status.go index 67553c2ec7c4..96c1b35ea9f2 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -894,7 +894,7 @@ func (s *statusServer) GetFiles( var dir string switch req.Type { - //TODO(ridwanmsharif): Serve logfiles so debug-zip can fetch them + //TODO(ridwanmsharif): serve logfiles so debug-zip can fetch them // intead of reading indididual entries. case serverpb.FileType_HEAP: // Requesting for saved Heap Profiles. dir = s.admin.server.cfg.HeapProfileDirName diff --git a/pkg/testutils/BUILD.bazel b/pkg/testutils/BUILD.bazel index 99e6c098250d..26e956744c69 100644 --- a/pkg/testutils/BUILD.bazel +++ b/pkg/testutils/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "dir.go", "error.go", "files.go", + "hook.go", "keys.go", "net.go", "pprof.go", diff --git a/pkg/testutils/hook.go b/pkg/testutils/hook.go new file mode 100644 index 000000000000..340bbb4de7a0 --- /dev/null +++ b/pkg/testutils/hook.go @@ -0,0 +1,24 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package testutils + +import "reflect" + +// HookGlobal sets `*ptr = val` and returns a closure for restoring `*ptr` to +// its original value. A runtime panic will occur if `val` is not assignable to +// `*ptr`. +func HookGlobal(ptr, val interface{}) func() { + global := reflect.ValueOf(ptr).Elem() + orig := reflect.New(global.Type()).Elem() + orig.Set(global) + global.Set(reflect.ValueOf(val)) + return func() { global.Set(orig) } +} diff --git a/pkg/testutils/lint/passes/fmtsafe/functions.go b/pkg/testutils/lint/passes/fmtsafe/functions.go index 286b1863b528..3f5c5e395e1d 100644 --- a/pkg/testutils/lint/passes/fmtsafe/functions.go +++ b/pkg/testutils/lint/passes/fmtsafe/functions.go @@ -200,7 +200,7 @@ var requireConstFmt = map[string]bool{ "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate.inputErrorf": true, - "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl.NewErrorf": true, + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl.newErrorf": true, } func init() { diff --git a/pkg/util/grpcutil/grpc_log.go b/pkg/util/grpcutil/grpc_log.go index ab5096d37368..10a6cca35cf9 100644 --- a/pkg/util/grpcutil/grpc_log.go +++ b/pkg/util/grpcutil/grpc_log.go @@ -231,7 +231,7 @@ const ( // When a TCP probe simply opens and immediately closes the // connection, gRPC is unhappy that the TLS handshake did not // complete. We don't care. - incomingConnSpamReSrc = `^grpc: Server\.Serve failed to complete security handshake from "[^"]*": EOF` + incomingConnSpamReSrc = `^grpc: Server\.serve failed to complete security handshake from "[^"]*": EOF` ) var outgoingConnSpamRe = regexp.MustCompile(outgoingConnSpamReSrc) diff --git a/pkg/util/log/fluent_client_test.go b/pkg/util/log/fluent_client_test.go index f4feb056f34b..aea1165f6683 100644 --- a/pkg/util/log/fluent_client_test.go +++ b/pkg/util/log/fluent_client_test.go @@ -109,7 +109,7 @@ func servePseudoFluent(t *testing.T) (serverAddr string, cleanup func(), fluentD } t.Logf("got client: %v", conn.RemoteAddr()) - // Serve the connection. + // serve the connection. wg.Add(1) go func() { defer wg.Done() diff --git a/pkg/util/netutil/net.go b/pkg/util/netutil/net.go index ffdfbb09323d..1981c7f1d95c 100644 --- a/pkg/util/netutil/net.go +++ b/pkg/util/netutil/net.go @@ -72,7 +72,7 @@ type Server struct { // // It can serve two different purposes simultaneously: // -// - to serve as actual HTTP server, using the .Serve(net.Listener) method. +// - to serve as actual HTTP server, using the .serve(net.Listener) method. // - to serve as plain TCP server, using the .ServeWith(...) method. // // The latter is used e.g. to accept SQL client connections. @@ -103,7 +103,7 @@ func MakeServer(stopper *stop.Stopper, tlsConfig *tls.Config, handler http.Handl ctx := context.TODO() - // net/http.(*Server).Serve/http2.ConfigureServer are not thread safe with + // net/http.(*Server).serve/http2.ConfigureServer are not thread safe with // respect to net/http.(*Server).TLSConfig, so we call it synchronously here. if err := http2.ConfigureServer(server.Server, nil); err != nil { log.Fatalf(ctx, "%v", err) @@ -129,7 +129,7 @@ func MakeServer(stopper *stop.Stopper, tlsConfig *tls.Config, handler http.Handl func (s *Server) ServeWith( ctx context.Context, stopper *stop.Stopper, l net.Listener, serveConn func(net.Conn), ) error { - // Inspired by net/http.(*Server).Serve + // Inspired by net/http.(*Server).serve var tempDelay time.Duration // how long to sleep on accept failure for { rw, e := l.Accept() @@ -152,7 +152,7 @@ func (s *Server) ServeWith( tempDelay = 0 go func() { defer stopper.Recover(ctx) - s.Server.ConnState(rw, http.StateNew) // before Serve can return + s.Server.ConnState(rw, http.StateNew) // before serve can return serveConn(rw) s.Server.ConnState(rw, http.StateClosed) }()