Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlproxyccl: minor fixes and enhancements to the proxy handler and denylist #66412

Merged
merged 2 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/ccl/sqlproxyccl/denylist/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ func (dl *Denylist) Denied(entity DenyEntity) (*Entry, error) {
dl.mu.RLock()
defer dl.mu.RUnlock()

if ent, ok := dl.mu.entries[entity]; ok && !ent.Expiration.Before(dl.timeSource.Now()) {
if ent, ok := dl.mu.entries[entity]; ok &&
(ent.Expiration.IsZero() || !ent.Expiration.Before(dl.timeSource.Now())) {
return &Entry{ent.Reason}, nil
}
return nil, nil
Expand Down
146 changes: 81 additions & 65 deletions pkg/ccl/sqlproxyccl/denylist/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ func TestDenyListFileParsing(t *testing.T) {
cases := []struct {
t Type
expected string
}{{
IPAddrType, "ip",
}, {
ClusterType, "cluster",
}}
}{
{IPAddrType, "ip"},
{ClusterType, "cluster"},
}
for _, tc := range cases {
s, err := tc.t.MarshalYAML()
require.NoError(t, err)
Expand All @@ -44,26 +43,14 @@ func TestDenyListFileParsing(t *testing.T) {
cases := []struct {
raw string
expected Type
}{{
"ip", IPAddrType,
}, {
"IP", IPAddrType,
},
{
"Ip", IPAddrType,
},
{
"Cluster", ClusterType,
},
{
"cluster", ClusterType,
},
{
"CLUSTER", ClusterType,
},
{
"random text", UnknownType,
},
}{
{"ip", IPAddrType},
{"IP", IPAddrType},
{"Ip", IPAddrType},
{"Cluster", ClusterType},
{"cluster", ClusterType},
{"CLUSTER", ClusterType},
{"random text", UnknownType},
}
for _, tc := range cases {
var parsed Type
Expand All @@ -85,31 +72,34 @@ func TestDenyListFileParsing(t *testing.T) {
expected map[DenyEntity]*DenyEntry
}{
{"text: ", emptyMap},
{"random text\n\n\nmore random text",
emptyMap},
{"random text\n\n\nmore random text", emptyMap},
{defaultEmptyDenylistText, emptyMap},
{"SequenceNumber: 7", emptyMap},
{
// old denylist format, making sure it won't break new denylist code
// Old denylist format; making sure it won't break new denylist code.
`
SequenceNumber: 8
1.1.1.1: some reason
61: another reason`,
SequenceNumber: 8
1.1.1.1: some reason
61: another reason`,
emptyMap,
}, {
},
{
fmt.Sprintf(`
SequenceNumber: 9
denylist:
- entity: {"item":"1.2.3.4", "type": "ip"}
expiration: %s
reason: over quota
`, expirationTimeString),
map[DenyEntity]*DenyEntry{{"1.2.3.4", IPAddrType}: {
DenyEntity{"1.2.3.4", IPAddrType},
expirationTime,
"over quota",
SequenceNumber: 9
denylist:
- entity: {"item":"1.2.3.4", "type": "ip"}
expiration: %s
reason: over quota`,
expirationTimeString,
),
map[DenyEntity]*DenyEntry{
{"1.2.3.4", IPAddrType}: {
DenyEntity{"1.2.3.4", IPAddrType},
expirationTime,
"over quota",
},
},
}},
},
}

// use cancel to prevent leaked goroutines from file watches
Expand Down Expand Up @@ -162,57 +152,83 @@ func TestDenylistLogic(t *testing.T) {
outcome *Entry
}

// This is a time evolution of a denylist
// This is a time evolution of a denylist.
testCases := []struct {
input string
time time.Time
specs []denyIOSpec
}{
// Blocks IP address only.
{
fmt.Sprintf(`
SequenceNumber: 9
denylist:
- entity: {"item": "1.2.3.4", "type": "IP"}
expiration: %s
reason: over quota`, expirationTimeString),
SequenceNumber: 9
denylist:
- entity: {"item": "1.2.3.4", "type": "IP"}
expiration: %s
reason: over quota`,
expirationTimeString,
),
startTime.Add(10 * time.Second),
[]denyIOSpec{
{DenyEntity{"1.2.3.4", IPAddrType}, &Entry{"over quota"}},
{DenyEntity{"61", ClusterType}, nil},
{DenyEntity{"1.2.3.5", IPAddrType}, nil},
},
},
// Blocks both IP address and tenant cluster.
{
fmt.Sprintf(`
SequenceNumber: 10
denylist:
- entity: {"item": "1.2.3.4", "type": "IP"}
expiration: %s
reason: over quota
- entity: {"item": 61, "type": "Cluster"}
expiration: %s
reason: splunk pipeline`, expirationTimeString, expirationTimeString),
SequenceNumber: 10
denylist:
- entity: {"item": "1.2.3.4", "type": "IP"}
expiration: %s
reason: over quota
- entity: {"item": 61, "type": "Cluster"}
expiration: %s
reason: splunk pipeline`,
expirationTimeString,
expirationTimeString,
),
startTime.Add(20 * time.Second),
[]denyIOSpec{
{DenyEntity{"1.2.3.4", IPAddrType}, &Entry{"over quota"}},
{DenyEntity{"61", ClusterType}, &Entry{"splunk pipeline"}},
{DenyEntity{"1.2.3.5", IPAddrType}, nil},
}},
},
},
// Entry that has expired.
{
fmt.Sprintf(`
SequenceNumber: 11
denylist:
- entity: {"item": "1.2.3.4", "type": "ip"}
expiration: %s
reason: over quota`, expirationTimeString),
SequenceNumber: 11
denylist:
- entity: {"item": "1.2.3.4", "type": "ip"}
expiration: %s
reason: over quota`,
expirationTimeString,
),
futureTime,
[]denyIOSpec{
{DenyEntity{"1.2.3.4", IPAddrType}, nil},
{DenyEntity{"61", ClusterType}, nil},
{DenyEntity{"1.2.3.5", IPAddrType}, nil},
}},
},
},
// Entry without any expiration.
{
`
SequenceNumber: 11
denylist:
- entity: {"item": "1.2.3.4", "type": "ip"}
reason: over quota`,
futureTime,
[]denyIOSpec{
{DenyEntity{"1.2.3.4", IPAddrType}, &Entry{"over quota"}},
{DenyEntity{"61", ClusterType}, nil},
{DenyEntity{"1.2.3.5", IPAddrType}, nil},
},
},
}
// use cancel to prevent leaked goroutines from file watches
// Use cancel to prevent leaked goroutines from file watches.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tempDir := t.TempDir()
Expand Down
7 changes: 7 additions & 0 deletions pkg/ccl/sqlproxyccl/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ var sendErrToClient = func(conn net.Conn, err error) {
Code: pgCode,
Message: msg,
}).Encode(nil))
} else {
// Return a generic "internal server error" message.
_, _ = conn.Write((&pgproto3.ErrorResponse{
Severity: "FATAL",
Code: "08004", // rejected connection
Message: "internal server error",
}).Encode(nil))
}
}

Expand Down
33 changes: 21 additions & 12 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn
return nil
}

// Note that the errors returned from this function are user-facing errors so
// we should be careful with the details that we want to expose.
backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(msg)
// NOTE: Errors returned from this function are user-facing errors so we
// should be careful with the details that we want to expose.
backendStartupMsg, clusterName, tenID, err := clusterNameAndTenantFromParams(ctx, msg)
if err != nil {
clientErr := &codeError{codeParamsRoutingFailed, err}
log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", clientErr.Error())
log.Errorf(ctx, "unable to extract cluster name and tenant id: %s", err.Error())
updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics)
return clientErr
}
Expand All @@ -212,8 +212,8 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn

ipAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
clientErr := &codeError{codeParamsRoutingFailed, err}
log.Errorf(ctx, "could not parse address: %v", clientErr.Error())
clientErr := newErrorf(codeParamsRoutingFailed, "unexpected connection address")
log.Errorf(ctx, "could not parse address: %v", err.Error())
updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics)
return clientErr
}
Expand Down Expand Up @@ -438,7 +438,7 @@ func (handler *proxyHandler) validateAccessAndThrottle(
// Hence we need to rate limit.
if err := handler.throttleService.LoginCheck(ipAddr, timeutil.Now()); err != nil {
log.Errorf(ctx, "throttler refused connection: %v", err.Error())
return newErrorf(codeProxyRefusedConnection, "Connection attempt throttled")
return newErrorf(codeProxyRefusedConnection, "connection attempt throttled")
}
}

Expand Down Expand Up @@ -541,7 +541,7 @@ var reportFailureToDirectory = func(
// through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and
// "--NAME=VALUE".
func clusterNameAndTenantFromParams(
msg *pgproto3.StartupMessage,
ctx context.Context, msg *pgproto3.StartupMessage,
) (*pgproto3.StartupMessage, string, roachpb.TenantID, error) {
clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"])
if err != nil {
Expand All @@ -566,19 +566,28 @@ func clusterNameAndTenantFromParams(
}

sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep)

// Cluster name provided without a tenant ID in the end.
if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 {
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name %s", clusterNameFromDB)
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}
clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:]

clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:]
if !clusterNameRegex.MatchString(clusterNameSansTenant) {
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}

tenID, err := strconv.ParseUint(tenantIDStr, 10, 64)
if err != nil {
return msg, "", roachpb.MaxTenantID, errors.Wrapf(err, "cannot parse %s as uint64", tenantIDStr)
// Log these non user-facing errors.
log.Errorf(ctx, "cannot parse tenant ID in %s: %v", clusterNameFromDB, err)
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}

if tenID < roachpb.MinTenantID.ToUint64() {
// Log these non user-facing errors.
log.Errorf(ctx, "%s contains an invalid tenant ID", clusterNameFromDB)
return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB)
}

// Make and return a copy of the startup msg so the original is not modified.
Expand Down Expand Up @@ -609,7 +618,7 @@ func parseDatabaseParam(databaseParam string) (clusterName, databaseName string,
return "", "", nil
}

parts := strings.SplitN(databaseParam, ".", 2)
parts := strings.Split(databaseParam, ".")

// Database param provided without cluster name.
if len(parts) <= 1 {
Expand Down
Loading