Skip to content

Commit

Permalink
Merge pull request panoplyio#41 from panoplyio/revert-39-revert-34-ex…
Browse files Browse the repository at this point in the history
…tended-query-messages

Revert "Revert "Extended query message handling""
  • Loading branch information
aviaoh authored Jun 20, 2019
2 parents af4e1d6 + 0b7ecaf commit 8aa3249
Show file tree
Hide file tree
Showing 19 changed files with 1,104 additions and 306 deletions.
35 changes: 14 additions & 21 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/md5"
"crypto/rand"
"fmt"
"github.com/panoplyio/pgsrv/protocol"
)

const errExpectedPassword = "expected password response, got message type %q"
Expand All @@ -13,31 +14,23 @@ const errWrongPassword = "password does not match for user \"%s\""
// authenticator interface defines objects able to perform user authentication
// that happens at the very beginning of every session.
type authenticator interface {
// authenticate accepts a msgReadWriter instance and a map of args that describe
// authenticate accepts a protocol.MessageReadWriter instance and a map of args that describe
// the current session. It returns no error if the authentication succeeds,
// or an error if something fails.
//
// Authentication errors as well as welcome messages are sent by this function,
// so there is no need for the caller to send these. It is caller's responsibility
// though to terminate the session in case that an error is returned.
authenticate(rw msgReadWriter, args map[string]interface{}) error
authenticate(rw protocol.MessageReadWriter, args map[string]interface{}) error
}

// noPasswordAuthenticator responds with auth OK immediately.
type noPasswordAuthenticator struct{}

func (np *noPasswordAuthenticator) authenticate(rw msgReadWriter, args map[string]interface{}) error {
func (np *noPasswordAuthenticator) authenticate(rw protocol.MessageReadWriter, args map[string]interface{}) error {
return rw.Write(authOKMsg())
}

// messageReadWriter describes objects that handle client-server communication.
// Objects implementing this interface are used to send password requests to users,
// and receive their responses.
type msgReadWriter interface {
Write(m msg) error
Read() (msg, error)
}

// AuthType represents various types of authentication
type AuthType string

Expand Down Expand Up @@ -97,9 +90,9 @@ type clearTextAuthenticator struct {
pp PasswordProvider
}

func (a *clearTextAuthenticator) authenticate(rw msgReadWriter, args map[string]interface{}) error {
func (a *clearTextAuthenticator) authenticate(rw protocol.MessageReadWriter, args map[string]interface{}) error {
// AuthenticationClearText
passwordRequest := msg{
passwordRequest := protocol.Message{
'R',
0, 0, 0, 8, // length
0, 0, 0, 3, // clear text auth type
Expand All @@ -118,7 +111,7 @@ func (a *clearTextAuthenticator) authenticate(rw msgReadWriter, args map[string]
if m.Type() != 'p' {
err = fmt.Errorf(errExpectedPassword, m.Type())
err = WithSeverity(fromErr(err), fatalSeverity)
rw.Write(errMsg(err))
rw.Write(protocol.ErrorResponse(err))
return err
}

Expand All @@ -129,7 +122,7 @@ func (a *clearTextAuthenticator) authenticate(rw msgReadWriter, args map[string]
if !bytes.Equal(expectedPassword, actualPassword) {
err = fmt.Errorf(errWrongPassword, user)
err = WithSeverity(fromErr(err), fatalSeverity)
rw.Write(errMsg(err))
rw.Write(protocol.ErrorResponse(err))
return err
}

Expand All @@ -143,9 +136,9 @@ type md5Authenticator struct {
pp PasswordProvider
}

func (a *md5Authenticator) authenticate(rw msgReadWriter, args map[string]interface{}) error {
func (a *md5Authenticator) authenticate(rw protocol.MessageReadWriter, args map[string]interface{}) error {
// AuthenticationMD5Password
passwordRequest := msg{
passwordRequest := protocol.Message{
'R',
0, 0, 0, 12, // length
0, 0, 0, 5, // md5 auth type
Expand All @@ -166,7 +159,7 @@ func (a *md5Authenticator) authenticate(rw msgReadWriter, args map[string]interf
if m.Type() != 'p' {
err = fmt.Errorf(errExpectedPassword, m.Type())
err = WithSeverity(fromErr(err), fatalSeverity)
rw.Write(errMsg(err))
rw.Write(protocol.ErrorResponse(err))
return err
}

Expand All @@ -179,15 +172,15 @@ func (a *md5Authenticator) authenticate(rw msgReadWriter, args map[string]interf
if !bytes.Equal(expectedHash, actualHash) {
err = fmt.Errorf(errWrongPassword, user)
err = WithSeverity(fromErr(err), fatalSeverity)
rw.Write(errMsg(err))
rw.Write(protocol.ErrorResponse(err))
return err
}

return rw.Write(authOKMsg())
}

// authOKMsg returns a message that indicates that the client is now authenticated.
func authOKMsg() msg {
func authOKMsg() protocol.Message {
return []byte{'R', 0, 0, 0, 8, 0, 0, 0, 0}
}

Expand All @@ -200,7 +193,7 @@ func getRandomSalt() []byte {

// extractPassword extracts the password from a provided 'p' message.
// It assumes that the message is valid.
func extractPassword(m msg) []byte {
func extractPassword(m protocol.Message) []byte {
// password starts after the size (4 bytes) and lasts until null-terminator
return m[5 : len(m)-1]
}
Expand Down
45 changes: 23 additions & 22 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package pgsrv
import (
"bytes"
"crypto/md5"
"github.com/panoplyio/pgsrv/protocol"
"github.com/stretchr/testify/require"
"testing"
)

var authOKMessage = msg{'R', 0, 0, 0, 8, 0, 0, 0, 0}
var authOKMessage = protocol.Message{'R', 0, 0, 0, 8, 0, 0, 0, 0}
var fatalMarker = []byte{
83, 70, 65, 84, 65, 76,
}
Expand All @@ -20,7 +21,7 @@ func TestAuthOKMsg(t *testing.T) {
}

func TestNoPassword_authenticate(t *testing.T) {
rw := &mockMessageReadWriter{output: []msg{}}
rw := &mockMessageReadWriter{output: []protocol.Message{}}
args := map[string]interface{}{
"user": "this-is-user",
}
Expand All @@ -29,21 +30,21 @@ func TestNoPassword_authenticate(t *testing.T) {
err := np.authenticate(rw, args)

require.NoError(t, err)
require.Equal(t, []msg{authOKMessage}, rw.messages)
require.Equal(t, []protocol.Message{authOKMessage}, rw.messages)
}

func TestAuthenticationClearText_authenticate(t *testing.T) {
passwordRequest := msg{
passwordRequest := protocol.Message{
'R',
0, 0, 0, 8, // length
0, 0, 0, 3, // clear text auth type
}
passwordMessage := msg{
passwordMessage := protocol.Message{
'p',
0, 0, 0, 8,
109, 101, 104, 0, // 'meh'
}
rw := &mockMessageReadWriter{output: []msg{passwordMessage}}
rw := &mockMessageReadWriter{output: []protocol.Message{passwordMessage}}
args := map[string]interface{}{
"user": "this-is-user",
}
Expand All @@ -56,7 +57,7 @@ func TestAuthenticationClearText_authenticate(t *testing.T) {
err := a.authenticate(rw, args)

require.NoError(t, err)
expectedMessages := []msg{
expectedMessages := []protocol.Message{
passwordRequest,
authOKMessage,
}
Expand All @@ -75,7 +76,7 @@ func TestAuthenticationClearText_authenticate(t *testing.T) {

t.Run("invalid message type", func(t *testing.T) {
defer rw.Reset()
rw = &mockMessageReadWriter{output: []msg{
rw = &mockMessageReadWriter{output: []protocol.Message{
{'q', 0, 0, 0, 5, 1},
}}
err := a.authenticate(rw, args)
Expand All @@ -87,7 +88,7 @@ func TestAuthenticationClearText_authenticate(t *testing.T) {
}

func TestAuthenticationMD5_authenticate(t *testing.T) {
passwordRequest := msg{
passwordRequest := protocol.Message{
'R',
0, 0, 0, 12, // length
0, 0, 0, 5, // md5 auth type
Expand Down Expand Up @@ -125,7 +126,7 @@ func TestAuthenticationMD5_authenticate(t *testing.T) {

t.Run("invalid message type", func(t *testing.T) {
defer rw.Reset()
rw := &mockMessageReadWriter{output: []msg{
rw := &mockMessageReadWriter{output: []protocol.Message{
{'q', 0, 0, 0, 5, 1},
}}
err := a.authenticate(rw, args)
Expand Down Expand Up @@ -166,7 +167,7 @@ func TestGetRandomSalt(t *testing.T) {

func TestExtractPassword(t *testing.T) {
t.Run("regular password", func(t *testing.T) {
passwordMessage := msg{
passwordMessage := protocol.Message{
'p',
0, 0, 0, 9,
42, 42, 42, 42,
Expand All @@ -179,7 +180,7 @@ func TestExtractPassword(t *testing.T) {
})

t.Run("empty password", func(t *testing.T) {
passwordMessage := msg{
passwordMessage := protocol.Message{
'p',
0, 0, 0, 5,
0,
Expand All @@ -194,22 +195,22 @@ func TestExtractPassword(t *testing.T) {
// mockMessageReadWriter implements messageReadWriter and outputs the provided output
// message by message, looped.
type mockMessageReadWriter struct {
output []msg
output []protocol.Message
currentOutput int
messages []msg
messages []protocol.Message
}

func (rw *mockMessageReadWriter) Read() (msg, error) {
func (rw *mockMessageReadWriter) Read() (protocol.Message, error) {
return rw.output[rw.currentOutput%len(rw.output)], nil
}

func (rw *mockMessageReadWriter) Write(m msg) error {
func (rw *mockMessageReadWriter) Write(m protocol.Message) error {
rw.messages = append(rw.messages, m)
return nil
}

func (rw *mockMessageReadWriter) Reset() {
rw.messages = make([]msg, 0)
rw.messages = make([]protocol.Message, 0)
}

// mockMD5MessageReadWriter implements messageReadWriter and outputs password
Expand All @@ -218,11 +219,11 @@ type mockMD5MessageReadWriter struct {
user string
pass []byte
salt []byte
messages []msg
messages []protocol.Message
}

func (rw *mockMD5MessageReadWriter) Read() (msg, error) {
message := msg{
func (rw *mockMD5MessageReadWriter) Read() (protocol.Message, error) {
message := protocol.Message{
'p',
0, 0, 0, 25,
}
Expand All @@ -232,12 +233,12 @@ func (rw *mockMD5MessageReadWriter) Read() (msg, error) {
return message, nil
}

func (rw *mockMD5MessageReadWriter) Write(m msg) error {
func (rw *mockMD5MessageReadWriter) Write(m protocol.Message) error {
rw.salt = m[9:]
rw.messages = append(rw.messages, m)
return nil
}

func (rw *mockMD5MessageReadWriter) Reset() {
rw.messages = make([]msg, 0)
rw.messages = make([]protocol.Message, 0)
}
21 changes: 0 additions & 21 deletions msg.go

This file was deleted.

51 changes: 51 additions & 0 deletions protocol/extend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package protocol

// ParseComplete is sent when backend parsed a prepared statement successfully
var ParseComplete = []byte{'1', 0, 0, 0, 4}

// BindComplete is sent when backend prepared a portal and finished planning the query
var BindComplete = []byte{'2', 0, 0, 0, 4}

// CreatesTransaction tells weather this is a frontend message that should start/continue a transaction
func (m *Message) CreatesTransaction() bool {
return m.Type() == Parse || m.Type() == Bind
}

// EndsTransaction tells weather this is a frontend message that should end the current transaction
func (m *Message) EndsTransaction() bool {
return m.Type() == Query || m.Type() == Sync
}

// transaction represents a sequence of frontend and backend messages
// that apply only on commit. the purpose of transaction is to support
// extended query flow.
type transaction struct {
transport *Transport
in []Message // TODO: asses if we need it after implementation of prepared statements and portals is done
out []Message // TODO: add size limit
}

// Read uses Transport to read the next message into the transaction's incoming messages buffer
func (t *transaction) Read() (msg Message, err error) {
if msg, err = t.transport.Read(); err == nil {
t.in = append(t.in, msg)
}
return
}

// Write writes the provided message into the transaction's outgoing messages buffer
func (t *transaction) Write(msg Message) error {
t.out = append(t.out, msg)
return nil
}

func (t *transaction) flush() (err error) {
for len(t.out) > 0 {
err = t.transport.write(t.out[0])
if err != nil {
break
}
t.out = t.out[1:]
}
return
}
41 changes: 41 additions & 0 deletions protocol/extend_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package protocol

import (
"bufio"
"bytes"
"github.com/stretchr/testify/require"
"testing"
)

func TestTransaction_Read(t *testing.T) {
buf := bytes.Buffer{}
comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf))
p := &Transport{W: comm, R: comm, initialized: true}
trans := &transaction{transport: p, in: []Message{}, out: []Message{}}

_, err := comm.Write([]byte{'P', 0, 0, 0, 4})
require.NoError(t, err)

err = comm.Flush()
require.NoError(t, err)

m, err := trans.Read()
require.NoError(t, err)
require.NotNil(t, m,
"expected to receive message from transaction. got nil")

require.Equalf(t, 1, len(trans.in),
"expected exactly 1 message in transaction incoming buffer. actual: %d", len(trans.in))

require.Equalf(t, byte('P'), trans.in[0].Type(),
"expected type of the only message in transaction incoming buffer to be 'P'. actual: %c", trans.in[0].Type())

require.Equalf(t, 0, len(trans.out),
"expected no message to exist in transaction's outgoing message buffer. actual buffer length: %d", len(trans.out))

err = trans.Write(CommandComplete(""))
require.NoError(t, err)

require.Equalf(t, 1, len(trans.out),
"expected exactly one message in transaction's outgoind message buffer. actual messages count: %d", len(trans.out))
}
Loading

0 comments on commit 8aa3249

Please sign in to comment.