Skip to content

Commit

Permalink
Merge pull request panoplyio#40 from avivklas/extended-query-messages
Browse files Browse the repository at this point in the history
fixed a bug in startup phase
  • Loading branch information
aviaoh authored Jun 20, 2019
2 parents e03085a + 3312c34 commit 0b7ecaf
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 26 deletions.
2 changes: 1 addition & 1 deletion protocol/extend.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type transaction struct {

// Read uses Transport to read the next message into the transaction's incoming messages buffer
func (t *transaction) Read() (msg Message, err error) {
if msg, err = t.transport.read(); err == nil {
if msg, err = t.transport.Read(); err == nil {
t.in = append(t.in, msg)
}
return
Expand Down
26 changes: 16 additions & 10 deletions protocol/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Transport struct {
// StartUp handles the very first messages exchange between frontend and backend of new session
func (t *Transport) StartUp() (Message, error) {
// read the initial connection startup message
msg, err := t.read()
msg, err := t.Read()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -71,22 +71,27 @@ func (t *Transport) endTransaction() (err error) {
return
}

// Read reads and returns a single message from the connection.
// Read expects to be called only after a call to StartUp without an error response
// NextMessage reads and returns a single message from the connection when available.
// if within a transaction, the transaction will read from the connection,
// otherwise a ReadyForQuery message will first be sent to the frontend and then reading
// a single message from the connection will happen
//
// NextMessage expects to be called only after a call to StartUp without an error response
// otherwise, an error is returned
func (t *Transport) Read() (msg Message, err error) {
func (t *Transport) NextMessage() (msg Message, err error) {
if !t.initialized {
err = fmt.Errorf("transport not yet initialized")
return
}
if t.transaction != nil {
msg, err = t.transaction.Read()
} else {
if !t.initialized {
err = fmt.Errorf("transport not yet initialized")
return
}
// when not in transaction, client waits for ReadyForQuery before sending next message
err = t.Write(ReadyForQuery)
if err != nil {
return
}
msg, err = t.read()
msg, err = t.Read()
}
if err != nil {
return
Expand All @@ -101,7 +106,8 @@ func (t *Transport) Read() (msg Message, err error) {
return
}

func (t *Transport) read() (Message, error) {
// Read reads and returns a single message from the connection.
func (t *Transport) Read() (Message, error) {
typeChar := make([]byte, 1)

if t.initialized {
Expand Down
28 changes: 14 additions & 14 deletions protocol/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ func TestProtocol_Read(t *testing.T) {
frontend, err := pgproto3.NewFrontend(f, f)
require.NoError(t, err)

p := NewTransport(b, b)
p.initialized = true
transport := NewTransport(b, b)
transport.initialized = true

msg := make(chan Message)
go func() {
m, err := p.Read()
m, err := transport.NextMessage()
require.NoError(t, err)

msg <- m
Expand All @@ -106,19 +106,19 @@ func TestProtocol_Read(t *testing.T) {
res := <-msg
require.Equalf(t, byte('Q'), res.Type(), "expected protocol to identify sent message as type 'Q'. actual: %c", res.Type())

require.Nil(t, p.transaction, "expected protocol not to start transaction")
require.Nil(t, transport.transaction, "expected protocol not to start transaction")
})

t.Run("extended query message flow", func(t *testing.T) {
t.Run("starts transaction", func(t *testing.T) {
f, b := net.Pipe()

p := NewTransport(b, b)
p.initialized = true
transport := NewTransport(b, b)
transport.initialized = true

go func() {
for {
_, err := p.Read()
_, err := transport.NextMessage()
require.NoError(t, err)
}
}()
Expand All @@ -130,25 +130,25 @@ func TestProtocol_Read(t *testing.T) {
})
require.NoError(t, err)

require.NotNil(t, p.transaction, "expected protocol to start transaction")
require.NotNil(t, transport.transaction, "expected protocol to start transaction")
})

t.Run("ends transaction", func(t *testing.T) {
f, b := net.Pipe()

p := NewTransport(b, b)
p.initialized = true
transport := NewTransport(b, b)
transport.initialized = true

go func() {
for {
m, err := p.Read()
m, err := transport.NextMessage()
require.NoError(t, err)

switch m.Type() {
case Parse:
err = p.Write(ParseComplete)
err = transport.Write(ParseComplete)
case Bind:
err = p.Write(BindComplete)
err = transport.Write(BindComplete)
}
require.NoError(t, err)
}
Expand All @@ -166,7 +166,7 @@ func TestProtocol_Read(t *testing.T) {

require.NoError(t, err)

require.Nil(t, p.transaction, "expected protocol to end transaction")
require.Nil(t, transport.transaction, "expected protocol to end transaction")
})
})
}
2 changes: 1 addition & 1 deletion sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (s *session) Serve() error {

// query-cycle
for {
msg, err = t.Read()
msg, err = t.NextMessage()
if err != nil {
return err
}
Expand Down

0 comments on commit 0b7ecaf

Please sign in to comment.