diff --git a/protocol/extend.go b/protocol/extend.go index 4bfe5b6..ec5f8dd 100644 --- a/protocol/extend.go +++ b/protocol/extend.go @@ -8,16 +8,6 @@ var ParseComplete = []byte{'1', 0, 0, 0, 4} // BindComplete is sent when backend prepared a portal and finished planning the query var BindComplete = []byte{'2', 0, 0, 0, 4} -// CreatesTransaction tells weather this is a frontend message that should start/continue a transaction -func (m *Message) CreatesTransaction() bool { - return m.Type() == Parse || m.Type() == Bind -} - -// EndsTransaction tells weather this is a frontend message that should end the current transaction -func (m *Message) EndsTransaction() bool { - return m.Type() == Query || m.Type() == Sync -} - // transaction represents a sequence of frontend and backend messages // that apply only on commit. the purpose of transaction is to support // extended query flow. diff --git a/protocol/transport.go b/protocol/transport.go index d9902ed..e422ab2 100644 --- a/protocol/transport.go +++ b/protocol/transport.go @@ -9,56 +9,50 @@ import ( // NewTransport creates a protocol func NewTransport(r io.Reader, w io.Writer) *Transport { - backend, _ := pgproto3.NewBackend(r, nil) return &Transport{ - R: r, - W: w, - frontReader: backend, + w: w, + r: newReader(r), } } // Transport manages the underlying wire protocol between backend and frontend. type Transport struct { - R io.Reader - W io.Writer - frontReader *pgproto3.Backend + w io.Writer + r *reader initialized bool transaction *transaction } // StartUp handles the very first messages exchange between frontend and backend of new session -func (t *Transport) StartUp() (Message, error) { +func (t *Transport) StartUp() (msg Message, err error) { // read the initial connection startup message - raw, err := t.readBody() + msg, err = t.r.readRawMessage() if err != nil { return nil, err } - msg := Message(raw) - if msg.IsCancel() { return msg, nil } + // ssl request. see: SSLRequest in https://www.postgresql.org/docs/current/protocol-message-formats.html if msg.IsTLSRequest() { // currently we don't support TLS. - err := t.Write(TLSResponse(false)) + err = t.Write(TLSResponse(false)) if err != nil { return nil, err } - raw, err := t.readBody() + msg, err = t.r.readRawMessage() if err != nil { return nil, err } - msg = Message(raw) } v, err := msg.StartupVersion() if err != nil { return nil, err } - if v != "3.0" { return nil, fmt.Errorf("unsupported protocol version %s", v) } @@ -78,14 +72,14 @@ func (t *Transport) endTransaction() (err error) { return } -// NextMessage reads and returns a single message from the connection when available. +// NextFrontendMessage reads and returns a single message from the connection when available. // if within a transaction, the transaction will read from the connection, // otherwise a ReadyForQuery message will first be sent to the frontend and then reading // a single message from the connection will happen // -// NextMessage expects to be called only after a call to StartUp without an error response +// NextFrontendMessage expects to be called only after a call to StartUp without an error response // otherwise, an error is returned -func (t *Transport) NextMessage() (msg pgproto3.FrontendMessage, err error) { +func (t *Transport) NextFrontendMessage() (msg pgproto3.FrontendMessage, err error) { if !t.initialized { err = fmt.Errorf("transport not yet initialized") return @@ -98,7 +92,7 @@ func (t *Transport) NextMessage() (msg pgproto3.FrontendMessage, err error) { if err != nil { return } - msg, err = t.readFrontendMessage() + msg, err = t.r.readFrontendMessage() } if err != nil { return @@ -120,53 +114,61 @@ func (t *Transport) NextMessage() (msg pgproto3.FrontendMessage, err error) { } func (t *Transport) readFrontendMessage() (pgproto3.FrontendMessage, error) { - return t.frontReader.Receive() + return t.r.readFrontendMessage() } // Read reads and returns a single message from the connection. -func (t *Transport) Read() (Message, error) { - typeChar := make([]byte, 1) - +func (t *Transport) Read() (msg Message, err error) { if t.initialized { - // we've already started up, so all future messages are MUST start with - // a single-byte type identifier. - _, err := t.R.Read(typeChar) - if err != nil { - return nil, err - } + return t.r.readTypedMessage() } - // read the actual body of the message - msg, err := t.readBody() - if err != nil { - return nil, err + return t.r.readRawMessage() +} + +// Write writes the provided message to the client connection +func (t *Transport) Write(m Message) error { + if t.transaction != nil { + return t.transaction.Write(m) } + return t.write(m) +} - if typeChar[0] != 0 { +func (t *Transport) write(m Message) error { + _, err := t.w.Write(m) + return err +} - // we have a typed-message, prepend it to the message body by first - // creating a new message that's 1-byte longer than the body in order to - // make room in memory for the type byte - body := msg - msg = make([]byte, len(body)+1) +func newReader(r io.Reader) *reader { + return &reader{r: r} +} - // fixing the type byte at the beginning (position 0) of the new message - msg[0] = typeChar[0] +type reader struct { + r io.Reader + frontReader *pgproto3.Backend +} - // finally append the body to the new message, starting from position 1 - copy(msg[1:], body) +func (r *reader) readTypedMessage() (Message, error) { + msgType := Message(make([]byte, 1)) + _, err := r.r.Read(msgType) + if err != nil { + return nil, err } - return Message(msg), nil + body, err := r.readRawMessage() + if err != nil { + return nil, err + } + return append(msgType, body...), nil } -// readBody reads the body of the next message in the connection. The body is +// readRawMessage reads un-typed message in the connection. The message is // comprised of an Int32 body-length (N), inclusive of the length itself // followed by N-bytes of the actual body. -func (t *Transport) readBody() ([]byte, error) { +func (r *reader) readRawMessage() (Message, error) { // messages starts with an Int32 Length of message contents in bytes, // including self. lenBytes := make([]byte, 4) - _, err := io.ReadFull(t.R, lenBytes) + _, err := io.ReadFull(r.r, lenBytes) if err != nil { return nil, err } @@ -176,7 +178,7 @@ func (t *Transport) readBody() ([]byte, error) { // read the remaining bytes in the message msg := make([]byte, length) - _, err = io.ReadFull(t.R, msg[4:]) // keep 4 bytes for the length + _, err = io.ReadFull(r.r, msg[4:]) // keep 4 bytes for the length if err != nil { return nil, err } @@ -187,15 +189,13 @@ func (t *Transport) readBody() ([]byte, error) { return msg, nil } -// Write writes the provided message to the client connection -func (t *Transport) Write(m Message) error { - if t.transaction != nil { - return t.transaction.Write(m) +// readFrontendMessage reads and returns a single decoded typed message from the connection. +func (r *reader) readFrontendMessage() (msg pgproto3.FrontendMessage, err error) { + if r.frontReader == nil { + r.frontReader, err = pgproto3.NewBackend(r.r, nil) + if err != nil { + return + } } - return t.write(m) -} - -func (t *Transport) write(m Message) error { - _, err := t.W.Write(m) - return err + return r.frontReader.Receive() } diff --git a/protocol/transport_test.go b/protocol/transport_test.go index 5f33575..0179727 100644 --- a/protocol/transport_test.go +++ b/protocol/transport_test.go @@ -13,11 +13,11 @@ import ( "time" ) -func TestProtocol_StartUp(t *testing.T) { +func TestTransport_StartUp(t *testing.T) { t.Run("supported protocol version", func(t *testing.T) { buf := bytes.Buffer{} comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf)) - p := &Transport{W: comm, R: comm} + transport := NewTransport(comm, comm) _, err := comm.Write([]byte{ 0, 0, 0, 8, // length @@ -29,14 +29,15 @@ func TestProtocol_StartUp(t *testing.T) { err = comm.Flush() require.NoError(t, err) - _, err = p.StartUp() + _, err = transport.StartUp() require.NoError(t, err) + require.Equal(t, true, transport.initialized) }) t.Run("unsupported protocol version", func(t *testing.T) { buf := bytes.Buffer{} comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf)) - p := &Transport{W: comm, R: comm} + transport := NewTransport(comm, comm) _, err := comm.Write([]byte{ 0, 0, 0, 8, // length @@ -48,7 +49,7 @@ func TestProtocol_StartUp(t *testing.T) { err = comm.Flush() require.NoError(t, err) - _, err = p.StartUp() + _, err = transport.StartUp() require.Error(t, err, "expected error of unsupported version. got none") }) } @@ -78,7 +79,7 @@ func runStory(t *testing.T, conn io.ReadWriter, steps []pgstories.Step) error { return err } -func TestProtocol_Read(t *testing.T) { +func TestTransport_Read(t *testing.T) { t.Run("standard message flow", func(t *testing.T) { f, b := net.Pipe() @@ -90,7 +91,7 @@ func TestProtocol_Read(t *testing.T) { msg := make(chan pgproto3.FrontendMessage) go func() { - m, err := transport.NextMessage() + m, err := transport.NextFrontendMessage() require.NoError(t, err) msg <- m @@ -120,7 +121,7 @@ func TestProtocol_Read(t *testing.T) { go func() { for { - _, err := transport.NextMessage() + _, err := transport.NextFrontendMessage() require.NoError(t, err) } }() @@ -143,7 +144,7 @@ func TestProtocol_Read(t *testing.T) { go func() { for { - m, err := transport.NextMessage() + m, err := transport.NextFrontendMessage() require.NoError(t, err) err = nil diff --git a/session.go b/session.go index cdad00d..301fa0b 100644 --- a/session.go +++ b/session.go @@ -92,7 +92,7 @@ func (s *session) Serve() error { // query-cycle for { - msg, err := t.NextMessage() + msg, err := t.NextFrontendMessage() if err != nil { return err }