diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index aa150a844f99..8e3b58d7874d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -45,6 +45,9 @@ /pkg/cli/import_test.go @cockroachdb/cli-prs @cockroachdb/bulk-prs /pkg/cli/sql*.go @cockroachdb/cli-prs @cockroachdb/sql-experience /pkg/cli/start*.go @cockroachdb/cli-prs @cockroachdb/server-prs +/pkg/cli/mt_proxy.go @cockroachdb/sqlproxy-prs @cockroachdb/server-prs +/pkg/cli/mt_start_sql.go @cockroachdb/sqlproxy-prs @cockroachdb/server-prs +/pkg/cli/mt_test_directory.go @cockroachdb/sqlproxy-prs @cockroachdb/server-prs /pkg/cli/connect*.go @cockroachdb/cli-prs @cockroachdb/server-prs /pkg/cli/init.go @cockroachdb/cli-prs @cockroachdb/server-prs diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 3b0d5fc60377..4f12e44b81d7 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -23,10 +23,10 @@ ALL_TESTS = [ "//pkg/ccl/oidcccl:oidcccl_test", "//pkg/ccl/partitionccl:partitionccl_test", "//pkg/ccl/serverccl:serverccl_test", - "//pkg/ccl/sqlproxyccl/admitter:admitter_test", "//pkg/ccl/sqlproxyccl/cache:cache_test", "//pkg/ccl/sqlproxyccl/denylist:denylist_test", - "//pkg/ccl/sqlproxyccl/tenant:tenant_test", + "//pkg/ccl/sqlproxyccl/tenantdirsvr:tenantdirsvr_test", + "//pkg/ccl/sqlproxyccl/throttler:admitter_test", "//pkg/ccl/sqlproxyccl:sqlproxyccl_test", "//pkg/ccl/storageccl/engineccl:engineccl_test", "//pkg/ccl/storageccl:storageccl_test", diff --git a/pkg/ccl/cliccl/BUILD.bazel b/pkg/ccl/cliccl/BUILD.bazel index ac7b4850565f..840b02e1b104 100644 --- a/pkg/ccl/cliccl/BUILD.bazel +++ b/pkg/ccl/cliccl/BUILD.bazel @@ -7,7 +7,6 @@ go_library( "debug.go", "debug_backup.go", "demo.go", - "mtproxy.go", "start.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/cliccl", @@ -19,7 +18,6 @@ go_library( "//pkg/ccl/backupccl", "//pkg/ccl/baseccl", "//pkg/ccl/cliccl/cliflagsccl", - "//pkg/ccl/sqlproxyccl", "//pkg/ccl/storageccl", "//pkg/ccl/storageccl/engineccl/enginepbccl:enginepbccl_go_proto", "//pkg/ccl/workloadccl/cliccl", @@ -53,12 +51,9 @@ go_library( "//pkg/util/timeutil/pgdate", "//pkg/util/uuid", "@com_github_cockroachdb_apd_v2//:apd", - "@com_github_cockroachdb_cmux//:cmux", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//oserror", - "@com_github_jackc_pgproto3_v2//:pgproto3", "@com_github_spf13_cobra//:cobra", - "@org_golang_x_sync//errgroup", ], ) diff --git a/pkg/ccl/cliccl/mtproxy.go b/pkg/ccl/cliccl/mtproxy.go deleted file mode 100644 index e297f3188035..000000000000 --- a/pkg/ccl/cliccl/mtproxy.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2020 The Cockroach Authors. -// -// Licensed as a CockroachDB Enterprise file under the Cockroach Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt - -package cliccl - -import ( - "context" - "crypto/tls" - "io/ioutil" - "net" - "strings" - - "github.com/cockroachdb/cmux" - "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" - "github.com/cockroachdb/cockroach/pkg/cli" - "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" - "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" -) - -var sqlProxyListenAddr, sqlProxyTargetAddr string -var sqlProxyListenCert, sqlProxyListenKey string - -func init() { - startSQLProxyCmd := &cobra.Command{ - Use: "start-proxy ", - Short: "start-proxy host:port", - Long: `Starts a SQL proxy for testing purposes. - -This proxy provides very limited functionality. It accepts incoming connections -and relays them to the specified backend server after verifying that at least -one of the following holds: - -1. the supplied database name is prefixed with 'prancing-pony.'; this prefix - will then be removed for the connection to the backend server, and/or -2. the options parameter is 'prancing-pony'. - -Connections to the target address use TLS but do not identify the identity of -the peer, making them susceptible to MITM attacks. -`, - RunE: cli.MaybeDecorateGRPCError(runStartSQLProxy), - Args: cobra.NoArgs, - } - f := startSQLProxyCmd.Flags() - f.StringVarP(&sqlProxyListenCert, "listen-cert", "", "", "Certificate file to use for listener (auto-generate if empty)") - f.StringVarP(&sqlProxyListenKey, "listen-key", "", "", "Private key file to use for listener(auto-generate if empty)") - f.StringVarP(&sqlProxyListenAddr, "listen-addr", "", "127.0.0.1:46257", "Address for incoming connections") - f.StringVarP(&sqlProxyTargetAddr, "target-addr", "", "127.0.0.1:26257", "Address for outgoing connections") - cli.AddMTCommand(startSQLProxyCmd) -} - -func runStartSQLProxy(*cobra.Command, []string) error { - // openssl genrsa -out testserver.key 2048 - // openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt -days 3650 - // ^-- Enter * as Common Name below, rest can be empty. - certBytes := []byte(`-----BEGIN CERTIFICATE----- -MIICpDCCAYwCCQDWdkou+YTT/DANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls -b2NhbGhvc3QwHhcNMjAwNTIwMTQxMjIyWhcNMzAwNTE4MTQxMjIyWjAUMRIwEAYD -VQQDDAlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDG -H6V5TZjppRR61azexRtKLOnftVpO0+CPslHynbWwrJ6sxZIdglUWoCT3a/93tq0h -SaMWIxH+29wDXiICCTbr485h3Sov5Rq7kV/AwcLMOpdqjbN2PRBW95aq8rV3h/Ui -K3hu8OZjeh4DzhxsDWYwLG+1aUHnzpwDVIvXqiiKHVtT3WLDHRNUAuph9o/4Fao0 -m1KAzXvfbnNyMUWTAOUCIX2tlq79rEIAKOySCKDr07TuVrzCKcF5sbXkFXlmyFNl -KbmXRuD3UxghxLMmUZar7eZR84x6R/Rj5Dqyrs3nfl+/30Zk0pNe6naaKO39zqlR -rWQIqwSZrY1HwGGeVJFjAgMBAAEwDQYJKoZIhvcNAQELBQADggEBACluo7vP0kXd -uXD3joPKiMJ0FgZqeDtuSvBPfl0okqPN+bk/Huqu+FgxfChCs+2EcreGFxshjzuv -J58ogFq1YMB4pS4GlqarHE+UdliOobD+OyvX40w9lTJ2wI+v7kI79udFE+tyLIs6 -YkuzFd1nB0Zcf8QFzyPRTVXVpsWid3ZvARDakp4z7klPLnkfVrXo/ivlKqGF+Ymy -vJ/riLR01omTVi6W40cml/H4DAtG/XVsQeFXWpjUv97MWGRVYycmpCleVkK+uC2x -XAEi/UMoPhhJd6HEWG+56IkFFoN4lNtPuyal0vzOJCn70pgQx3yKh61RQcPrJMlD -m9qz1xbrzj8= ------END CERTIFICATE----- -`) - keyBytes := []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAxh+leU2Y6aUUetWs3sUbSizp37VaTtPgj7JR8p21sKyerMWS -HYJVFqAk92v/d7atIUmjFiMR/tvcA14iAgk26+POYd0qL+Uau5FfwMHCzDqXao2z -dj0QVveWqvK1d4f1Iit4bvDmY3oeA84cbA1mMCxvtWlB586cA1SL16ooih1bU91i -wx0TVALqYfaP+BWqNJtSgM17325zcjFFkwDlAiF9rZau/axCACjskgig69O07la8 -winBebG15BV5ZshTZSm5l0bg91MYIcSzJlGWq+3mUfOMekf0Y+Q6sq7N535fv99G -ZNKTXup2mijt/c6pUa1kCKsEma2NR8BhnlSRYwIDAQABAoIBADFpiSKUyNNU2aO9 -EO1KaYD5bKbfmxNX4oTUK33/+WWD19stN0Dm1YPcEvwmUkOwKsPHksYdnwpaGSg5 -3O93DtyMJ1ffCfuB/0XSfvgbGxNGdacciiquFhoqi8g82idioC+SeenpaPxcY4n9 -aLdGLDtNidrL0qUWsXBfMLVr+cpgENPMiri31CGLNpfO1b4icdQjiltEn70To2Al -68Ptar60m/lJzf8QMFSf499/W3b7fLjGFK+Gzump94xAVMd7HhACf42ZpWRPe1Ih -lyHP6D0091cIRhGxIZrhLToSuySpf1A+C/rQYTqzEPv/a3b4Ja6poulpBppwJyDa -roC4KtkCgYEA/h0HRzekvNAg2cOV/Ag1RyE4AmdyCDubBHcPNJi9kI/RsA0RgurO -pr2oET0HWTENgE4e4hYQnlqUvTXYisvtvhigiCkcynpGoMJa5Y4St7J1QKdQtuwY -vcRqOGGSKl73biK79+BIV/6swWCkB+VzoGzKP8dY/XZHsI0FDdnia8UCgYEAx5gz -9qfzfiqOQP/GN6vGIzoldGxCDCHyyEZmvy0FiBlMJK36Qkjtz48eqYEXOUCX8Z5X -gB583iMv72Oz/wmefoIjnUd9uXyMqvhnYxG4vQhU4a83K4q4TPkNd7+sLiNqxIq2 -o2jT6BktOHE5OiICeFGMFOfHtsyV78JMsuzUEwcCgYAXU5LXdsQokPJzCwE5oYdC -gEoj7lsJZm9UeZlrupmsK4eUIZ755ZQSulYzPubtyRL0NDehiWT9JFODCu5Vz2KD -kL8rwJpj+9V/7Fdrux78veUFilZedE3RHbaidlJ0kUMlWQroNi5t5XL2TWjBUM7M -azAlqqcAnVr3WfqcyuN+AQKBgQCsz+xV6I7bMy9dudc+hlyUTZj2V3FMHeyeWM5H -QkzizLxvma7vy0MUDd/HdTzNVk74ZVdvV3ZXwvGS/Klw7TwsXrNFTwvdGKiWs2KY -lVR1XwxXJyTGb2IpSw3NG8iRXhroNw3xKCcpcvsDPo0E90NaN4jo5NG3RSWgpINR -+9mW6wKBgCze3gZB6AU0W/Fluql88ANx/+nqZhrsJyfrv87AkgDtt0o1tOj+emKR -Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2 -/ur8gv24YZJvV7OvPhw1SAuYL7MKMsfTW4TEKWTfkZWvm4YfZNmR ------END RSA PRIVATE KEY----- -`) - - if (sqlProxyListenKey == "") != (sqlProxyListenCert == "") { - return errors.New("must specify either both or neither of cert and key") - } - - if sqlProxyListenCert != "" { - var err error - certBytes, err = ioutil.ReadFile(sqlProxyListenCert) - if err != nil { - return err - } - } - if sqlProxyListenKey != "" { - var err error - keyBytes, err = ioutil.ReadFile(sqlProxyListenKey) - if err != nil { - return err - } - } - - cer, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - return err - } - - ln, err := net.Listen("tcp", sqlProxyListenAddr) - if err != nil { - return err - } - defer func() { _ = ln.Close() }() - - // Multiplex the listen address to give easy access to metrics from this - // command. - mux := cmux.New(ln) - httpLn := mux.Match(cmux.HTTP1Fast()) - proxyLn := mux.Match(cmux.Any()) - - outgoingConf := &tls.Config{ - InsecureSkipVerify: true, - } - server := sqlproxyccl.NewServer(sqlproxyccl.Options{ - FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return sqlproxyccl.FrontendAdmit( - incoming, - &tls.Config{ - Certificates: []tls.Certificate{cer}, - }, - ) - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - params := msg.Parameters - const magic = "prancing-pony" - if strings.HasPrefix(params["database"], magic+".") { - params["database"] = params["database"][len(magic)+1:] - } else if params["options"] != "--cluster="+magic { - return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) - } - conn, err := sqlproxyccl.BackendDial(msg, sqlProxyTargetAddr, outgoingConf) - if err != nil { - return nil, err - } - return conn, nil - }, - }) - - group, ctx := errgroup.WithContext(context.Background()) - - group.Go(func() error { - return server.ServeHTTP(ctx, httpLn) - }) - - group.Go(func() error { - return server.Serve(proxyLn) - }) - - group.Go(func() error { - return mux.Serve() - }) - - return group.Wait() -} diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 4c229fbc6687..03af6c4ea7f9 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -11,20 +11,32 @@ go_library( "idle_disconnect_connection.go", "metrics.go", "proxy.go", + "proxy_handler.go", "server.go", ":gen-errorcode-stringer", # keep ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl", visibility = ["//visibility:public"], deps = [ + "//pkg/ccl/sqlproxyccl/cache", + "//pkg/ccl/sqlproxyccl/denylist", + "//pkg/ccl/sqlproxyccl/tenant", + "//pkg/ccl/sqlproxyccl/throttler", + "//pkg/roachpb", + "//pkg/security/certmgr", + "//pkg/util", "//pkg/util/contextutil", "//pkg/util/httputil", "//pkg/util/log", "//pkg/util/metric", + "//pkg/util/retry", + "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", "@com_github_jackc_pgproto3_v2//:pgproto3", + "@org_golang_google_grpc//:go_default_library", ], ) @@ -36,7 +48,7 @@ go_test( "frontend_admitter_test.go", "idle_disconnect_connection_test.go", "main_test.go", - "proxy_test.go", + "proxy_handler_test.go", "server_test.go", ], data = [ @@ -48,16 +60,24 @@ go_test( tags = ["broken_in_bazel"], deps = [ "//pkg/base", + "//pkg/ccl/kvccl/kvtenantccl", + "//pkg/ccl/sqlproxyccl/tenant", "//pkg/ccl/utilccl", + "//pkg/roachpb", "//pkg/security", "//pkg/security/securitytest", "//pkg/server", + "//pkg/sql", + "//pkg/sql/pgwire", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", "//pkg/util/leaktest", + "//pkg/util/log", "//pkg/util/randutil", + "//pkg/util/stop", + "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index 4f09a93e5c69..bee246503b47 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -14,7 +14,10 @@ import ( "github.com/jackc/pgproto3/v2" ) -func authenticate(clientConn, crdbConn net.Conn) error { +// authenticate handles the startup of the pgwire protocol to the point where +// the connections is considered authenticated. If that doesn't happen, it +// returns an error. +var authenticate = func(clientConn, crdbConn net.Conn) error { fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn) @@ -26,13 +29,13 @@ func authenticate(clientConn, crdbConn net.Conn) error { // TODO(spaskob): in verbose mode, log these messages. backendMsg, err := be.Receive() if err != nil { - return NewErrorf(CodeBackendReadFailed, "unable to receive message from backend: %v", err) + return newErrorf(codeBackendReadFailed, "unable to receive message from backend: %v", err) } err = fe.Send(backendMsg) if err != nil { - return NewErrorf( - CodeClientWriteFailed, "unable to send message %v to client: %v", backendMsg, err, + return newErrorf( + codeClientWriteFailed, "unable to send message %v to client: %v", backendMsg, err, ) } @@ -40,12 +43,12 @@ func authenticate(clientConn, crdbConn net.Conn) error { switch tp := backendMsg.(type) { case *pgproto3.ReadyForQuery: // Server has authenticated the connection successfully and is ready to - // serve queries. + // Serve queries. return nil case *pgproto3.AuthenticationOk: // Server has authenticated the connection; keep reading messages until // `pgproto3.ReadyForQuery` is encountered which signifies that server - // is ready to serve queries. + // is ready to Serve queries. case *pgproto3.ParameterStatus: // Server sent status message; keep reading messages until // `pgproto3.ReadyForQuery` is encountered. @@ -55,7 +58,7 @@ func authenticate(clientConn, crdbConn net.Conn) error { case *pgproto3.ErrorResponse: // Server has rejected the authentication response from the client and // has closed the connection. - return NewErrorf(CodeAuthFailed, "authentication failed: %v", backendMsg) + return newErrorf(codeAuthFailed, "authentication failed: %v", backendMsg) case *pgproto3.AuthenticationCleartextPassword, *pgproto3.AuthenticationMD5Password, @@ -64,17 +67,17 @@ func authenticate(clientConn, crdbConn net.Conn) error { // Read the client response and forward it to server. fntMsg, err := fe.Receive() if err != nil { - return NewErrorf(CodeClientReadFailed, "unable to receive message from client: %v", err) + return newErrorf(codeClientReadFailed, "unable to receive message from client: %v", err) } err = be.Send(fntMsg) if err != nil { - return NewErrorf( - CodeBackendWriteFailed, "unable to send message %v to backend: %v", fntMsg, err, + return newErrorf( + codeBackendWriteFailed, "unable to send message %v to backend: %v", fntMsg, err, ) } default: - return NewErrorf(CodeBackendDisconnected, "received unexpected backend message type: %v", tp) + return newErrorf(codeBackendDisconnected, "received unexpected backend message type: %v", tp) } } - return NewErrorf(CodeBackendDisconnected, "authentication took more than %d iterations", i) + return newErrorf(codeBackendDisconnected, "authentication took more than %d iterations", i) } diff --git a/pkg/ccl/sqlproxyccl/authentication_test.go b/pkg/ccl/sqlproxyccl/authentication_test.go index 1719f5ec9297..12a1250b66a1 100644 --- a/pkg/ccl/sqlproxyccl/authentication_test.go +++ b/pkg/ccl/sqlproxyccl/authentication_test.go @@ -95,9 +95,9 @@ func TestAuthenticateError(t *testing.T) { err := authenticate(srv, cli) require.Error(t, err) - codeErr := (*CodeError)(nil) + codeErr := (*codeError)(nil) require.True(t, errors.As(err, &codeErr)) - require.Equal(t, CodeAuthFailed, codeErr.code) + require.Equal(t, codeAuthFailed, codeErr.code) } func TestAuthenticateUnexpectedMessage(t *testing.T) { @@ -117,7 +117,7 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) { err := authenticate(srv, cli) require.Error(t, err) - codeErr := (*CodeError)(nil) + codeErr := (*codeError)(nil) require.True(t, errors.As(err, &codeErr)) - require.Equal(t, CodeBackendDisconnected, codeErr.code) + require.Equal(t, codeBackendDisconnected, codeErr.code) } diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index 0e3aff5e4cbb..70b37b760a47 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -17,33 +17,34 @@ import ( "github.com/jackc/pgproto3/v2" ) -// BackendDial is an example backend dialer that does a TCP/IP connection -// to a backend, SSL and forwards the start message. -func BackendDial( +// backendDial is an example backend dialer that does a TCP/IP connection +// to a backend, SSL and forwards the start message. It is defined as a variable +// so it can be redirected for testing. +var backendDial = func( msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { conn, err := net.Dial("tcp", outgoingAddress) if err != nil { - return nil, NewErrorf( - CodeBackendDown, "unable to reach backend SQL server: %v", err, + return nil, newErrorf( + codeBackendDown, "unable to reach backend SQL server: %v", err, ) } - conn, err = SSLOverlay(conn, tlsConfig) + conn, err = sslOverlay(conn, tlsConfig) if err != nil { return nil, err } - err = RelayStartupMsg(conn, msg) + err = relayStartupMsg(conn, msg) if err != nil { - return nil, NewErrorf( - CodeBackendDown, "relaying StartupMessage to target server %v: %v", + return nil, newErrorf( + codeBackendDown, "relaying StartupMessage to target server %v: %v", outgoingAddress, err) } return conn, nil } -// SSLOverlay attempts to upgrade the PG connection to use SSL +// sslOverlay attempts to upgrade the PG connection to use SSL // if a tls.Config is specified.. -func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { +func sslOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { if tlsConfig == nil { return conn, nil } @@ -51,20 +52,20 @@ func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { var err error // Send SSLRequest. if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { - return nil, NewErrorf( - CodeBackendDown, "sending SSLRequest to target server: %v", err, + return nil, newErrorf( + codeBackendDown, "sending SSLRequest to target server: %v", err, ) } response := make([]byte, 1) if _, err = io.ReadFull(conn, response); err != nil { return nil, - NewErrorf(CodeBackendDown, "reading response to SSLRequest") + newErrorf(codeBackendDown, "reading response to SSLRequest") } if response[0] != pgAcceptSSLRequest { - return nil, NewErrorf( - CodeBackendRefusedTLS, "target server refused TLS connection", + return nil, newErrorf( + codeBackendRefusedTLS, "target server refused TLS connection", ) } @@ -72,8 +73,8 @@ func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { return tls.Client(conn, outCfg), nil } -// RelayStartupMsg forwards the start message on the backend connection. -func RelayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { +// relayStartupMsg forwards the start message on the backend connection. +func relayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { _, err = conn.Write(msg.Encode(nil)) return } diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index e6dcb3dd6e23..7df35e255324 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -14,88 +14,88 @@ import ( "github.com/cockroachdb/errors" ) -// ErrorCode classifies errors emitted by Proxy(). -//go:generate stringer -type=ErrorCode -type ErrorCode int +// errorCode classifies errors emitted by Proxy(). +//go:generate stringer -type=errorCode +type errorCode int const ( - _ ErrorCode = iota + _ errorCode = iota - // CodeAuthFailed indicates that client authentication attempt has failed and + // codeAuthFailed indicates that client authentication attempt has failed and // backend has closed the connection. - CodeAuthFailed + codeAuthFailed - // CodeBackendReadFailed indicates an error reading from backend connection. - CodeBackendReadFailed - // CodeBackendWriteFailed indicates an error writing to backend connection. - CodeBackendWriteFailed + // codeBackendReadFailed indicates an error reading from backend connection. + codeBackendReadFailed + // codeBackendWriteFailed indicates an error writing to backend connection. + codeBackendWriteFailed - // CodeClientReadFailed indicates an error reading from the client connection - CodeClientReadFailed - // CodeClientWriteFailed indicates an error writing to the client connection. - CodeClientWriteFailed + // codeClientReadFailed indicates an error reading from the client connection + codeClientReadFailed + // codeClientWriteFailed indicates an error writing to the client connection. + codeClientWriteFailed - // CodeUnexpectedInsecureStartupMessage indicates that the client sent a + // codeUnexpectedInsecureStartupMessage indicates that the client sent a // StartupMessage which was unexpected. Typically this means that an // SSLRequest was expected but the client attempted to go ahead without TLS, // or vice versa. - CodeUnexpectedInsecureStartupMessage + codeUnexpectedInsecureStartupMessage - // CodeSNIRoutingFailed indicates an error choosing a backend address based on + // codeSNIRoutingFailed indicates an error choosing a backend address based on // the client's SNI header. - CodeSNIRoutingFailed + codeSNIRoutingFailed - // CodeUnexpectedStartupMessage indicates an unexpected startup message + // codeUnexpectedStartupMessage indicates an unexpected startup message // received from the client after TLS negotiation. - CodeUnexpectedStartupMessage + codeUnexpectedStartupMessage - // CodeParamsRoutingFailed indicates an error choosing a backend address based + // codeParamsRoutingFailed indicates an error choosing a backend address based // on the client's session parameters. - CodeParamsRoutingFailed + codeParamsRoutingFailed - // CodeBackendDown indicates an error establishing or maintaining a connection + // codeBackendDown indicates an error establishing or maintaining a connection // to the backend SQL server. - CodeBackendDown + codeBackendDown - // CodeBackendRefusedTLS indicates that the backend SQL server refused a TLS- + // codeBackendRefusedTLS indicates that the backend SQL server refused a TLS- // enabled SQL connection. - CodeBackendRefusedTLS + codeBackendRefusedTLS - // CodeBackendDisconnected indicates that the backend disconnected (with a + // codeBackendDisconnected indicates that the backend disconnected (with a // connection error) while serving client traffic. - CodeBackendDisconnected + codeBackendDisconnected - // CodeClientDisconnected indicates that the client disconnected unexpectedly + // codeClientDisconnected indicates that the client disconnected unexpectedly // (with a connection error) while in a session with backend SQL server. - CodeClientDisconnected + codeClientDisconnected - // CodeProxyRefusedConnection indicates that the proxy refused the connection + // codeProxyRefusedConnection indicates that the proxy refused the connection // request due to high load or too many connection attempts. - CodeProxyRefusedConnection + codeProxyRefusedConnection - // CodeExpiredClientConnection indicates that proxy connection to the client + // codeExpiredClientConnection indicates that proxy connection to the client // has expired and should be closed. - CodeExpiredClientConnection + codeExpiredClientConnection - // CodeIdleDisconnect indicates that the connection was disconnected for + // codeIdleDisconnect indicates that the connection was disconnected for // being idle for longer than the specified timeout. - CodeIdleDisconnect + codeIdleDisconnect ) -// CodeError is combines an error with one of the above codes to ease +// codeError is combines an error with one of the above codes to ease // the processing of the errors. -type CodeError struct { - code ErrorCode +type codeError struct { + code errorCode err error } -func (e *CodeError) Error() string { +func (e *codeError) Error() string { return fmt.Sprintf("%s: %s", e.code, e.err) } -// NewErrorf returns a new CodeError out of the supplied args. -func NewErrorf(code ErrorCode, format string, args ...interface{}) error { - return &CodeError{ +// newErrorf returns a new codeError out of the supplied args. +func newErrorf(code errorCode, format string, args ...interface{}) error { + return &codeError{ code: code, err: errors.Errorf(format, args...), } diff --git a/pkg/ccl/sqlproxyccl/errorcode_string.go b/pkg/ccl/sqlproxyccl/errorcode_string.go index de1432ced7e4..04ecb4a68bf4 100644 --- a/pkg/ccl/sqlproxyccl/errorcode_string.go +++ b/pkg/ccl/sqlproxyccl/errorcode_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT. +// Code generated by "stringer -type=errorCode"; DO NOT EDIT. package sqlproxyccl @@ -8,32 +8,32 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[CodeAuthFailed-1] - _ = x[CodeBackendReadFailed-2] - _ = x[CodeBackendWriteFailed-3] - _ = x[CodeClientReadFailed-4] - _ = x[CodeClientWriteFailed-5] - _ = x[CodeUnexpectedInsecureStartupMessage-6] - _ = x[CodeSNIRoutingFailed-7] - _ = x[CodeUnexpectedStartupMessage-8] - _ = x[CodeParamsRoutingFailed-9] - _ = x[CodeBackendDown-10] - _ = x[CodeBackendRefusedTLS-11] - _ = x[CodeBackendDisconnected-12] - _ = x[CodeClientDisconnected-13] - _ = x[CodeProxyRefusedConnection-14] - _ = x[CodeExpiredClientConnection-15] - _ = x[CodeIdleDisconnect-16] + _ = x[codeAuthFailed-1] + _ = x[codeBackendReadFailed-2] + _ = x[codeBackendWriteFailed-3] + _ = x[codeClientReadFailed-4] + _ = x[codeClientWriteFailed-5] + _ = x[codeUnexpectedInsecureStartupMessage-6] + _ = x[codeSNIRoutingFailed-7] + _ = x[codeUnexpectedStartupMessage-8] + _ = x[codeParamsRoutingFailed-9] + _ = x[codeBackendDown-10] + _ = x[codeBackendRefusedTLS-11] + _ = x[codeBackendDisconnected-12] + _ = x[codeClientDisconnected-13] + _ = x[codeProxyRefusedConnection-14] + _ = x[codeExpiredClientConnection-15] + _ = x[codeIdleDisconnect-16] } -const _ErrorCode_name = "CodeAuthFailedCodeBackendReadFailedCodeBackendWriteFailedCodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnectionCodeExpiredClientConnectionCodeIdleDisconnect" +const _errorCode_name = "codeAuthFailedcodeBackendReadFailedcodeBackendWriteFailedcodeClientReadFailedcodeClientWriteFailedcodeUnexpectedInsecureStartupMessagecodeSNIRoutingFailedcodeUnexpectedStartupMessagecodeParamsRoutingFailedcodeBackendDowncodeBackendRefusedTLScodeBackendDisconnectedcodeClientDisconnectedcodeProxyRefusedConnectioncodeExpiredClientConnectioncodeIdleDisconnect" -var _ErrorCode_index = [...]uint16{0, 14, 35, 57, 77, 98, 134, 154, 182, 205, 220, 241, 264, 286, 312, 339, 357} +var _errorCode_index = [...]uint16{0, 14, 35, 57, 77, 98, 134, 154, 182, 205, 220, 241, 264, 286, 312, 339, 357} -func (i ErrorCode) String() string { +func (i errorCode) String() string { i -= 1 - if i < 0 || i >= ErrorCode(len(_ErrorCode_index)-1) { - return "ErrorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" + if i < 0 || i >= errorCode(len(_errorCode_index)-1) { + return "errorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" } - return _ErrorCode_name[_ErrorCode_index[i]:_ErrorCode_index[i+1]] + return _errorCode_name[_errorCode_index[i]:_errorCode_index[i+1]] } diff --git a/pkg/ccl/sqlproxyccl/frontend_admitter.go b/pkg/ccl/sqlproxyccl/frontend_admitter.go index 2f523f8231fa..7811b1d5f323 100644 --- a/pkg/ccl/sqlproxyccl/frontend_admitter.go +++ b/pkg/ccl/sqlproxyccl/frontend_admitter.go @@ -15,10 +15,13 @@ import ( "github.com/jackc/pgproto3/v2" ) -// FrontendAdmit is the default implementation of a frontend admitter. It can +// frontendAdmit is the default implementation of a frontend admitter. It can // upgrade to an optional SSL connection, and will handle and verify // the startup message received from the PG SQL client. -func FrontendAdmit( +// The connection returned should never be nil in case of error. Depending +// on whether the error happened before the connection was upgraded to TLS or not +// it will either be the original or the TLS connection. +var frontendAdmit = func( conn net.Conn, incomingTLSConfig *tls.Config, ) (net.Conn, *pgproto3.StartupMessage, error) { // `conn` could be replaced by `conn` embedded in a `tls.Conn` connection, @@ -30,7 +33,7 @@ func FrontendAdmit( if incomingTLSConfig != nil { m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage() if err != nil { - return nil, nil, NewErrorf(CodeClientReadFailed, "while receiving startup message") + return conn, nil, newErrorf(codeClientReadFailed, "while receiving startup message") } switch m.(type) { case *pgproto3.SSLRequest: @@ -38,15 +41,15 @@ func FrontendAdmit( // Ignore CancelRequest explicitly. We don't need to do this but it makes // testing easier by avoiding a call to sendErrToClient on this path // (which would confuse assertCtx). - return nil, nil, nil + return conn, nil, nil default: - code := CodeUnexpectedInsecureStartupMessage - return nil, nil, NewErrorf(code, "unsupported startup message: %T", m) + code := codeUnexpectedInsecureStartupMessage + return conn, nil, newErrorf(code, "unsupported startup message: %T", m) } _, err = conn.Write([]byte{pgAcceptSSLRequest}) if err != nil { - return nil, nil, NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err) + return conn, nil, newErrorf(codeClientWriteFailed, "acking SSLRequest: %v", err) } cfg := incomingTLSConfig.Clone() @@ -60,11 +63,11 @@ func FrontendAdmit( m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage() if err != nil { - return nil, nil, NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err) + return conn, nil, newErrorf(codeClientReadFailed, "receiving post-TLS startup message: %v", err) } msg, ok := m.(*pgproto3.StartupMessage) if !ok { - return nil, nil, NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) + return conn, nil, newErrorf(codeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) } // Add the sniServerName (if used) as parameter diff --git a/pkg/ccl/sqlproxyccl/frontend_admitter_test.go b/pkg/ccl/sqlproxyccl/frontend_admitter_test.go index e4f3b50e6c8b..5d6fc04f2418 100644 --- a/pkg/ccl/sqlproxyccl/frontend_admitter_test.go +++ b/pkg/ccl/sqlproxyccl/frontend_admitter_test.go @@ -58,7 +58,7 @@ func TestFrontendAdmitWithClientSSLDisableAndCustomParam(t *testing.T) { fmt.Printf("Done\n") }() - frontendCon, msg, err := FrontendAdmit(srv, nil) + frontendCon, msg, err := frontendAdmit(srv, nil) require.NoError(t, err) require.Equal(t, srv, frontendCon) require.NotNil(t, msg) @@ -88,7 +88,7 @@ func TestFrontendAdmitWithClientSSLRequire(t *testing.T) { tlsConfig, err := tlsConfig() require.NoError(t, err) - frontendCon, msg, err := FrontendAdmit(srv, tlsConfig) + frontendCon, msg, err := frontendAdmit(srv, tlsConfig) require.NoError(t, err) require.NotEqual(t, srv, frontendCon) // The connection was replaced by SSL require.NotNil(t, msg) @@ -107,12 +107,12 @@ func TestFrontendAdmitWithCancel(t *testing.T) { require.NoError(t, err) }() - frontendCon, msg, err := FrontendAdmit(srv, nil) + frontendCon, msg, err := frontendAdmit(srv, nil) require.EqualError(t, err, - "CodeUnexpectedStartupMessage: "+ + "codeUnexpectedStartupMessage: "+ "unsupported post-TLS startup message: *pgproto3.CancelRequest", ) - require.Nil(t, frontendCon) + require.NotNil(t, frontendCon) require.Nil(t, msg) } @@ -139,11 +139,11 @@ func TestFrontendAdmitWithSSLAndCancel(t *testing.T) { tlsConfig, err := tlsConfig() require.NoError(t, err) - frontendCon, msg, err := FrontendAdmit(srv, tlsConfig) + frontendCon, msg, err := frontendAdmit(srv, tlsConfig) require.EqualError(t, err, - "CodeUnexpectedStartupMessage: "+ + "codeUnexpectedStartupMessage: "+ "unsupported post-TLS startup message: *pgproto3.CancelRequest", ) - require.Nil(t, frontendCon) + require.NotNil(t, frontendCon) require.Nil(t, msg) } diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection.go similarity index 74% rename from pkg/ccl/sqlproxyccl/idle_disconnect_connection.go rename to pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection.go index bc5fd1be9e41..711bad21c585 100644 --- a/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go +++ b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package sqlproxyccl +package idle import ( "net" @@ -17,10 +17,10 @@ import ( "github.com/cockroachdb/errors" ) -// IdleDisconnectConnection is a wrapper around net.Conn that disconnects if +// DisconnectConnection is a wrapper around net.Conn that disconnects if // connection is idle. The idle time is only counted while the client is // waiting, blocked on Read. -type IdleDisconnectConnection struct { +type DisconnectConnection struct { net.Conn timeout time.Duration mu struct { @@ -30,10 +30,10 @@ type IdleDisconnectConnection struct { } var errNotSupported = errors.Errorf( - "Not supported for IdleDisconnectConnection", + "Not supported for DisconnectConnection", ) -func (c *IdleDisconnectConnection) updateDeadline() error { +func (c *DisconnectConnection) updateDeadline() error { now := timeutil.Now() // If it has been more than 1% of the timeout duration - advance the deadline. c.mu.Lock() @@ -49,7 +49,7 @@ func (c *IdleDisconnectConnection) updateDeadline() error { } // Read reads data from the connection with timeout. -func (c *IdleDisconnectConnection) Read(b []byte) (n int, err error) { +func (c *DisconnectConnection) Read(b []byte) (n int, err error) { if err := c.updateDeadline(); err != nil { return 0, err } @@ -57,7 +57,7 @@ func (c *IdleDisconnectConnection) Read(b []byte) (n int, err error) { } // Write writes data to the connection and sets the read timeout. -func (c *IdleDisconnectConnection) Write(b []byte) (n int, err error) { +func (c *DisconnectConnection) Write(b []byte) (n int, err error) { // The Write for the connection is not blocking (or can block only temporary // in case of flow control). For idle connections, the Read will be the call // that will block and stay blocked until the backend doesn't send something. @@ -72,26 +72,26 @@ func (c *IdleDisconnectConnection) Write(b []byte) (n int, err error) { } // SetDeadline is unsupported as it will interfere with the reads. -func (c *IdleDisconnectConnection) SetDeadline(t time.Time) error { +func (c *DisconnectConnection) SetDeadline(t time.Time) error { return errNotSupported } // SetReadDeadline is unsupported as it will interfere with the reads. -func (c *IdleDisconnectConnection) SetReadDeadline(t time.Time) error { +func (c *DisconnectConnection) SetReadDeadline(t time.Time) error { return errNotSupported } // SetWriteDeadline is unsupported as it will interfere with the reads. -func (c *IdleDisconnectConnection) SetWriteDeadline(t time.Time) error { +func (c *DisconnectConnection) SetWriteDeadline(t time.Time) error { return errNotSupported } -// IdleDisconnectOverlay upgrades the connection to one that closes when +// DisconnectOverlay upgrades the connection to one that closes when // idle for more than timeout duration. Timeout of zero will turn off // the idle disconnect code. -func IdleDisconnectOverlay(conn net.Conn, timeout time.Duration) net.Conn { +func DisconnectOverlay(conn net.Conn, timeout time.Duration) net.Conn { if timeout != 0 { - return &IdleDisconnectConnection{Conn: conn, timeout: timeout} + return &DisconnectConnection{Conn: conn, timeout: timeout} } return conn } diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection_test.go similarity index 97% rename from pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go rename to pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection_test.go index 002901450914..edb074dbd7cb 100644 --- a/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go +++ b/pkg/ccl/sqlproxyccl/idle/idle_disconnect_connection_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package sqlproxyccl +package idle import ( "fmt" @@ -31,7 +31,7 @@ func setupServerWithIdleDisconnect(t testing.TB, timeout time.Duration) net.Addr t.Errorf("Error during accept: %v", err) } defer cServ.Close() - cServ = IdleDisconnectOverlay(cServ, timeout) + cServ = DisconnectOverlay(cServ, timeout) _, _ = io.Copy(cServ, cServ) }() return server.Addr() diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index c643a7b4d81a..bea2ff15d3de 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -8,11 +8,14 @@ package sqlproxyccl -import "github.com/cockroachdb/cockroach/pkg/util/metric" +import ( + "github.com/cockroachdb/cockroach/pkg/util/metric" + "github.com/cockroachdb/errors" +) -// Metrics contains pointers to the metrics for monitoring proxy +// metrics contains pointers to the metrics for monitoring proxy // operations. -type Metrics struct { +type metrics struct { BackendDisconnectCount *metric.Counter IdleDisconnectCount *metric.Counter BackendDownCount *metric.Counter @@ -26,9 +29,9 @@ type Metrics struct { } // MetricStruct implements the metrics.Struct interface. -func (Metrics) MetricStruct() {} +func (metrics) MetricStruct() {} -var _ metric.Struct = Metrics{} +var _ metric.Struct = metrics{} var ( metaCurConnCount = metric.Metadata{ @@ -93,9 +96,9 @@ var ( } ) -// MakeProxyMetrics instantiates the metrics holder for proxy monitoring. -func MakeProxyMetrics() Metrics { - return Metrics{ +// makeProxyMetrics instantiates the metrics holder for proxy monitoring. +func makeProxyMetrics() metrics { + return metrics{ BackendDisconnectCount: metric.NewCounter(metaBackendDisconnectCount), IdleDisconnectCount: metric.NewCounter(metaIdleDisconnectCount), BackendDownCount: metric.NewCounter(metaBackendDownCount), @@ -108,3 +111,34 @@ func MakeProxyMetrics() Metrics { ExpiredClientConnCount: metric.NewCounter(metaExpiredClientConnCount), } } + +// updateForError updates the metrics relevant for the type of the +// error message. +func (metrics *metrics) updateForError(err error) { + if err == nil { + return + } + codeErr := (*codeError)(nil) + if errors.As(err, &codeErr) { + switch codeErr.code { + case codeExpiredClientConnection: + metrics.ExpiredClientConnCount.Inc(1) + case codeBackendDisconnected: + metrics.BackendDisconnectCount.Inc(1) + case codeClientDisconnected: + metrics.ClientDisconnectCount.Inc(1) + case codeIdleDisconnect: + metrics.IdleDisconnectCount.Inc(1) + case codeProxyRefusedConnection: + metrics.RefusedConnCount.Inc(1) + metrics.BackendDownCount.Inc(1) + case codeParamsRoutingFailed: + metrics.RoutingErrCount.Inc(1) + metrics.BackendDownCount.Inc(1) + case codeBackendDown: + metrics.BackendDownCount.Inc(1) + case codeAuthFailed: + metrics.AuthFailedCount.Inc(1) + } + } +} diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 2ffc4dff17b2..cdd275b73e61 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -9,11 +9,8 @@ package sqlproxyccl import ( - "context" - "crypto/tls" "io" "net" - "os" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" @@ -24,76 +21,42 @@ const pgAcceptSSLRequest = 'S' // See https://www.postgresql.org/docs/9.1/protocol-message-formats.html. var pgSSLRequest = []int32{8, 80877103} -// BackendConfig contains the configuration of a backend connection that is -// being proxied. -// To be removed once all clients are migrated to use backend dialer. -type BackendConfig struct { - // The address to which the connection is forwarded. - OutgoingAddress string - // TLS settings to use when connecting to OutgoingAddress. - TLSConf *tls.Config - // Called after successfully connecting to OutgoingAddr. - OnConnectionSuccess func() - // KeepAliveLoop if provided controls the lifetime of the proxy connection. - // It will be run in its own goroutine when the connection is successfully - // opened. Returning from `KeepAliveLoop` will close the proxy connection. - // Note that non-nil error return values will be forwarded to the user and - // hence should explain the reason for terminating the connection. - // Most common use of KeepAliveLoop will be as an infinite loop that - // periodically checks if the connection should still be kept alive. Hence it - // may block indefinitely so it's prudent to use the provided context and - // return on context cancellation. - // See `TestProxyKeepAlive` for example usage. - KeepAliveLoop func(context.Context) error +// sendErrToClientAndUpdateMetrics simply combines the update of the metrics and +// the transmission of the err back to the client. +func updateMetricsAndSendErrToClient(err error, conn net.Conn, metrics *metrics) { + metrics.updateForError(err) + sendErrToClient(conn, err) } -// Options are the options to the Proxy method. -type Options struct { - // Deprecated: construct FrontendAdmitter, passing this information in case - // that SSL is desired. - IncomingTLSConfig *tls.Config // config used for client -> proxy connection - - // BackendFromParams returns the config to use for the proxy -> backend - // connection. The TLS config is in it and it must have an appropriate - // ServerName for the remote backend. - // Deprecated: processing of the params now happens in the BackendDialer. - // This is only here to support OnSuccess and KeepAlive. - BackendConfigFromParams func( - params map[string]string, incomingConn *Conn, - ) (config *BackendConfig, clientErr error) - - // If set, consulted to modify the parameters set by the frontend before - // forwarding them to the backend during startup. - // Deprecated: include the code that modifies the request params - // in the backend dialer. - ModifyRequestParams func(map[string]string) - - // If set, consulted to decorate an error message to be sent to the client. - // The error passed to this method will contain no internal information. - OnSendErrToClient func(code ErrorCode, msg string) string - - // If set, will be called immediately after a new incoming connection - // is accepted. It can optionally negotiate SSL, provide admittance control or - // other types of frontend connection filtering. - FrontendAdmitter func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) - - // If set, will be used to establish and return connection to the backend. - // If not set, the old logic will be used. - // The argument is the startup message received from the frontend. It - // contains the protocol version and params sent by the client. - BackendDialer func(msg *pgproto3.StartupMessage) (net.Conn, error) -} - -// Proxy takes an incoming client connection and relays it to a backend SQL -// server. -func (s *Server) Proxy(proxyConn *Conn) error { - sendErrToClient := func(conn net.Conn, code ErrorCode, msg string) { - if s.opts.OnSendErrToClient != nil { - msg = s.opts.OnSendErrToClient(code, msg) +// sendErrToClient will encode and pass back to the SQL client an error message. +// It can be called by the implementors of proxyHandler to give more +// information to the end user in case of a problem. +var sendErrToClient = func(conn net.Conn, err error) { + if err == nil || conn == nil { + return + } + codeErr := (*codeError)(nil) + if errors.As(err, &codeErr) { + var msg string + switch codeErr.code { + // These are send as is. + case codeExpiredClientConnection, + codeBackendDown, + codeParamsRoutingFailed, + codeClientDisconnected, + codeBackendDisconnected, + codeAuthFailed, + codeProxyRefusedConnection: + msg = codeErr.Error() + // The rest - the message send back is sanitized. + case codeIdleDisconnect: + msg = "terminating connection due to idle timeout" + case codeUnexpectedInsecureStartupMessage: + msg = "server requires encryption" } var pgCode string - if code == CodeIdleDisconnect { + if codeErr.code == codeIdleDisconnect { pgCode = "57P01" // admin shutdown } else { pgCode = "08004" // rejected connection @@ -104,135 +67,13 @@ func (s *Server) Proxy(proxyConn *Conn) error { Message: msg, }).Encode(nil)) } +} - frontendAdmitter := s.opts.FrontendAdmitter - if frontendAdmitter == nil { - // Keep this until all clients are switched to provide FrontendAdmitter - // at what point we can also drop IncomingTLSConfig - frontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return FrontendAdmit(incoming, s.opts.IncomingTLSConfig) - } - } - - conn, msg, err := frontendAdmitter(proxyConn) - if err != nil { - var codeErr *CodeError - if ok := errors.As(err, &codeErr); ok && codeErr.code == CodeUnexpectedInsecureStartupMessage { - sendErrToClient( - proxyConn, // Do this on the TCP connection as it means denying SSL - CodeUnexpectedInsecureStartupMessage, - "server requires encryption", - ) - } else if ok { - sendErrToClient(proxyConn, codeErr.code, codeErr.Error()) - } else { - sendErrToClient(proxyConn, CodeClientDisconnected, err.Error()) - } - return err - } - - // This currently only happens for CancelRequest type of startup messages - // that we don't support - if conn == nil { - return nil - - } - defer func() { _ = conn.Close() }() - - backendDialer := s.opts.BackendDialer - var backendConfig *BackendConfig - if s.opts.BackendConfigFromParams != nil { - var clientErr error - backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) - if clientErr != nil { - var codeErr *CodeError - if !errors.As(clientErr, &codeErr) { - codeErr = &CodeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), - } - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - } - if backendDialer == nil { - // This we need to keep until all the clients are switched to provide BackendDialer. - // It constructs a backend dialer from the information provided via - // BackendConfigFromParams function. - backendDialer = func(msg *pgproto3.StartupMessage) (net.Conn, error) { - // We should be able to remove this when the all clients switch to - // backend dialer. - if s.opts.ModifyRequestParams != nil { - s.opts.ModifyRequestParams(msg.Parameters) - } - - crdbConn, err := BackendDial(msg, backendConfig.OutgoingAddress, backendConfig.TLSConf) - if err != nil { - if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) { - sendErrToClient(conn, codeErr.code, codeErr.Error()) - } else { - sendErrToClient(conn, CodeBackendDisconnected, err.Error()) - } - return nil, err - } - - return crdbConn, nil - } - } - - crdbConn, err := backendDialer(msg) - if err != nil { - s.metrics.BackendDownCount.Inc(1) - var codeErr *CodeError - if !errors.As(err, &codeErr) { - codeErr = &CodeError{ - code: CodeBackendDown, - err: errors.Errorf("unable to reach backend SQL server: %v", err), - } - } - if codeErr.code == CodeProxyRefusedConnection { - s.metrics.RefusedConnCount.Inc(1) - } else if codeErr.code == CodeParamsRoutingFailed { - s.metrics.RoutingErrCount.Inc(1) - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - defer func() { _ = crdbConn.Close() }() - - if err := authenticate(conn, crdbConn); err != nil { - s.metrics.AuthFailedCount.Inc(1) - var codeErr *CodeError - if !errors.As(err, &codeErr) { - codeErr = &CodeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("unrecognized auth failure"), - } - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - - s.metrics.SuccessfulConnCount.Inc(1) - - // These channels are buffered because we'll only consume one of them. +// connectionCopy does a bi-directional copy between the backend and frontend +// connections. It terminates when one of connections terminate. +func connectionCopy(crdbConn, conn net.Conn) error { errOutgoing := make(chan error, 1) errIncoming := make(chan error, 1) - errExpired := make(chan error, 1) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if backendConfig != nil { - if backendConfig.OnConnectionSuccess != nil { - backendConfig.OnConnectionSuccess() - } - if backendConfig.KeepAliveLoop != nil { - go func() { - errExpired <- backendConfig.KeepAliveLoop(ctx) - }() - } - } go func() { _, err := io.Copy(crdbConn, conn) @@ -253,34 +94,18 @@ func (s *Server) Proxy(proxyConn *Conn) error { case err := <-errIncoming: if err == nil { return nil - } else if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) && - codeErr.code == CodeExpiredClientConnection { - s.metrics.ExpiredClientConnCount.Inc(1) - sendErrToClient(conn, codeErr.code, codeErr.Error()) + } else if codeErr := (*codeError)(nil); errors.As(err, &codeErr) && + codeErr.code == codeExpiredClientConnection { return codeErr - } else if errors.Is(err, os.ErrDeadlineExceeded) { - s.metrics.IdleDisconnectCount.Inc(1) - sendErrToClient(conn, CodeIdleDisconnect, "terminating connection due to idle timeout") - return NewErrorf(CodeIdleDisconnect, "terminating connection due to idle timeout: %v", err) + } else if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { + return newErrorf(codeIdleDisconnect, "terminating connection due to idle timeout: %v", err) } else { - s.metrics.BackendDisconnectCount.Inc(1) - sendErrToClient(conn, CodeBackendDisconnected, "copying from target server to client") - return NewErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err) + return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) } case err := <-errOutgoing: // The incoming connection got closed. if err != nil { - s.metrics.ClientDisconnectCount.Inc(1) - return NewErrorf(CodeClientDisconnected, "copying from target server to client: %v", err) - } - return nil - case err := <-errExpired: - if err != nil { - // The client connection expired. - s.metrics.ExpiredClientConnCount.Inc(1) - code := CodeExpiredClientConnection - sendErrToClient(conn, code, err.Error()) - return NewErrorf(code, "expired client conn: %v", err) + return newErrorf(codeClientDisconnected, "copying from target server to client: %v", err) } return nil } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go new file mode 100644 index 000000000000..b60e27f9015d --- /dev/null +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -0,0 +1,613 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/cache" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/idle" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/security/certmgr" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/retry" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" + "github.com/jackc/pgproto3/v2" + "google.golang.org/grpc" +) + +var ( + // This assumes that whitespaces are used to separate command line args. + // Unlike the original spec, this does not handle escaping rules. + // + // See "options" in https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS. + clusterNameLongOptionRE = regexp.MustCompile(`(?:-c\s*|--)cluster=([\S]*)`) + + // clusterNameRegex restricts cluster names to have between 6 and 20 + // alphanumeric characters, with dashes allowed within the name (but not as a + // starting or ending character). + clusterNameRegex = regexp.MustCompile("^[a-z0-9][a-z0-9-]{4,18}[a-z0-9]$") +) + +const ( + // Cluster identifier is in the form "clustername-. Tenant id is + // always in the end but the cluster name can also contain '-' or digits. + // For example: + // "foo-7-10" -> cluster name is "foo-7" and tenant id is 10. + clusterTenantSep = "-" + // TODO(spaskob): add ballpark estimate. + maxKnownConnCacheSize = 5e6 // 5 million. +) + +// ProxyOptions is the information needed to construct a new proxyHandler. +type ProxyOptions struct { + // Denylist file to limit access to IP addresses and tenant ids. + Denylist string + // ListenAddr is the listen address for incoming connections. + ListenAddr string + // ListenCert is the file containing PEM-encoded x509 certificate for listen address. + // Set to "*" to auto-generate self-signed cert. + ListenCert string + // ListenKey is the file containing PEM-encoded x509 key for listen address. + // Set to "*" to auto-generate self-signed cert. + ListenKey string + // MetricsAddress is the listen address for incoming connections for metrics retrieval. + MetricsAddress string + // SkipVerify if set will skip the identity verification of the + // backend. This is for testing only. + SkipVerify bool + // Insecure if set, will not use TLS for the backend connection. For testing. + Insecure bool + // RoutingRule for constructing the backend address for each incoming + // connection. Optionally use '{{clusterName}}' + // which will be substituted with the cluster name. + RoutingRule string + // DirectoryAddr specified optional {HOSTNAME}:{PORT} for service that does the resolution + // from backend id to IP address. If specified - it will be used instead of the + // routing rule above. + DirectoryAddr string + // RatelimitBaseDelay is the initial backoff after a failed login attempt. + // Set to 0 to disable rate limiting. + RatelimitBaseDelay time.Duration + // ValidateAccessInterval is the time interval between validations, confirming + // that current connections are still valid. + ValidateAccessInterval time.Duration + // PollConfigInterval defines polling interval for pickup up changes in config file. + PollConfigInterval time.Duration + // IdleTimeout if set, will close connections that have been idle for that duration. + IdleTimeout time.Duration +} + +// proxyHandler is the default implementation of a proxy handler. +type proxyHandler struct { + ProxyOptions + + // metrics contains various counters reflecting the proxy operations. + metrics *metrics + + // stopper is used to do an orderly shutdown. + stopper *stop.Stopper + + // incomingCert is the managed cert of the proxy endpoint to + // which clients connect. + incomingCert certmgr.Cert + + // denyListService provides access control. + denyListService denylist.Service + + // throttleService will do throttling of incoming connection requests. + throttleService throttler.Service + + // directory is optional and if set, will be used to resolve + // backend id to IP addresses. + directory *tenant.Directory + + // CertManger keeps up to date the certificates used. + certManager *certmgr.CertManager + + //connCache is used to keep track of all current connections. + connCache cache.ConnCache +} + +// newProxyHandler will create a new proxy handler with configuration based on +// the provided options. +func newProxyHandler( + ctx context.Context, stopper *stop.Stopper, proxyMetrics *metrics, options ProxyOptions, +) (*proxyHandler, error) { + handler := proxyHandler{ + stopper: stopper, + metrics: proxyMetrics, + ProxyOptions: options, + certManager: certmgr.NewCertManager(ctx), + } + + var err error + err = handler.setupIncomingCert() + if err != nil { + return nil, err + } + + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + handler.denyListService, err = denylist.NewViperDenyListFromFile(ctx, options.Denylist, options.PollConfigInterval) + if err != nil { + return nil, err + } + + handler.throttleService = throttler.NewLocalService(throttler.WithBaseDelay(options.RatelimitBaseDelay)) + handler.connCache = cache.NewCappedConnCache(maxKnownConnCacheSize) + + if handler.DirectoryAddr != "" { + conn, err := grpc.Dial(handler.DirectoryAddr, grpc.WithInsecure()) + if err != nil { + return nil, err + } + // nolint:grpcconnclose + stopper.AddCloser(stop.CloserFn(func() { _ = conn.Close() /* nolint:grpcconnclose */ })) + client := tenant.NewDirectoryClient(conn) + handler.directory, err = tenant.NewDirectory(ctx, stopper, client) + if err != nil { + return nil, err + } + } + + return &handler, nil +} + +// handle is called by the proxy server to handle a single incoming client +// connection. +func (handler *proxyHandler) handle(ctx context.Context, proxyConn *conn) error { + conn, msg, err := frontendAdmit(proxyConn, handler.IncomingTLSConfig()) + defer func() { _ = conn.Close() }() + if err != nil { + sendErrToClient(conn, err) + return err + } + + // This currently only happens for CancelRequest type of startup messages + // that we don't support + if msg == nil { + return nil + } + + // Note that the errors returned from this function are user-facing errors so + // we should be careful with the details that we want to expose. + backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(msg) + if err != nil { + clientErr := &codeError{codeParamsRoutingFailed, err} + log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", clientErr.Error()) + updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) + return clientErr + } + // This forwards the remote addr to the backend. + backendStartupMsg.Parameters["crdb:remote_addr"] = conn.RemoteAddr().String() + + ctx = logtags.AddTag(ctx, "cluster", clusterName) + ctx = logtags.AddTag(ctx, "tenant", tenID) + + ipAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + clientErr := &codeError{codeParamsRoutingFailed, err} + log.Errorf(ctx, "could not parse address: %v", clientErr.Error()) + updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) + return clientErr + } + + if err = handler.validateAccessAndThrottle(ctx, tenID, ipAddr); err != nil { + updateMetricsAndSendErrToClient(err, conn, handler.metrics) + return err + } + + var TLSConf *tls.Config + if !handler.Insecure { + TLSConf = &tls.Config{InsecureSkipVerify: handler.SkipVerify} + } + + var crdbConn net.Conn + var outgoingAddress string + retryOpts := retry.Options{InitialBackoff: 10 * time.Millisecond, MaxBackoff: time.Second} + everyMinute := log.Every(time.Second) + var outgoingAddressErrs, codeBackendDownErrs, reportFailureErrs int + for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { + shouldLog := everyMinute.ShouldLog() + outgoingAddress, err = handler.outgoingAddress(ctx, clusterName, tenID) + if err != nil { + outgoingAddressErrs++ + if shouldLog { + log.Ops.Errorf(ctx, + "outgoing address: %v (%d errors in the past minute)", + err, + outgoingAddressErrs, + ) + outgoingAddressErrs = 0 + } + continue + } + + crdbConn, err = backendDial(backendStartupMsg, outgoingAddress, TLSConf) + // If we get a backend down error and are using the directory - report the + // error to the directory and retry the connection. + codeErr := (*codeError)(nil) + if err != nil && + errors.As(err, &codeErr) && + codeErr.code == codeBackendDown && + handler.directory != nil { + codeBackendDownErrs++ + if shouldLog { + log.Ops.Errorf(ctx, + "backend dial: %v (%d errors in the past minute)", + err, + codeBackendDownErrs, + ) + codeBackendDownErrs = 0 + } + err = handler.directory.ReportFailure(ctx, roachpb.MakeTenantID(tenID), outgoingAddress) + if err != nil { + reportFailureErrs++ + if shouldLog { + log.Ops.Errorf(ctx, + "report failure: %v (%d errors in the past minute)", + err, + reportFailureErrs, + ) + reportFailureErrs = 0 + } + } + continue + } + break + } + + if err != nil { + updateMetricsAndSendErrToClient(err, conn, handler.metrics) + return err + } + + crdbConn = idle.DisconnectOverlay(crdbConn, handler.IdleTimeout) + + defer func() { _ = crdbConn.Close() }() + + if err := authenticate(conn, crdbConn); err != nil { + handler.metrics.updateForError(err) + return errors.AssertionFailedf("unrecognized auth failure") + } + + handler.metrics.SuccessfulConnCount.Inc(1) + + handler.connCache.Insert( + &cache.ConnKey{IPAddress: ipAddr, TenantID: roachpb.MakeTenantID(tenID)}, + ) + + errConnectionCopy := make(chan error, 1) + errExpired := make(chan error, 1) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // TODO(darinpp): starting a new go routine for every connection here is inefficient. + // Change to maintain a map of connections by IP/tenant and have single + // go routine that closes connections that are blocklisted. + go func() { + errExpired <- func(ctx context.Context) error { + t := timeutil.NewTimer() + defer t.Stop() + t.Reset(0) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + t.Read = true + if err := handler.validateAccess(ctx, tenID, ipAddr); err != nil { + return err + } + } + t.Reset(util.Jitter(handler.ValidateAccessInterval, 0.15)) + } + }(ctx) + }() + + log.Infof(ctx, "new connection") + connBegin := timeutil.Now() + defer func() { + log.Infof(ctx, "closing after %.2fs", timeutil.Since(connBegin).Seconds()) + }() + + go func() { + err := connectionCopy(crdbConn, conn) + errConnectionCopy <- err + }() + + select { + case err := <-errConnectionCopy: + updateMetricsAndSendErrToClient(err, conn, handler.metrics) + return err + case err := <-errExpired: + if err != nil { + // The client connection expired. + codeErr := newErrorf( + codeExpiredClientConnection, "expired client conn: %v", err, + ) + updateMetricsAndSendErrToClient(codeErr, conn, handler.metrics) + return codeErr + } + return nil + case <-handler.stopper.ShouldQuiesce(): + return nil + } + +} + +// outgoingAddress resolves a tenant ID and a tenant cluster name to an IP of +// the backend. +func (handler *proxyHandler) outgoingAddress( + ctx context.Context, name string, tenID uint64, +) (string, error) { + if handler.directory == nil { + addr := strings.ReplaceAll(handler.RoutingRule, "{{clusterName}}", name) + log.Infof(ctx, "backend %s resolved to '%s'", name, addr) + return addr, nil + } + + // This doesn't verify the name part of the tenant and relies just on the int id. + addr, err := handler.directory.EnsureTenantIP(ctx, roachpb.MakeTenantID(tenID), "") + if err != nil { + return "", err + } + return addr, nil +} + +func (handler *proxyHandler) validateAccessAndThrottle( + ctx context.Context, tenID uint64, ipAddr string, +) error { + if err := handler.validateAccess(ctx, tenID, ipAddr); err != nil { + return err + } + + // Admit the connection + connKey := cache.ConnKey{IPAddress: ipAddr, TenantID: roachpb.MakeTenantID(tenID)} + if !handler.connCache.Exists(&connKey) { + // Unknown previous successful connections from this IP and tenant. + // Hence we need to rate limit. + if err := handler.throttleService.LoginCheck(ipAddr, timeutil.Now()); err != nil { + log.Errorf(ctx, "throttler refused connection: %v", err.Error()) + return newErrorf(codeProxyRefusedConnection, "Connection attempt throttled") + } + } + + return nil +} + +func (handler *proxyHandler) validateAccess( + ctx context.Context, tenID uint64, ipAddr string, +) error { + // First validate against the deny list service + list := handler.denyListService + if entry, err := list.Denied(denylist.DenyEntity{Item: fmt.Sprint(tenID), Type: denylist.ClusterType}); err != nil { + // Log error but don't return since this could be transient. + log.Errorf(ctx, "could not consult denied list for tenant: %d", tenID) + } else if entry != nil { + log.Errorf(ctx, "access denied for tenant: %d, reason: %s", tenID, entry.Reason) + return newErrorf(codeProxyRefusedConnection, "tenant %d %s", tenID, entry.Reason) + } + + if entry, err := list.Denied(denylist.DenyEntity{Item: ipAddr, Type: denylist.IPAddrType}); err != nil { + // Log error but don't return since this could be transient. + log.Errorf(ctx, "could not consult denied list for IP address: %s", ipAddr) + } else if entry != nil { + log.Errorf(ctx, "access denied for IP address: %s, reason: %s", ipAddr, entry.Reason) + return newErrorf(codeProxyRefusedConnection, "IP address %s %s", ipAddr, entry.Reason) + } + + return nil +} + +// clusterNameAndTenantFromParams extracts the cluster name from the connection +// parameters, and rewrites the database param, if necessary. We currently +// support embedding the cluster name in two ways: +// - Within the database param (e.g. "happy-koala.defaultdb") +// +// - Within the options param (e.g. "... --cluster=happy-koala ..."). +// PostgreSQL supports three different ways to set a run-time parameter +// through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and +// "--NAME=VALUE". +func clusterNameAndTenantFromParams( + msg *pgproto3.StartupMessage, +) (*pgproto3.StartupMessage, string, uint64, error) { + clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"]) + if err != nil { + return msg, "", 0, err + } + + clusterNameFromOpt, err := parseOptionsParam(msg.Parameters["options"]) + if err != nil { + return msg, "", 0, err + } + + if clusterNameFromDB == "" && clusterNameFromOpt == "" { + return msg, "", 0, errors.New("missing cluster name in connection string") + } + + if clusterNameFromDB != "" && clusterNameFromOpt != "" { + return msg, "", 0, errors.New("multiple cluster names provided") + } + + if clusterNameFromDB == "" { + clusterNameFromDB = clusterNameFromOpt + } + + sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep) + // Cluster name provided without a tenant ID in the end. + if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 { + return msg, "", 0, errors.Errorf("invalid cluster name %s", clusterNameFromDB) + } + clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:] + + if !clusterNameRegex.MatchString(clusterNameSansTenant) { + return msg, "", 0, errors.Errorf("invalid cluster name '%s'", clusterNameSansTenant) + } + + tenID, err := strconv.ParseUint(tenantIDStr, 10, 64) + if err != nil { + return msg, "", 0, errors.Wrapf(err, "cannot parse %s as uint64", tenantIDStr) + } + + // Make and return a copy of the startup msg so the original is not modified. + paramsOut := map[string]string{} + for key, value := range msg.Parameters { + if key == "database" { + paramsOut[key] = databaseName + } else if key != "options" { + paramsOut[key] = value + } + } + outMsg := &pgproto3.StartupMessage{ + ProtocolVersion: msg.ProtocolVersion, + Parameters: paramsOut, + } + + return outMsg, clusterNameFromDB, tenID, nil +} + +// parseDatabaseParam parses the database parameter from the PG connection +// string, and tries to extract the cluster name if present. The cluster +// name should be embedded in the database parameter using the dot (".") +// delimiter in the form of ".". This approach +// is safe because dots are not allowed in the database names themselves. +func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, err error) { + // Database param is not provided. + if databaseParam == "" { + return "", "", nil + } + + parts := strings.SplitN(databaseParam, ".", 2) + + // Database param provided without cluster name. + if len(parts) <= 1 { + return "", databaseParam, nil + } + + clusterName, databaseName = parts[0], parts[1] + + // Ensure that the param is in the right format if the delimiter is provided. + if len(parts) > 2 || clusterName == "" || databaseName == "" { + return "", "", errors.New("invalid database param") + } + + return clusterName, databaseName, nil +} + +// parseOptionsParam parses the options parameter from the PG connection string, +// and tries to return the cluster name if present. Just like PostgreSQL, the +// sqlproxy supports three different ways to set a run-time parameter through +// its command-line options: +// -c NAME=VALUE (commonly used throughout documentation around PGOPTIONS) +// -cNAME=VALUE +// --NAME=VALUE +// +// CockroachDB currently does not support the options parameter, so the parsing +// logic is built on that assumption. If we do start supporting options in +// CockroachDB itself, then we should revisit this. +// +// Note that this parsing approach is not perfect as it allows a negative case +// like options="-c --cluster=happy-koala -c -c -c" to go through. To properly +// parse this, we need to traverse the string from left to right, and look at +// every single argument, but that involves quite a bit of work, so we'll punt +// for now. +func parseOptionsParam(optionsParam string) (string, error) { + // Only search up to 2 in case of large inputs. + matches := clusterNameLongOptionRE.FindAllStringSubmatch(optionsParam, 2 /* n */) + if len(matches) == 0 { + return "", nil + } + + if len(matches) > 1 { + // Technically we could still allow requests to go through if all + // cluster names match, but we don't want to parse the entire string, so + // we will just error out if at least two cluster flags are provided. + return "", errors.New("multiple cluster flags provided") + } + + // Length of each match should always be 2 with the given regex, one for + // the full string, and the other for the cluster name. + if len(matches[0]) != 2 { + // We don't want to panic here. + return "", errors.New("internal server error") + } + + // Flag was provided, but value is NULL. + if len(matches[0][1]) == 0 { + return "", errors.New("invalid cluster flag") + } + + return matches[0][1], nil +} + +// IncomingTLSConfig gets back the current TLS config for the incoiming client +// connection endpoint. +func (handler *proxyHandler) IncomingTLSConfig() *tls.Config { + if handler.incomingCert == nil { + return nil + } + + cert := handler.incomingCert.TLSCert() + if cert == nil { + return nil + } + + return &tls.Config{Certificates: []tls.Certificate{*cert}} +} + +// setupIncomingCert will setup a manged cert for the incoming connections. +// They can either be unencrypted (in case a cert and key names are empty), +// using self-signed, runtime generated cert (if cert is set to *) or +// using file based cert where the cert/key values refer to file names +// containing the information. +func (handler *proxyHandler) setupIncomingCert() error { + if (handler.ListenKey == "") != (handler.ListenCert == "") { + return errors.New("must specify either both or neither of cert and key") + } + + if handler.ListenCert == "" { + return nil + } + + // TODO(darin): change the cert manager so it uses the stopper. + ctx, _ := handler.stopper.WithCancelOnQuiesce(context.Background()) + certMgr := certmgr.NewCertManager(ctx) + var cert certmgr.Cert + if handler.ListenCert == "*" { + cert = certmgr.NewSelfSignedCert(0, 3, 0, 0) + } else if handler.ListenCert != "" { + cert = certmgr.NewFileCert(handler.ListenCert, handler.ListenKey) + } + cert.Reload(ctx) + err := cert.Err() + if err != nil { + return err + } + certMgr.ManageCert("client", cert) + handler.certManager = certMgr + handler.incomingCert = cert + + return nil +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go new file mode 100644 index 000000000000..bdcbb90020dd --- /dev/null +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -0,0 +1,735 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/require" +) + +// To ensure tenant startup code is included. +var _ = kvtenantccl.Connector{} + +const frontendError = "Frontend error!" +const backendError = "Backend error!" + +func hookBackendDial( + f func( + msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + ) (net.Conn, error), +) func() { + return testutils.HookGlobal(&backendDial, f) +} +func hookFrontendAdmit( + f func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error), +) func() { + return testutils.HookGlobal(&frontendAdmit, f) +} +func hookSendErrToClient(f func(conn net.Conn, err error)) func() { + return testutils.HookGlobal(&sendErrToClient, f) +} +func hookAuthenticate(f func(clientConn, crdbConn net.Conn) error) func() { + return testutils.HookGlobal(&authenticate, f) +} + +func newSecureProxyServer( + ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, +) (server *Server, addr string) { + // Created via: + const _ = ` +openssl genrsa -out testserver.key 2048 +openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ + -days 3650 -config testserver_config.cnf +` + opts.ListenKey = "testserver.key" + opts.ListenCert = "testserver.crt" + + return newProxyServer(ctx, t, stopper, opts) +} + +func newProxyServer( + ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, +) (server *Server, addr string) { + const listenAddress = "127.0.0.1:0" + ln, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + require.NoError(t, + stopper.RunAsyncTask(ctx, "proxy-ln-close", func(ctx context.Context) { + <-stopper.ShouldQuiesce() + require.NoError(t, ln.Close()) + })) + + server, err = NewServer(ctx, stopper, *opts) + require.NoError(t, err) + + err = server.Stopper.RunAsyncTask(ctx, "proxy-server-serve", func(ctx context.Context) { + _ = server.Serve(ctx, ln) + }) + require.NoError(t, err) + + return server, ln.Addr().String() +} + +func runTestQuery(ctx context.Context, conn *pgx.Conn) error { + var n int + if err := conn.QueryRow(ctx, "SELECT $1::int", 1).Scan(&n); err != nil { + return err + } + if n != 1 { + return errors.Errorf("expected 1 got %d", n) + } + return nil +} + +type assertCtx struct { + emittedCode *errorCode +} + +func makeAssertCtx() assertCtx { + var emittedCode errorCode = -1 + return assertCtx{ + emittedCode: &emittedCode, + } +} + +func (ac *assertCtx) onSendErrToClient(code errorCode) { + *ac.emittedCode = code +} + +func (ac *assertCtx) assertConnectErr( + ctx context.Context, t *testing.T, prefix, suffix string, expCode errorCode, expErr string, +) { + t.Helper() + *ac.emittedCode = -1 + t.Run(suffix, func(t *testing.T) { + conn, err := pgx.Connect(ctx, prefix+suffix) + if err == nil { + _ = conn.Close(ctx) + } + require.Contains(t, err.Error(), expErr) + require.Equal(t, expCode, *ac.emittedCode) + }) +} + +func TestLongDBName(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookBackendDial(func(_ *pgproto3.StartupMessage, _ string, _ *tls.Config) (net.Conn, error) { + return nil, newErrorf(codeParamsRoutingFailed, "boom") + })() + + ac := makeAssertCtx() + originalSendErrToClient := sendErrToClient + defer hookSendErrToClient(func(conn net.Conn, err error) { + if codeErr, ok := err.(*codeError); ok { + ac.onSendErrToClient(codeErr.code) + } + originalSendErrToClient(conn, err) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + + longDB := strings.Repeat("x", 70) // 63 is limit + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=dim-dog-28", addr, longDB) + ac.assertConnectErr(ctx, t, pgurl, "" /* suffix */, codeParamsRoutingFailed, "boom") + require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) +} + +func TestFailedConnection(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + // TODO(asubiotto): consider using datadriven for these, especially if the + // proxy becomes more complex. + + var originalSendErrToClient = sendErrToClient + ac := makeAssertCtx() + defer hookSendErrToClient(func(conn net.Conn, err error) { + if codeErr, ok := err.(*codeError); ok { + ac.onSendErrToClient(codeErr.code) + } + originalSendErrToClient(conn, err) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) + + _, p, err := net.SplitHostPort(addr) + require.NoError(t, err) + u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) + // Valid connections, but no backend server running. + for _, sslmode := range []string{"require", "prefer"} { + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode="+sslmode, + codeBackendDown, "unable to reach backend SQL server", + ) + } + + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode=verify-ca&sslrootcert=testserver.crt", + codeBackendDown, "unable to reach backend SQL server", + ) + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode=verify-full&sslrootcert=testserver.crt", + codeBackendDown, "unable to reach backend SQL server", + ) + require.Equal(t, int64(4), s.metrics.BackendDownCount.Count()) + + // Unencrypted connections bounce. + for _, sslmode := range []string{"disable", "allow"} { + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode="+sslmode, + codeUnexpectedInsecureStartupMessage, "server requires encryption", + ) + } + require.Equal(t, int64(0), s.metrics.RoutingErrCount.Count()) + // TenantID rejected as malformed. + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim&sslmode=require", + codeParamsRoutingFailed, "invalid cluster name", + ) + require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) + // No TenantID. + ac.assertConnectErr( + ctx, t, u, "?sslmode=require", + codeParamsRoutingFailed, "missing cluster name", + ) + require.Equal(t, int64(2), s.metrics.RoutingErrCount.Count()) +} + +func TestUnexpectedError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + // Set up a Server whose FrontendAdmitter function always errors with a + // non-codeError error. + defer hookFrontendAdmit( + func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + log.Infof(context.Background(), "frontendAdmit returning unexpected error") + return conn, nil, errors.New("unexpected error") + })() + + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + + u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addr) + + // Time how long it takes for pgx.Connect to return. If the proxy handles + // errors appropriately, pgx.Connect should return near immediately + // because the server should close the connection. If not, it may take up + // to the 5s connect_timeout for pgx.Connect to give up. + start := timeutil.Now() + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + t.Log(err) + elapsed := timeutil.Since(start) + if elapsed >= 5*time.Second { + t.Errorf("pgx.Connect took %s to error out", elapsed) + } +} + +func TestProxyAgainstSecureCRDB(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + var connSuccess bool + originalAuthenticate := authenticate + defer hookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := originalAuthenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + + url := fmt.Sprintf("postgres://bob:wrong@%s/dim-dog-28.defaultdb?sslmode=require", addr) + _, err := pgx.Connect(ctx, url) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + url = fmt.Sprintf("postgres://bob@%s/dim-dog-28.defaultdb?sslmode=require", addr) + _, err = pgx.Connect(ctx, url) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + url = fmt.Sprintf("postgres://bob:builder@%s/dim-dog-28.defaultdb?sslmode=require", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(2), s.metrics.AuthFailedCount.Count()) + }() + + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestProxyTLSClose(t *testing.T) { + defer leaktest.AfterTest(t)() + // NB: The leaktest call is an important part of this test. We're + // verifying that no goroutines are leaked, despite calling Close an + // underlying TCP connection (rather than the TLSConn that wraps it). + + ctx := context.Background() + + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + var proxyIncomingConn atomic.Value // *conn + var connSuccess bool + frontendAdmit := frontendAdmit + defer hookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + proxyIncomingConn.Store(conn) + return frontendAdmit(conn, incomingTLSConfig) + })() + originalAuthenticate := authenticate + defer hookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := originalAuthenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + + url := fmt.Sprintf("postgres://bob:builder@%s/dim-dog-28.defaultdb?sslmode=require", addr) + c, err := pgx.Connect(ctx, url) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + incomingConn, ok := proxyIncomingConn.Load().(*conn) + require.True(t, ok) + require.NoError(t, incomingConn.Close()) + <-incomingConn.done() // should immediately proceed + + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) + }() + + require.NoError(t, runTestQuery(ctx, c)) +} + +func TestProxyModifyRequestParams(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + proxyOutgoingTLSConfig := outgoingTLSConfig.Clone() + proxyOutgoingTLSConfig.InsecureSkipVerify = true + + backendDial := backendDial + defer hookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + params := msg.Parameters + authToken, ok := params["authToken"] + require.True(t, ok) + require.Equal(t, "abc123", authToken) + user, ok := params["user"] + require.True(t, ok) + require.Equal(t, "bogususer", user) + require.Contains(t, params, "user") + + // NB: This test will fail unless the user used between the proxy + // and the backend is changed to a user that actually exists. + delete(params, "authToken") + params["user"] = "root" + + return backendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + })() + + defer sql.Stopper().Stop(ctx) + + s, proxyAddr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) + + u := fmt.Sprintf("postgres://bogususer@%s/?sslmode=require&authToken=abc123&options=--cluster=dim-dog-28", proxyAddr) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestInsecureProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + s, addr := newProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + + u := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + u = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + + defer func() { + require.NoError(t, conn.Close(ctx)) + require.Equal(t, int64(1), s.metrics.AuthFailedCount.Count()) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestInsecureDoubleProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) + defer sql.Stopper().Stop(ctx) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + // Test multiple proxies: proxyB -> proxyA -> tc + _, proxyA := newProxyServer(ctx, t, sql.Stopper(), + &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), Insecure: true}, + ) + _, proxyB := newProxyServer(ctx, t, sql.Stopper(), + &ProxyOptions{RoutingRule: proxyA, Insecure: true}, + ) + + u := fmt.Sprintf("postgres://root:admin@%s/dim-dog-28.dim-dog-29.defaultdb?sslmode=disable", proxyB) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + }() + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestErroneousFrontend(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return conn, nil, errors.New(frontendError) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + + u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + // Generic message here as the Frontend's error is not codeError and + // by default we don't pass back error's text. The startup message doesn't get + // processed in this case. + require.Regexp(t, "connection reset by peer|failed to receive message", err) +} + +func TestErroneousBackend(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookBackendDial( + func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + return nil, errors.New(backendError) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + + u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + // Generic message here as the Backend's error is not codeError and + // by default we don't pass back error's text. The startup message has already + // been processed. + require.Regexp(t, "failed to receive message", err) +} + +func TestProxyRefuseConn(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + defer hookBackendDial(func(_ *pgproto3.StartupMessage, _ string, _ *tls.Config) (net.Conn, error) { + return nil, newErrorf(codeProxyRefusedConnection, "too many attempts") + })() + + ac := makeAssertCtx() + originalSendErrToClient := sendErrToClient + defer hookSendErrToClient(func(conn net.Conn, err error) { + if codeErr, ok := err.(*codeError); ok { + ac.onSendErrToClient(codeErr.code) + } + originalSendErrToClient(conn, err) + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + + ac.assertConnectErr( + ctx, t, fmt.Sprintf("postgres://root:admin@%s/", addr), + "?sslmode=require&options=--cluster=dim-dog-28", + codeProxyRefusedConnection, "too many attempts", + ) + require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) + require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) +} + +func TestDenylistUpdate(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + denyList, err := ioutil.TempFile("", "*_denylist.yml") + require.NoError(t, err) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + proxyOutgoingTLSConfig := outgoingTLSConfig.Clone() + proxyOutgoingTLSConfig.InsecureSkipVerify = true + + backendDial := backendDial + defer hookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + time.AfterFunc(100*time.Millisecond, func() { + _, err := denyList.WriteString("127.0.0.1: test-denied") + require.NoError(t, err) + }) + return backendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{Denylist: denyList.Name()}) + defer func() { _ = os.Remove(denyList.Name()) }() + + url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(context.Background(), url) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.Equal(t, int64(1), s.metrics.ExpiredClientConnCount.Count()) + }() + + require.Eventuallyf( + t, + func() bool { + _, err = conn.Exec(context.Background(), "SELECT 1") + return err != nil && strings.Contains(err.Error(), "expired") + }, + time.Second, 5*time.Millisecond, + "unexpected error received: %v", err, + ) +} + +func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + proxyOutgoingTLSConfig := outgoingTLSConfig.Clone() + proxyOutgoingTLSConfig.InsecureSkipVerify = true + + idleTimeout, _ := time.ParseDuration("0.5s") + var connSuccess bool + frontendAdmit := frontendAdmit + defer hookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return frontendAdmit(conn, incomingTLSConfig) + })() + backendDial := backendDial + defer hookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + return backendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + })() + originalAuthenticate := authenticate + defer hookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := originalAuthenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + defer sql.Stopper().Stop(ctx) + + s, addr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{IdleTimeout: idleTimeout}) + + url := fmt.Sprintf("postgres://root:admin@%s/?sslmode=require&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + }() + + var n int + err = conn.QueryRow(ctx, "SELECT $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + time.Sleep(idleTimeout * 2) + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") +} + +func newDirectoryServer( + ctx context.Context, t *testing.T, srv serverutils.TestServerInterface, addr *net.TCPAddr, +) (*stop.Stopper, *net.TCPAddr) { + tdsStopper := stop.NewStopper() + listener, err := net.ListenTCP("tcp", addr) + require.NoError(t, err) + tds := tenant.NewTestDirectoryServer(tdsStopper) + tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*tenant.Process, error) { + log.TestingClearServerIdentifiers() + tenantStopper := tenant.NewSubStopper(tdsStopper) + ten, err := srv.StartTenant(ctx, base.TestTenantArgs{ + Existing: true, + TenantID: roachpb.MakeTenantID(tenantID), + ForceInsecure: true, + Stopper: tenantStopper, + }) + require.NoError(t, err) + sqlAddr, err := net.ResolveTCPAddr("tcp", ten.SQLAddr()) + require.NoError(t, err) + ten.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + return &tenant.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil + } + go func() { require.NoError(t, tds.Serve(listener)) }() + return tdsStopper, listener.Addr().(*net.TCPAddr) +} + +func TestDirectoryReconnect(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + // New test cluster + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) + defer srv.Stopper().Stop(ctx) + + // Create tenant 28 + sqlConn := srv.InternalExecutor().(*sql.InternalExecutor) + _, err := sqlConn.Exec(ctx, "", nil, "SELECT crdb_internal.create_tenant(28)") + require.NoError(t, err) + + // New test directory server + stopper1, tdsAddr := newDirectoryServer(ctx, t, srv, &net.TCPAddr{}) + + // New proxy server using the directory + _, addr := newProxyServer( + ctx, t, srv.Stopper(), &ProxyOptions{DirectoryAddr: tdsAddr.String(), Insecure: true}, + ) + + // try to connect - should be successful. + url := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err = pgx.Connect(ctx, url) + require.NoError(t, err) + + // Stop the directory server and the tenant + stopper1.Stop(ctx) + + var succeeded syncutil.AtomicBool + + stopper2, _ := newDirectoryServer(ctx, t, srv, tdsAddr) + defer stopper2.Stop(ctx) + + require.Eventually(t, func() bool { + // try to connect through the proxy again - should be successful. + url = fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err = pgx.Connect(ctx, url) + require.NoError(t, err) + succeeded.Set(true) + return true + }, 1000*time.Second, 100*time.Millisecond) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_test.go b/pkg/ccl/sqlproxyccl/proxy_test.go deleted file mode 100644 index 1b81ea2cf61b..000000000000 --- a/pkg/ccl/sqlproxyccl/proxy_test.go +++ /dev/null @@ -1,671 +0,0 @@ -// Copyright 2020 The Cockroach Authors. -// -// Licensed as a CockroachDB Enterprise file under the Cockroach Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt - -package sqlproxyccl - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/require" -) - -const FrontendError = "Frontend error!" -const BackendError = "Backend error!" - -func setupTestProxyWithCerts( - t *testing.T, opts *Options, -) (server *Server, addr string, done func()) { - // Created via: - const _ = ` -openssl genrsa -out testserver.key 2048 -openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ - -days 3650 -config testserver_config.cnf -` - cer, err := tls.LoadX509KeyPair("testserver.crt", "testserver.key") - require.NoError(t, err) - opts.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return FrontendAdmit( - incoming, - &tls.Config{ - Certificates: []tls.Certificate{cer}, - ServerName: "localhost", - }, - ) - } - - const listenAddress = "127.0.0.1:0" - // NB: ln closes before wg.Wait or we deadlock. - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - - var wg sync.WaitGroup - wg.Add(1) - - done = func() { - _ = ln.Close() - wg.Wait() - } - - server = NewServer(*opts) - - go func() { - defer wg.Done() - _ = server.Serve(ln) - }() - - return server, ln.Addr().String(), done -} - -func testingTenantIDFromDatabaseForAddr( - addr string, validTenant string, -) func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return func(msg *pgproto3.StartupMessage) (net.Conn, error) { - const dbKey = "database" - p := msg.Parameters - db, ok := p[dbKey] - if !ok { - return nil, NewErrorf( - CodeParamsRoutingFailed, "need to specify database", - ) - } - sl := strings.SplitN(db, "_", 2) - if len(sl) != 2 { - return nil, NewErrorf( - CodeParamsRoutingFailed, "malformed database name", - ) - } - db, tenantID := sl[0], sl[1] - - if tenantID != validTenant { - return nil, NewErrorf(CodeParamsRoutingFailed, "invalid tenantID") - } - - p[dbKey] = db - return BackendDial(msg, addr, &tls.Config{ - // NB: this would be false in production. - InsecureSkipVerify: true, - }) - } -} - -func runTestQuery(conn *pgx.Conn) error { - var n int - if err := conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n); err != nil { - return err - } - if n != 1 { - return errors.Errorf("expected 1 got %d", n) - } - return nil -} - -type assertCtx struct { - emittedCode *ErrorCode -} - -func makeAssertCtx() assertCtx { - var emittedCode ErrorCode = -1 - return assertCtx{ - emittedCode: &emittedCode, - } -} - -func (ac *assertCtx) onSendErrToClient(code ErrorCode, msg string) string { - *ac.emittedCode = code - return msg -} - -func (ac *assertCtx) assertConnectErr( - t *testing.T, prefix, suffix string, expCode ErrorCode, expErr string, -) { - t.Helper() - *ac.emittedCode = -1 - t.Run(suffix, func(t *testing.T) { - ctx := context.Background() - conn, err := pgx.Connect(ctx, prefix+suffix) - if err == nil { - _ = conn.Close(ctx) - } - require.Contains(t, err.Error(), expErr) - require.Equal(t, expCode, *ac.emittedCode) - - }) -} - -func TestLongDBName(t *testing.T) { - defer leaktest.AfterTest(t)() - - ac := makeAssertCtx() - - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return nil, NewErrorf(CodeParamsRoutingFailed, "boom") - }, - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - longDB := strings.Repeat("x", 70) // 63 is limit - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s", addr, longDB) - ac.assertConnectErr(t, pgurl, "" /* suffix */, CodeParamsRoutingFailed, "boom") - require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) -} - -func TestFailedConnection(t *testing.T) { - defer leaktest.AfterTest(t)() - - // TODO(asubiotto): consider using datadriven for these, especially if the - // proxy becomes more complex. - - ac := makeAssertCtx() - opts := Options{ - BackendDialer: testingTenantIDFromDatabaseForAddr( - "undialable%$!@$", "29", - ), - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - _, p, err := net.SplitHostPort(addr) - require.NoError(t, err) - u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) - // Valid connections, but no backend server running. - for _, sslmode := range []string{"require", "prefer"} { - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode="+sslmode, - CodeBackendDown, "unable to reach backend SQL server", - ) - } - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode=verify-ca&sslrootcert=testserver.crt", - CodeBackendDown, "unable to reach backend SQL server", - ) - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode=verify-full&sslrootcert=testserver.crt", - CodeBackendDown, "unable to reach backend SQL server", - ) - require.Equal(t, int64(4), s.metrics.BackendDownCount.Count()) - - // Unencrypted connections bounce. - for _, sslmode := range []string{"disable", "allow"} { - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode="+sslmode, - CodeUnexpectedInsecureStartupMessage, "server requires encryption", - ) - } - - // TenantID rejected by test hook. - ac.assertConnectErr( - t, u, "defaultdb_28?sslmode=require", - CodeParamsRoutingFailed, "invalid tenantID", - ) - - // No TenantID. - ac.assertConnectErr( - t, u, "defaultdb?sslmode=require", - CodeParamsRoutingFailed, "malformed database name", - ) - require.Equal(t, int64(2), s.metrics.RoutingErrCount.Count()) -} - -func TestUnexpectedError(t *testing.T) { - defer leaktest.AfterTest(t)() - - // Set up a Server whose FrontendAdmitter function always errors with a - // non-CodeError error. - ctx := context.Background() - s := NewServer(Options{ - FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return nil, nil, errors.New("unexpected error") - }, - }) - const listenAddress = "127.0.0.1:0" - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = s.Serve(ln) - }() - defer func() { - _ = ln.Close() - wg.Wait() - }() - - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", ln.Addr().String()) - - // Time how long it takes for pgx.Connect to return. If the proxy handles - // errors appropriately, pgx.Connect should return near immediately - // because the server should close the connection. If not, it may take up - // to the 5s connect_timeout for pgx.Connect to give up. - start := timeutil.Now() - _, err = pgx.Connect(ctx, u) - require.Error(t, err) - t.Log(err) - elapsed := timeutil.Since(start) - if elapsed >= 5*time.Second { - t.Errorf("pgx.Connect took %s to error out", elapsed) - } -} - -func TestProxyAgainstSecureCRDB(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{OnConnectionSuccess: func() { connSuccess = true }}, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=require", addr) - _, err := pgx.Connect(context.Background(), url) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - url = fmt.Sprintf("postgres://bob@%s?sslmode=require", addr) - _, err = pgx.Connect(context.Background(), url) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - url = fmt.Sprintf("postgres://bob:builder@%s?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(2), s.metrics.AuthFailedCount.Count()) - }() - - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - require.NoError(t, runTestQuery(conn)) -} - -func TestProxyTLSClose(t *testing.T) { - defer leaktest.AfterTest(t)() - // NB: The leaktest call is an important part of this test. We're - // verifying that no goroutines are leaked, despite calling Close an - // underlying TCP connection (rather than the TLSConn that wraps it). - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - var proxyIncomingConn atomic.Value // *Conn - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, conn *Conn) (*BackendConfig, error) { - proxyIncomingConn.Store(conn) - return &BackendConfig{ - OnConnectionSuccess: func() { connSuccess = true }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://bob:builder@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - incomingConn := proxyIncomingConn.Load().(*Conn) - require.NoError(t, incomingConn.Close()) - <-incomingConn.Done() // should immediately proceed - - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func TestProxyModifyRequestParams(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - params := msg.Parameters - require.EqualValues(t, map[string]string{ - "authToken": "abc123", - "user": "bogususer", - }, params) - - // NB: This test will fail unless the user used between the proxy - // and the backend is changed to a user that actually exists. - delete(params, "authToken") - params["user"] = "root" - - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - }, - } - s, proxyAddr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - u := fmt.Sprintf("postgres://bogususer@%s/?sslmode=require&authToken=abc123", proxyAddr) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - require.NoError(t, conn.Close(ctx)) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func newInsecureProxyServer( - t *testing.T, outgoingAddr string, outgoingTLSConfig *tls.Config, customOptions ...func(*Options), -) (server *Server, addr string, cleanup func()) { - op := Options{ - BackendDialer: func(message *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(message, outgoingAddr, outgoingTLSConfig) - }} - for _, opt := range customOptions { - opt(&op) - } - s := NewServer(op) - const listenAddress = "127.0.0.1:0" - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = s.Serve(ln) - }() - return s, ln.Addr().String(), func() { - _ = ln.Close() - wg.Wait() - } -} - -func TestInsecureProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - s, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}, func(op *Options) {} /* custom options */) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable", addr) - _, err := pgx.Connect(context.Background(), u) - require.Error(t, err) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - u = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - - defer func() { - require.NoError(t, conn.Close(ctx)) - require.Equal(t, int64(1), s.metrics.AuthFailedCount.Count()) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func TestInsecureDoubleProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ - ServerArgs: base.TestServerArgs{Insecure: true}, - }) - defer tc.Stopper().Stop(ctx) - - // Test multiple proxies: proxyB -> proxyA -> tc - _, proxyA, cleanupA := newInsecureProxyServer(t, tc.Server(0).ServingSQLAddr(), - nil /* tls config */, func(op *Options) {} /* custom server options */) - defer cleanupA() - _, proxyB, cleanupB := newInsecureProxyServer(t, proxyA, nil /* tls config */, func(op *Options) {} /* custom server options */) - defer cleanupB() - - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable", proxyB) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - }() - require.NoError(t, runTestQuery(conn)) -} - -func TestErroneousFrontend(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - _, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), nil, /* tls config */ - func(op *Options) { - op.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return nil, nil, errors.New(FrontendError) - } - }) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - - _, err := pgx.Connect(ctx, u) - require.Error(t, err) - require.True(t, testutils.IsError(err, FrontendError)) -} - -func TestErroneousBackend(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - _, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), nil, /* tls config */ - func(op *Options) { - op.BackendDialer = func(message *pgproto3.StartupMessage) (net.Conn, error) { - return nil, errors.New(BackendError) - } - }) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - - _, err := pgx.Connect(ctx, u) - require.Error(t, err) - require.True(t, testutils.IsError(err, BackendError)) -} - -func TestProxyRefuseConn(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - ac := makeAssertCtx() - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return nil, NewErrorf(CodeProxyRefusedConnection, "too many attempts") - }, - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - ac.assertConnectErr( - t, fmt.Sprintf("postgres://root:admin@%s/", addr), "defaultdb_29?sslmode=require", - CodeProxyRefusedConnection, "too many attempts", - ) - require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) - require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) -} - -func TestProxyKeepAlive(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - // Don't let connections last more than 100ms. - KeepAliveLoop: func(ctx context.Context) error { - t := timeutil.NewTimer() - t.Reset(100 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - t.Read = true - return errors.New("expired") - } - } - }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.Equal(t, int64(1), s.metrics.ExpiredClientConnCount.Count()) - }() - - require.Eventuallyf( - t, - func() bool { - _, err = conn.Exec(context.Background(), "SELECT 1") - return err != nil && strings.Contains(err.Error(), "expired") - }, - time.Second, 5*time.Millisecond, - "unexpected error received: %v", err, - ) -} - -func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - idleTimeout, _ := time.ParseDuration("0.5s") - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OnConnectionSuccess: func() { connSuccess = true }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - conn, err := BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - if err != nil { - return nil, err - } - return IdleDisconnectOverlay(conn, idleTimeout), nil - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - }() - - var n int - err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) - require.NoError(t, err) - require.EqualValues(t, 1, n) - time.Sleep(idleTimeout * 2) - err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) - require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") -} diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index 2dedd3b3bfbf..fb5779d559b2 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -18,18 +18,24 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/httputil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" ) +// proxyConnHandler defines the signature of the function that handles each +// individual new incoming connection. +type proxyConnHandler func(ctx context.Context, proxyConn *conn) error + // Server is a TCP server that proxies SQL connections to a // configurable backend. It may also run an HTTP server to expose a // health check and prometheus metrics. type Server struct { - opts *Options + Stopper *stop.Stopper + connHandler proxyConnHandler mux *http.ServeMux - metrics *Metrics + metrics *metrics metricsRegistry *metric.Registry promMu syncutil.Mutex @@ -38,17 +44,22 @@ type Server struct { // NewServer constructs a new proxy server and provisions metrics and health // checks as well. -func NewServer(opts Options) *Server { +func NewServer(ctx context.Context, stopper *stop.Stopper, options ProxyOptions) (*Server, error) { + proxyMetrics := makeProxyMetrics() + handler, err := newProxyHandler(ctx, stopper, &proxyMetrics, options) + if err != nil { + return nil, err + } + mux := http.NewServeMux() registry := metric.NewRegistry() - proxyMetrics := MakeProxyMetrics() - - registry.AddMetricStruct(proxyMetrics) + registry.AddMetricStruct(&proxyMetrics) s := &Server{ - opts: &opts, + Stopper: stopper, + connHandler: handler.handle, mux: mux, metrics: &proxyMetrics, metricsRegistry: registry, @@ -60,7 +71,7 @@ func NewServer(opts Options) *Server { mux.HandleFunc("/_status/vars/", s.handleVars) mux.HandleFunc("/_status/healthz/", s.handleHealth) - return s + return s, nil } func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { @@ -134,32 +145,34 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error { // Serve serves a listener according to the Options given in NewServer(). // Incoming client connections are taken through the Postgres handshake and // relayed to the configured backend server. -func (s *Server) Serve(ln net.Listener) error { +func (s *Server) Serve(ctx context.Context, ln net.Listener) error { for { origConn, err := ln.Accept() if err != nil { return err } - conn := &Conn{ + conn := &conn{ Conn: origConn, } - go func() { + err = s.Stopper.RunAsyncTask(ctx, "proxy-con-handler", func(ctx context.Context) { defer func() { _ = conn.Close() }() s.metrics.CurConnCount.Inc(1) defer s.metrics.CurConnCount.Dec(1) - tBegin := timeutil.Now() remoteAddr := conn.RemoteAddr() - log.Infof(context.Background(), "handling client %s", remoteAddr) - err := s.Proxy(conn) - log.Infof(context.Background(), "client %s disconnected after %.2fs: %v", - remoteAddr, timeutil.Since(tBegin).Seconds(), err) - }() + ctxWithTag := logtags.AddTag(ctx, "client", remoteAddr) + if err := s.connHandler(ctxWithTag, conn); err != nil { + log.Infof(ctxWithTag, "connection error: %v", err) + } + }) + if err != nil { + return err + } } } -// Conn is a SQL connection into the proxy. -type Conn struct { +// conn is a SQL connection into the proxy. +type conn struct { net.Conn mu struct { @@ -170,7 +183,7 @@ type Conn struct { } // Done returns a channel that's closed when the connection is closed. -func (c *Conn) Done() <-chan struct{} { +func (c *conn) done() <-chan struct{} { c.mu.Lock() defer c.mu.Unlock() if c.mu.closedCh == nil { @@ -184,8 +197,8 @@ func (c *Conn) Done() <-chan struct{} { // Close closes the connection. // Any blocked Read or Write operations will be unblocked and return errors. -// The connection's Done channel will be closed. -func (c *Conn) Close() error { +// The connection's Done channel will be closed. This overrides net.Conn.Close. +func (c *conn) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.mu.closed { diff --git a/pkg/ccl/sqlproxyccl/server_test.go b/pkg/ccl/sqlproxyccl/server_test.go index f6b411695d2f..c20491b07780 100644 --- a/pkg/ccl/sqlproxyccl/server_test.go +++ b/pkg/ccl/sqlproxyccl/server_test.go @@ -9,18 +9,25 @@ package sqlproxyccl import ( + "context" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/stretchr/testify/require" ) func TestHandleHealth(t *testing.T) { defer leaktest.AfterTest(t)() - proxyServer := NewServer(Options{}) + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + proxyServer, err := NewServer(ctx, stopper, ProxyOptions{}) + require.NoError(t, err) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/_status/healthz/", nil) @@ -36,7 +43,12 @@ func TestHandleHealth(t *testing.T) { func TestHandleVars(t *testing.T) { defer leaktest.AfterTest(t)() - proxyServer := NewServer(Options{}) + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + proxyServer, err := NewServer(ctx, stopper, ProxyOptions{}) + require.NoError(t, err) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/_status/vars/", nil) diff --git a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel index 509fb42f6adb..243234607fa9 100644 --- a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel @@ -1,5 +1,5 @@ load("@bazel_gomock//:gomock.bzl", "gomock") -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") @@ -47,34 +47,6 @@ go_library( ], ) -go_test( - name = "tenant_test", - srcs = [ - "directory_test.go", - "main_test.go", - ], - embed = [":tenant"], - deps = [ - "//pkg/base", - "//pkg/ccl", - "//pkg/ccl/utilccl", - "//pkg/roachpb", - "//pkg/security", - "//pkg/security/securitytest", - "//pkg/server", - "//pkg/sql", - "//pkg/testutils/serverutils", - "//pkg/testutils/testcluster", - "//pkg/util/leaktest", - "//pkg/util/log", - "//pkg/util/randutil", - "//pkg/util/stop", - "@com_github_stretchr_testify//assert", - "@com_github_stretchr_testify//require", - "@org_golang_google_grpc//:go_default_library", - ], -) - gomock( name = "mocks_tenant", out = "mocks_generated.go", diff --git a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go index d2e4dd7ec624..315f86b675ee 100644 --- a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go @@ -193,7 +193,7 @@ func (s *TestDirectoryServer) EnsureEndpoint( s.proc.Lock() defer s.proc.Unlock() - lst, err := s.listLocked(ctx, &ListEndpointsRequest{req.TenantID}) + lst, err := s.listLocked(ctx, &ListEndpointsRequest{TenantID: req.TenantID}) if err != nil { return nil, err } @@ -225,6 +225,7 @@ func NewTestDirectoryServer(stopper *stop.Stopper) *TestDirectoryServer { dir.TenantStarterFunc = dir.startTenantLocked dir.proc.processByAddrByTenantID = map[uint64]map[net.Addr]*Process{} dir.listen.eventListeners = list.New() + stopper.AddCloser(stop.CloserFn(func() { dir.grpcServer.GracefulStop() })) RegisterDirectoryServer(dir.grpcServer, dir) return dir } @@ -301,6 +302,7 @@ func (s *TestDirectoryServer) startTenantLocked( fmt.Sprintf("--sql-addr=%s", sql.Addr().String()), fmt.Sprintf("--http-addr=%s", http.Addr().String()), fmt.Sprintf("--tenant-id=%d", tenantID), + "--insecure", } if err = sql.Close(); err != nil { return nil, err @@ -331,7 +333,7 @@ func (s *TestDirectoryServer) startTenantLocked( _ = c.Process.Kill() s.deregisterInstance(tenantID, process.SQL) })) - err = process.Stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { + err = s.stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { if err := c.Wait(); err != nil { log.Infof(ctx, "finished %s with err %s", process.Cmd.Args, err) log.Infof(ctx, "output %s", b.Bytes()) diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel new file mode 100644 index 000000000000..475c9d728605 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel @@ -0,0 +1,30 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "tenantdirsvr_test", + srcs = [ + "directory_test.go", + "main_test.go", + ], + deps = [ + "//pkg/base", + "//pkg/ccl", + "//pkg/ccl/kvccl/kvtenantccl", + "//pkg/ccl/sqlproxyccl/tenant", + "//pkg/ccl/utilccl", + "//pkg/roachpb", + "//pkg/security", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql", + "//pkg/testutils/serverutils", + "//pkg/testutils/testcluster", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/randutil", + "//pkg/util/stop", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_test.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go similarity index 90% rename from pkg/ccl/sqlproxyccl/tenant/directory_test.go rename to pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go index 5027281400af..b172e31530e2 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_test.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package tenant +package tenantdirsvr import ( "context" @@ -16,6 +16,8 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -27,6 +29,9 @@ import ( "google.golang.org/grpc" ) +// To ensure tenant startup code is included. +var _ = kvtenantccl.Connector{} + func TestDirectoryErrors(t *testing.T) { defer leaktest.AfterTest(t)() defer log.ScopeWithoutShowLogs(t).Close(t) @@ -89,7 +94,7 @@ func TestEndpointWatcher(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // Resume tenant again by a direct call to the directory server - _, err = tds.EnsureEndpoint(ctx, &EnsureEndpointRequest{tenantID.ToUint64()}) + _, err = tds.EnsureEndpoint(ctx, &tenant.EnsureEndpointRequest{tenantID.ToUint64()}) require.NoError(t, err) // Wait for background watcher to populate the initial endpoint. @@ -188,7 +193,7 @@ func TestResume(t *testing.T) { }(i) } - var processes map[net.Addr]*Process + var processes map[net.Addr]*tenant.Process // Eventually the tenant process will be resumed. require.Eventually(t, func() bool { processes = tds.Get(tenantID) @@ -212,7 +217,7 @@ func TestDeleteTenant(t *testing.T) { // Create the directory. ctx := context.Background() // Disable throttling for this test - tc, dir, tds := newTestDirectory(t, RefreshDelay(-1)) + tc, dir, tds := newTestDirectory(t, tenant.RefreshDelay(-1)) defer tc.Stopper().Stop(ctx) tenantID := roachpb.MakeTenantID(50) @@ -266,7 +271,7 @@ func TestRefreshThrottling(t *testing.T) { // Create the directory, but with extreme rate limiting so that directory // will never refresh. ctx := context.Background() - tc, dir, _ := newTestDirectory(t, RefreshDelay(60*time.Minute)) + tc, dir, _ := newTestDirectory(t, tenant.RefreshDelay(60*time.Minute)) defer tc.Stopper().Stop(ctx) // Create test tenant. @@ -325,10 +330,10 @@ func destroyTenant(tc serverutils.TestClusterInterface, id roachpb.TenantID) err func startTenant( ctx context.Context, srv serverutils.TestServerInterface, id uint64, -) (*Process, error) { +) (*tenant.Process, error) { log.TestingClearServerIdentifiers() - tenantStopper := NewSubStopper(srv.Stopper()) - tenant, err := srv.StartTenant( + tenantStopper := tenant.NewSubStopper(srv.Stopper()) + t, err := srv.StartTenant( ctx, base.TestTenantArgs{ Existing: true, @@ -339,18 +344,22 @@ func startTenant( if err != nil { return nil, err } - sqlAddr, err := net.ResolveTCPAddr("tcp", tenant.SQLAddr()) + sqlAddr, err := net.ResolveTCPAddr("tcp", t.SQLAddr()) if err != nil { return nil, err } - return &Process{SQL: sqlAddr, Stopper: tenantStopper}, nil + return &tenant.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil } // Setup directory that uses a client connected to a test directory server // that manages tenants connected to a backing KV server. func newTestDirectory( - t *testing.T, opts ...DirOption, -) (tc serverutils.TestClusterInterface, directory *Directory, tds *TestDirectoryServer) { + t *testing.T, opts ...tenant.DirOption, +) ( + tc serverutils.TestClusterInterface, + directory *tenant.Directory, + tds *tenant.TestDirectoryServer, +) { tc = serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ // We need to start the cluster insecure in order to not // care about TLS settings for the RPC client connection. @@ -359,8 +368,8 @@ func newTestDirectory( }, }) clusterStopper := tc.Stopper() - tds = NewTestDirectoryServer(clusterStopper) - tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*Process, error) { + tds = tenant.NewTestDirectoryServer(clusterStopper) + tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*tenant.Process, error) { t.Logf("starting tenant %d", tenantID) process, err := startTenant(ctx, tc.Server(0), tenantID) if err != nil { @@ -372,7 +381,6 @@ func newTestDirectory( listenPort, err := net.Listen("tcp", ":0") require.NoError(t, err) - clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, listenPort.Close()) })) go func() { _ = tds.Serve(listenPort) }() // Setup directory @@ -381,8 +389,8 @@ func newTestDirectory( require.NoError(t, err) // nolint:grpcconnclose clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, conn.Close() /* nolint:grpcconnclose */) })) - client := NewDirectoryClient(conn) - directory, err = NewDirectory(context.Background(), clusterStopper, client, opts...) + client := tenant.NewDirectoryClient(conn) + directory, err = tenant.NewDirectory(context.Background(), clusterStopper, client, opts...) require.NoError(t, err) return diff --git a/pkg/ccl/sqlproxyccl/tenant/main_test.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go similarity index 98% rename from pkg/ccl/sqlproxyccl/tenant/main_test.go rename to pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go index 3a25676ae8dd..90199cefb604 100644 --- a/pkg/ccl/sqlproxyccl/tenant/main_test.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package tenant +package tenantdirsvr import ( "os" diff --git a/pkg/ccl/sqlproxyccl/admitter/BUILD.bazel b/pkg/ccl/sqlproxyccl/throttler/BUILD.bazel similarity index 69% rename from pkg/ccl/sqlproxyccl/admitter/BUILD.bazel rename to pkg/ccl/sqlproxyccl/throttler/BUILD.bazel index 2840f7d49664..1647e1188fbe 100644 --- a/pkg/ccl/sqlproxyccl/admitter/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/throttler/BUILD.bazel @@ -41,3 +41,22 @@ gomock( library = ":service", package = "admitter", ) + +go_library( + name = "throttler", + srcs = ["local.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler", + visibility = ["//visibility:public"], + deps = [ + "//pkg/util/log", + "//pkg/util/syncutil", + "@com_github_cockroachdb_errors//:errors", + ], +) + +go_test( + name = "throttler_test", + srcs = ["local_test.go"], + embed = [":throttler"], + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/pkg/ccl/sqlproxyccl/admitter/local.go b/pkg/ccl/sqlproxyccl/throttler/local.go similarity index 96% rename from pkg/ccl/sqlproxyccl/admitter/local.go rename to pkg/ccl/sqlproxyccl/throttler/local.go index 24fb50d381b4..0c57b801a3a0 100644 --- a/pkg/ccl/sqlproxyccl/admitter/local.go +++ b/pkg/ccl/sqlproxyccl/throttler/local.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package admitter +package throttler import ( "context" @@ -37,7 +37,7 @@ type limiter struct { index int } -// localService is an admitter Service that manages state purely in local +// localService is an throttler Service that manages state purely in local // memory. Internally, it maintains a map from IP address to rate limiting info // for that address. In order to put a cap on memory usage, the map is capped // at a maximum size, at which point a random IP address will be evicted. @@ -81,7 +81,7 @@ func WithBaseDelay(d time.Duration) LocalOption { } } -// NewLocalService returns an admitter Service that manages state purely in +// NewLocalService returns an throttler Service that manages state purely in // local memory. func NewLocalService(opts ...LocalOption) Service { s := &localService{ diff --git a/pkg/ccl/sqlproxyccl/admitter/local_test.go b/pkg/ccl/sqlproxyccl/throttler/local_test.go similarity index 99% rename from pkg/ccl/sqlproxyccl/admitter/local_test.go rename to pkg/ccl/sqlproxyccl/throttler/local_test.go index 3caeccd414b9..22bde2acace0 100644 --- a/pkg/ccl/sqlproxyccl/admitter/local_test.go +++ b/pkg/ccl/sqlproxyccl/throttler/local_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package admitter +package throttler import ( "fmt" diff --git a/pkg/ccl/sqlproxyccl/admitter/mocks_generated.go b/pkg/ccl/sqlproxyccl/throttler/mocks_generated.go similarity index 94% rename from pkg/ccl/sqlproxyccl/admitter/mocks_generated.go rename to pkg/ccl/sqlproxyccl/throttler/mocks_generated.go index 65bcc3cff0a8..5499db346db9 100644 --- a/pkg/ccl/sqlproxyccl/admitter/mocks_generated.go +++ b/pkg/ccl/sqlproxyccl/throttler/mocks_generated.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: service.go -// Package admitter is a generated GoMock package. -package admitter +// Package throttler is a generated GoMock package. +package throttler import ( reflect "reflect" diff --git a/pkg/ccl/sqlproxyccl/admitter/service.go b/pkg/ccl/sqlproxyccl/throttler/service.go similarity index 70% rename from pkg/ccl/sqlproxyccl/admitter/service.go rename to pkg/ccl/sqlproxyccl/throttler/service.go index 766ecca28f27..bac2218ac7e9 100644 --- a/pkg/ccl/sqlproxyccl/admitter/service.go +++ b/pkg/ccl/sqlproxyccl/throttler/service.go @@ -6,14 +6,14 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -// Package admitter provides admission checks functionality. Rate limiting currently. -package admitter +// Package throttler provides admission checks functionality. Rate limiting currently. +package throttler import "time" -//go:generate mockgen -package=admitter -destination=mocks_generated.go -source=service.go . Service +//go:generate mockgen -package=throttler -destination=mocks_generated.go -source=service.go . Service -// Service provides the interface for performing admission checks before +// Service provides the interface for performing throttle checks before // allowing requests into the managed service system. type Service interface { // LoginCheck determines whether a login request should be allowed to diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 20ff9ed85026..e770f7a64a4a 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -38,7 +38,9 @@ go_library( "log_flags.go", "mt.go", "mt_cert.go", + "mt_proxy.go", "mt_start_sql.go", + "mt_test_directory.go", "node.go", "nodelocal.go", "quit.go", @@ -91,6 +93,8 @@ go_library( deps = [ "//pkg/base", "//pkg/build", + "//pkg/ccl/sqlproxyccl", + "//pkg/ccl/sqlproxyccl/tenant", "//pkg/cli/cliflags", "//pkg/cli/exit", "//pkg/cli/syncbench", @@ -219,49 +223,9 @@ go_library( "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", "@org_golang_x_sync//errgroup", + "@org_golang_x_sys//unix", "@org_golang_x_time//rate", - ] + select({ - "@io_bazel_rules_go//go/platform:aix": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:android": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:darwin": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:dragonfly": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:freebsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:illumos": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:ios": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:js": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:linux": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:netbsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:openbsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:plan9": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:solaris": [ - "@org_golang_x_sys//unix", - ], - "//conditions:default": [], - }), + ], ) go_test( diff --git a/pkg/cli/cliflags/flags_mt.go b/pkg/cli/cliflags/flags_mt.go index b3bcc9845d54..c8dada8a8578 100644 --- a/pkg/cli/cliflags/flags_mt.go +++ b/pkg/cli/cliflags/flags_mt.go @@ -23,4 +23,74 @@ var ( EnvVar: "COCKROACH_KV_ADDRS", Description: `A comma-separated list of KV endpoints (load balancers allowed).`, } + + DenyList = FlagInfo{ + Name: "denylist-file", + Description: "Denylist file to limit access to IP addresses and tenant ids.", + } + + ProxyListenAddr = FlagInfo{ + Name: "listen-addr", + Description: "Listen address for incoming connections.", + } + + ListenCert = FlagInfo{ + Name: "listen-cert", + Description: "File containing PEM-encoded x509 certificate for listen address.", + } + + ListenKey = FlagInfo{ + Name: "listen-key", + Description: "File containing PEM-encoded x509 key for listen address.", + } + + ListenMetrics = FlagInfo{ + Name: "listen-metrics", + Description: "Listen address for incoming connections for metrics retrieval.", + } + + RoutingRule = FlagInfo{ + Name: "routing-rule", + Description: "Routing rule for incoming connections. Use '{{clusterName}}' for substitution.", + } + + DirectoryAddr = FlagInfo{ + Name: "directory", + Description: "Directory address of the service doing resolution from backend id to IP.", + } + + SkipVerify = FlagInfo{ + Name: "skip-verify", + Description: "If true, skip identity verification of backend. For testing only.", + } + + InsecureBackend = FlagInfo{ + Name: "insecure", + Description: "If true, use insecure connection to the backend.", + } + + RatelimitBaseDelay = FlagInfo{ + Name: "ratelimit-base-delay", + Description: "Initial backoff after a failed login attempt. Set to 0 to disable rate limiting.", + } + + ValidateAccessInterval = FlagInfo{ + Name: "validate-access-interval", + Description: "Time interval between validation that current connections are still valid.", + } + + PollConfigInterval = FlagInfo{ + Name: "poll-config-interval", + Description: "Polling interval changes in config file.", + } + + IdleTimeout = FlagInfo{ + Name: "idle-timeout", + Description: "Close connections idle for this duration.", + } + + TestDirectoryListenPort = FlagInfo{ + Name: "port", + Description: "Test directory server binds and listens on this port.", + } ) diff --git a/pkg/cli/context.go b/pkg/cli/context.go index 33983521d495..aa8f622dbee0 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -18,6 +18,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" "github.com/cockroachdb/cockroach/pkg/config/zonepb" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/settings" @@ -52,6 +53,8 @@ func initCLIDefaults() { setStmtDiagContextDefaults() setAuthContextDefaults() setImportContextDefaults() + setProxyContextDefaults() + setTestDirectorySvrContextDefaults() initPreFlagsDefaults() @@ -615,6 +618,33 @@ func setImportContextDefaults() { importCtx.rowLimit = 0 } +// proxyContext captures the command-line parameters of the `mt start-proxy` command. +var proxyContext sqlproxyccl.ProxyOptions + +func setProxyContextDefaults() { + proxyContext.Denylist = "" + proxyContext.ListenAddr = "127.0.0.1:46257" + proxyContext.ListenCert = "" + proxyContext.ListenKey = "" + proxyContext.MetricsAddress = "0.0.0.0:8080" + proxyContext.RoutingRule = "" + proxyContext.DirectoryAddr = "" + proxyContext.SkipVerify = false + proxyContext.Insecure = false + proxyContext.RatelimitBaseDelay = 50 * time.Millisecond + proxyContext.ValidateAccessInterval = 30 * time.Second + proxyContext.PollConfigInterval = 30 * time.Second + proxyContext.IdleTimeout = 0 +} + +var testDirectorySvrContext struct { + port int +} + +func setTestDirectorySvrContextDefaults() { + testDirectorySvrContext.port = 36257 +} + // GetServerCfgStores provides direct public access to the StoreSpecList inside // serverCfg. This is used by CCL code to populate some fields. // diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index c4d5c47b534d..d10e5bfb97c4 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -945,6 +945,28 @@ func init() { boolFlag(f, &serverCfg.ExternalIODirConfig.DisableImplicitCredentials, cliflags.ExternalIODisableImplicitCredentials) } + // Multi-tenancy proxy command flags. + { + f := mtStartSQLProxyCmd.Flags() + stringFlag(f, &proxyContext.Denylist, cliflags.DenyList) + stringFlag(f, &proxyContext.ListenAddr, cliflags.ProxyListenAddr) + stringFlag(f, &proxyContext.ListenCert, cliflags.ListenCert) + stringFlag(f, &proxyContext.ListenKey, cliflags.ListenKey) + stringFlag(f, &proxyContext.MetricsAddress, cliflags.ListenMetrics) + stringFlag(f, &proxyContext.RoutingRule, cliflags.RoutingRule) + stringFlag(f, &proxyContext.DirectoryAddr, cliflags.DirectoryAddr) + boolFlag(f, &proxyContext.SkipVerify, cliflags.SkipVerify) + boolFlag(f, &proxyContext.Insecure, cliflags.InsecureBackend) + durationFlag(f, &proxyContext.RatelimitBaseDelay, cliflags.RatelimitBaseDelay) + durationFlag(f, &proxyContext.ValidateAccessInterval, cliflags.ValidateAccessInterval) + durationFlag(f, &proxyContext.PollConfigInterval, cliflags.PollConfigInterval) + durationFlag(f, &proxyContext.IdleTimeout, cliflags.IdleTimeout) + } + // Multi-tenancy test directory command flags. + { + f := mtTestDirectorySvr.Flags() + intFlag(f, &testDirectorySvrContext.port, cliflags.TestDirectoryListenPort) + } } type tenantIDWrapper struct { diff --git a/pkg/cli/mt.go b/pkg/cli/mt.go index 6144e7a96b0b..2d0cd67c8532 100644 --- a/pkg/cli/mt.go +++ b/pkg/cli/mt.go @@ -12,14 +12,11 @@ package cli import "github.com/spf13/cobra" -// AddMTCommand adds a subcommand to `./cockroach mt`. -func AddMTCommand(cmd *cobra.Command) { - mtCmd.AddCommand(cmd) -} - func init() { cockroachCmd.AddCommand(mtCmd) mtCmd.AddCommand(mtStartSQLCmd) + mtCmd.AddCommand(mtStartSQLProxyCmd) + mtCmd.AddCommand(mtTestDirectorySvr) mtCertsCmd.AddCommand( mtCreateTenantClientCACertCmd, diff --git a/pkg/cli/mt_proxy.go b/pkg/cli/mt_proxy.go new file mode 100644 index 000000000000..4be4b500c72c --- /dev/null +++ b/pkg/cli/mt_proxy.go @@ -0,0 +1,172 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import ( + "context" + "fmt" + "net" + "os" + "os/signal" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/log/severity" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/spf13/cobra" + "golang.org/x/sys/unix" +) + +var mtStartSQLProxyCmd = &cobra.Command{ + Use: "start-proxy", + Short: "start a sql proxy", + Long: `Starts a SQL proxy. + +This proxy accepts incoming connections and relays them to a backend server +determined by the arguments used. +`, + RunE: MaybeDecorateGRPCError(runStartSQLProxy), + Args: cobra.NoArgs, +} + +func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { + // Initialize logging, stopper and context that can be canceled + ctx, stopper, err := initLogging(cmd) + if err != nil { + return err + } + defer stopper.Stop(ctx) + + log.Infof(ctx, "New proxy with opts: %+v", proxyContext) + + proxyLn, err := net.Listen("tcp", proxyContext.ListenAddr) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = proxyLn.Close() })) + + metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = metricsLn.Close() })) + + server, err := sqlproxyccl.NewServer(ctx, stopper, proxyContext) + if err != nil { + return err + } + + errChan := make(chan error, 1) + + if err := stopper.RunAsyncTask(ctx, "serve-http", func(ctx context.Context) { + log.Infof(ctx, "HTTP metrics server listening at %s", metricsLn.Addr()) + if err := server.ServeHTTP(ctx, metricsLn); err != nil { + errChan <- err + } + }); err != nil { + return err + } + + if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) { + log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr()) + if err := server.Serve(ctx, proxyLn); err != nil { + errChan <- err + } + }); err != nil { + return err + } + + return waitForSignals(ctx, stopper, errChan) +} + +func initLogging(cmd *cobra.Command) (ctx context.Context, stopper *stop.Stopper, err error) { + // Remove the default store, which avoids using it to set up logging. + // Instead, we'll default to logging to stderr unless --log-dir is + // specified. This makes sense since the standalone SQL server is + // at the time of writing stateless and may not be provisioned with + // suitable storage. + serverCfg.Stores.Specs = nil + serverCfg.ClusterName = "" + + ctx = context.Background() + stopper, err = setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) + if err != nil { + return + } + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + return ctx, stopper, err +} + +func waitForSignals( + ctx context.Context, stopper *stop.Stopper, errChan chan error, +) (returnErr error) { + // Need to alias the signals if this has to run on non-unix OSes too. + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, unix.SIGINT, unix.SIGTERM) + + // Dump the stacks when QUIT is received. + quitSignalCh := make(chan os.Signal, 1) + signal.Notify(quitSignalCh, unix.SIGQUIT) + go func() { + for { + <-quitSignalCh + log.DumpStacks(context.Background()) + } + }() + + select { + case err := <-errChan: + log.StartAlwaysFlush() + return err + case <-stopper.ShouldQuiesce(): + // Stop has been requested through the stopper's Stop + <-stopper.IsStopped() + // StartAlwaysFlush both flushes and ensures that subsequent log + // writes are flushed too. + log.StartAlwaysFlush() + case sig := <-signalCh: // INT or TERM + log.StartAlwaysFlush() // In case the caller follows up with KILL + log.Ops.Infof(ctx, "received signal '%s'", sig) + if sig == os.Interrupt { + returnErr = errors.New("interrupted") + } + go func() { + log.Infof(ctx, "server stopping") + stopper.Stop(ctx) + }() + case <-log.FatalChan(): + stopper.Stop(ctx) + select {} // Block and wait for logging go routine to shut down the process + } + + for { + select { + case sig := <-signalCh: + switch sig { + case os.Interrupt: // SIGTERM after SIGTERM + log.Ops.Infof(ctx, "received additional signal '%s'; continuing graceful shutdown", sig) + continue + } + + log.Ops.Shoutf(ctx, severity.ERROR, + "received signal '%s' during shutdown, initiating hard shutdown", log.Safe(sig)) + panic("terminate") + case <-stopper.IsStopped(): + const msgDone = "server shutdown completed" + log.Ops.Infof(ctx, msgDone) + fmt.Fprintln(os.Stdout, msgDone) + } + break + } + + return returnErr +} diff --git a/pkg/cli/mt_start_sql.go b/pkg/cli/mt_start_sql.go index d7dd8c83e39b..9bc1e57b6729 100644 --- a/pkg/cli/mt_start_sql.go +++ b/pkg/cli/mt_start_sql.go @@ -136,6 +136,10 @@ func runStartSQL(cmd *cobra.Command, args []string) error { ch := make(chan os.Signal, 1) signal.Notify(ch, drainSignals...) - sig := <-ch - return errors.Newf("received signal %v", sig) + select { + case sig := <-ch: + return errors.Newf("received signal %v", sig) + case <-stopper.ShouldQuiesce(): + return nil + } } diff --git a/pkg/cli/mt_test_directory.go b/pkg/cli/mt_test_directory.go new file mode 100644 index 000000000000..09e5f8541340 --- /dev/null +++ b/pkg/cli/mt_test_directory.go @@ -0,0 +1,53 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import ( + "context" + "fmt" + "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/spf13/cobra" +) + +var mtTestDirectorySvr = &cobra.Command{ + Use: "test-directory", + Short: "Run a test directory service.", + Long: ` +Run a test directory service. +`, + Args: cobra.NoArgs, + RunE: MaybeDecorateGRPCError(runDirectorySvr), +} + +func runDirectorySvr(cmd *cobra.Command, args []string) (returnErr error) { + ctx := context.Background() + serverCfg.Stores.Specs = nil + + stopper, err := setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) + if err != nil { + return err + } + defer stopper.Stop(ctx) + + tds := tenant.NewTestDirectoryServer(stopper) + + listenPort, err := net.Listen( + "tcp", fmt.Sprintf(":%d", testDirectorySvrContext.port), + ) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = listenPort.Close() })) + return tds.Serve(listenPort) +} diff --git a/pkg/cli/start.go b/pkg/cli/start.go index af46bc511aa5..a669d4064e80 100644 --- a/pkg/cli/start.go +++ b/pkg/cli/start.go @@ -118,7 +118,9 @@ var serverCmds = append(StartCmds, mtStartSQLCmd) // customLoggingSetupCmds lists the commands that call setupLogging() // after other types of configuration. -var customLoggingSetupCmds = append(serverCmds, debugCheckLogConfigCmd, demoCmd) +var customLoggingSetupCmds = append( + serverCmds, debugCheckLogConfigCmd, demoCmd, mtStartSQLProxyCmd, mtTestDirectorySvr, +) func initBlockProfile() { // Enable the block profile for a sample of mutex and channel operations. diff --git a/pkg/security/certmgr/BUILD.bazel b/pkg/security/certmgr/BUILD.bazel index a236e350159c..207dfba183b3 100644 --- a/pkg/security/certmgr/BUILD.bazel +++ b/pkg/security/certmgr/BUILD.bazel @@ -15,6 +15,7 @@ go_library( srcs = [ "cert_manager.go", "file_cert.go", + "self_signed_cert.go", ":mocks_certmgr", # keep ], embed = [":certlib"], # keep @@ -25,6 +26,7 @@ go_library( "//pkg/util/log/eventpb", "//pkg/util/syncutil", "//pkg/util/sysutil", + "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_golang_mock//gomock", # keep ], @@ -35,6 +37,7 @@ go_test( srcs = [ "cert_manager_test.go", "file_cert_test.go", + "self_signed_cert_test.go", ], embed = [":certmgr"], deps = [ diff --git a/pkg/security/certmgr/cert.go b/pkg/security/certmgr/cert.go index f39d04fb4ec7..8baa3e394e0a 100644 --- a/pkg/security/certmgr/cert.go +++ b/pkg/security/certmgr/cert.go @@ -10,7 +10,10 @@ package certmgr -import "context" +import ( + "context" + "crypto/tls" +) //go:generate mockgen -package=certmgr -destination=mocks_generated.go -source=cert.go . Cert @@ -43,4 +46,6 @@ type Cert interface { Err() error // ClearErr will clear the last reported err and allow reloads to resume. ClearErr() + // TLSCert will return the last loaded tls.Certificate + TLSCert() *tls.Certificate } diff --git a/pkg/security/certmgr/file_cert_test.go b/pkg/security/certmgr/file_cert_test.go index 2ee58858d4a3..24f4ac66f43c 100644 --- a/pkg/security/certmgr/file_cert_test.go +++ b/pkg/security/certmgr/file_cert_test.go @@ -58,7 +58,6 @@ func TestFileCert_Err(t *testing.T) { require.Nil(t, fc.Err()) fc.Reload(context.Background()) require.NotNil(t, fc.Err()) - require.NotNil(t, fc.Err()) fc.ClearErr() require.Nil(t, fc.Err()) } diff --git a/pkg/security/certmgr/mocks_generated.go b/pkg/security/certmgr/mocks_generated.go index 013d7ef63c24..097b2ba77bca 100644 --- a/pkg/security/certmgr/mocks_generated.go +++ b/pkg/security/certmgr/mocks_generated.go @@ -6,6 +6,7 @@ package certmgr import ( context "context" + tls "crypto/tls" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -71,3 +72,17 @@ func (mr *MockCertMockRecorder) Reload(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reload", reflect.TypeOf((*MockCert)(nil).Reload), ctx) } + +// TLSCert mocks base method. +func (m *MockCert) TLSCert() *tls.Certificate { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TLSCert") + ret0, _ := ret[0].(*tls.Certificate) + return ret0 +} + +// TLSCert indicates an expected call of TLSCert. +func (mr *MockCertMockRecorder) TLSCert() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TLSCert", reflect.TypeOf((*MockCert)(nil).TLSCert)) +} diff --git a/pkg/security/certmgr/self_signed_cert.go b/pkg/security/certmgr/self_signed_cert.go new file mode 100644 index 000000000000..0ba941f8f45e --- /dev/null +++ b/pkg/security/certmgr/self_signed_cert.go @@ -0,0 +1,103 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package certmgr + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" +) + +// Ensure that SelfSignedCert implements Cert. +var _ Cert = (*SelfSignedCert)(nil) + +// SelfSignedCert represents a single, self-signed certificate, generated at runtime. +type SelfSignedCert struct { + syncutil.Mutex + years, months, days int + secs time.Duration + err error + cert *tls.Certificate +} + +// NewSelfSignedCert will generate a new self-signed cert. +// A follow up Reload will regenerate the cert. +func NewSelfSignedCert(years, months, days int, secs time.Duration) *SelfSignedCert { + return &SelfSignedCert{years: years, months: months, days: days, secs: secs} +} + +// Reload will regenerate the self-signed cert. +func (ssc *SelfSignedCert) Reload(ctx context.Context) { + ssc.Lock() + defer ssc.Unlock() + + // There was a previous error that is not yet retrieved. + if ssc.err != nil { + return + } + + // Generate self signed cert for testing. + // Use "openssl s_client -showcerts -starttls postgres -connect {HOSTNAME}:{PORT}" to + // inspect the certificate or save the certificate portion to a file and + // use sslmode=verify-full + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + ssc.err = errors.Wrapf(err, "could not generate key") + return + } + + from := timeutil.Now() + until := from.AddDate(ssc.years, ssc.months, ssc.days).Add(ssc.secs) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: from, + NotAfter: until, + DNSNames: []string{"localhost"}, + IsCA: true, + } + cer, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) + if err != nil { + ssc.err = errors.Wrapf(err, "could not create certificate") + } + + ssc.cert = &tls.Certificate{ + Certificate: [][]byte{cer}, + PrivateKey: priv, + } +} + +// Err will return the last error that occurred during reload or nil if the +// last reload was successful. +func (ssc *SelfSignedCert) Err() error { + ssc.Lock() + defer ssc.Unlock() + return ssc.err +} + +// ClearErr will clear the last err so the follow up Reload can execute. +func (ssc *SelfSignedCert) ClearErr() { + ssc.Lock() + defer ssc.Unlock() + ssc.err = nil +} + +// TLSCert returns the tls certificate if the cert generation was successful. +func (ssc *SelfSignedCert) TLSCert() *tls.Certificate { + return ssc.cert +} diff --git a/pkg/security/certmgr/self_signed_cert_test.go b/pkg/security/certmgr/self_signed_cert_test.go new file mode 100644 index 000000000000..9a75b58a80be --- /dev/null +++ b/pkg/security/certmgr/self_signed_cert_test.go @@ -0,0 +1,44 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package certmgr + +import ( + "context" + "crypto/x509" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSelfSignedCert_Err(t *testing.T) { + ssc := NewSelfSignedCert(-9999, 0, 0, 0) + require.NotNil(t, ssc) + require.Nil(t, ssc.Err()) + ssc.Reload(context.Background()) + require.Regexp(t, "cannot represent time as GeneralizedTime", ssc.Err()) + ssc.ClearErr() + require.Nil(t, ssc.Err()) +} + +func TestSelfSignedCert_TLSCert(t *testing.T) { + ssc := NewSelfSignedCert(1, 6, 3, 5*time.Hour) + require.NotNil(t, ssc) + require.Nil(t, ssc.Err()) + ssc.Reload(context.Background()) + require.Nil(t, ssc.Err()) + require.NotNil(t, ssc.TLSCert()) + require.Len(t, ssc.TLSCert().Certificate, 1) + cert, err := x509.ParseCertificate(ssc.TLSCert().Certificate[0]) + require.NoError(t, err) + expectedUntil := cert.NotBefore.AddDate(1, 6, 3).Add(5 * time.Hour) + require.Equal(t, expectedUntil, cert.NotAfter) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 2ba695f17f0c..aa3a975c25bf 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1189,7 +1189,7 @@ func (s *Server) PreStart(ctx context.Context) error { filtered := s.cfg.FilterGossipBootstrapResolvers(ctx) // Set up the init server. We have to do this relatively early because we - // can't call RegisterInitServer() after `grpc.Serve`, which is called in + // can't call RegisterInitServer() after `grpc.serve`, which is called in // startRPCServer (and for the loopback grpc-gw connection). var initServer *initServer { @@ -1735,7 +1735,7 @@ func (s *Server) PreStart(ctx context.Context) error { } s.oidc = oidc - // Serve UI assets. + // serve UI assets. // // The authentication mux used here is created in "allow anonymous" mode so that the UI // assets are served up whether or not there is a session. If there is a session, the mux @@ -2121,7 +2121,7 @@ func (s *Server) startListenRPCAndSQL( // initialize) before we accept RPC requests. The caller // (Server.Start) will call this at the right moment. startRPCServer = func(ctx context.Context) { - // Serve the gRPC endpoint. + // serve the gRPC endpoint. _ = s.stopper.RunAsyncTask(workersCtx, "serve-grpc", func(context.Context) { netutil.FatalIfUnexpected(s.grpc.Serve(anyL)) }) @@ -2174,7 +2174,7 @@ func (s *Server) startServeUI( return err } - // Serve the plain HTTP (non-TLS) connection over clearL. + // serve the plain HTTP (non-TLS) connection over clearL. // This produces a HTTP redirect to the `https` URL for the path /, // handles the request normally (via s.ServeHTTP) for the path /health, // and produces 404 for anything else. @@ -2195,7 +2195,7 @@ func (s *Server) startServeUI( httpLn = tls.NewListener(tlsL, uiTLSConfig) } - // Serve the HTTP endpoint. This will be the original httpLn + // serve the HTTP endpoint. This will be the original httpLn // listening on --http-addr without TLS if uiTLSConfig was // nil, or overridden above if uiTLSConfig was not nil to come from // the TLS negotiation over the HTTP port. diff --git a/pkg/server/status.go b/pkg/server/status.go index 67553c2ec7c4..96c1b35ea9f2 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -894,7 +894,7 @@ func (s *statusServer) GetFiles( var dir string switch req.Type { - //TODO(ridwanmsharif): Serve logfiles so debug-zip can fetch them + //TODO(ridwanmsharif): serve logfiles so debug-zip can fetch them // intead of reading indididual entries. case serverpb.FileType_HEAP: // Requesting for saved Heap Profiles. dir = s.admin.server.cfg.HeapProfileDirName diff --git a/pkg/testutils/BUILD.bazel b/pkg/testutils/BUILD.bazel index 99e6c098250d..26e956744c69 100644 --- a/pkg/testutils/BUILD.bazel +++ b/pkg/testutils/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "dir.go", "error.go", "files.go", + "hook.go", "keys.go", "net.go", "pprof.go", diff --git a/pkg/testutils/hook.go b/pkg/testutils/hook.go new file mode 100644 index 000000000000..340bbb4de7a0 --- /dev/null +++ b/pkg/testutils/hook.go @@ -0,0 +1,24 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package testutils + +import "reflect" + +// HookGlobal sets `*ptr = val` and returns a closure for restoring `*ptr` to +// its original value. A runtime panic will occur if `val` is not assignable to +// `*ptr`. +func HookGlobal(ptr, val interface{}) func() { + global := reflect.ValueOf(ptr).Elem() + orig := reflect.New(global.Type()).Elem() + orig.Set(global) + global.Set(reflect.ValueOf(val)) + return func() { global.Set(orig) } +} diff --git a/pkg/testutils/lint/passes/fmtsafe/functions.go b/pkg/testutils/lint/passes/fmtsafe/functions.go index 286b1863b528..3f5c5e395e1d 100644 --- a/pkg/testutils/lint/passes/fmtsafe/functions.go +++ b/pkg/testutils/lint/passes/fmtsafe/functions.go @@ -200,7 +200,7 @@ var requireConstFmt = map[string]bool{ "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate.inputErrorf": true, - "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl.NewErrorf": true, + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl.newErrorf": true, } func init() { diff --git a/pkg/util/grpcutil/grpc_log.go b/pkg/util/grpcutil/grpc_log.go index ab5096d37368..10a6cca35cf9 100644 --- a/pkg/util/grpcutil/grpc_log.go +++ b/pkg/util/grpcutil/grpc_log.go @@ -231,7 +231,7 @@ const ( // When a TCP probe simply opens and immediately closes the // connection, gRPC is unhappy that the TLS handshake did not // complete. We don't care. - incomingConnSpamReSrc = `^grpc: Server\.Serve failed to complete security handshake from "[^"]*": EOF` + incomingConnSpamReSrc = `^grpc: Server\.serve failed to complete security handshake from "[^"]*": EOF` ) var outgoingConnSpamRe = regexp.MustCompile(outgoingConnSpamReSrc) diff --git a/pkg/util/log/fluent_client_test.go b/pkg/util/log/fluent_client_test.go index f4feb056f34b..aea1165f6683 100644 --- a/pkg/util/log/fluent_client_test.go +++ b/pkg/util/log/fluent_client_test.go @@ -109,7 +109,7 @@ func servePseudoFluent(t *testing.T) (serverAddr string, cleanup func(), fluentD } t.Logf("got client: %v", conn.RemoteAddr()) - // Serve the connection. + // serve the connection. wg.Add(1) go func() { defer wg.Done() diff --git a/pkg/util/netutil/net.go b/pkg/util/netutil/net.go index ffdfbb09323d..1981c7f1d95c 100644 --- a/pkg/util/netutil/net.go +++ b/pkg/util/netutil/net.go @@ -72,7 +72,7 @@ type Server struct { // // It can serve two different purposes simultaneously: // -// - to serve as actual HTTP server, using the .Serve(net.Listener) method. +// - to serve as actual HTTP server, using the .serve(net.Listener) method. // - to serve as plain TCP server, using the .ServeWith(...) method. // // The latter is used e.g. to accept SQL client connections. @@ -103,7 +103,7 @@ func MakeServer(stopper *stop.Stopper, tlsConfig *tls.Config, handler http.Handl ctx := context.TODO() - // net/http.(*Server).Serve/http2.ConfigureServer are not thread safe with + // net/http.(*Server).serve/http2.ConfigureServer are not thread safe with // respect to net/http.(*Server).TLSConfig, so we call it synchronously here. if err := http2.ConfigureServer(server.Server, nil); err != nil { log.Fatalf(ctx, "%v", err) @@ -129,7 +129,7 @@ func MakeServer(stopper *stop.Stopper, tlsConfig *tls.Config, handler http.Handl func (s *Server) ServeWith( ctx context.Context, stopper *stop.Stopper, l net.Listener, serveConn func(net.Conn), ) error { - // Inspired by net/http.(*Server).Serve + // Inspired by net/http.(*Server).serve var tempDelay time.Duration // how long to sleep on accept failure for { rw, e := l.Accept() @@ -152,7 +152,7 @@ func (s *Server) ServeWith( tempDelay = 0 go func() { defer stopper.Recover(ctx) - s.Server.ConnState(rw, http.StateNew) // before Serve can return + s.Server.ConnState(rw, http.StateNew) // before serve can return serveConn(rw) s.Server.ConnState(rw, http.StateClosed) }()