Skip to content

Commit

Permalink
rearranged code for robustness and clearness
Browse files Browse the repository at this point in the history
  • Loading branch information
avivklas committed Jun 23, 2019
1 parent 41825bc commit 1bde08f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 78 deletions.
10 changes: 0 additions & 10 deletions protocol/extend.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@ var ParseComplete = []byte{'1', 0, 0, 0, 4}
// BindComplete is sent when backend prepared a portal and finished planning the query
var BindComplete = []byte{'2', 0, 0, 0, 4}

// CreatesTransaction tells weather this is a frontend message that should start/continue a transaction
func (m *Message) CreatesTransaction() bool {
return m.Type() == Parse || m.Type() == Bind
}

// EndsTransaction tells weather this is a frontend message that should end the current transaction
func (m *Message) EndsTransaction() bool {
return m.Type() == Query || m.Type() == Sync
}

// transaction represents a sequence of frontend and backend messages
// that apply only on commit. the purpose of transaction is to support
// extended query flow.
Expand Down
116 changes: 58 additions & 58 deletions protocol/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,50 @@ import (

// NewTransport creates a protocol
func NewTransport(r io.Reader, w io.Writer) *Transport {
backend, _ := pgproto3.NewBackend(r, nil)
return &Transport{
R: r,
W: w,
frontReader: backend,
w: w,
r: newReader(r),
}
}

// Transport manages the underlying wire protocol between backend and frontend.
type Transport struct {
R io.Reader
W io.Writer
frontReader *pgproto3.Backend
w io.Writer
r *reader
initialized bool
transaction *transaction
}

// StartUp handles the very first messages exchange between frontend and backend of new session
func (t *Transport) StartUp() (Message, error) {
func (t *Transport) StartUp() (msg Message, err error) {
// read the initial connection startup message
raw, err := t.readBody()
msg, err = t.r.readRawMessage()
if err != nil {
return nil, err
}

msg := Message(raw)

if msg.IsCancel() {
return msg, nil
}

// ssl request. see: SSLRequest in https://www.postgresql.org/docs/current/protocol-message-formats.html
if msg.IsTLSRequest() {
// currently we don't support TLS.
err := t.Write(TLSResponse(false))
err = t.Write(TLSResponse(false))
if err != nil {
return nil, err
}

raw, err := t.readBody()
msg, err = t.r.readRawMessage()
if err != nil {
return nil, err
}
msg = Message(raw)
}

v, err := msg.StartupVersion()
if err != nil {
return nil, err
}

if v != "3.0" {
return nil, fmt.Errorf("unsupported protocol version %s", v)
}
Expand All @@ -78,14 +72,14 @@ func (t *Transport) endTransaction() (err error) {
return
}

// NextMessage reads and returns a single message from the connection when available.
// NextFrontendMessage reads and returns a single message from the connection when available.
// if within a transaction, the transaction will read from the connection,
// otherwise a ReadyForQuery message will first be sent to the frontend and then reading
// a single message from the connection will happen
//
// NextMessage expects to be called only after a call to StartUp without an error response
// NextFrontendMessage expects to be called only after a call to StartUp without an error response
// otherwise, an error is returned
func (t *Transport) NextMessage() (msg pgproto3.FrontendMessage, err error) {
func (t *Transport) NextFrontendMessage() (msg pgproto3.FrontendMessage, err error) {
if !t.initialized {
err = fmt.Errorf("transport not yet initialized")
return
Expand All @@ -98,7 +92,7 @@ func (t *Transport) NextMessage() (msg pgproto3.FrontendMessage, err error) {
if err != nil {
return
}
msg, err = t.readFrontendMessage()
msg, err = t.r.readFrontendMessage()
}
if err != nil {
return
Expand All @@ -120,53 +114,61 @@ func (t *Transport) NextMessage() (msg pgproto3.FrontendMessage, err error) {
}

func (t *Transport) readFrontendMessage() (pgproto3.FrontendMessage, error) {
return t.frontReader.Receive()
return t.r.readFrontendMessage()
}

// Read reads and returns a single message from the connection.
func (t *Transport) Read() (Message, error) {
typeChar := make([]byte, 1)

func (t *Transport) Read() (msg Message, err error) {
if t.initialized {
// we've already started up, so all future messages are MUST start with
// a single-byte type identifier.
_, err := t.R.Read(typeChar)
if err != nil {
return nil, err
}
return t.r.readTypedMessage()
}
// read the actual body of the message
msg, err := t.readBody()
if err != nil {
return nil, err
return t.r.readRawMessage()
}

// Write writes the provided message to the client connection
func (t *Transport) Write(m Message) error {
if t.transaction != nil {
return t.transaction.Write(m)
}
return t.write(m)
}

if typeChar[0] != 0 {
func (t *Transport) write(m Message) error {
_, err := t.w.Write(m)
return err
}

// we have a typed-message, prepend it to the message body by first
// creating a new message that's 1-byte longer than the body in order to
// make room in memory for the type byte
body := msg
msg = make([]byte, len(body)+1)
func newReader(r io.Reader) *reader {
return &reader{r: r}
}

// fixing the type byte at the beginning (position 0) of the new message
msg[0] = typeChar[0]
type reader struct {
r io.Reader
frontReader *pgproto3.Backend
}

// finally append the body to the new message, starting from position 1
copy(msg[1:], body)
func (r *reader) readTypedMessage() (Message, error) {
msgType := Message(make([]byte, 1))
_, err := r.r.Read(msgType)
if err != nil {
return nil, err
}

return Message(msg), nil
body, err := r.readRawMessage()
if err != nil {
return nil, err
}
return append(msgType, body...), nil
}

// readBody reads the body of the next message in the connection. The body is
// readRawMessage reads un-typed message in the connection. The message is
// comprised of an Int32 body-length (N), inclusive of the length itself
// followed by N-bytes of the actual body.
func (t *Transport) readBody() ([]byte, error) {
func (r *reader) readRawMessage() (Message, error) {
// messages starts with an Int32 Length of message contents in bytes,
// including self.
lenBytes := make([]byte, 4)
_, err := io.ReadFull(t.R, lenBytes)
_, err := io.ReadFull(r.r, lenBytes)
if err != nil {
return nil, err
}
Expand All @@ -176,7 +178,7 @@ func (t *Transport) readBody() ([]byte, error) {

// read the remaining bytes in the message
msg := make([]byte, length)
_, err = io.ReadFull(t.R, msg[4:]) // keep 4 bytes for the length
_, err = io.ReadFull(r.r, msg[4:]) // keep 4 bytes for the length
if err != nil {
return nil, err
}
Expand All @@ -187,15 +189,13 @@ func (t *Transport) readBody() ([]byte, error) {
return msg, nil
}

// Write writes the provided message to the client connection
func (t *Transport) Write(m Message) error {
if t.transaction != nil {
return t.transaction.Write(m)
// readFrontendMessage reads and returns a single decoded typed message from the connection.
func (r *reader) readFrontendMessage() (msg pgproto3.FrontendMessage, err error) {
if r.frontReader == nil {
r.frontReader, err = pgproto3.NewBackend(r.r, nil)
if err != nil {
return
}
}
return t.write(m)
}

func (t *Transport) write(m Message) error {
_, err := t.W.Write(m)
return err
return r.frontReader.Receive()
}
19 changes: 10 additions & 9 deletions protocol/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
"time"
)

func TestProtocol_StartUp(t *testing.T) {
func TestTransport_StartUp(t *testing.T) {
t.Run("supported protocol version", func(t *testing.T) {
buf := bytes.Buffer{}
comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf))
p := &Transport{W: comm, R: comm}
transport := NewTransport(comm, comm)

_, err := comm.Write([]byte{
0, 0, 0, 8, // length
Expand All @@ -29,14 +29,15 @@ func TestProtocol_StartUp(t *testing.T) {
err = comm.Flush()
require.NoError(t, err)

_, err = p.StartUp()
_, err = transport.StartUp()
require.NoError(t, err)
require.Equal(t, true, transport.initialized)
})

t.Run("unsupported protocol version", func(t *testing.T) {
buf := bytes.Buffer{}
comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf))
p := &Transport{W: comm, R: comm}
transport := NewTransport(comm, comm)

_, err := comm.Write([]byte{
0, 0, 0, 8, // length
Expand All @@ -48,7 +49,7 @@ func TestProtocol_StartUp(t *testing.T) {
err = comm.Flush()
require.NoError(t, err)

_, err = p.StartUp()
_, err = transport.StartUp()
require.Error(t, err, "expected error of unsupported version. got none")
})
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func runStory(t *testing.T, conn io.ReadWriter, steps []pgstories.Step) error {
return err
}

func TestProtocol_Read(t *testing.T) {
func TestTransport_Read(t *testing.T) {
t.Run("standard message flow", func(t *testing.T) {
f, b := net.Pipe()

Expand All @@ -90,7 +91,7 @@ func TestProtocol_Read(t *testing.T) {

msg := make(chan pgproto3.FrontendMessage)
go func() {
m, err := transport.NextMessage()
m, err := transport.NextFrontendMessage()
require.NoError(t, err)

msg <- m
Expand Down Expand Up @@ -120,7 +121,7 @@ func TestProtocol_Read(t *testing.T) {

go func() {
for {
_, err := transport.NextMessage()
_, err := transport.NextFrontendMessage()
require.NoError(t, err)
}
}()
Expand All @@ -143,7 +144,7 @@ func TestProtocol_Read(t *testing.T) {

go func() {
for {
m, err := transport.NextMessage()
m, err := transport.NextFrontendMessage()
require.NoError(t, err)

err = nil
Expand Down
2 changes: 1 addition & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (s *session) Serve() error {

// query-cycle
for {
msg, err := t.NextMessage()
msg, err := t.NextFrontendMessage()
if err != nil {
return err
}
Expand Down

0 comments on commit 1bde08f

Please sign in to comment.