-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
sqlproxy: complete connection handshake
Currently the proxy receives the startup message from the server but does not actually check if it is of type `pgproto3.ReadyForQuery` which is signifies that the connection has been established and the server can start serve queries. Because of this we cannot recognize successful connections. Release note: none.
- Loading branch information
Spas Bojanov
committed
Jan 14, 2021
1 parent
b22e903
commit 24890c0
Showing
8 changed files
with
306 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright 2021 The Cockroach Authors. | ||
// | ||
// Licensed as a CockroachDB Enterprise file under the Cockroach Community | ||
// License (the "License"); you may not use this file except in compliance with | ||
// the License. You may obtain a copy of the License at | ||
// | ||
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt | ||
|
||
package sqlproxyccl | ||
|
||
import ( | ||
"net" | ||
|
||
"github.com/jackc/pgproto3/v2" | ||
) | ||
|
||
func authenticate(clientConn, crdbConn net.Conn) error { | ||
fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) | ||
be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn) | ||
|
||
// The auth step should require only a few back and forths so 20 iterations | ||
// should be enough. | ||
var i int | ||
for ; i < 20; i++ { | ||
// Read the server response and forward it to the client. | ||
// TODO(spaskob): in verbose mode, log these messages. | ||
backendMsg, err := be.Receive() | ||
if err != nil { | ||
return NewErrorf(CodeBackendReadFailed, "unable to receive message from backend: %v", err) | ||
} | ||
|
||
err = fe.Send(backendMsg) | ||
if err != nil { | ||
return NewErrorf( | ||
CodeClientWriteFailed, "unable to send message %v to client: %v", backendMsg, err, | ||
) | ||
} | ||
|
||
// Decide what to do based on the type of the server response. | ||
switch tp := backendMsg.(type) { | ||
case *pgproto3.ReadyForQuery: | ||
// Server has authenticated the connection successfully and is ready to | ||
// serve queries. | ||
return nil | ||
case *pgproto3.AuthenticationOk: | ||
// Server has authenticated the connection; keep reading messages until | ||
// `pgproto3.ReadyForQuery` is encountered which signifies that server | ||
// is ready to serve queries. | ||
case *pgproto3.ParameterStatus: | ||
// Server sent status message; keep reading messages until | ||
// `pgproto3.ReadyForQuery` is encountered. | ||
case *pgproto3.ErrorResponse: | ||
// Server has rejected the authentication response from the client and | ||
// has closed the connection. | ||
return NewErrorf(CodeAuthFailed, "authentication failed: %v", backendMsg) | ||
case | ||
*pgproto3.AuthenticationCleartextPassword, | ||
*pgproto3.AuthenticationMD5Password, | ||
*pgproto3.AuthenticationSASL: | ||
// The backend is requesting the user to authenticate. | ||
// Read the client response and forward it to server. | ||
fntMsg, err := fe.Receive() | ||
if err != nil { | ||
return NewErrorf(CodeClientReadFailed, "unable to receive message from client: %v", err) | ||
} | ||
err = be.Send(fntMsg) | ||
if err != nil { | ||
return NewErrorf( | ||
CodeBackendWriteFailed, "unable to send message %v to backend: %v", fntMsg, err, | ||
) | ||
} | ||
default: | ||
return NewErrorf(CodeBackendDisconnected, "received unexpected backend message type: %v", tp) | ||
} | ||
} | ||
return NewErrorf(CodeBackendDisconnected, "authentication took more than %d iterations", i) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright 2021 The Cockroach Authors. | ||
// | ||
// Licensed as a CockroachDB Enterprise file under the Cockroach Community | ||
// License (the "License"); you may not use this file except in compliance with | ||
// the License. You may obtain a copy of the License at | ||
// | ||
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt | ||
|
||
package sqlproxyccl | ||
|
||
import ( | ||
"net" | ||
"testing" | ||
|
||
"github.com/cockroachdb/cockroach/pkg/util/leaktest" | ||
"github.com/cockroachdb/errors" | ||
"github.com/jackc/pgproto3/v2" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestAuthenticateOK(t *testing.T) { | ||
defer leaktest.AfterTest(t)() | ||
|
||
cli, srv := net.Pipe() | ||
be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) | ||
fe := pgproto3.NewFrontend(pgproto3.NewChunkReader(cli), cli) | ||
|
||
go func() { | ||
err := be.Send(&pgproto3.ReadyForQuery{}) | ||
require.NoError(t, err) | ||
beMsg, err := fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.ReadyForQuery{}) | ||
}() | ||
|
||
require.NoError(t, authenticate(srv, cli)) | ||
} | ||
|
||
func TestAuthenticateClearText(t *testing.T) { | ||
defer leaktest.AfterTest(t)() | ||
|
||
cli, srv := net.Pipe() | ||
be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) | ||
fe := pgproto3.NewFrontend(pgproto3.NewChunkReader(cli), cli) | ||
|
||
go func() { | ||
err := be.Send(&pgproto3.AuthenticationCleartextPassword{}) | ||
require.NoError(t, err) | ||
beMsg, err := fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.AuthenticationCleartextPassword{}) | ||
|
||
err = fe.Send(&pgproto3.PasswordMessage{Password: "password"}) | ||
require.NoError(t, err) | ||
feMsg, err := be.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, feMsg, &pgproto3.PasswordMessage{Password: "password"}) | ||
|
||
err = be.Send(&pgproto3.AuthenticationOk{}) | ||
require.NoError(t, err) | ||
beMsg, err = fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.AuthenticationOk{}) | ||
|
||
err = be.Send(&pgproto3.ParameterStatus{Name: "Server Version", Value: "1.3"}) | ||
require.NoError(t, err) | ||
beMsg, err = fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.ParameterStatus{Name: "Server Version", Value: "1.3"}) | ||
|
||
err = be.Send(&pgproto3.ReadyForQuery{}) | ||
require.NoError(t, err) | ||
beMsg, err = fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.ReadyForQuery{}) | ||
}() | ||
|
||
require.NoError(t, authenticate(srv, cli)) | ||
} | ||
|
||
func TestAuthenticateError(t *testing.T) { | ||
defer leaktest.AfterTest(t)() | ||
|
||
cli, srv := net.Pipe() | ||
be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) | ||
fe := pgproto3.NewFrontend(pgproto3.NewChunkReader(cli), cli) | ||
|
||
go func() { | ||
err := be.Send(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "foo"}) | ||
require.NoError(t, err) | ||
beMsg, err := fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.ErrorResponse{Severity: "FATAL", Code: "foo"}) | ||
}() | ||
|
||
err := authenticate(srv, cli) | ||
require.Error(t, err) | ||
codeErr := (*CodeError)(nil) | ||
require.True(t, errors.As(err, &codeErr)) | ||
require.Equal(t, CodeAuthFailed, codeErr.code) | ||
} | ||
|
||
func TestAuthenticateUnexpectedMessage(t *testing.T) { | ||
defer leaktest.AfterTest(t)() | ||
|
||
cli, srv := net.Pipe() | ||
be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) | ||
fe := pgproto3.NewFrontend(pgproto3.NewChunkReader(cli), cli) | ||
|
||
go func() { | ||
err := be.Send(&pgproto3.BackendKeyData{}) | ||
require.NoError(t, err) | ||
beMsg, err := fe.Receive() | ||
require.NoError(t, err) | ||
require.Equal(t, beMsg, &pgproto3.BackendKeyData{}) | ||
}() | ||
|
||
err := authenticate(srv, cli) | ||
require.Error(t, err) | ||
codeErr := (*CodeError)(nil) | ||
require.True(t, errors.As(err, &codeErr)) | ||
require.Equal(t, CodeBackendDisconnected, codeErr.code) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.