From ce75cd28b783da49dc17b6e968283027b96f16d7 Mon Sep 17 00:00:00 2001 From: Darin Peshev Date: Thu, 13 May 2021 13:28:58 -0700 Subject: [PATCH] ccl/sqlproxyccl: CC code migration to DB Previsouly the sql proxy code was in the CC repo. This was making the testing of the proxy against a live SQL server hard and was also requiring a frequent cockroach repo bumps in case of changes. This moves all the code from the CC report to the DB repo so now the proxy is part of the cockroach executable. More detailed list of changed: * The old, sample star-proxy code has been retired in favor of the code moving over from the CC repo. * The code that handles individual connections to the backend has been separated into a new ProxyHandler. Added tests for the proxy handler. * BackendConfig has been retired. * Using stop.Stopper to control the shutdown of the proxy. * Added a command under mt that can be used to run the test directory server. * Added proxy options to control idle timeout, rate limits, config options, use of directory server etc. * Added code to monitor and handle os signals (HUP, TERM, INT). * Intergated the cert manager so the certificates can be reloaded on external signal. * Fixed the SQL tenant process so now the idle timeout causes the stopper to quiesce and the process to terminate successfuly. * Set up the logging for the new proxy. * Added a self-signed cert type to the cert manager to be used when testing secure connections witout generating explicit key/cert files. * Moved the HookGlobal code from CC that can be used for temporary hooks during testing. Here is how to test end to end the proxy, SQL tenant and host server, using the test directory: ``` ./cockroach start-single-node --insecure --log="{sinks: {stderr: {filter: info}}}" ./cockroach mt test-directory --port 36257 --log="{sinks: {stderr: {filter: info}}}" ./cockroach mt start-proxy --directory=:36257 --listen-metrics=:8081 --log="{sinks: {stderr: {filter: info}}}" --insecure ./cockroach sql --url="postgresql://sqladmin@127.0.0.1:46257/dim-dog-2.defaultdb" --insecure ``` Release note: None --- pkg/BUILD.bazel | 1 + pkg/ccl/cliccl/BUILD.bazel | 5 - pkg/ccl/cliccl/mtproxy.go | 188 ----- pkg/ccl/sqlproxyccl/BUILD.bazel | 14 +- pkg/ccl/sqlproxyccl/authentication.go | 5 +- pkg/ccl/sqlproxyccl/authentication_test.go | 12 +- pkg/ccl/sqlproxyccl/backend_dialer.go | 2 +- pkg/ccl/sqlproxyccl/error.go | 6 +- pkg/ccl/sqlproxyccl/frontend_admitter.go | 17 +- pkg/ccl/sqlproxyccl/proxy.go | 273 ++----- pkg/ccl/sqlproxyccl/proxy_handler.go | 582 +++++++++++++++ pkg/ccl/sqlproxyccl/proxy_handler_test.go | 680 ++++++++++++++++++ pkg/ccl/sqlproxyccl/proxy_test.go | 671 ----------------- pkg/ccl/sqlproxyccl/server.go | 34 +- pkg/ccl/sqlproxyccl/server_test.go | 20 +- pkg/ccl/sqlproxyccl/tenant/BUILD.bazel | 30 +- .../sqlproxyccl/tenant/test_directory_svr.go | 5 +- pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel | 29 + .../directory_test.go | 37 +- .../{tenant => tenantdirsvr}/main_test.go | 2 +- pkg/cli/BUILD.bazel | 48 +- pkg/cli/mt.go | 7 +- pkg/cli/mt_proxy.go | 207 ++++++ pkg/cli/mt_start_sql.go | 8 +- pkg/cli/mt_test_directory.go | 63 ++ pkg/cli/start.go | 4 +- pkg/security/certmgr/BUILD.bazel | 3 + pkg/security/certmgr/cert.go | 7 +- pkg/security/certmgr/file_cert_test.go | 1 - pkg/security/certmgr/mocks_generated.go | 15 + pkg/security/certmgr/self_signed_cert.go | 103 +++ pkg/security/certmgr/self_signed_cert_test.go | 44 ++ pkg/testutils/BUILD.bazel | 1 + pkg/testutils/hook.go | 24 + 34 files changed, 1937 insertions(+), 1211 deletions(-) delete mode 100644 pkg/ccl/cliccl/mtproxy.go create mode 100644 pkg/ccl/sqlproxyccl/proxy_handler.go create mode 100644 pkg/ccl/sqlproxyccl/proxy_handler_test.go delete mode 100644 pkg/ccl/sqlproxyccl/proxy_test.go create mode 100644 pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel rename pkg/ccl/sqlproxyccl/{tenant => tenantdirsvr}/directory_test.go (91%) rename pkg/ccl/sqlproxyccl/{tenant => tenantdirsvr}/main_test.go (98%) create mode 100644 pkg/cli/mt_proxy.go create mode 100644 pkg/cli/mt_test_directory.go create mode 100644 pkg/security/certmgr/self_signed_cert.go create mode 100644 pkg/security/certmgr/self_signed_cert_test.go create mode 100644 pkg/testutils/hook.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index ddc8deac4be6..89460e70ad44 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -27,6 +27,7 @@ ALL_TESTS = [ "//pkg/ccl/sqlproxyccl/cache:cache_test", "//pkg/ccl/sqlproxyccl/denylist:denylist_test", "//pkg/ccl/sqlproxyccl/tenant:tenant_test", + "//pkg/ccl/sqlproxyccl/tenantdirsvr:tenantdirsvr_test", "//pkg/ccl/sqlproxyccl:sqlproxyccl_test", "//pkg/ccl/storageccl/engineccl:engineccl_test", "//pkg/ccl/storageccl:storageccl_test", diff --git a/pkg/ccl/cliccl/BUILD.bazel b/pkg/ccl/cliccl/BUILD.bazel index ac7b4850565f..840b02e1b104 100644 --- a/pkg/ccl/cliccl/BUILD.bazel +++ b/pkg/ccl/cliccl/BUILD.bazel @@ -7,7 +7,6 @@ go_library( "debug.go", "debug_backup.go", "demo.go", - "mtproxy.go", "start.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/cliccl", @@ -19,7 +18,6 @@ go_library( "//pkg/ccl/backupccl", "//pkg/ccl/baseccl", "//pkg/ccl/cliccl/cliflagsccl", - "//pkg/ccl/sqlproxyccl", "//pkg/ccl/storageccl", "//pkg/ccl/storageccl/engineccl/enginepbccl:enginepbccl_go_proto", "//pkg/ccl/workloadccl/cliccl", @@ -53,12 +51,9 @@ go_library( "//pkg/util/timeutil/pgdate", "//pkg/util/uuid", "@com_github_cockroachdb_apd_v2//:apd", - "@com_github_cockroachdb_cmux//:cmux", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//oserror", - "@com_github_jackc_pgproto3_v2//:pgproto3", "@com_github_spf13_cobra//:cobra", - "@org_golang_x_sync//errgroup", ], ) diff --git a/pkg/ccl/cliccl/mtproxy.go b/pkg/ccl/cliccl/mtproxy.go deleted file mode 100644 index e297f3188035..000000000000 --- a/pkg/ccl/cliccl/mtproxy.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2020 The Cockroach Authors. -// -// Licensed as a CockroachDB Enterprise file under the Cockroach Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt - -package cliccl - -import ( - "context" - "crypto/tls" - "io/ioutil" - "net" - "strings" - - "github.com/cockroachdb/cmux" - "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" - "github.com/cockroachdb/cockroach/pkg/cli" - "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" - "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" -) - -var sqlProxyListenAddr, sqlProxyTargetAddr string -var sqlProxyListenCert, sqlProxyListenKey string - -func init() { - startSQLProxyCmd := &cobra.Command{ - Use: "start-proxy ", - Short: "start-proxy host:port", - Long: `Starts a SQL proxy for testing purposes. - -This proxy provides very limited functionality. It accepts incoming connections -and relays them to the specified backend server after verifying that at least -one of the following holds: - -1. the supplied database name is prefixed with 'prancing-pony.'; this prefix - will then be removed for the connection to the backend server, and/or -2. the options parameter is 'prancing-pony'. - -Connections to the target address use TLS but do not identify the identity of -the peer, making them susceptible to MITM attacks. -`, - RunE: cli.MaybeDecorateGRPCError(runStartSQLProxy), - Args: cobra.NoArgs, - } - f := startSQLProxyCmd.Flags() - f.StringVarP(&sqlProxyListenCert, "listen-cert", "", "", "Certificate file to use for listener (auto-generate if empty)") - f.StringVarP(&sqlProxyListenKey, "listen-key", "", "", "Private key file to use for listener(auto-generate if empty)") - f.StringVarP(&sqlProxyListenAddr, "listen-addr", "", "127.0.0.1:46257", "Address for incoming connections") - f.StringVarP(&sqlProxyTargetAddr, "target-addr", "", "127.0.0.1:26257", "Address for outgoing connections") - cli.AddMTCommand(startSQLProxyCmd) -} - -func runStartSQLProxy(*cobra.Command, []string) error { - // openssl genrsa -out testserver.key 2048 - // openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt -days 3650 - // ^-- Enter * as Common Name below, rest can be empty. - certBytes := []byte(`-----BEGIN CERTIFICATE----- -MIICpDCCAYwCCQDWdkou+YTT/DANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls -b2NhbGhvc3QwHhcNMjAwNTIwMTQxMjIyWhcNMzAwNTE4MTQxMjIyWjAUMRIwEAYD -VQQDDAlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDG -H6V5TZjppRR61azexRtKLOnftVpO0+CPslHynbWwrJ6sxZIdglUWoCT3a/93tq0h -SaMWIxH+29wDXiICCTbr485h3Sov5Rq7kV/AwcLMOpdqjbN2PRBW95aq8rV3h/Ui -K3hu8OZjeh4DzhxsDWYwLG+1aUHnzpwDVIvXqiiKHVtT3WLDHRNUAuph9o/4Fao0 -m1KAzXvfbnNyMUWTAOUCIX2tlq79rEIAKOySCKDr07TuVrzCKcF5sbXkFXlmyFNl -KbmXRuD3UxghxLMmUZar7eZR84x6R/Rj5Dqyrs3nfl+/30Zk0pNe6naaKO39zqlR -rWQIqwSZrY1HwGGeVJFjAgMBAAEwDQYJKoZIhvcNAQELBQADggEBACluo7vP0kXd -uXD3joPKiMJ0FgZqeDtuSvBPfl0okqPN+bk/Huqu+FgxfChCs+2EcreGFxshjzuv -J58ogFq1YMB4pS4GlqarHE+UdliOobD+OyvX40w9lTJ2wI+v7kI79udFE+tyLIs6 -YkuzFd1nB0Zcf8QFzyPRTVXVpsWid3ZvARDakp4z7klPLnkfVrXo/ivlKqGF+Ymy -vJ/riLR01omTVi6W40cml/H4DAtG/XVsQeFXWpjUv97MWGRVYycmpCleVkK+uC2x -XAEi/UMoPhhJd6HEWG+56IkFFoN4lNtPuyal0vzOJCn70pgQx3yKh61RQcPrJMlD -m9qz1xbrzj8= ------END CERTIFICATE----- -`) - keyBytes := []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAxh+leU2Y6aUUetWs3sUbSizp37VaTtPgj7JR8p21sKyerMWS -HYJVFqAk92v/d7atIUmjFiMR/tvcA14iAgk26+POYd0qL+Uau5FfwMHCzDqXao2z -dj0QVveWqvK1d4f1Iit4bvDmY3oeA84cbA1mMCxvtWlB586cA1SL16ooih1bU91i -wx0TVALqYfaP+BWqNJtSgM17325zcjFFkwDlAiF9rZau/axCACjskgig69O07la8 -winBebG15BV5ZshTZSm5l0bg91MYIcSzJlGWq+3mUfOMekf0Y+Q6sq7N535fv99G -ZNKTXup2mijt/c6pUa1kCKsEma2NR8BhnlSRYwIDAQABAoIBADFpiSKUyNNU2aO9 -EO1KaYD5bKbfmxNX4oTUK33/+WWD19stN0Dm1YPcEvwmUkOwKsPHksYdnwpaGSg5 -3O93DtyMJ1ffCfuB/0XSfvgbGxNGdacciiquFhoqi8g82idioC+SeenpaPxcY4n9 -aLdGLDtNidrL0qUWsXBfMLVr+cpgENPMiri31CGLNpfO1b4icdQjiltEn70To2Al -68Ptar60m/lJzf8QMFSf499/W3b7fLjGFK+Gzump94xAVMd7HhACf42ZpWRPe1Ih -lyHP6D0091cIRhGxIZrhLToSuySpf1A+C/rQYTqzEPv/a3b4Ja6poulpBppwJyDa -roC4KtkCgYEA/h0HRzekvNAg2cOV/Ag1RyE4AmdyCDubBHcPNJi9kI/RsA0RgurO -pr2oET0HWTENgE4e4hYQnlqUvTXYisvtvhigiCkcynpGoMJa5Y4St7J1QKdQtuwY -vcRqOGGSKl73biK79+BIV/6swWCkB+VzoGzKP8dY/XZHsI0FDdnia8UCgYEAx5gz -9qfzfiqOQP/GN6vGIzoldGxCDCHyyEZmvy0FiBlMJK36Qkjtz48eqYEXOUCX8Z5X -gB583iMv72Oz/wmefoIjnUd9uXyMqvhnYxG4vQhU4a83K4q4TPkNd7+sLiNqxIq2 -o2jT6BktOHE5OiICeFGMFOfHtsyV78JMsuzUEwcCgYAXU5LXdsQokPJzCwE5oYdC -gEoj7lsJZm9UeZlrupmsK4eUIZ755ZQSulYzPubtyRL0NDehiWT9JFODCu5Vz2KD -kL8rwJpj+9V/7Fdrux78veUFilZedE3RHbaidlJ0kUMlWQroNi5t5XL2TWjBUM7M -azAlqqcAnVr3WfqcyuN+AQKBgQCsz+xV6I7bMy9dudc+hlyUTZj2V3FMHeyeWM5H -QkzizLxvma7vy0MUDd/HdTzNVk74ZVdvV3ZXwvGS/Klw7TwsXrNFTwvdGKiWs2KY -lVR1XwxXJyTGb2IpSw3NG8iRXhroNw3xKCcpcvsDPo0E90NaN4jo5NG3RSWgpINR -+9mW6wKBgCze3gZB6AU0W/Fluql88ANx/+nqZhrsJyfrv87AkgDtt0o1tOj+emKR -Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2 -/ur8gv24YZJvV7OvPhw1SAuYL7MKMsfTW4TEKWTfkZWvm4YfZNmR ------END RSA PRIVATE KEY----- -`) - - if (sqlProxyListenKey == "") != (sqlProxyListenCert == "") { - return errors.New("must specify either both or neither of cert and key") - } - - if sqlProxyListenCert != "" { - var err error - certBytes, err = ioutil.ReadFile(sqlProxyListenCert) - if err != nil { - return err - } - } - if sqlProxyListenKey != "" { - var err error - keyBytes, err = ioutil.ReadFile(sqlProxyListenKey) - if err != nil { - return err - } - } - - cer, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - return err - } - - ln, err := net.Listen("tcp", sqlProxyListenAddr) - if err != nil { - return err - } - defer func() { _ = ln.Close() }() - - // Multiplex the listen address to give easy access to metrics from this - // command. - mux := cmux.New(ln) - httpLn := mux.Match(cmux.HTTP1Fast()) - proxyLn := mux.Match(cmux.Any()) - - outgoingConf := &tls.Config{ - InsecureSkipVerify: true, - } - server := sqlproxyccl.NewServer(sqlproxyccl.Options{ - FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return sqlproxyccl.FrontendAdmit( - incoming, - &tls.Config{ - Certificates: []tls.Certificate{cer}, - }, - ) - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - params := msg.Parameters - const magic = "prancing-pony" - if strings.HasPrefix(params["database"], magic+".") { - params["database"] = params["database"][len(magic)+1:] - } else if params["options"] != "--cluster="+magic { - return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) - } - conn, err := sqlproxyccl.BackendDial(msg, sqlProxyTargetAddr, outgoingConf) - if err != nil { - return nil, err - } - return conn, nil - }, - }) - - group, ctx := errgroup.WithContext(context.Background()) - - group.Go(func() error { - return server.ServeHTTP(ctx, httpLn) - }) - - group.Go(func() error { - return server.Serve(proxyLn) - }) - - group.Go(func() error { - return mux.Serve() - }) - - return group.Wait() -} diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 4c229fbc6687..7f36f9e29b8d 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -11,20 +11,31 @@ go_library( "idle_disconnect_connection.go", "metrics.go", "proxy.go", + "proxy_handler.go", "server.go", ":gen-errorcode-stringer", # keep ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl", visibility = ["//visibility:public"], deps = [ + "//pkg/ccl/sqlproxyccl/admitter", + "//pkg/ccl/sqlproxyccl/cache", + "//pkg/ccl/sqlproxyccl/denylist", + "//pkg/ccl/sqlproxyccl/tenant", + "//pkg/roachpb", + "//pkg/security/certmgr", + "//pkg/util", "//pkg/util/contextutil", "//pkg/util/httputil", "//pkg/util/log", "//pkg/util/metric", + "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", "@com_github_jackc_pgproto3_v2//:pgproto3", + "@org_golang_google_grpc//:go_default_library", ], ) @@ -36,7 +47,7 @@ go_test( "frontend_admitter_test.go", "idle_disconnect_connection_test.go", "main_test.go", - "proxy_test.go", + "proxy_handler_test.go", "server_test.go", ], data = [ @@ -58,6 +69,7 @@ go_test( "//pkg/testutils/testcluster", "//pkg/util/leaktest", "//pkg/util/randutil", + "//pkg/util/stop", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index 4f09a93e5c69..bc9c4f69d0f2 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -14,7 +14,10 @@ import ( "github.com/jackc/pgproto3/v2" ) -func authenticate(clientConn, crdbConn net.Conn) error { +// Authenticate handles the startup of the pgwire protocol to the point where +// the connections is considered authenticated. If that doesn't happen, it +// returns an error. +var Authenticate = func(clientConn, crdbConn net.Conn) error { fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn) diff --git a/pkg/ccl/sqlproxyccl/authentication_test.go b/pkg/ccl/sqlproxyccl/authentication_test.go index 1719f5ec9297..a5839f0ba027 100644 --- a/pkg/ccl/sqlproxyccl/authentication_test.go +++ b/pkg/ccl/sqlproxyccl/authentication_test.go @@ -33,7 +33,7 @@ func TestAuthenticateOK(t *testing.T) { require.Equal(t, beMsg, &pgproto3.ReadyForQuery{}) }() - require.NoError(t, authenticate(srv, cli)) + require.NoError(t, Authenticate(srv, cli)) } func TestAuthenticateClearText(t *testing.T) { @@ -75,7 +75,7 @@ func TestAuthenticateClearText(t *testing.T) { require.Equal(t, beMsg, &pgproto3.ReadyForQuery{}) }() - require.NoError(t, authenticate(srv, cli)) + require.NoError(t, Authenticate(srv, cli)) } func TestAuthenticateError(t *testing.T) { @@ -93,11 +93,11 @@ func TestAuthenticateError(t *testing.T) { require.Equal(t, beMsg, &pgproto3.ErrorResponse{Severity: "FATAL", Code: "foo"}) }() - err := authenticate(srv, cli) + err := Authenticate(srv, cli) require.Error(t, err) codeErr := (*CodeError)(nil) require.True(t, errors.As(err, &codeErr)) - require.Equal(t, CodeAuthFailed, codeErr.code) + require.Equal(t, CodeAuthFailed, codeErr.Code) } func TestAuthenticateUnexpectedMessage(t *testing.T) { @@ -115,9 +115,9 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) { require.Equal(t, beMsg, &pgproto3.BindComplete{}) }() - err := authenticate(srv, cli) + err := Authenticate(srv, cli) require.Error(t, err) codeErr := (*CodeError)(nil) require.True(t, errors.As(err, &codeErr)) - require.Equal(t, CodeBackendDisconnected, codeErr.code) + require.Equal(t, CodeBackendDisconnected, codeErr.Code) } diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index 0e3aff5e4cbb..99485238d2d1 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -19,7 +19,7 @@ import ( // BackendDial is an example backend dialer that does a TCP/IP connection // to a backend, SSL and forwards the start message. -func BackendDial( +var BackendDial = func( msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { conn, err := net.Dial("tcp", outgoingAddress) diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index e6dcb3dd6e23..cc1ef2ad7eba 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -85,18 +85,18 @@ const ( // CodeError is combines an error with one of the above codes to ease // the processing of the errors. type CodeError struct { - code ErrorCode + Code ErrorCode err error } func (e *CodeError) Error() string { - return fmt.Sprintf("%s: %s", e.code, e.err) + return fmt.Sprintf("%s: %s", e.Code, e.err) } // NewErrorf returns a new CodeError out of the supplied args. func NewErrorf(code ErrorCode, format string, args ...interface{}) error { return &CodeError{ - code: code, + Code: code, err: errors.Errorf(format, args...), } } diff --git a/pkg/ccl/sqlproxyccl/frontend_admitter.go b/pkg/ccl/sqlproxyccl/frontend_admitter.go index 2f523f8231fa..c878eae04205 100644 --- a/pkg/ccl/sqlproxyccl/frontend_admitter.go +++ b/pkg/ccl/sqlproxyccl/frontend_admitter.go @@ -18,7 +18,10 @@ import ( // FrontendAdmit is the default implementation of a frontend admitter. It can // upgrade to an optional SSL connection, and will handle and verify // the startup message received from the PG SQL client. -func FrontendAdmit( +// The connection returned should never be nil in case of error. Depending +// on whether the error happened before the connection was upgraded to TLS or not +// it will either be the original or the TLS connection. +var FrontendAdmit = func( conn net.Conn, incomingTLSConfig *tls.Config, ) (net.Conn, *pgproto3.StartupMessage, error) { // `conn` could be replaced by `conn` embedded in a `tls.Conn` connection, @@ -30,7 +33,7 @@ func FrontendAdmit( if incomingTLSConfig != nil { m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage() if err != nil { - return nil, nil, NewErrorf(CodeClientReadFailed, "while receiving startup message") + return conn, nil, NewErrorf(CodeClientReadFailed, "while receiving startup message") } switch m.(type) { case *pgproto3.SSLRequest: @@ -38,15 +41,15 @@ func FrontendAdmit( // Ignore CancelRequest explicitly. We don't need to do this but it makes // testing easier by avoiding a call to sendErrToClient on this path // (which would confuse assertCtx). - return nil, nil, nil + return conn, nil, nil default: code := CodeUnexpectedInsecureStartupMessage - return nil, nil, NewErrorf(code, "unsupported startup message: %T", m) + return conn, nil, NewErrorf(code, "unsupported startup message: %T", m) } _, err = conn.Write([]byte{pgAcceptSSLRequest}) if err != nil { - return nil, nil, NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err) + return conn, nil, NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err) } cfg := incomingTLSConfig.Clone() @@ -60,11 +63,11 @@ func FrontendAdmit( m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage() if err != nil { - return nil, nil, NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err) + return conn, nil, NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err) } msg, ok := m.(*pgproto3.StartupMessage) if !ok { - return nil, nil, NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) + return conn, nil, NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) } // Add the sniServerName (if used) as parameter diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 2ffc4dff17b2..562d9429ec44 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -9,11 +9,8 @@ package sqlproxyccl import ( - "context" - "crypto/tls" "io" "net" - "os" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" @@ -24,76 +21,66 @@ const pgAcceptSSLRequest = 'S' // See https://www.postgresql.org/docs/9.1/protocol-message-formats.html. var pgSSLRequest = []int32{8, 80877103} -// BackendConfig contains the configuration of a backend connection that is -// being proxied. -// To be removed once all clients are migrated to use backend dialer. -type BackendConfig struct { - // The address to which the connection is forwarded. - OutgoingAddress string - // TLS settings to use when connecting to OutgoingAddress. - TLSConf *tls.Config - // Called after successfully connecting to OutgoingAddr. - OnConnectionSuccess func() - // KeepAliveLoop if provided controls the lifetime of the proxy connection. - // It will be run in its own goroutine when the connection is successfully - // opened. Returning from `KeepAliveLoop` will close the proxy connection. - // Note that non-nil error return values will be forwarded to the user and - // hence should explain the reason for terminating the connection. - // Most common use of KeepAliveLoop will be as an infinite loop that - // periodically checks if the connection should still be kept alive. Hence it - // may block indefinitely so it's prudent to use the provided context and - // return on context cancellation. - // See `TestProxyKeepAlive` for example usage. - KeepAliveLoop func(context.Context) error -} - -// Options are the options to the Proxy method. -type Options struct { - // Deprecated: construct FrontendAdmitter, passing this information in case - // that SSL is desired. - IncomingTLSConfig *tls.Config // config used for client -> proxy connection - - // BackendFromParams returns the config to use for the proxy -> backend - // connection. The TLS config is in it and it must have an appropriate - // ServerName for the remote backend. - // Deprecated: processing of the params now happens in the BackendDialer. - // This is only here to support OnSuccess and KeepAlive. - BackendConfigFromParams func( - params map[string]string, incomingConn *Conn, - ) (config *BackendConfig, clientErr error) - - // If set, consulted to modify the parameters set by the frontend before - // forwarding them to the backend during startup. - // Deprecated: include the code that modifies the request params - // in the backend dialer. - ModifyRequestParams func(map[string]string) - - // If set, consulted to decorate an error message to be sent to the client. - // The error passed to this method will contain no internal information. - OnSendErrToClient func(code ErrorCode, msg string) string - - // If set, will be called immediately after a new incoming connection - // is accepted. It can optionally negotiate SSL, provide admittance control or - // other types of frontend connection filtering. - FrontendAdmitter func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) - - // If set, will be used to establish and return connection to the backend. - // If not set, the old logic will be used. - // The argument is the startup message received from the frontend. It - // contains the protocol version and params sent by the client. - BackendDialer func(msg *pgproto3.StartupMessage) (net.Conn, error) +// UpdateMetricsForError updates the metrics relevant for the type of the +// error message. +func UpdateMetricsForError(metrics *Metrics, err error) { + if err == nil { + return + } + codeErr := (*CodeError)(nil) + if errors.As(err, &codeErr) { + switch codeErr.Code { + case CodeExpiredClientConnection: + metrics.ExpiredClientConnCount.Inc(1) + case CodeBackendDisconnected: + metrics.BackendDisconnectCount.Inc(1) + case CodeClientDisconnected: + metrics.ClientDisconnectCount.Inc(1) + case CodeIdleDisconnect: + metrics.IdleDisconnectCount.Inc(1) + case CodeProxyRefusedConnection: + metrics.RefusedConnCount.Inc(1) + metrics.BackendDownCount.Inc(1) + case CodeParamsRoutingFailed: + metrics.RoutingErrCount.Inc(1) + metrics.BackendDownCount.Inc(1) + case CodeBackendDown: + metrics.BackendDownCount.Inc(1) + case CodeAuthFailed: + metrics.AuthFailedCount.Inc(1) + } + } } -// Proxy takes an incoming client connection and relays it to a backend SQL -// server. -func (s *Server) Proxy(proxyConn *Conn) error { - sendErrToClient := func(conn net.Conn, code ErrorCode, msg string) { - if s.opts.OnSendErrToClient != nil { - msg = s.opts.OnSendErrToClient(code, msg) +// SendErrToClient will encode and pass back to the SQL client an error message. +// It can be called by the implementors of ProxyHandler to give more +// information to the end user in case of a problem. +var SendErrToClient = func(conn net.Conn, err error) { + if err == nil || conn == nil { + return + } + codeErr := (*CodeError)(nil) + if errors.As(err, &codeErr) { + var msg string + switch codeErr.Code { + // These are send as is. + case CodeExpiredClientConnection, + CodeBackendDown, + CodeParamsRoutingFailed, + CodeClientDisconnected, + CodeBackendDisconnected, + CodeAuthFailed, + CodeProxyRefusedConnection: + msg = codeErr.Error() + // The rest - the message send back is sanitized. + case CodeIdleDisconnect: + msg = "terminating connection due to idle timeout" + case CodeUnexpectedInsecureStartupMessage: + msg = "server requires encryption" } var pgCode string - if code == CodeIdleDisconnect { + if codeErr.Code == CodeIdleDisconnect { pgCode = "57P01" // admin shutdown } else { pgCode = "08004" // rejected connection @@ -104,135 +91,13 @@ func (s *Server) Proxy(proxyConn *Conn) error { Message: msg, }).Encode(nil)) } +} - frontendAdmitter := s.opts.FrontendAdmitter - if frontendAdmitter == nil { - // Keep this until all clients are switched to provide FrontendAdmitter - // at what point we can also drop IncomingTLSConfig - frontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return FrontendAdmit(incoming, s.opts.IncomingTLSConfig) - } - } - - conn, msg, err := frontendAdmitter(proxyConn) - if err != nil { - var codeErr *CodeError - if ok := errors.As(err, &codeErr); ok && codeErr.code == CodeUnexpectedInsecureStartupMessage { - sendErrToClient( - proxyConn, // Do this on the TCP connection as it means denying SSL - CodeUnexpectedInsecureStartupMessage, - "server requires encryption", - ) - } else if ok { - sendErrToClient(proxyConn, codeErr.code, codeErr.Error()) - } else { - sendErrToClient(proxyConn, CodeClientDisconnected, err.Error()) - } - return err - } - - // This currently only happens for CancelRequest type of startup messages - // that we don't support - if conn == nil { - return nil - - } - defer func() { _ = conn.Close() }() - - backendDialer := s.opts.BackendDialer - var backendConfig *BackendConfig - if s.opts.BackendConfigFromParams != nil { - var clientErr error - backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) - if clientErr != nil { - var codeErr *CodeError - if !errors.As(clientErr, &codeErr) { - codeErr = &CodeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), - } - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - } - if backendDialer == nil { - // This we need to keep until all the clients are switched to provide BackendDialer. - // It constructs a backend dialer from the information provided via - // BackendConfigFromParams function. - backendDialer = func(msg *pgproto3.StartupMessage) (net.Conn, error) { - // We should be able to remove this when the all clients switch to - // backend dialer. - if s.opts.ModifyRequestParams != nil { - s.opts.ModifyRequestParams(msg.Parameters) - } - - crdbConn, err := BackendDial(msg, backendConfig.OutgoingAddress, backendConfig.TLSConf) - if err != nil { - if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) { - sendErrToClient(conn, codeErr.code, codeErr.Error()) - } else { - sendErrToClient(conn, CodeBackendDisconnected, err.Error()) - } - return nil, err - } - - return crdbConn, nil - } - } - - crdbConn, err := backendDialer(msg) - if err != nil { - s.metrics.BackendDownCount.Inc(1) - var codeErr *CodeError - if !errors.As(err, &codeErr) { - codeErr = &CodeError{ - code: CodeBackendDown, - err: errors.Errorf("unable to reach backend SQL server: %v", err), - } - } - if codeErr.code == CodeProxyRefusedConnection { - s.metrics.RefusedConnCount.Inc(1) - } else if codeErr.code == CodeParamsRoutingFailed { - s.metrics.RoutingErrCount.Inc(1) - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - defer func() { _ = crdbConn.Close() }() - - if err := authenticate(conn, crdbConn); err != nil { - s.metrics.AuthFailedCount.Inc(1) - var codeErr *CodeError - if !errors.As(err, &codeErr) { - codeErr = &CodeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("unrecognized auth failure"), - } - } - sendErrToClient(conn, codeErr.code, codeErr.Error()) - return codeErr - } - - s.metrics.SuccessfulConnCount.Inc(1) - - // These channels are buffered because we'll only consume one of them. +// ConnectionCopy does a bi-directional copy between the backend and frontend +// connections. It terminates when one of connections terminate. +func ConnectionCopy(crdbConn, conn net.Conn) error { errOutgoing := make(chan error, 1) errIncoming := make(chan error, 1) - errExpired := make(chan error, 1) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if backendConfig != nil { - if backendConfig.OnConnectionSuccess != nil { - backendConfig.OnConnectionSuccess() - } - if backendConfig.KeepAliveLoop != nil { - go func() { - errExpired <- backendConfig.KeepAliveLoop(ctx) - }() - } - } go func() { _, err := io.Copy(crdbConn, conn) @@ -254,34 +119,18 @@ func (s *Server) Proxy(proxyConn *Conn) error { if err == nil { return nil } else if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) && - codeErr.code == CodeExpiredClientConnection { - s.metrics.ExpiredClientConnCount.Inc(1) - sendErrToClient(conn, codeErr.code, codeErr.Error()) + codeErr.Code == CodeExpiredClientConnection { return codeErr - } else if errors.Is(err, os.ErrDeadlineExceeded) { - s.metrics.IdleDisconnectCount.Inc(1) - sendErrToClient(conn, CodeIdleDisconnect, "terminating connection due to idle timeout") + } else if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { return NewErrorf(CodeIdleDisconnect, "terminating connection due to idle timeout: %v", err) } else { - s.metrics.BackendDisconnectCount.Inc(1) - sendErrToClient(conn, CodeBackendDisconnected, "copying from target server to client") return NewErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err) } case err := <-errOutgoing: // The incoming connection got closed. if err != nil { - s.metrics.ClientDisconnectCount.Inc(1) return NewErrorf(CodeClientDisconnected, "copying from target server to client: %v", err) } return nil - case err := <-errExpired: - if err != nil { - // The client connection expired. - s.metrics.ExpiredClientConnCount.Inc(1) - code := CodeExpiredClientConnection - sendErrToClient(conn, code, err.Error()) - return NewErrorf(code, "expired client conn: %v", err) - } - return nil } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go new file mode 100644 index 000000000000..d8cd52697b1f --- /dev/null +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -0,0 +1,582 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/admitter" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/cache" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/security/certmgr" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" + "github.com/jackc/pgproto3/v2" + "google.golang.org/grpc" +) + +var ( + // This assumes that whitespaces are used to separate command line args. + // Unlike the original spec, this does not handle escaping rules. + // + // See "options" in https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS. + clusterNameLongOptionRE = regexp.MustCompile(`(?:-c\s*|--)cluster=([\S]*)`) + + // ClusterNameRegex restricts cluster names to have between 6 and 20 + // alphanumeric characters, with dashes allowed within the name (but not as a + // starting or ending character). + ClusterNameRegex = regexp.MustCompile("^[a-z0-9][a-z0-9-]{4,18}[a-z0-9]$") +) + +const ( + // Cluster identifier is in the form "clustername-. Tenant id is + // always in the end but the cluster name can also contain '-' or digits. + // For example: + // "foo-7-10" -> cluster name is "foo-7" and tenant id is 10. + clusterTenantSep = "-" + // TODO(spaskob): add ballpark estimate. + maxKnownConnCacheSize = 5e6 // 5 million. +) + +// ProxyOptions is the information needed to construct a new ProxyHandler. +type ProxyOptions struct { + Denylist string + ListenAddr string + ListenCert string + ListenKey string + MetricsAddress string + // SkipVerify if set will skip the identity verification of the + // backend. This is for testing only. + SkipVerify bool + // Insecure if set, will not use TLS for the backend connection. For testing. + Insecure bool + // Routing rule for constructing the backend address for each incoming + // connection. Optionally use '{{clusterName}}' + // which will be substituted with the cluster name. + RoutingRule string + // Directory is an optional {HOSTNAME}:{PORT} service that does the resolution + // from backend id to IP address. If specified - it will be used instead of the + // routing rule above. + DirectoryAddr string + RatelimitBaseDelay time.Duration + ValidateAccessInterval time.Duration + PollConfigInterval time.Duration + // IdleTimeout if set, will close connections that have been idle for that duration. + IdleTimeout time.Duration +} + +// ProxyHandler is the default implementation of a proxy handler. +type ProxyHandler struct { + ProxyOptions + + // stopper is used to do an orderly shutdown. + stopper *stop.Stopper + + // IncomingTLSCert is the managed cert of the proxy endpoint to + // which clients connect. + IncomingCert certmgr.Cert + + // DenyListService provides access control + DenyListService denylist.Service + + // AdmitterService will do throttling of incoming connection requests. + AdmitterService admitter.Service + + // Directory is optional and if set, will be used to resolve + // backend id to IP addresses. + Directory *tenant.Directory + + // CertManger keeps up to date the certificates used. + CertManager *certmgr.CertManager + + //ConnCache is used to keep track of all current connections. + ConnCache cache.ConnCache + + // Called after successfully connecting to OutgoingAddr. + OnConnectionSuccess func() + + // KeepAliveLoop if provided controls the lifetime of the proxy connection. + // It will be run in its own goroutine when the connection is successfully + // opened. Returning from `KeepAliveLoop` will close the proxy connection. + // Note that non-nil error return values will be forwarded to the user and + // hence should explain the reason for terminating the connection. + // Most common use of KeepAliveLoop will be as an infinite loop that + // periodically checks if the connection should still be kept alive. Hence it + // may block indefinitely so it's prudent to use the provided context and + // return on context cancellation. + // See `TestProxyKeepAlive` for example usage. + KeepAliveLoop func(context.Context) error +} + +// NewProxyHandler will create a new proxy handler with configuration based on +// the provided options. +func NewProxyHandler( + ctx context.Context, stopper *stop.Stopper, options ProxyOptions, +) (*ProxyHandler, error) { + handler := ProxyHandler{ + stopper: stopper, + ProxyOptions: options, + CertManager: certmgr.NewCertManager(ctx), + } + + var err error + err = handler.setupIncomingCert() + if err != nil { + return nil, err + } + + handler.DenyListService, err = denylist.NewViperDenyListFromFile(ctx, options.Denylist, options.PollConfigInterval) + if err != nil { + return nil, err + } + + handler.AdmitterService = admitter.NewLocalService(admitter.WithBaseDelay(options.RatelimitBaseDelay)) + handler.ConnCache = cache.NewCappedConnCache(maxKnownConnCacheSize) + + if handler.DirectoryAddr != "" { + conn, err := grpc.Dial(handler.DirectoryAddr, grpc.WithInsecure()) + if err != nil { + return nil, err + } + // nolint:grpcconnclose + stopper.AddCloser(stop.CloserFn(func() { _ = conn.Close() /* nolint:grpcconnclose */ })) + client := tenant.NewDirectoryClient(conn) + handler.Directory, err = tenant.NewDirectory(ctx, stopper, client) + if err != nil { + return nil, err + } + } + + return &handler, nil +} + +// Handle is called by the proxy server to handle a single incoming client +// connection. +func (handler *ProxyHandler) Handle(ctx context.Context, metrics *Metrics, proxyConn *Conn) error { + conn, msg, err := FrontendAdmit(proxyConn, handler.IncomingTLSConfig()) + defer func() { _ = conn.Close() }() + if err != nil { + SendErrToClient(conn, err) + return err + } + + // This currently only happens for CancelRequest type of startup messages + // that we don't support + if msg == nil { + return nil + } + + // Note that the errors returned from this function are user-facing errors so + // we should be careful with the details that we want to expose. + backendStartupMsg, clusterName, tenID, err := ClusterNameAndTenantFromParams(msg) + if err != nil { + clientErr := &CodeError{CodeParamsRoutingFailed, err} + log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", clientErr.Error()) + SendErrToClient(conn, clientErr) + UpdateMetricsForError(metrics, clientErr) + return clientErr + } + // This forwards the remote addr to the backend. + backendStartupMsg.Parameters["crdb:remote_addr"] = conn.RemoteAddr().String() + + ctx = logtags.AddTag(ctx, "cluster", clusterName) + ctx = logtags.AddTag(ctx, "tenant", tenID) + + ipAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + clientErr := &CodeError{CodeParamsRoutingFailed, err} + log.Errorf(ctx, "could not parse address: %v", clientErr.Error()) + SendErrToClient(conn, clientErr) + UpdateMetricsForError(metrics, clientErr) + return clientErr + } + + if err = handler.validateAccessAndAdmitConnection(ctx, tenID, ipAddr); err != nil { + SendErrToClient(conn, err) + UpdateMetricsForError(metrics, err) + return err + } + + var TLSConf *tls.Config + if !handler.Insecure { + TLSConf = &tls.Config{InsecureSkipVerify: handler.SkipVerify} + } + + var crdbConn net.Conn + var outgoingAddress string + for i := 1; i < 10; i++ { + outgoingAddress, err = handler.OutgoingAddress(ctx, clusterName, tenID) + if err != nil { + return err + } + + crdbConn, err = BackendDial(backendStartupMsg, outgoingAddress, TLSConf) + // If we get a backend down error and are using the directory - report the + // error to the directory and retry the connection. + codeErr := (*CodeError)(nil) + if err != nil && + errors.As(err, &codeErr) && + codeErr.Code == CodeBackendDown && + handler.Directory != nil { + errReport := handler.Directory.ReportFailure(ctx, roachpb.MakeTenantID(tenID), outgoingAddress) + if errReport != nil { + return errReport + } + continue + } + break + } + + if err != nil { + UpdateMetricsForError(metrics, err) + SendErrToClient(conn, err) + return err + } + + crdbConn = IdleDisconnectOverlay(crdbConn, handler.IdleTimeout) + + defer func() { _ = crdbConn.Close() }() + + if err := Authenticate(conn, crdbConn); err != nil { + UpdateMetricsForError(metrics, err) + return errors.AssertionFailedf("unrecognized auth failure") + } + + metrics.SuccessfulConnCount.Inc(1) + + handler.ConnCache.Insert( + &cache.ConnKey{IPAddress: ipAddr, TenantID: roachpb.MakeTenantID(tenID)}, + ) + + errConnectionCopy := make(chan error, 1) + errExpired := make(chan error, 1) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + errExpired <- func(ctx context.Context) error { + t := timeutil.NewTimer() + defer t.Stop() + t.Reset(0) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + t.Read = true + if err := handler.validateAccess(ctx, tenID, ipAddr); err != nil { + return err + } + } + t.Reset(util.Jitter(handler.ValidateAccessInterval, 0.15)) + } + }(ctx) + }() + + log.Infof(ctx, "new connection") + connBegin := timeutil.Now() + defer func() { + log.Infof(ctx, "closing after %.2fs", timeutil.Since(connBegin).Seconds()) + }() + + go func() { + err := ConnectionCopy(crdbConn, conn) + errConnectionCopy <- err + }() + + select { + case err := <-errConnectionCopy: + UpdateMetricsForError(metrics, err) + SendErrToClient(conn, err) + return err + case err := <-errExpired: + if err != nil { + // The client connection expired. + codeErr := NewErrorf( + CodeExpiredClientConnection, "expired client conn: %v", err, + ) + UpdateMetricsForError(metrics, codeErr) + SendErrToClient(conn, codeErr) + return codeErr + } + return nil + case <-handler.stopper.ShouldQuiesce(): + return nil + } + +} + +// OutgoingAddress resolves a tenant ID and a tenant cluster name to an IP of +// the backend. +func (handler *ProxyHandler) OutgoingAddress( + ctx context.Context, name string, tenID uint64, +) (string, error) { + if handler.Directory == nil { + addr := strings.ReplaceAll(handler.RoutingRule, "{{clusterName}}", name) + log.Infof(ctx, "backend %s resolved to %s", name, addr) + return addr, nil + } + + // This doesn't verify the name part of the tenant and relies just on the int id. + addr, err := handler.Directory.EnsureTenantIP(ctx, roachpb.MakeTenantID(tenID), "") + if err != nil { + return "", err + } + return addr, nil +} + +func (handler *ProxyHandler) validateAccessAndAdmitConnection( + ctx context.Context, tenID uint64, ipAddr string, +) error { + if err := handler.validateAccess(ctx, tenID, ipAddr); err != nil { + return err + } + + // Admit the connection + connKey := cache.ConnKey{IPAddress: ipAddr, TenantID: roachpb.MakeTenantID(tenID)} + if !handler.ConnCache.Exists(&connKey) { + // Unknown previous successful connections from this IP and tenant. + // Hence we need to rate limit. + if err := handler.AdmitterService.LoginCheck(ipAddr, timeutil.Now()); err != nil { + log.Errorf(ctx, "admitter refused connection: %v", err.Error()) + return NewErrorf(CodeProxyRefusedConnection, "Connection attempt throttled") + } + } + + return nil +} + +func (handler *ProxyHandler) validateAccess( + ctx context.Context, tenID uint64, ipAddr string, +) error { + // First validate against the deny list service + list := handler.DenyListService + if entry, err := list.Denied(fmt.Sprint(tenID)); err != nil { + // Log error but don't return since this could be transient. + log.Errorf(ctx, "could not consult denied list for tenant: %d", tenID) + } else if entry != nil { + log.Errorf(ctx, "access denied for tenant: %d, reason: %s", tenID, entry.Reason) + return NewErrorf(CodeProxyRefusedConnection, "tenant %d %s", tenID, entry.Reason) + } + + if entry, err := list.Denied(ipAddr); err != nil { + // Log error but don't return since this could be transient. + log.Errorf(ctx, "could not consult denied list for IP address: %s", ipAddr) + } else if entry != nil { + log.Errorf(ctx, "access denied for IP address: %s, reason: %s", ipAddr, entry.Reason) + return NewErrorf(CodeProxyRefusedConnection, "IP address %s %s", ipAddr, entry.Reason) + } + + return nil +} + +// ClusterNameAndTenantFromParams extracts the cluster name from the connection +// parameters, and rewrites the database param, if necessary. We currently +// support embedding the cluster name in two ways: +// - Within the database param (e.g. "happy-koala.defaultdb") +// +// - Within the options param (e.g. "... --cluster=happy-koala ..."). +// PostgreSQL supports three different ways to set a run-time parameter +// through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and +// "--NAME=VALUE". +func ClusterNameAndTenantFromParams( + msg *pgproto3.StartupMessage, +) (*pgproto3.StartupMessage, string, uint64, error) { + clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"]) + if err != nil { + return msg, "", 0, err + } + + clusterNameFromOpt, err := parseOptionsParam(msg.Parameters["options"]) + if err != nil { + return msg, "", 0, err + } + + if clusterNameFromDB == "" && clusterNameFromOpt == "" { + return msg, "", 0, errors.New("missing cluster name in connection string") + } + + if clusterNameFromDB != "" && clusterNameFromOpt != "" { + return msg, "", 0, errors.New("multiple cluster names provided") + } + + if clusterNameFromDB == "" { + clusterNameFromDB = clusterNameFromOpt + } + + sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep) + // Cluster name provided without a tenant ID in the end. + if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 { + return msg, "", 0, errors.Errorf("invalid cluster name %s", clusterNameFromDB) + } + clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:] + + if !ClusterNameRegex.MatchString(clusterNameSansTenant) { + return msg, "", 0, errors.Errorf("invalid cluster name '%s'", clusterNameSansTenant) + } + + tenID, err := strconv.ParseUint(tenantIDStr, 10, 64) + if err != nil { + return msg, "", 0, errors.Wrapf(err, "cannot parse %s as uint64", tenantIDStr) + } + + // Make and return a copy of the startup msg so the original is not modified. + paramsOut := map[string]string{} + for key, value := range msg.Parameters { + if key == "database" { + paramsOut[key] = databaseName + } else if key != "options" { + paramsOut[key] = value + } + } + outMsg := &pgproto3.StartupMessage{ + ProtocolVersion: msg.ProtocolVersion, + Parameters: paramsOut, + } + + return outMsg, clusterNameFromDB, tenID, nil +} + +// parseDatabaseParam parses the database parameter from the PG connection +// string, and tries to extract the cluster name if present. The cluster +// name should be embedded in the database parameter using the dot (".") +// delimiter in the form of ".". This approach +// is safe because dots are not allowed in the database names themselves. +func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, err error) { + // Database param is not provided. + if databaseParam == "" { + return "", "", nil + } + + parts := strings.SplitN(databaseParam, ".", 2) + + // Database param provided without cluster name. + if len(parts) <= 1 { + return "", databaseParam, nil + } + + clusterName, databaseName = parts[0], parts[1] + + // Ensure that the param is in the right format if the delimiter is provided. + if len(parts) > 2 || clusterName == "" || databaseName == "" { + return "", "", errors.New("invalid database param") + } + + return clusterName, databaseName, nil +} + +// parseOptionsParam parses the options parameter from the PG connection string, +// and tries to return the cluster name if present. Just like PostgreSQL, the +// sqlproxy supports three different ways to set a run-time parameter through +// its command-line options: +// -c NAME=VALUE (commonly used throughout documentation around PGOPTIONS) +// -cNAME=VALUE +// --NAME=VALUE +// +// CockroachDB currently does not support the options parameter, so the parsing +// logic is built on that assumption. If we do start supporting options in +// CockroachDB itself, then we should revisit this. +// +// Note that this parsing approach is not perfect as it allows a negative case +// like options="-c --cluster=happy-koala -c -c -c" to go through. To properly +// parse this, we need to traverse the string from left to right, and look at +// every single argument, but that involves quite a bit of work, so we'll punt +// for now. +func parseOptionsParam(optionsParam string) (string, error) { + // Only search up to 2 in case of large inputs. + matches := clusterNameLongOptionRE.FindAllStringSubmatch(optionsParam, 2 /* n */) + if len(matches) == 0 { + return "", nil + } + + if len(matches) > 1 { + // Technically we could still allow requests to go through if all + // cluster names match, but we don't want to parse the entire string, so + // we will just error out if at least two cluster flags are provided. + return "", errors.New("multiple cluster flags provided") + } + + // Length of each match should always be 2 with the given regex, one for + // the full string, and the other for the cluster name. + if len(matches[0]) != 2 { + // We don't want to panic here. + return "", errors.New("internal server error") + } + + // Flag was provided, but value is NULL. + if len(matches[0][1]) == 0 { + return "", errors.New("invalid cluster flag") + } + + return matches[0][1], nil +} + +// IncomingTLSConfig gets back the current TLS config for the incoiming client +// connection endpoint. +func (handler *ProxyHandler) IncomingTLSConfig() *tls.Config { + if handler.IncomingCert == nil { + return nil + } + + cert := handler.IncomingCert.TLSCert() + if cert == nil { + return nil + } + + return &tls.Config{Certificates: []tls.Certificate{*cert}} +} + +// setupIncomingCert will setup a manged cert for the incoming connections. +// They can either be unencrypted (in case a cert and key names are empty), +// using self-signed, runtime generated cert (if cert is set to *) or +// using file based cert where the cert/key values refer to file names +// containing the information. +func (handler *ProxyHandler) setupIncomingCert() error { + if (handler.ListenKey == "") != (handler.ListenCert == "") { + return errors.New("must specify either both or neither of cert and key") + } + + if handler.ListenCert == "" { + return nil + } + + // TODO(darin): change the cert manager so it uses the stopper. + ctx, _ := handler.stopper.WithCancelOnQuiesce(context.Background()) + certMgr := certmgr.NewCertManager(ctx) + var cert certmgr.Cert + if handler.ListenCert == "*" { + cert = certmgr.NewSelfSignedCert(0, 3, 0, 0) + } else if handler.ListenCert != "" { + cert = certmgr.NewFileCert(handler.ListenCert, handler.ListenKey) + } + cert.Reload(ctx) + err := cert.Err() + if err != nil { + return err + } + certMgr.ManageCert("client", cert) + handler.CertManager = certMgr + handler.IncomingCert = cert + + return nil +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go new file mode 100644 index 000000000000..ff1b406fc403 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -0,0 +1,680 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/require" +) + +const FrontendError = "Frontend error!" +const BackendError = "Backend error!" + +func HookBackendDial( + f func( + msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + ) (net.Conn, error), +) func() { + return testutils.HookGlobal(&BackendDial, f) +} +func HookFrontendAdmit( + f func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error), +) func() { + return testutils.HookGlobal(&FrontendAdmit, f) +} +func HookSendErrToClient(f func(conn net.Conn, err error)) func() { + return testutils.HookGlobal(&SendErrToClient, f) +} +func HookAuthenticate(f func(clientConn, crdbConn net.Conn) error) func() { + return testutils.HookGlobal(&Authenticate, f) +} + +func newSecureProxyServer( + ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, +) (server *Server, addr string, cleanup func()) { + // Created via: + const _ = ` +openssl genrsa -out testserver.key 2048 +openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ + -days 3650 -config testserver_config.cnf +` + opts.ListenKey = "testserver.key" + opts.ListenCert = "testserver.crt" + + return newProxyServer(ctx, t, stopper, opts) +} + +func newProxyServer( + ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, +) (server *Server, addr string, done func()) { + handler, err := NewProxyHandler(ctx, stopper, *opts) + require.NoError(t, err) + + const listenAddress = "127.0.0.1:0" + // NB: ln closes before wg.Wait or we deadlock. + ln, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + done = func() { + _ = ln.Close() + wg.Wait() + } + + server = NewServer(func(ctx context.Context, metrics *Metrics, proxyConn *Conn) error { + return handler.Handle(ctx, metrics, proxyConn) + }) + + go func() { + defer wg.Done() + _ = server.Serve(ctx, ln) + }() + + return server, ln.Addr().String(), done +} + +func runTestQuery(ctx context.Context, conn *pgx.Conn) error { + var n int + if err := conn.QueryRow(ctx, "SELECT $1::int", 1).Scan(&n); err != nil { + return err + } + if n != 1 { + return errors.Errorf("expected 1 got %d", n) + } + return nil +} + +type assertCtx struct { + emittedCode *ErrorCode +} + +func makeAssertCtx() assertCtx { + var emittedCode ErrorCode = -1 + return assertCtx{ + emittedCode: &emittedCode, + } +} + +func (ac *assertCtx) onSendErrToClient(code ErrorCode) { + *ac.emittedCode = code +} + +func (ac *assertCtx) assertConnectErr( + ctx context.Context, t *testing.T, prefix, suffix string, expCode ErrorCode, expErr string, +) { + t.Helper() + *ac.emittedCode = -1 + t.Run(suffix, func(t *testing.T) { + conn, err := pgx.Connect(ctx, prefix+suffix) + if err == nil { + _ = conn.Close(ctx) + } + require.Contains(t, err.Error(), expErr) + require.Equal(t, expCode, *ac.emittedCode) + + }) +} + +func TestLongDBName(t *testing.T) { + defer leaktest.AfterTest(t)() + + defer HookBackendDial(func(_ *pgproto3.StartupMessage, _ string, _ *tls.Config) (net.Conn, error) { + return nil, NewErrorf(CodeParamsRoutingFailed, "boom") + })() + + ac := makeAssertCtx() + sendErrToClient := SendErrToClient + defer HookSendErrToClient(func(conn net.Conn, err error) { + sendErrToClient(conn, err) + if codeErr, ok := err.(*CodeError); ok { + ac.onSendErrToClient(codeErr.Code) + } + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, addr, done := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + defer done() + + longDB := strings.Repeat("x", 70) // 63 is limit + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=dim-dog-28", addr, longDB) + ac.assertConnectErr(ctx, t, pgurl, "" /* suffix */, CodeParamsRoutingFailed, "boom") + require.Equal(t, int64(1), s.Metrics.RoutingErrCount.Count()) +} + +func TestFailedConnection(t *testing.T) { + defer leaktest.AfterTest(t)() + + // TODO(asubiotto): consider using datadriven for these, especially if the + // proxy becomes more complex. + + var sendErrToClient = SendErrToClient + ac := makeAssertCtx() + defer HookSendErrToClient(func(conn net.Conn, err error) { + sendErrToClient(conn, err) + if codeErr, ok := err.(*CodeError); ok { + ac.onSendErrToClient(codeErr.Code) + } + })() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s, addr, done := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) + defer done() + + _, p, err := net.SplitHostPort(addr) + require.NoError(t, err) + u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) + // Valid connections, but no backend server running. + for _, sslmode := range []string{"require", "prefer"} { + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode="+sslmode, + CodeBackendDown, "unable to reach backend SQL server", + ) + } + + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode=verify-ca&sslrootcert=testserver.crt", + CodeBackendDown, "unable to reach backend SQL server", + ) + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode=verify-full&sslrootcert=testserver.crt", + CodeBackendDown, "unable to reach backend SQL server", + ) + require.Equal(t, int64(4), s.Metrics.BackendDownCount.Count()) + + // Unencrypted connections bounce. + for _, sslmode := range []string{"disable", "allow"} { + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim-dog-28&sslmode="+sslmode, + CodeUnexpectedInsecureStartupMessage, "server requires encryption", + ) + } + + // TenantID rejected as malformed. + ac.assertConnectErr( + ctx, t, u, "?options=--cluster=dim&sslmode=require", + CodeParamsRoutingFailed, "invalid cluster name", + ) + + // No TenantID. + ac.assertConnectErr( + ctx, t, u, "?sslmode=require", + CodeParamsRoutingFailed, "missing cluster name", + ) + require.Equal(t, int64(2), s.Metrics.RoutingErrCount.Count()) +} + +func TestUnexpectedError(t *testing.T) { + defer leaktest.AfterTest(t)() + + // Set up a Server whose FrontendAdmitter function always errors with a + // non-CodeError error. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + defer HookFrontendAdmit( + func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return conn, nil, errors.New("unexpected error") + })() + s := NewServer((&ProxyHandler{}).Handle) + const listenAddress = "127.0.0.1:0" + ln, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = s.Serve(ctx, ln) + }() + defer func() { + _ = ln.Close() + wg.Wait() + }() + + u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", ln.Addr().String()) + + // Time how long it takes for pgx.Connect to return. If the proxy handles + // errors appropriately, pgx.Connect should return near immediately + // because the server should close the connection. If not, it may take up + // to the 5s connect_timeout for pgx.Connect to give up. + start := timeutil.Now() + _, err = pgx.Connect(ctx, u) + require.Error(t, err) + t.Log(err) + elapsed := timeutil.Since(start) + if elapsed >= 5*time.Second { + t.Errorf("pgx.Connect took %s to error out", elapsed) + } +} + +func TestProxyAgainstSecureCRDB(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + var connSuccess bool + authenticate := Authenticate + defer HookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := authenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + + s, addr, done := newSecureProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + defer done() + + url := fmt.Sprintf("postgres://bob:wrong@%s/dim-dog-28.defaultdb?sslmode=require", addr) + _, err := pgx.Connect(ctx, url) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + url = fmt.Sprintf("postgres://bob@%s/dim-dog-28.defaultdb?sslmode=require", addr) + _, err = pgx.Connect(ctx, url) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + url = fmt.Sprintf("postgres://bob:builder@%s/dim-dog-28.defaultdb?sslmode=require", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.Metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(2), s.Metrics.AuthFailedCount.Count()) + }() + + require.Equal(t, int64(1), s.Metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestProxyTLSClose(t *testing.T) { + defer leaktest.AfterTest(t)() + // NB: The leaktest call is an important part of this test. We're + // verifying that no goroutines are leaked, despite calling Close an + // underlying TCP connection (rather than the TLSConn that wraps it). + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + var proxyIncomingConn atomic.Value // *Conn + var connSuccess bool + frontendAdmit := FrontendAdmit + defer HookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + proxyIncomingConn.Store(conn) + return frontendAdmit(conn, incomingTLSConfig) + })() + authenticate := Authenticate + defer HookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := authenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + s, addr, done := newSecureProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + defer done() + + url := fmt.Sprintf("postgres://bob:builder@%s/dim-dog-28.defaultdb?sslmode=require", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + require.Equal(t, int64(1), s.Metrics.CurConnCount.Value()) + defer func() { + incomingConn, ok := proxyIncomingConn.Load().(*Conn) + require.True(t, ok) + require.NoError(t, incomingConn.Close()) + <-incomingConn.Done() // should immediately proceed + + require.True(t, connSuccess) + require.Equal(t, int64(1), s.Metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(0), s.Metrics.AuthFailedCount.Count()) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestProxyModifyRequestParams(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + backendDial := BackendDial + defer HookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + params := msg.Parameters + authToken, ok := params["authToken"] + require.True(t, ok) + require.Equal(t, "abc123", authToken) + user, ok := params["user"] + require.True(t, ok) + require.Equal(t, "bogususer", user) + require.Contains(t, params, "user") + + // NB: This test will fail unless the user used between the proxy + // and the backend is changed to a user that actually exists. + delete(params, "authToken") + params["user"] = "root" + + return backendDial(msg, sql.ServingSQLAddr(), outgoingTLSConfig) + })() + + s, proxyAddr, done := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) + defer done() + + u := fmt.Sprintf("postgres://bogususer@%s/?sslmode=require&authToken=abc123&options=--cluster=dim-dog-28", proxyAddr) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + require.Equal(t, int64(1), s.Metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestInsecureProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, db, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) + + s, addr, cleanup := newProxyServer( + ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, + ) + defer cleanup() + + u := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=dim-dog-28", addr) + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + require.Regexp(t, "ERROR: password authentication failed for user bob", err) + + u = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + + defer func() { + require.NoError(t, conn.Close(ctx)) + require.Equal(t, int64(1), s.Metrics.AuthFailedCount.Count()) + require.Equal(t, int64(1), s.Metrics.SuccessfulConnCount.Count()) + }() + + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestInsecureDoubleProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + // Test multiple proxies: proxyB -> proxyA -> tc + _, proxyA, cleanupA := newProxyServer(ctx, t, sql.Stopper(), + &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), Insecure: true}, + ) + defer cleanupA() + _, proxyB, cleanupB := newProxyServer(ctx, t, sql.Stopper(), + &ProxyOptions{RoutingRule: proxyA, Insecure: true}, + ) + defer cleanupB() + + u := fmt.Sprintf("postgres://root:admin@%s/dim-dog-28.dim-dog-29.defaultdb?sslmode=disable", proxyB) + conn, err := pgx.Connect(ctx, u) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + }() + require.NoError(t, runTestQuery(ctx, conn)) +} + +func TestErroneousFrontend(t *testing.T) { + defer leaktest.AfterTest(t)() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + ctx, cancel := stopper.WithCancelOnQuiesce(context.Background()) + defer cancel() + + defer HookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return conn, nil, errors.New(FrontendError) + })() + + _, addr, cleanup := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + defer cleanup() + + u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + // Generic message here as the Frontend's error is not CodeError and + // by default we don't pass back error's text. The startup message doesn't get + // processed in this case. + require.Regexp(t, "connection reset by peer", err) +} + +func TestErroneousBackend(t *testing.T) { + defer leaktest.AfterTest(t)() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, addr, cleanup := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + defer cleanup() + + defer HookBackendDial( + func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + return nil, errors.New(BackendError) + })() + + u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) + + _, err := pgx.Connect(ctx, u) + require.Error(t, err) + // Generic message here as the Backend's error is not CodeError and + // by default we don't pass back error's text. The startup message has already + // been processed. + require.Regexp(t, "failed to receive message", err) +} + +func TestProxyRefuseConn(t *testing.T) { + defer leaktest.AfterTest(t)() + + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + defer HookBackendDial(func(_ *pgproto3.StartupMessage, _ string, _ *tls.Config) (net.Conn, error) { + return nil, NewErrorf(CodeProxyRefusedConnection, "too many attempts") + })() + + ac := makeAssertCtx() + sendErrToClient := SendErrToClient + defer HookSendErrToClient(func(conn net.Conn, err error) { + sendErrToClient(conn, err) + if codeErr, ok := err.(*CodeError); ok { + ac.onSendErrToClient(codeErr.Code) + } + })() + + s, addr, done := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + defer done() + + ac.assertConnectErr( + ctx, t, fmt.Sprintf("postgres://root:admin@%s/", addr), + "?sslmode=require&options=--cluster=dim-dog-28", + CodeProxyRefusedConnection, "too many attempts", + ) + require.Equal(t, int64(1), s.Metrics.RefusedConnCount.Count()) + require.Equal(t, int64(0), s.Metrics.SuccessfulConnCount.Count()) + require.Equal(t, int64(0), s.Metrics.AuthFailedCount.Count()) +} + +func TestProxyKeepAlive(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + denyList, err := ioutil.TempFile("", "*_denylist.yml") + require.NoError(t, err) + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + backendDial := BackendDial + defer HookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + time.AfterFunc(100*time.Millisecond, func() { + _, err := denyList.WriteString("127.0.0.1: test-denied") + require.NoError(t, err) + }) + return backendDial(msg, sql.ServingSQLAddr(), outgoingTLSConfig) + })() + + s, addr, done := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{Denylist: denyList.Name()}) + defer done() + defer func() { _ = os.Remove(denyList.Name()) }() + + url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(context.Background(), url) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.Equal(t, int64(1), s.Metrics.ExpiredClientConnCount.Count()) + }() + + require.Eventuallyf( + t, + func() bool { + _, err = conn.Exec(context.Background(), "SELECT 1") + return err != nil && strings.Contains(err.Error(), "expired") + }, + time.Second, 5*time.Millisecond, + "unexpected error received: %v", err, + ) +} + +func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + defer sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true)() + + outgoingTLSConfig, err := sql.RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + idleTimeout, _ := time.ParseDuration("0.5s") + var connSuccess bool + frontendAdmit := FrontendAdmit + defer HookFrontendAdmit(func(conn net.Conn, incomingTLSConfig *tls.Config) (net.Conn, *pgproto3.StartupMessage, error) { + return frontendAdmit(conn, incomingTLSConfig) + })() + backendDial := BackendDial + defer HookBackendDial(func(msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config) (net.Conn, error) { + return backendDial(msg, sql.ServingSQLAddr(), outgoingTLSConfig) + })() + authenticate := Authenticate + defer HookAuthenticate(func(clientConn, crdbConn net.Conn) error { + err := authenticate(clientConn, crdbConn) + connSuccess = err == nil + return err + })() + s, addr, done := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{IdleTimeout: idleTimeout}) + defer done() + + url := fmt.Sprintf("postgres://root:admin@%s/?sslmode=require&options=--cluster=dim-dog-28", addr) + conn, err := pgx.Connect(ctx, url) + require.NoError(t, err) + require.Equal(t, int64(1), s.Metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.Metrics.SuccessfulConnCount.Count()) + }() + + var n int + err = conn.QueryRow(ctx, "SELECT $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + time.Sleep(idleTimeout * 2) + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") +} diff --git a/pkg/ccl/sqlproxyccl/proxy_test.go b/pkg/ccl/sqlproxyccl/proxy_test.go deleted file mode 100644 index 1b81ea2cf61b..000000000000 --- a/pkg/ccl/sqlproxyccl/proxy_test.go +++ /dev/null @@ -1,671 +0,0 @@ -// Copyright 2020 The Cockroach Authors. -// -// Licensed as a CockroachDB Enterprise file under the Cockroach Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt - -package sqlproxyccl - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/testutils" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" - "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/require" -) - -const FrontendError = "Frontend error!" -const BackendError = "Backend error!" - -func setupTestProxyWithCerts( - t *testing.T, opts *Options, -) (server *Server, addr string, done func()) { - // Created via: - const _ = ` -openssl genrsa -out testserver.key 2048 -openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ - -days 3650 -config testserver_config.cnf -` - cer, err := tls.LoadX509KeyPair("testserver.crt", "testserver.key") - require.NoError(t, err) - opts.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return FrontendAdmit( - incoming, - &tls.Config{ - Certificates: []tls.Certificate{cer}, - ServerName: "localhost", - }, - ) - } - - const listenAddress = "127.0.0.1:0" - // NB: ln closes before wg.Wait or we deadlock. - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - - var wg sync.WaitGroup - wg.Add(1) - - done = func() { - _ = ln.Close() - wg.Wait() - } - - server = NewServer(*opts) - - go func() { - defer wg.Done() - _ = server.Serve(ln) - }() - - return server, ln.Addr().String(), done -} - -func testingTenantIDFromDatabaseForAddr( - addr string, validTenant string, -) func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return func(msg *pgproto3.StartupMessage) (net.Conn, error) { - const dbKey = "database" - p := msg.Parameters - db, ok := p[dbKey] - if !ok { - return nil, NewErrorf( - CodeParamsRoutingFailed, "need to specify database", - ) - } - sl := strings.SplitN(db, "_", 2) - if len(sl) != 2 { - return nil, NewErrorf( - CodeParamsRoutingFailed, "malformed database name", - ) - } - db, tenantID := sl[0], sl[1] - - if tenantID != validTenant { - return nil, NewErrorf(CodeParamsRoutingFailed, "invalid tenantID") - } - - p[dbKey] = db - return BackendDial(msg, addr, &tls.Config{ - // NB: this would be false in production. - InsecureSkipVerify: true, - }) - } -} - -func runTestQuery(conn *pgx.Conn) error { - var n int - if err := conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n); err != nil { - return err - } - if n != 1 { - return errors.Errorf("expected 1 got %d", n) - } - return nil -} - -type assertCtx struct { - emittedCode *ErrorCode -} - -func makeAssertCtx() assertCtx { - var emittedCode ErrorCode = -1 - return assertCtx{ - emittedCode: &emittedCode, - } -} - -func (ac *assertCtx) onSendErrToClient(code ErrorCode, msg string) string { - *ac.emittedCode = code - return msg -} - -func (ac *assertCtx) assertConnectErr( - t *testing.T, prefix, suffix string, expCode ErrorCode, expErr string, -) { - t.Helper() - *ac.emittedCode = -1 - t.Run(suffix, func(t *testing.T) { - ctx := context.Background() - conn, err := pgx.Connect(ctx, prefix+suffix) - if err == nil { - _ = conn.Close(ctx) - } - require.Contains(t, err.Error(), expErr) - require.Equal(t, expCode, *ac.emittedCode) - - }) -} - -func TestLongDBName(t *testing.T) { - defer leaktest.AfterTest(t)() - - ac := makeAssertCtx() - - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return nil, NewErrorf(CodeParamsRoutingFailed, "boom") - }, - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - longDB := strings.Repeat("x", 70) // 63 is limit - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s", addr, longDB) - ac.assertConnectErr(t, pgurl, "" /* suffix */, CodeParamsRoutingFailed, "boom") - require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) -} - -func TestFailedConnection(t *testing.T) { - defer leaktest.AfterTest(t)() - - // TODO(asubiotto): consider using datadriven for these, especially if the - // proxy becomes more complex. - - ac := makeAssertCtx() - opts := Options{ - BackendDialer: testingTenantIDFromDatabaseForAddr( - "undialable%$!@$", "29", - ), - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - _, p, err := net.SplitHostPort(addr) - require.NoError(t, err) - u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) - // Valid connections, but no backend server running. - for _, sslmode := range []string{"require", "prefer"} { - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode="+sslmode, - CodeBackendDown, "unable to reach backend SQL server", - ) - } - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode=verify-ca&sslrootcert=testserver.crt", - CodeBackendDown, "unable to reach backend SQL server", - ) - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode=verify-full&sslrootcert=testserver.crt", - CodeBackendDown, "unable to reach backend SQL server", - ) - require.Equal(t, int64(4), s.metrics.BackendDownCount.Count()) - - // Unencrypted connections bounce. - for _, sslmode := range []string{"disable", "allow"} { - ac.assertConnectErr( - t, u, "defaultdb_29?sslmode="+sslmode, - CodeUnexpectedInsecureStartupMessage, "server requires encryption", - ) - } - - // TenantID rejected by test hook. - ac.assertConnectErr( - t, u, "defaultdb_28?sslmode=require", - CodeParamsRoutingFailed, "invalid tenantID", - ) - - // No TenantID. - ac.assertConnectErr( - t, u, "defaultdb?sslmode=require", - CodeParamsRoutingFailed, "malformed database name", - ) - require.Equal(t, int64(2), s.metrics.RoutingErrCount.Count()) -} - -func TestUnexpectedError(t *testing.T) { - defer leaktest.AfterTest(t)() - - // Set up a Server whose FrontendAdmitter function always errors with a - // non-CodeError error. - ctx := context.Background() - s := NewServer(Options{ - FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return nil, nil, errors.New("unexpected error") - }, - }) - const listenAddress = "127.0.0.1:0" - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = s.Serve(ln) - }() - defer func() { - _ = ln.Close() - wg.Wait() - }() - - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", ln.Addr().String()) - - // Time how long it takes for pgx.Connect to return. If the proxy handles - // errors appropriately, pgx.Connect should return near immediately - // because the server should close the connection. If not, it may take up - // to the 5s connect_timeout for pgx.Connect to give up. - start := timeutil.Now() - _, err = pgx.Connect(ctx, u) - require.Error(t, err) - t.Log(err) - elapsed := timeutil.Since(start) - if elapsed >= 5*time.Second { - t.Errorf("pgx.Connect took %s to error out", elapsed) - } -} - -func TestProxyAgainstSecureCRDB(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{OnConnectionSuccess: func() { connSuccess = true }}, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=require", addr) - _, err := pgx.Connect(context.Background(), url) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - url = fmt.Sprintf("postgres://bob@%s?sslmode=require", addr) - _, err = pgx.Connect(context.Background(), url) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - url = fmt.Sprintf("postgres://bob:builder@%s?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(2), s.metrics.AuthFailedCount.Count()) - }() - - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - require.NoError(t, runTestQuery(conn)) -} - -func TestProxyTLSClose(t *testing.T) { - defer leaktest.AfterTest(t)() - // NB: The leaktest call is an important part of this test. We're - // verifying that no goroutines are leaked, despite calling Close an - // underlying TCP connection (rather than the TLSConn that wraps it). - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - var proxyIncomingConn atomic.Value // *Conn - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, conn *Conn) (*BackendConfig, error) { - proxyIncomingConn.Store(conn) - return &BackendConfig{ - OnConnectionSuccess: func() { connSuccess = true }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://bob:builder@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - incomingConn := proxyIncomingConn.Load().(*Conn) - require.NoError(t, incomingConn.Close()) - <-incomingConn.Done() // should immediately proceed - - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func TestProxyModifyRequestParams(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - params := msg.Parameters - require.EqualValues(t, map[string]string{ - "authToken": "abc123", - "user": "bogususer", - }, params) - - // NB: This test will fail unless the user used between the proxy - // and the backend is changed to a user that actually exists. - delete(params, "authToken") - params["user"] = "root" - - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - }, - } - s, proxyAddr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - u := fmt.Sprintf("postgres://bogususer@%s/?sslmode=require&authToken=abc123", proxyAddr) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - require.NoError(t, conn.Close(ctx)) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func newInsecureProxyServer( - t *testing.T, outgoingAddr string, outgoingTLSConfig *tls.Config, customOptions ...func(*Options), -) (server *Server, addr string, cleanup func()) { - op := Options{ - BackendDialer: func(message *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(message, outgoingAddr, outgoingTLSConfig) - }} - for _, opt := range customOptions { - opt(&op) - } - s := NewServer(op) - const listenAddress = "127.0.0.1:0" - ln, err := net.Listen("tcp", listenAddress) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = s.Serve(ln) - }() - return s, ln.Addr().String(), func() { - _ = ln.Close() - wg.Wait() - } -} - -func TestInsecureProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - sqlDB := sqlutils.MakeSQLRunner(tc.ServerConn(0)) - sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - - s, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), &tls.Config{InsecureSkipVerify: true}, func(op *Options) {} /* custom options */) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable", addr) - _, err := pgx.Connect(context.Background(), u) - require.Error(t, err) - require.True(t, testutils.IsError(err, "ERROR: password authentication failed for user bob")) - - u = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - - defer func() { - require.NoError(t, conn.Close(ctx)) - require.Equal(t, int64(1), s.metrics.AuthFailedCount.Count()) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - }() - - require.NoError(t, runTestQuery(conn)) -} - -func TestInsecureDoubleProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ - ServerArgs: base.TestServerArgs{Insecure: true}, - }) - defer tc.Stopper().Stop(ctx) - - // Test multiple proxies: proxyB -> proxyA -> tc - _, proxyA, cleanupA := newInsecureProxyServer(t, tc.Server(0).ServingSQLAddr(), - nil /* tls config */, func(op *Options) {} /* custom server options */) - defer cleanupA() - _, proxyB, cleanupB := newInsecureProxyServer(t, proxyA, nil /* tls config */, func(op *Options) {} /* custom server options */) - defer cleanupB() - - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable", proxyB) - conn, err := pgx.Connect(ctx, u) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - }() - require.NoError(t, runTestQuery(conn)) -} - -func TestErroneousFrontend(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - _, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), nil, /* tls config */ - func(op *Options) { - op.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) { - return nil, nil, errors.New(FrontendError) - } - }) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - - _, err := pgx.Connect(ctx, u) - require.Error(t, err) - require.True(t, testutils.IsError(err, FrontendError)) -} - -func TestErroneousBackend(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - _, addr, cleanup := newInsecureProxyServer( - t, tc.Server(0).ServingSQLAddr(), nil, /* tls config */ - func(op *Options) { - op.BackendDialer = func(message *pgproto3.StartupMessage) (net.Conn, error) { - return nil, errors.New(BackendError) - } - }) - defer cleanup() - - u := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable", addr) - - _, err := pgx.Connect(ctx, u) - require.Error(t, err) - require.True(t, testutils.IsError(err, BackendError)) -} - -func TestProxyRefuseConn(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - ac := makeAssertCtx() - opts := Options{ - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return nil, NewErrorf(CodeProxyRefusedConnection, "too many attempts") - }, - OnSendErrToClient: ac.onSendErrToClient, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - ac.assertConnectErr( - t, fmt.Sprintf("postgres://root:admin@%s/", addr), "defaultdb_29?sslmode=require", - CodeProxyRefusedConnection, "too many attempts", - ) - require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) - require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) - require.Equal(t, int64(0), s.metrics.AuthFailedCount.Count()) -} - -func TestProxyKeepAlive(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - // Don't let connections last more than 100ms. - KeepAliveLoop: func(ctx context.Context) error { - t := timeutil.NewTimer() - t.Reset(100 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - t.Read = true - return errors.New("expired") - } - } - }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.Equal(t, int64(1), s.metrics.ExpiredClientConnCount.Count()) - }() - - require.Eventuallyf( - t, - func() bool { - _, err = conn.Exec(context.Background(), "SELECT 1") - return err != nil && strings.Contains(err.Error(), "expired") - }, - time.Second, 5*time.Millisecond, - "unexpected error received: %v", err, - ) -} - -func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) - defer tc.Stopper().Stop(ctx) - - outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - require.NoError(t, err) - outgoingTLSConfig.InsecureSkipVerify = true - - idleTimeout, _ := time.ParseDuration("0.5s") - var connSuccess bool - opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OnConnectionSuccess: func() { connSuccess = true }, - }, nil - }, - BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { - conn, err := BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) - if err != nil { - return nil, err - } - return IdleDisconnectOverlay(conn, idleTimeout), nil - }, - } - s, addr, done := setupTestProxyWithCerts(t, &opts) - defer done() - - url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) - conn, err := pgx.Connect(context.Background(), url) - require.NoError(t, err) - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - defer func() { - require.NoError(t, conn.Close(ctx)) - require.True(t, connSuccess) - require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) - }() - - var n int - err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) - require.NoError(t, err) - require.EqualValues(t, 1, n) - time.Sleep(idleTimeout * 2) - err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) - require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") -} diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index 2dedd3b3bfbf..dc9c181cd0fb 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -19,17 +19,17 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/syncutil" - "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" ) // Server is a TCP server that proxies SQL connections to a // configurable backend. It may also run an HTTP server to expose a // health check and prometheus metrics. type Server struct { - opts *Options + connHandler func(ctx context.Context, metrics *Metrics, proxyConn *Conn) error mux *http.ServeMux - metrics *Metrics + Metrics *Metrics metricsRegistry *metric.Registry promMu syncutil.Mutex @@ -38,7 +38,9 @@ type Server struct { // NewServer constructs a new proxy server and provisions metrics and health // checks as well. -func NewServer(opts Options) *Server { +func NewServer( + connHandler func(ctx context.Context, metrics *Metrics, proxyConn *Conn) error, +) *Server { mux := http.NewServeMux() registry := metric.NewRegistry() @@ -48,9 +50,9 @@ func NewServer(opts Options) *Server { registry.AddMetricStruct(proxyMetrics) s := &Server{ - opts: &opts, + connHandler: connHandler, mux: mux, - metrics: &proxyMetrics, + Metrics: &proxyMetrics, metricsRegistry: registry, prometheusExporter: metric.MakePrometheusExporter(), } @@ -134,7 +136,12 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error { // Serve serves a listener according to the Options given in NewServer(). // Incoming client connections are taken through the Postgres handshake and // relayed to the configured backend server. -func (s *Server) Serve(ln net.Listener) error { +func (s *Server) Serve(ctx context.Context, ln net.Listener) error { + go func() { + <-ctx.Done() + _ = ln.Close() + }() + for { origConn, err := ln.Accept() if err != nil { @@ -146,14 +153,13 @@ func (s *Server) Serve(ln net.Listener) error { go func() { defer func() { _ = conn.Close() }() - s.metrics.CurConnCount.Inc(1) - defer s.metrics.CurConnCount.Dec(1) - tBegin := timeutil.Now() + s.Metrics.CurConnCount.Inc(1) + defer s.Metrics.CurConnCount.Dec(1) remoteAddr := conn.RemoteAddr() - log.Infof(context.Background(), "handling client %s", remoteAddr) - err := s.Proxy(conn) - log.Infof(context.Background(), "client %s disconnected after %.2fs: %v", - remoteAddr, timeutil.Since(tBegin).Seconds(), err) + ctx = logtags.AddTag(ctx, "client", remoteAddr) + if err := s.connHandler(ctx, s.Metrics, conn); err != nil { + log.Infof(ctx, "connection error: %v", err) + } }() } } diff --git a/pkg/ccl/sqlproxyccl/server_test.go b/pkg/ccl/sqlproxyccl/server_test.go index f6b411695d2f..05070032d176 100644 --- a/pkg/ccl/sqlproxyccl/server_test.go +++ b/pkg/ccl/sqlproxyccl/server_test.go @@ -9,18 +9,27 @@ package sqlproxyccl import ( + "context" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/stretchr/testify/require" ) func TestHandleHealth(t *testing.T) { defer leaktest.AfterTest(t)() - proxyServer := NewServer(Options{}) + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + handler, err := NewProxyHandler(context.Background(), stopper, ProxyOptions{}) + require.NoError(t, err) + proxyServer := NewServer(func(ctx context.Context, metrics *Metrics, proxyConn *Conn) error { + return handler.Handle(ctx, metrics, proxyConn) + }) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/_status/healthz/", nil) @@ -36,7 +45,14 @@ func TestHandleHealth(t *testing.T) { func TestHandleVars(t *testing.T) { defer leaktest.AfterTest(t)() - proxyServer := NewServer(Options{}) + stopper := stop.NewStopper() + defer stopper.Stop(context.Background()) + + handler, err := NewProxyHandler(context.Background(), stopper, ProxyOptions{}) + require.NoError(t, err) + proxyServer := NewServer(func(ctx context.Context, metrics *Metrics, proxyConn *Conn) error { + return handler.Handle(ctx, metrics, proxyConn) + }) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/_status/vars/", nil) diff --git a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel index 509fb42f6adb..243234607fa9 100644 --- a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel @@ -1,5 +1,5 @@ load("@bazel_gomock//:gomock.bzl", "gomock") -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") @@ -47,34 +47,6 @@ go_library( ], ) -go_test( - name = "tenant_test", - srcs = [ - "directory_test.go", - "main_test.go", - ], - embed = [":tenant"], - deps = [ - "//pkg/base", - "//pkg/ccl", - "//pkg/ccl/utilccl", - "//pkg/roachpb", - "//pkg/security", - "//pkg/security/securitytest", - "//pkg/server", - "//pkg/sql", - "//pkg/testutils/serverutils", - "//pkg/testutils/testcluster", - "//pkg/util/leaktest", - "//pkg/util/log", - "//pkg/util/randutil", - "//pkg/util/stop", - "@com_github_stretchr_testify//assert", - "@com_github_stretchr_testify//require", - "@org_golang_google_grpc//:go_default_library", - ], -) - gomock( name = "mocks_tenant", out = "mocks_generated.go", diff --git a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go index d2e4dd7ec624..794f16094ee5 100644 --- a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go @@ -193,7 +193,7 @@ func (s *TestDirectoryServer) EnsureEndpoint( s.proc.Lock() defer s.proc.Unlock() - lst, err := s.listLocked(ctx, &ListEndpointsRequest{req.TenantID}) + lst, err := s.listLocked(ctx, &ListEndpointsRequest{TenantID: req.TenantID}) if err != nil { return nil, err } @@ -301,6 +301,7 @@ func (s *TestDirectoryServer) startTenantLocked( fmt.Sprintf("--sql-addr=%s", sql.Addr().String()), fmt.Sprintf("--http-addr=%s", http.Addr().String()), fmt.Sprintf("--tenant-id=%d", tenantID), + "--insecure", } if err = sql.Close(); err != nil { return nil, err @@ -331,7 +332,7 @@ func (s *TestDirectoryServer) startTenantLocked( _ = c.Process.Kill() s.deregisterInstance(tenantID, process.SQL) })) - err = process.Stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { + err = s.stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { if err := c.Wait(); err != nil { log.Infof(ctx, "finished %s with err %s", process.Cmd.Args, err) log.Infof(ctx, "output %s", b.Bytes()) diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel new file mode 100644 index 000000000000..820fa65b9b89 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel @@ -0,0 +1,29 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "tenantdirsvr_test", + srcs = [ + "directory_test.go", + "main_test.go", + ], + deps = [ + "//pkg/base", + "//pkg/ccl", + "//pkg/ccl/sqlproxyccl/tenant", + "//pkg/ccl/utilccl", + "//pkg/roachpb", + "//pkg/security", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql", + "//pkg/testutils/serverutils", + "//pkg/testutils/testcluster", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/randutil", + "//pkg/util/stop", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_test.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go similarity index 91% rename from pkg/ccl/sqlproxyccl/tenant/directory_test.go rename to pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go index 5027281400af..2ffb2877a18a 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_test.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/directory_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package tenant +package tenantdirsvr import ( "context" @@ -16,6 +16,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -89,7 +90,7 @@ func TestEndpointWatcher(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // Resume tenant again by a direct call to the directory server - _, err = tds.EnsureEndpoint(ctx, &EnsureEndpointRequest{tenantID.ToUint64()}) + _, err = tds.EnsureEndpoint(ctx, &tenant.EnsureEndpointRequest{tenantID.ToUint64()}) require.NoError(t, err) // Wait for background watcher to populate the initial endpoint. @@ -188,7 +189,7 @@ func TestResume(t *testing.T) { }(i) } - var processes map[net.Addr]*Process + var processes map[net.Addr]*tenant.Process // Eventually the tenant process will be resumed. require.Eventually(t, func() bool { processes = tds.Get(tenantID) @@ -212,7 +213,7 @@ func TestDeleteTenant(t *testing.T) { // Create the directory. ctx := context.Background() // Disable throttling for this test - tc, dir, tds := newTestDirectory(t, RefreshDelay(-1)) + tc, dir, tds := newTestDirectory(t, tenant.RefreshDelay(-1)) defer tc.Stopper().Stop(ctx) tenantID := roachpb.MakeTenantID(50) @@ -266,7 +267,7 @@ func TestRefreshThrottling(t *testing.T) { // Create the directory, but with extreme rate limiting so that directory // will never refresh. ctx := context.Background() - tc, dir, _ := newTestDirectory(t, RefreshDelay(60*time.Minute)) + tc, dir, _ := newTestDirectory(t, tenant.RefreshDelay(60*time.Minute)) defer tc.Stopper().Stop(ctx) // Create test tenant. @@ -325,10 +326,10 @@ func destroyTenant(tc serverutils.TestClusterInterface, id roachpb.TenantID) err func startTenant( ctx context.Context, srv serverutils.TestServerInterface, id uint64, -) (*Process, error) { +) (*tenant.Process, error) { log.TestingClearServerIdentifiers() - tenantStopper := NewSubStopper(srv.Stopper()) - tenant, err := srv.StartTenant( + tenantStopper := tenant.NewSubStopper(srv.Stopper()) + t, err := srv.StartTenant( ctx, base.TestTenantArgs{ Existing: true, @@ -339,18 +340,22 @@ func startTenant( if err != nil { return nil, err } - sqlAddr, err := net.ResolveTCPAddr("tcp", tenant.SQLAddr()) + sqlAddr, err := net.ResolveTCPAddr("tcp", t.SQLAddr()) if err != nil { return nil, err } - return &Process{SQL: sqlAddr, Stopper: tenantStopper}, nil + return &tenant.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil } // Setup directory that uses a client connected to a test directory server // that manages tenants connected to a backing KV server. func newTestDirectory( - t *testing.T, opts ...DirOption, -) (tc serverutils.TestClusterInterface, directory *Directory, tds *TestDirectoryServer) { + t *testing.T, opts ...tenant.DirOption, +) ( + tc serverutils.TestClusterInterface, + directory *tenant.Directory, + tds *tenant.TestDirectoryServer, +) { tc = serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ // We need to start the cluster insecure in order to not // care about TLS settings for the RPC client connection. @@ -359,8 +364,8 @@ func newTestDirectory( }, }) clusterStopper := tc.Stopper() - tds = NewTestDirectoryServer(clusterStopper) - tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*Process, error) { + tds = tenant.NewTestDirectoryServer(clusterStopper) + tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*tenant.Process, error) { t.Logf("starting tenant %d", tenantID) process, err := startTenant(ctx, tc.Server(0), tenantID) if err != nil { @@ -381,8 +386,8 @@ func newTestDirectory( require.NoError(t, err) // nolint:grpcconnclose clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, conn.Close() /* nolint:grpcconnclose */) })) - client := NewDirectoryClient(conn) - directory, err = NewDirectory(context.Background(), clusterStopper, client, opts...) + client := tenant.NewDirectoryClient(conn) + directory, err = tenant.NewDirectory(context.Background(), clusterStopper, client, opts...) require.NoError(t, err) return diff --git a/pkg/ccl/sqlproxyccl/tenant/main_test.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go similarity index 98% rename from pkg/ccl/sqlproxyccl/tenant/main_test.go rename to pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go index 3a25676ae8dd..90199cefb604 100644 --- a/pkg/ccl/sqlproxyccl/tenant/main_test.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/main_test.go @@ -6,7 +6,7 @@ // // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt -package tenant +package tenantdirsvr import ( "os" diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 4e597f4dd813..02de947c4076 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -37,7 +37,9 @@ go_library( "log_flags.go", "mt.go", "mt_cert.go", + "mt_proxy.go", "mt_start_sql.go", + "mt_test_directory.go", "node.go", "nodelocal.go", "quit.go", @@ -89,6 +91,8 @@ go_library( deps = [ "//pkg/base", "//pkg/build", + "//pkg/ccl/sqlproxyccl", + "//pkg/ccl/sqlproxyccl/tenant", "//pkg/cli/cliflags", "//pkg/cli/exit", "//pkg/cli/syncbench", @@ -216,49 +220,9 @@ go_library( "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", "@org_golang_x_sync//errgroup", + "@org_golang_x_sys//unix", "@org_golang_x_time//rate", - ] + select({ - "@io_bazel_rules_go//go/platform:aix": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:android": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:darwin": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:dragonfly": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:freebsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:illumos": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:ios": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:js": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:linux": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:netbsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:openbsd": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:plan9": [ - "@org_golang_x_sys//unix", - ], - "@io_bazel_rules_go//go/platform:solaris": [ - "@org_golang_x_sys//unix", - ], - "//conditions:default": [], - }), + ], ) go_test( diff --git a/pkg/cli/mt.go b/pkg/cli/mt.go index 6144e7a96b0b..2d0cd67c8532 100644 --- a/pkg/cli/mt.go +++ b/pkg/cli/mt.go @@ -12,14 +12,11 @@ package cli import "github.com/spf13/cobra" -// AddMTCommand adds a subcommand to `./cockroach mt`. -func AddMTCommand(cmd *cobra.Command) { - mtCmd.AddCommand(cmd) -} - func init() { cockroachCmd.AddCommand(mtCmd) mtCmd.AddCommand(mtStartSQLCmd) + mtCmd.AddCommand(mtStartSQLProxyCmd) + mtCmd.AddCommand(mtTestDirectorySvr) mtCertsCmd.AddCommand( mtCreateTenantClientCACertCmd, diff --git a/pkg/cli/mt_proxy.go b/pkg/cli/mt_proxy.go new file mode 100644 index 000000000000..88657b546300 --- /dev/null +++ b/pkg/cli/mt_proxy.go @@ -0,0 +1,207 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import ( + "context" + "fmt" + "net" + "os" + "os/signal" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/log/severity" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/spf13/cobra" + "golang.org/x/sys/unix" +) + +var proxyOpts sqlproxyccl.ProxyOptions + +var mtStartSQLProxyCmd = &cobra.Command{ + Use: "start-proxy ", + Short: "start-proxy host:port", + Long: `Starts a SQL proxy. + +This proxy accepts incoming connections and relays them to a backend server +determined by the arguments used. +`, + RunE: MaybeDecorateGRPCError(runStartSQLProxy), + Args: cobra.NoArgs, +} + +func init() { + f := mtStartSQLProxyCmd.Flags() + f.StringVar(&proxyOpts.Denylist, "denylist-file", "", + "Denylist file to limit access to IP addresses and tenant ids.") + f.StringVar(&proxyOpts.ListenAddr, "listen-addr", "127.0.0.1:46257", + "Listen address for incoming connections.") + f.StringVar(&proxyOpts.ListenCert, "listen-cert", "", + "File containing PEM-encoded x509 certificate for listen address.") + f.StringVar(&proxyOpts.ListenKey, "listen-key", "", + "File containing PEM-encoded x509 key for listen address.") + f.StringVar(&proxyOpts.MetricsAddress, "listen-metrics", "0.0.0.0:8080", + "Listen address for incoming connections.") + f.StringVar(&proxyOpts.RoutingRule, "routing-rule", "", + "Routing rule for incoming connections. Use '{{clusterName}}' for substitution.") + f.StringVar(&proxyOpts.DirectoryAddr, "directory", "", + "Directory address for resolving from backend id to IP.") + f.BoolVar(&proxyOpts.SkipVerify, "skip-verify", false, + "If true, skip identity verification of backend. For testing only.") + f.BoolVar(&proxyOpts.Insecure, "insecure", false, + "If true, use insecure connection to the backend.") + f.DurationVar(&proxyOpts.RatelimitBaseDelay, "ratelimit-base-delay", 50*time.Millisecond, + "Initial backoff after a failed login attempt. Set to 0 to disable rate limiting.") + f.DurationVar(&proxyOpts.ValidateAccessInterval, "validate-access-interval", 30*time.Second, + "Time interval between validation that current connections are still valid.") + f.DurationVar(&proxyOpts.PollConfigInterval, "poll-config-interval", 30*time.Second, + "Polling interval changes in config file.") + f.DurationVar(&proxyOpts.IdleTimeout, "idle-timeout", 0, + "Close connections idle for this duration.") +} + +func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { + // Initialize logging, stopper and context that can be canceled + ctx, stopper, err := initLogging(cmd) + if err != nil { + return err + } + defer stopper.Stop(ctx) + + log.Infof(ctx, "New proxy with opts: %+v", proxyOpts) + + proxyLn, err := net.Listen("tcp", proxyOpts.ListenAddr) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = proxyLn.Close() })) + + metricsLn, err := net.Listen("tcp", proxyOpts.MetricsAddress) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = metricsLn.Close() })) + + handler, err := sqlproxyccl.NewProxyHandler(ctx, stopper, proxyOpts) + if err != nil { + return err + } + + server := sqlproxyccl.NewServer(handler.Handle) + + errChan := make(chan error, 1) + + if err := stopper.RunAsyncTask(ctx, "serve-http", func(ctx context.Context) { + log.Infof(ctx, "HTTP metrics server listening at %s", metricsLn.Addr()) + if err := server.ServeHTTP(ctx, metricsLn); err != nil { + errChan <- err + } + }); err != nil { + return err + } + + if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) { + log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr()) + if err := server.Serve(ctx, proxyLn); err != nil { + errChan <- err + } + }); err != nil { + return err + } + + return waitForSignals(ctx, stopper, errChan) +} + +func initLogging(cmd *cobra.Command) (ctx context.Context, stopper *stop.Stopper, err error) { + // Remove the default store, which avoids using it to set up logging. + // Instead, we'll default to logging to stderr unless --log-dir is + // specified. This makes sense since the standalone SQL server is + // at the time of writing stateless and may not be provisioned with + // suitable storage. + serverCfg.Stores.Specs = nil + serverCfg.ClusterName = "" + + ctx = context.Background() + stopper, err = setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) + if err != nil { + return + } + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + return +} + +func waitForSignals( + ctx context.Context, stopper *stop.Stopper, errChan chan error, +) (returnErr error) { + // Need to alias the signals if this has to run on non-unix OSes too. + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, unix.SIGINT, unix.SIGTERM) + + // Dump the stacks when QUIT is received. + quitSignalCh := make(chan os.Signal, 1) + signal.Notify(quitSignalCh, unix.SIGQUIT) + go func() { + for { + <-quitSignalCh + log.DumpStacks(context.Background()) + } + }() + + select { + case err := <-errChan: + log.StartAlwaysFlush() + return err + case <-stopper.ShouldQuiesce(): + // Stop has been requested through the stopper's Stop + <-stopper.IsStopped() + // StartAlwaysFlush both flushes and ensures that subsequent log + // writes are flushed too. + log.StartAlwaysFlush() + case sig := <-signalCh: // INT or TERM + log.StartAlwaysFlush() // In case the caller follows up with KILL + log.Ops.Infof(ctx, "received signal '%s'", sig) + if sig == os.Interrupt { + returnErr = errors.New("interrupted") + } + go func() { + log.Infof(ctx, "server stopping") + stopper.Stop(ctx) + }() + case <-log.FatalChan(): + stopper.Stop(ctx) + select {} // Block and wait for logging go routine to shut down the process + } + + for { + select { + case sig := <-signalCh: + switch sig { + case os.Interrupt: // SIGTERM after SIGTERM + log.Ops.Infof(ctx, "received additional signal '%s'; continuing graceful shutdown", sig) + continue + } + + log.Ops.Shoutf(ctx, severity.ERROR, + "received signal '%s' during shutdown, initiating hard shutdown", log.Safe(sig)) + panic("terminate") + case <-stopper.IsStopped(): + const msgDone = "server shutdown completed" + log.Ops.Infof(ctx, msgDone) + fmt.Fprintln(os.Stdout, msgDone) + } + break + } + + return returnErr +} diff --git a/pkg/cli/mt_start_sql.go b/pkg/cli/mt_start_sql.go index d7dd8c83e39b..9bc1e57b6729 100644 --- a/pkg/cli/mt_start_sql.go +++ b/pkg/cli/mt_start_sql.go @@ -136,6 +136,10 @@ func runStartSQL(cmd *cobra.Command, args []string) error { ch := make(chan os.Signal, 1) signal.Notify(ch, drainSignals...) - sig := <-ch - return errors.Newf("received signal %v", sig) + select { + case sig := <-ch: + return errors.Newf("received signal %v", sig) + case <-stopper.ShouldQuiesce(): + return nil + } } diff --git a/pkg/cli/mt_test_directory.go b/pkg/cli/mt_test_directory.go new file mode 100644 index 000000000000..61622b86f539 --- /dev/null +++ b/pkg/cli/mt_test_directory.go @@ -0,0 +1,63 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import ( + "context" + "fmt" + "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/cli/cliflags" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/spf13/cobra" +) + +var port int + +func init() { + f := mtTestDirectorySvr.Flags() + ServerPort := cliflags.FlagInfo{ + Name: "port", + Description: `Port to listen on.`, + } + intFlag(f, &port, ServerPort) +} + +var mtTestDirectorySvr = &cobra.Command{ + Use: "test-directory", + Short: "Run a test directory service.", + Long: ` +Run a test directory service. +`, + Args: cobra.NoArgs, + RunE: MaybeDecorateGRPCError(runDirectorySvr), +} + +func runDirectorySvr(cmd *cobra.Command, args []string) (returnErr error) { + ctx := context.Background() + serverCfg.Stores.Specs = nil + + stopper, err := setupAndInitializeLoggingAndProfiling(ctx, cmd, false /* isServerCmd */) + if err != nil { + return err + } + defer stopper.Stop(ctx) + + tds := tenant.NewTestDirectoryServer(stopper) + + listenPort, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return err + } + stopper.AddCloser(stop.CloserFn(func() { _ = listenPort.Close() })) + return tds.Serve(listenPort) +} diff --git a/pkg/cli/start.go b/pkg/cli/start.go index af46bc511aa5..a669d4064e80 100644 --- a/pkg/cli/start.go +++ b/pkg/cli/start.go @@ -118,7 +118,9 @@ var serverCmds = append(StartCmds, mtStartSQLCmd) // customLoggingSetupCmds lists the commands that call setupLogging() // after other types of configuration. -var customLoggingSetupCmds = append(serverCmds, debugCheckLogConfigCmd, demoCmd) +var customLoggingSetupCmds = append( + serverCmds, debugCheckLogConfigCmd, demoCmd, mtStartSQLProxyCmd, mtTestDirectorySvr, +) func initBlockProfile() { // Enable the block profile for a sample of mutex and channel operations. diff --git a/pkg/security/certmgr/BUILD.bazel b/pkg/security/certmgr/BUILD.bazel index a236e350159c..207dfba183b3 100644 --- a/pkg/security/certmgr/BUILD.bazel +++ b/pkg/security/certmgr/BUILD.bazel @@ -15,6 +15,7 @@ go_library( srcs = [ "cert_manager.go", "file_cert.go", + "self_signed_cert.go", ":mocks_certmgr", # keep ], embed = [":certlib"], # keep @@ -25,6 +26,7 @@ go_library( "//pkg/util/log/eventpb", "//pkg/util/syncutil", "//pkg/util/sysutil", + "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_golang_mock//gomock", # keep ], @@ -35,6 +37,7 @@ go_test( srcs = [ "cert_manager_test.go", "file_cert_test.go", + "self_signed_cert_test.go", ], embed = [":certmgr"], deps = [ diff --git a/pkg/security/certmgr/cert.go b/pkg/security/certmgr/cert.go index f39d04fb4ec7..8baa3e394e0a 100644 --- a/pkg/security/certmgr/cert.go +++ b/pkg/security/certmgr/cert.go @@ -10,7 +10,10 @@ package certmgr -import "context" +import ( + "context" + "crypto/tls" +) //go:generate mockgen -package=certmgr -destination=mocks_generated.go -source=cert.go . Cert @@ -43,4 +46,6 @@ type Cert interface { Err() error // ClearErr will clear the last reported err and allow reloads to resume. ClearErr() + // TLSCert will return the last loaded tls.Certificate + TLSCert() *tls.Certificate } diff --git a/pkg/security/certmgr/file_cert_test.go b/pkg/security/certmgr/file_cert_test.go index 2ee58858d4a3..24f4ac66f43c 100644 --- a/pkg/security/certmgr/file_cert_test.go +++ b/pkg/security/certmgr/file_cert_test.go @@ -58,7 +58,6 @@ func TestFileCert_Err(t *testing.T) { require.Nil(t, fc.Err()) fc.Reload(context.Background()) require.NotNil(t, fc.Err()) - require.NotNil(t, fc.Err()) fc.ClearErr() require.Nil(t, fc.Err()) } diff --git a/pkg/security/certmgr/mocks_generated.go b/pkg/security/certmgr/mocks_generated.go index 013d7ef63c24..097b2ba77bca 100644 --- a/pkg/security/certmgr/mocks_generated.go +++ b/pkg/security/certmgr/mocks_generated.go @@ -6,6 +6,7 @@ package certmgr import ( context "context" + tls "crypto/tls" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -71,3 +72,17 @@ func (mr *MockCertMockRecorder) Reload(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reload", reflect.TypeOf((*MockCert)(nil).Reload), ctx) } + +// TLSCert mocks base method. +func (m *MockCert) TLSCert() *tls.Certificate { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TLSCert") + ret0, _ := ret[0].(*tls.Certificate) + return ret0 +} + +// TLSCert indicates an expected call of TLSCert. +func (mr *MockCertMockRecorder) TLSCert() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TLSCert", reflect.TypeOf((*MockCert)(nil).TLSCert)) +} diff --git a/pkg/security/certmgr/self_signed_cert.go b/pkg/security/certmgr/self_signed_cert.go new file mode 100644 index 000000000000..0ba941f8f45e --- /dev/null +++ b/pkg/security/certmgr/self_signed_cert.go @@ -0,0 +1,103 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package certmgr + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" +) + +// Ensure that SelfSignedCert implements Cert. +var _ Cert = (*SelfSignedCert)(nil) + +// SelfSignedCert represents a single, self-signed certificate, generated at runtime. +type SelfSignedCert struct { + syncutil.Mutex + years, months, days int + secs time.Duration + err error + cert *tls.Certificate +} + +// NewSelfSignedCert will generate a new self-signed cert. +// A follow up Reload will regenerate the cert. +func NewSelfSignedCert(years, months, days int, secs time.Duration) *SelfSignedCert { + return &SelfSignedCert{years: years, months: months, days: days, secs: secs} +} + +// Reload will regenerate the self-signed cert. +func (ssc *SelfSignedCert) Reload(ctx context.Context) { + ssc.Lock() + defer ssc.Unlock() + + // There was a previous error that is not yet retrieved. + if ssc.err != nil { + return + } + + // Generate self signed cert for testing. + // Use "openssl s_client -showcerts -starttls postgres -connect {HOSTNAME}:{PORT}" to + // inspect the certificate or save the certificate portion to a file and + // use sslmode=verify-full + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + ssc.err = errors.Wrapf(err, "could not generate key") + return + } + + from := timeutil.Now() + until := from.AddDate(ssc.years, ssc.months, ssc.days).Add(ssc.secs) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: from, + NotAfter: until, + DNSNames: []string{"localhost"}, + IsCA: true, + } + cer, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) + if err != nil { + ssc.err = errors.Wrapf(err, "could not create certificate") + } + + ssc.cert = &tls.Certificate{ + Certificate: [][]byte{cer}, + PrivateKey: priv, + } +} + +// Err will return the last error that occurred during reload or nil if the +// last reload was successful. +func (ssc *SelfSignedCert) Err() error { + ssc.Lock() + defer ssc.Unlock() + return ssc.err +} + +// ClearErr will clear the last err so the follow up Reload can execute. +func (ssc *SelfSignedCert) ClearErr() { + ssc.Lock() + defer ssc.Unlock() + ssc.err = nil +} + +// TLSCert returns the tls certificate if the cert generation was successful. +func (ssc *SelfSignedCert) TLSCert() *tls.Certificate { + return ssc.cert +} diff --git a/pkg/security/certmgr/self_signed_cert_test.go b/pkg/security/certmgr/self_signed_cert_test.go new file mode 100644 index 000000000000..9a75b58a80be --- /dev/null +++ b/pkg/security/certmgr/self_signed_cert_test.go @@ -0,0 +1,44 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package certmgr + +import ( + "context" + "crypto/x509" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSelfSignedCert_Err(t *testing.T) { + ssc := NewSelfSignedCert(-9999, 0, 0, 0) + require.NotNil(t, ssc) + require.Nil(t, ssc.Err()) + ssc.Reload(context.Background()) + require.Regexp(t, "cannot represent time as GeneralizedTime", ssc.Err()) + ssc.ClearErr() + require.Nil(t, ssc.Err()) +} + +func TestSelfSignedCert_TLSCert(t *testing.T) { + ssc := NewSelfSignedCert(1, 6, 3, 5*time.Hour) + require.NotNil(t, ssc) + require.Nil(t, ssc.Err()) + ssc.Reload(context.Background()) + require.Nil(t, ssc.Err()) + require.NotNil(t, ssc.TLSCert()) + require.Len(t, ssc.TLSCert().Certificate, 1) + cert, err := x509.ParseCertificate(ssc.TLSCert().Certificate[0]) + require.NoError(t, err) + expectedUntil := cert.NotBefore.AddDate(1, 6, 3).Add(5 * time.Hour) + require.Equal(t, expectedUntil, cert.NotAfter) +} diff --git a/pkg/testutils/BUILD.bazel b/pkg/testutils/BUILD.bazel index 99e6c098250d..26e956744c69 100644 --- a/pkg/testutils/BUILD.bazel +++ b/pkg/testutils/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "dir.go", "error.go", "files.go", + "hook.go", "keys.go", "net.go", "pprof.go", diff --git a/pkg/testutils/hook.go b/pkg/testutils/hook.go new file mode 100644 index 000000000000..340bbb4de7a0 --- /dev/null +++ b/pkg/testutils/hook.go @@ -0,0 +1,24 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package testutils + +import "reflect" + +// HookGlobal sets `*ptr = val` and returns a closure for restoring `*ptr` to +// its original value. A runtime panic will occur if `val` is not assignable to +// `*ptr`. +func HookGlobal(ptr, val interface{}) func() { + global := reflect.ValueOf(ptr).Elem() + orig := reflect.New(global.Type()).Elem() + orig.Set(global) + global.Set(reflect.ValueOf(val)) + return func() { global.Set(orig) } +}