diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 2f6484ccfbf6..cdaff009d0db 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -66,6 +66,13 @@ var sendErrToClient = func(conn net.Conn, err error) { Code: pgCode, Message: msg, }).Encode(nil)) + } else { + // Return a generic "internal server error" message. + _, _ = conn.Write((&pgproto3.ErrorResponse{ + Severity: "FATAL", + Code: "08004", // rejected connection + Message: "internal server error", + }).Encode(nil)) } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 4c6abd348c6d..46504479f8fa 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -195,12 +195,12 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn return nil } - // Note that the errors returned from this function are user-facing errors so - // we should be careful with the details that we want to expose. - backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(msg) + // NOTE: Errors returned from this function are user-facing errors so we + // should be careful with the details that we want to expose. + backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(ctx, msg) if err != nil { clientErr := &codeError{codeParamsRoutingFailed, err} - log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", clientErr.Error()) + log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", err.Error()) updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) return clientErr } @@ -212,8 +212,8 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn ipAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { - clientErr := &codeError{codeParamsRoutingFailed, err} - log.Errorf(ctx, "could not parse address: %v", clientErr.Error()) + clientErr := newErrorf(codeParamsRoutingFailed, "unexpected connection address") + log.Errorf(ctx, "could not parse address: %v", err.Error()) updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) return clientErr } @@ -438,7 +438,7 @@ func (handler *proxyHandler) validateAccessAndThrottle( // Hence we need to rate limit. if err := handler.throttleService.LoginCheck(ipAddr, timeutil.Now()); err != nil { log.Errorf(ctx, "throttler refused connection: %v", err.Error()) - return newErrorf(codeProxyRefusedConnection, "Connection attempt throttled") + return newErrorf(codeProxyRefusedConnection, "connection attempt throttled") } } @@ -541,7 +541,7 @@ var reportFailureToDirectory = func( // through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and // "--NAME=VALUE". func clusterNameAndTenantFromParams( - msg *pgproto3.StartupMessage, + ctx context.Context, msg *pgproto3.StartupMessage, ) (*pgproto3.StartupMessage, string, roachpb.TenantID, error) { clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"]) if err != nil { @@ -566,19 +566,28 @@ func clusterNameAndTenantFromParams( } sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep) + // Cluster name provided without a tenant ID in the end. if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 { - return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name %s", clusterNameFromDB) + return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) } - clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:] + clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:] if !clusterNameRegex.MatchString(clusterNameSansTenant) { return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) } tenID, err := strconv.ParseUint(tenantIDStr, 10, 64) if err != nil { - return msg, "", roachpb.MaxTenantID, errors.Wrapf(err, "cannot parse %s as uint64", tenantIDStr) + // Log these non user-facing errors. + log.Errorf(ctx, "cannot parse tenant ID in %s: %v", clusterNameFromDB, err) + return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) + } + + if tenID < roachpb.MinTenantID.ToUint64() { + // Log these non user-facing errors. + log.Errorf(ctx, "%s contains an invalid tenant ID", clusterNameFromDB) + return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) } // Make and return a copy of the startup msg so the original is not modified. @@ -609,7 +618,7 @@ func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, return "", "", nil } - parts := strings.SplitN(databaseParam, ".", 2) + parts := strings.Split(databaseParam, ".") // Database param provided without cluster name. if len(parts) <= 1 { diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index d520478faa21..3908d7768cd9 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -137,17 +137,24 @@ func TestFailedConnection(t *testing.T) { // TenantID rejected as malformed. te.TestConnectErr( - ctx, t, u+"?options=--cluster=dim&sslmode=require", - codeParamsRoutingFailed, "invalid cluster name", + ctx, t, u+"?options=--cluster=dimdog&sslmode=require", + codeParamsRoutingFailed, "invalid cluster name 'dimdog'", ) require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) - // No TenantID. + // No cluster name and TenantID. te.TestConnectErr( ctx, t, u+"?sslmode=require", - codeParamsRoutingFailed, "missing cluster name", + codeParamsRoutingFailed, "missing cluster name in connection string", ) require.Equal(t, int64(2), s.metrics.RoutingErrCount.Count()) + + // Bad TenantID. Ensure that we don't leak any parsing errors. + te.TestConnectErr( + ctx, t, u+"?options=--cluster=dim-dog-foo3&sslmode=require", + codeParamsRoutingFailed, "invalid cluster name 'dim-dog-foo3'", + ) + require.Equal(t, int64(3), s.metrics.RoutingErrCount.Count()) } func TestUnexpectedError(t *testing.T) { @@ -178,7 +185,7 @@ func TestUnexpectedError(t *testing.T) { // to the 5s connect_timeout for pgx.Connect to give up. start := timeutil.Now() _, err := pgx.Connect(ctx, u) - require.Error(t, err) + require.Contains(t, err.Error(), "internal server error") t.Log(err) elapsed := timeutil.Since(start) if elapsed >= 5*time.Second { @@ -342,31 +349,6 @@ func TestInsecureProxy(t *testing.T) { require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) } -func TestInsecureDoubleProxy(t *testing.T) { - defer leaktest.AfterTest(t)() - - ctx := context.Background() - te := newTester() - defer te.Close() - - sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) - sql.(*server.TestServer).PGServer().TestingSetTrustClientProvidedRemoteAddr(true) - defer sql.Stopper().Stop(ctx) - - // Test multiple proxies: proxyB -> proxyA -> tc - _, proxyA := newProxyServer(ctx, t, sql.Stopper(), - &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), Insecure: true}, - ) - _, proxyB := newProxyServer(ctx, t, sql.Stopper(), - &ProxyOptions{RoutingRule: proxyA, Insecure: true}, - ) - - url := fmt.Sprintf("postgres://root:admin@%s/dim-dog-28.dim-dog-29.defaultdb?sslmode=disable", proxyB) - te.TestConnect(ctx, t, url, func(conn *pgx.Conn) { - require.NoError(t, runTestQuery(ctx, conn)) - }) -} - func TestErroneousFrontend(t *testing.T) { defer leaktest.AfterTest(t)() @@ -387,9 +369,9 @@ func TestErroneousFrontend(t *testing.T) { url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) // Generic message here as the Frontend's error is not codeError and - // by default we don't pass back error's text. The startup message doesn't get - // processed in this case. - te.TestConnectErr(ctx, t, url, 0, "connection reset by peer|failed to receive message") + // by default we don't pass back error's text. The startup message doesn't + // get processed in this case. + te.TestConnectErr(ctx, t, url, 0, "internal server error") } func TestErroneousBackend(t *testing.T) { @@ -412,9 +394,9 @@ func TestErroneousBackend(t *testing.T) { url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=dim-dog-28", addr) // Generic message here as the Backend's error is not codeError and - // by default we don't pass back error's text. The startup message has already - // been processed. - te.TestConnectErr(ctx, t, url, 0, "failed to receive message") + // by default we don't pass back error's text. The startup message has + // already been processed. + te.TestConnectErr(ctx, t, url, 0, "internal server error") } func TestProxyRefuseConn(t *testing.T) { @@ -659,6 +641,211 @@ func TestDirectoryConnect(t *testing.T) { }) } +func TestClusterNameAndTenantFromParams(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + testCases := []struct { + name string + params map[string]string + expectedClusterName string + expectedTenantID uint64 + expectedParams map[string]string + expectedError string + }{ + { + name: "empty params", + params: map[string]string{}, + expectedError: "missing cluster name in connection string", + }, + { + name: "cluster name is not provided", + params: map[string]string{ + "database": "defaultdb", + "options": "--foo=bar", + }, + expectedError: "missing cluster name in connection string", + }, + { + name: "multiple similar cluster names", + params: map[string]string{ + "database": "happy-koala-7.defaultdb", + "options": "--cluster=happy-koala", + }, + expectedError: "multiple cluster names provided", + }, + { + name: "multiple different cluster names", + params: map[string]string{ + "database": "happy-koala-7.defaultdb", + "options": "--cluster=happy-tiger", + }, + expectedError: "multiple cluster names provided", + }, + { + name: "invalid cluster name in database param", + params: map[string]string{ + // Cluster names need to be between 6 to 20 alphanumeric characters. + "database": "short-0.defaultdb", + }, + expectedError: "invalid cluster name 'short-0'", + }, + { + name: "invalid cluster name in options param", + params: map[string]string{ + // Cluster names need to be between 6 to 20 alphanumeric characters. + "options": "--cluster=cockroachlabsdotcomfoobarbaz-0", + }, + expectedError: "invalid cluster name 'cockroachlabsdotcomfoobarbaz-0'", + }, + { + name: "invalid database param (1)", + params: map[string]string{"database": "."}, + expectedError: "invalid database param", + }, + { + name: "invalid database param (2)", + params: map[string]string{"database": "a."}, + expectedError: "invalid database param", + }, + { + name: "invalid database param (3)", + params: map[string]string{"database": ".b"}, + expectedError: "invalid database param", + }, + { + name: "invalid database param (4)", + params: map[string]string{"database": "a.b.c"}, + expectedError: "invalid database param", + }, + { + name: "multiple cluster flags", + params: map[string]string{ + "database": "hello-world.defaultdb", + "options": "--cluster=foobar --cluster=barbaz --cluster=testbaz", + }, + expectedError: "multiple cluster flags provided", + }, + { + name: "invalid cluster flag", + params: map[string]string{ + "database": "hello-world.defaultdb", + "options": "--cluster=", + }, + expectedError: "invalid cluster flag", + }, + { + name: "no tenant id", + params: map[string]string{"database": "happy2koala.defaultdb"}, + expectedError: "invalid cluster name 'happy2koala'", + }, + { + name: "missing tenant id", + params: map[string]string{"database": "happy2koala-.defaultdb"}, + expectedError: "invalid cluster name 'happy2koala-'", + }, + { + name: "missing cluster name", + params: map[string]string{"database": "-7.defaultdb"}, + expectedError: "invalid cluster name '-7'", + }, + { + name: "bad tenant id", + params: map[string]string{"database": "happy-koala-0-7a.defaultdb"}, + expectedError: "invalid cluster name 'happy-koala-0-7a'", + }, + { + name: "zero tenant id", + params: map[string]string{"database": "happy-koala-0.defaultdb"}, + expectedError: "invalid cluster name 'happy-koala-0'", + }, + { + name: "cluster name in database param", + params: map[string]string{ + "database": "happy-koala-7.defaultdb", + "foo": "bar", + }, + expectedClusterName: "happy-koala", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb", "foo": "bar"}, + }, + { + name: "valid cluster name with invalid arrangements", + params: map[string]string{ + "database": "defaultdb", + "options": "-c --cluster=happy-koala-7 -c -c -c", + }, + expectedClusterName: "happy-koala", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb"}, + }, + { + name: "short option: cluster name in options param", + params: map[string]string{ + "database": "defaultdb", + "options": "-ccluster=happy-koala-7", + }, + expectedClusterName: "happy-koala", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb"}, + }, + { + name: "short option with spaces: cluster name in options param", + params: map[string]string{ + "database": "defaultdb", + "options": "-c cluster=happy-koala-7", + }, + expectedClusterName: "happy-koala", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb"}, + }, + { + name: "long option: cluster name in options param", + params: map[string]string{ + "database": "defaultdb", + "options": "--cluster=happy-koala-7\t--foo=test", + }, + expectedClusterName: "happy-koala", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb"}, + }, + { + name: "leading 0s are ok", + params: map[string]string{"database": "happy-koala-0-07.defaultdb"}, + expectedClusterName: "happy-koala-0", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + msg := &pgproto3.StartupMessage{Parameters: tc.params} + + originalParams := make(map[string]string) + for k, v := range msg.Parameters { + originalParams[k] = v + } + + outMsg, clusterName, tenantID, err := clusterNameAndTenantFromParams(ctx, msg) + if tc.expectedError == "" { + require.NoErrorf(t, err, "failed test case\n%+v", tc) + + // When expectedError is specified, we always have a valid expectedTenantID. + require.Equal(t, roachpb.MakeTenantID(tc.expectedTenantID), tenantID) + + require.Equal(t, tc.expectedClusterName, clusterName) + require.Equal(t, tc.expectedParams, outMsg.Parameters) + } else { + require.EqualErrorf(t, err, tc.expectedError, "failed test case\n%+v", tc) + } + + // Check that the original parameters were not modified. + require.Equal(t, originalParams, msg.Parameters) + }) + } +} + type tester struct { // mu synchronizes the authenticated and errToClient fields, since they // need to be set on background goroutines, and will cause race builds to