Skip to content

Commit

Permalink
ccl/sqlproxyccl: add server name indication (SNI) support
Browse files Browse the repository at this point in the history
Previously the proxy supported two ways of providing tenant id and
cluster name information. One was through the database name. The second
was through the options parameter. As part of the serverless routing
changes, we want to support routing with a tenant id passed through SNI.
This PR adds the ability to route using SNI info.

Release justification: Low risk, high reward changes to existing functionality

Release note: None
  • Loading branch information
darinpp committed Apr 13, 2022
1 parent 92bf06a commit 2cc6ee5
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 12 deletions.
16 changes: 7 additions & 9 deletions pkg/ccl/sqlproxyccl/frontend_admitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ import (

// FrontendAdmitInfo contains the result of FrontendAdmit call.
type FrontendAdmitInfo struct {
conn net.Conn
msg *pgproto3.StartupMessage
err error
conn net.Conn
msg *pgproto3.StartupMessage
err error
sniServerName string
}

// FrontendAdmit is the default implementation of a frontend admitter. It can
Expand All @@ -34,7 +35,6 @@ var FrontendAdmit = func(
// `conn` could be replaced by `conn` embedded in a `tls.Conn` connection,
// hence it's important to close `conn` rather than `proxyConn` since closing
// the latter will not call `Close` method of `tls.Conn`.
var sniServerName string

// Read first message from client.
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
Expand All @@ -52,6 +52,8 @@ var FrontendAdmit = func(
return &FrontendAdmitInfo{conn: conn}
}

var sniServerName string

// If we have an incoming TLS Config, require that the client initiates with
// an SSLRequest message.
if incomingTLSConfig != nil {
Expand Down Expand Up @@ -84,10 +86,6 @@ var FrontendAdmit = func(
}

if startup, ok := m.(*pgproto3.StartupMessage); ok {
// Add the sniServerName (if used) as parameter
if sniServerName != "" {
startup.Parameters["sni-server"] = sniServerName
}
// This forwards the remote addr to the backend.
startup.Parameters[remoteAddrStartupParam] = conn.RemoteAddr().String()
// The client is blocked from using session revival tokens; only the proxy
Expand All @@ -102,7 +100,7 @@ var FrontendAdmit = func(
),
}
}
return &FrontendAdmitInfo{conn: conn, msg: startup}
return &FrontendAdmitInfo{conn: conn, msg: startup, sniServerName: sniServerName}
}

code := codeUnexpectedStartupMessage
Expand Down
2 changes: 2 additions & 0 deletions pkg/ccl/sqlproxyccl/frontend_admitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func TestFrontendAdmitWithClientSSLRequire(t *testing.T) {

go func() {
cfg, err := pgconn.ParseConfig("postgres://localhost?sslmode=require")
cfg.TLSConfig.ServerName = "test"
require.NoError(t, err)
require.NotNil(t, cfg)
cfg.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
Expand All @@ -96,6 +97,7 @@ func TestFrontendAdmitWithClientSSLRequire(t *testing.T) {
require.NotEqual(t, srv, fe.conn) // The connection was replaced by SSL
require.NotNil(t, fe.msg)
require.Contains(t, fe.msg.Parameters, remoteAddrStartupParam)
require.Equal(t, fe.sniServerName, "test")
}

// TestFrontendAdmitRequireEncryption sends StartupRequest when SSlRequest is
Expand Down
52 changes: 51 additions & 1 deletion pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,10 @@ func (handler *proxyHandler) setupIncomingCert(ctx context.Context) error {
// the connection parameters, and rewrites the database and options parameters,
// if necessary.
//
// We currently support embedding the cluster identifier in two ways:
// We currently support embedding the cluster identifier in three ways:
//
// - Through server name identification (SNI) when using TLS connections
// (e.g. serverless-101.5xj.gcp-us-central1.cockroachlabs.cloud)
//
// - Within the database param (e.g. "happy-koala-3.defaultdb")
//
Expand All @@ -528,8 +531,13 @@ func clusterNameAndTenantFromParams(
return fe.msg, "", roachpb.MaxTenantID, err
}

sniTenID, sniPresent := parseSNI(fe.sniServerName)

// No cluster identifiers were specified.
if clusterIdentifierDB == "" && clusterIdentifierOpt == "" {
if sniPresent {
return fe.msg, "", sniTenID, nil
}
err := errors.New("missing cluster identifier")
err = errors.WithHint(err, clusterIdentifierHint)
return fe.msg, "", roachpb.MaxTenantID, err
Expand Down Expand Up @@ -605,13 +613,52 @@ func clusterNameAndTenantFromParams(
}
}

// Cluster ID provided through one of options or database (or both and the
// info is matching). If sni has been provided as well - check for match.
if sniPresent && tenID != sniTenID.InternalValue {
err := errors.New("multiple different tenant IDs provided")
err = errors.WithHintf(err,
"Is '%d' (SNI) or '%d' (database/options) the identifier for the cluster that you're connecting to?",
sniTenID.InternalValue, tenID)
err = errors.WithHint(err, clusterIdentifierHint)
return fe.msg, "", roachpb.MaxTenantID, err
}

outMsg := &pgproto3.StartupMessage{
ProtocolVersion: fe.msg.ProtocolVersion,
Parameters: paramsOut,
}
return outMsg, clusterName, roachpb.MakeTenantID(tenID), nil
}

// parseSNI parses the sni server name parameter if provided and returns the
// extracted tenant id. If the extraction was successful the second parameter
// will be true. If not - false.
func parseSNI(sniServerName string) (roachpb.TenantID, bool) {
if sniServerName == "" {
return roachpb.MaxTenantID, false
}

// Try to obtain tenant ID from SNI
parts := strings.Split(sniServerName, ".")
if len(parts) == 0 {
return roachpb.MaxTenantID, false
}

hostname := parts[0]
hostnameParts := strings.Split(hostname, "-")
if len(hostnameParts) != 2 || !strings.EqualFold("serverless", hostnameParts[0]) {
return roachpb.MaxTenantID, false
}

tenID, err := strconv.ParseUint(hostnameParts[1], 10, 64)
if err != nil || tenID < roachpb.MinTenantID.ToUint64() {
return roachpb.MaxTenantID, false
}

return roachpb.MakeTenantID(tenID), true
}

// parseDatabaseParam parses the database parameter from the PG connection
// string, and tries to extract the cluster identifier if present. The cluster
// identifier should be embedded in the database parameter using the dot (".")
Expand Down Expand Up @@ -698,6 +745,9 @@ following methods:
Use "--cluster=<cluster identifier>" as the options parameter.
(e.g. options="--cluster=active-roach-42")
3) Host name:
Use secure connection to the host name assigned to your cluster.
For more details, please visit our docs site at:
https://www.cockroachlabs.com/docs/cockroachcloud/connect-to-a-serverless-cluster
`
Expand Down
46 changes: 44 additions & 2 deletions pkg/ccl/sqlproxyccl/proxy_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,52 @@ func TestProxyAgainstSecureCRDB(t *testing.T) {
s, addr := newSecureProxyServer(
ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true},
)
_, port, err := net.SplitHostPort(addr)
require.NoError(t, err)

url := fmt.Sprintf("postgres://bob:wrong@%s/tenant-cluster-28.defaultdb?sslmode=require", addr)
te.TestConnectErr(ctx, t, url, 0, "failed SASL auth")

url = fmt.Sprintf("postgres://bob@%s/tenant-cluster-28.defaultdb?sslmode=require", addr)
te.TestConnectErr(ctx, t, url, 0, "failed SASL auth")

url = fmt.Sprintf("postgres://bob:[email protected]:%s/defaultdb?sslmode=require", port)
te.TestConnectErr(ctx, t, url, codeParamsRoutingFailed, "server error")

url = fmt.Sprintf("postgres://bob:[email protected]:%s/defaultdb?sslmode=require", port)
te.TestConnectErr(ctx, t, url, codeParamsRoutingFailed, "server error")

url = fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", addr)
te.TestConnect(ctx, t, url, func(conn *pgx.Conn) {
require.Equal(t, int64(1), s.metrics.CurConnCount.Value())
require.NoError(t, runTestQuery(ctx, conn))
})
require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count())

// SNI provides tenant ID.
url = fmt.Sprintf("postgres://bob:[email protected]:%s/defaultdb?sslmode=require", port)
te.TestConnect(ctx, t, url, func(conn *pgx.Conn) {
require.Equal(t, int64(1), s.metrics.CurConnCount.Value())
require.NoError(t, runTestQuery(ctx, conn))
})

// SNI and database provide tenant IDs that match.
url = fmt.Sprintf(
"postgres://bob:[email protected]:%s/tenant-cluster-28.defaultdb?sslmode=require", port,
)
te.TestConnect(ctx, t, url, func(conn *pgx.Conn) {
require.Equal(t, int64(1), s.metrics.CurConnCount.Value())
require.NoError(t, runTestQuery(ctx, conn))
})

// SNI and database provide tenant IDs that don't match.
url = fmt.Sprintf(
"postgres://bob:[email protected]:%s/tenant-cluster-29.defaultdb?sslmode=require", port,
)
te.TestConnectErr(ctx, t, url, codeParamsRoutingFailed, "server error")

require.Equal(t, int64(3), s.metrics.SuccessfulConnCount.Count())
require.Equal(t, int64(2), s.metrics.AuthFailedCount.Count())
require.Equal(t, int64(3), s.metrics.RoutingErrCount.Count())
}

func TestProxyTLSConf(t *testing.T) {
Expand Down Expand Up @@ -1466,7 +1498,13 @@ func (te *tester) TestConnect(ctx context.Context, t *testing.T, url string, fn
t.Helper()
te.setAuthenticated(false)
te.setErrToClient(nil)
conn, err := pgx.Connect(ctx, url)
connConfig, err := pgx.ParseConfig(url)
require.NoError(t, err)
if !strings.EqualFold(connConfig.Host, "127.0.0.1") {
connConfig.TLSConfig.ServerName = connConfig.Host
connConfig.Host = "127.0.0.1"
}
conn, err := pgx.ConnectConfig(ctx, connConfig)
require.NoError(t, err)
fn(conn)
require.NoError(t, conn.Close(ctx))
Expand All @@ -1487,6 +1525,10 @@ func (te *tester) TestConnectErr(

// Prevent pgx from tying to connect to the `::1` ipv6 address for localhost.
cfg.LookupFunc = func(ctx context.Context, s string) ([]string, error) { return []string{s}, nil }
if !strings.EqualFold(cfg.Host, "127.0.0.1") && cfg.TLSConfig != nil {
cfg.TLSConfig.ServerName = cfg.Host
cfg.Host = "127.0.0.1"
}
conn, err := pgx.ConnectConfig(ctx, cfg)
if err == nil {
_ = conn.Close(ctx)
Expand Down

0 comments on commit 2cc6ee5

Please sign in to comment.