Skip to content

Commit

Permalink
sqlproxyccl: minor fixes around the proxy handler
Browse files Browse the repository at this point in the history
In cockroachdb#65164, we migrated the sqlproxy in the CC code to the DB repository, and
there were a few buglets:
- sqlproxy crashes when the tenant ID supplied in the connection string is 0
  because roachpb.MakeTenantID panics when the tenant ID is 0.
- sqlproxy leaks internal parsing errors to the client.

This patch hides internal parsing errors, and replaces them with friendly
user-facing errors (e.g. "invalid cluster name"). We also add a bounds check
to the parsed tenant ID so that the process does not crash on an invalid
tenant ID. More tests were added as well.

Release note: None
  • Loading branch information
jaylim-crl committed Jun 15, 2021
1 parent 7b778b3 commit 8cac2ea
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 48 deletions.
7 changes: 7 additions & 0 deletions pkg/ccl/sqlproxyccl/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ var sendErrToClient = func(conn net.Conn, err error) {
Code: pgCode,
Message: msg,
}).Encode(nil))
} else {
// Return a generic "internal server error" message.
_, _ = conn.Write((&pgproto3.ErrorResponse{
Severity: "FATAL",
Code: "08004", // rejected connection
Message: "internal server error",
}).Encode(nil))
}
}

Expand Down
33 changes: 21 additions & 12 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn
return nil
}

// Note that the errors returned from this function are user-facing errors so
// we should be careful with the details that we want to expose.
backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(msg)
// NOTE: Errors returned from this function are user-facing errors so we
// should be careful with the details that we want to expose.
backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(ctx, msg)
if err != nil {
clientErr := &codeError{codeParamsRoutingFailed, err}
log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", clientErr.Error())
log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", err.Error())
updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics)
return clientErr
}
Expand All @@ -212,8 +212,8 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn

ipAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
clientErr := &codeError{codeParamsRoutingFailed, err}
log.Errorf(ctx, "could not parse address: %v", clientErr.Error())
clientErr := newErrorf(codeParamsRoutingFailed, "unexpected connection address")
log.Errorf(ctx, "could not parse address: %v", err.Error())
updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics)
return clientErr
}
Expand Down Expand Up @@ -438,7 +438,7 @@ func (handler *proxyHandler) validateAccessAndThrottle(
// Hence we need to rate limit.
if err := handler.throttleService.LoginCheck(ipAddr, timeutil.Now()); err != nil {
log.Errorf(ctx, "throttler refused connection: %v", err.Error())
return newErrorf(codeProxyRefusedConnection, "Connection attempt throttled")
return newErrorf(codeProxyRefusedConnection, "connection attempt throttled")
}
}

Expand Down Expand Up @@ -541,7 +541,7 @@ var reportFailureToDirectory = func(
// through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and
// "--NAME=VALUE".
func clusterNameAndTenantFromParams(
msg *pgproto3.StartupMessage,
ctx context.Context, msg *pgproto3.StartupMessage,
) (*pgproto3.StartupMessage, string, roachpb.TenantID, error) {
clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"])
if err != nil {
Expand All @@ -566,19 +566,28 @@ func clusterNameAndTenantFromParams(
}

sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep)

// Cluster name provided without a tenant ID in the end.
if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 {
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name %s", clusterNameFromDB)
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}
clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:]

clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:]
if !clusterNameRegex.MatchString(clusterNameSansTenant) {
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}

tenID, err := strconv.ParseUint(tenantIDStr, 10, 64)
if err != nil {
return msg, "", roachpb.MaxTenantID, errors.Wrapf(err, "cannot parse %s as uint64", tenantIDStr)
// Log these non user-facing errors.
log.Errorf(ctx, "cannot parse tenant ID in %s: %v", clusterNameFromDB, err)
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}

if tenID < roachpb.MinTenantID.ToUint64() {
// Log these non user-facing errors.
log.Errorf(ctx, "%s contains an invalid tenant ID", clusterNameFromDB)
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}

// Make and return a copy of the startup msg so the original is not modified.
Expand Down Expand Up @@ -609,7 +618,7 @@ func parseDatabaseParam(databaseParam string) (clusterName, databaseName string,
return "", "", nil
}

parts := strings.SplitN(databaseParam, ".", 2)
parts := strings.Split(databaseParam, ".")

// Database param provided without cluster name.
if len(parts) <= 1 {
Expand Down
Loading

0 comments on commit 8cac2ea

Please sign in to comment.