Skip to content

Commit

Permalink
Parse prepared statements and bind portals (panoplyio#36)
Browse files Browse the repository at this point in the history
parse prepared statements and bind portals
  • Loading branch information
avivklas authored Jul 1, 2019
1 parent 38fa789 commit c6338ea
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 50 deletions.
18 changes: 18 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ func Unsupported(msg string, args ...interface{}) Err {
return &err{M: msg, C: "0A000", P: -1}
}

// InvalidSQLStatementName indicates that a referred statement name is
// unknown/missing to the server.
func InvalidSQLStatementName(stmtName string) Err {
msg := fmt.Sprintf("prepared statement \"%s\" does not exist", stmtName)
return &err{M: msg, C: "26000", P: -1}
}

// ProtocolViolation indicates that a provided typed message has an invalid value
func ProtocolViolation(msg string) Err {
return &err{M: msg, C: "08P01", P: -1}
}

// SyntaxError indicates that sent command is invalid
func SyntaxError(msg string, args ...interface{}) Err {
msg = fmt.Sprintf(msg, args...)
return &err{M: msg, C: "42601", P: -1, S: "ERROR"}
}

func fromErr(e error) *err {
err1, ok := e.(*err)
if ok {
Expand Down
36 changes: 35 additions & 1 deletion protocol/extend.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
package protocol

import "github.com/jackc/pgx/pgproto3"
import (
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
nodes "github.com/lfittl/pg_query_go/nodes"
)

// ParseComplete is sent when backend parsed a prepared statement successfully
var ParseComplete = []byte{'1', 0, 0, 0, 4}

// BindComplete is sent when backend prepared a portal and finished planning the query
var BindComplete = []byte{'2', 0, 0, 0, 4}

// Describe message object types
const (
DescribeStatement = 'S'
DescribePortal = 'P'
)

// ParameterDescription is sent when backend received Describe message from frontend
// with ObjectType = 'S' - requesting to describe prepared statement with a provided name
func ParameterDescription(ps *nodes.PrepareStmt) (Message, error) {
res := []byte{'t'}
sp := len(res)
res = pgio.AppendInt32(res, -1)

res = pgio.AppendUint16(res, uint16(len(ps.Argtypes.Items)))
for _, v := range ps.Argtypes.Items {
res = pgio.AppendUint32(res, uint32(v.(nodes.TypeName).TypeOid))
}

pgio.SetInt32(res[sp:], int32(len(res[sp:])))

return Message(res), nil
}

// transaction represents a sequence of frontend and backend messages
// that apply only on commit. the purpose of transaction is to support
// extended query flow.
Expand All @@ -27,10 +54,17 @@ func (t *transaction) NextFrontendMessage() (msg pgproto3.FrontendMessage, err e

// Write writes the provided message into the transaction's outgoing messages buffer
func (t *transaction) Write(msg Message) error {
if t.hasError() {
return nil
}
t.out = append(t.out, msg)
return nil
}

func (t *transaction) hasError() bool {
return len(t.out) > 0 && t.out[len(t.out)-1].IsError()
}

func (t *transaction) flush() (err error) {
for len(t.out) > 0 {
err = t.transport.write(t.out[0])
Expand Down
21 changes: 21 additions & 0 deletions protocol/message.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package protocol

import (
"fmt"
"github.com/jackc/pgx/pgproto3"
)

// frontend message types
const (
Terminate = 'X'
Expand All @@ -21,6 +26,22 @@ func (m Message) Type() byte {
return b
}

// IsError determines if the message is an ErrorResponse
func (m Message) IsError() bool {
return m.Type() == 'E'
}

// ErrorResponse parses message of type error and returns an object describes it
func (m Message) ErrorResponse() (res *pgproto3.ErrorResponse, err error) {
if !m.IsError() {
err = fmt.Errorf("message is not an error message")
return
}
res = &pgproto3.ErrorResponse{}
err = res.Decode(m[4:])
return
}

// MessageReadWriter describes objects that handle client-server communication.
// Objects implementing this interface are used by logic operations to send Message
// objects to frontend and receive Message back from it
Expand Down
45 changes: 40 additions & 5 deletions protocol/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ import (
"io"
)

// TransactionState is used as a return with every message read for commit and rollback implementation
type TransactionState int

const (
// TransactionUnknown is the default/unset value of this enum
TransactionUnknown TransactionState = iota
// NotInTransaction states that transaction is not active and operations should auto-commit
NotInTransaction
// InTransaction states that transaction is active and operations should not commit
InTransaction
// TransactionEnded states that the current transaction has finished and has to commit
TransactionEnded
// TransactionFailed states that the current transaction has failed and has to roll-back
TransactionFailed
)

// NewTransport creates a Transport
func NewTransport(rw io.ReadWriter) *Transport {
b, _ := pgproto3.NewBackend(rw, nil)
Expand Down Expand Up @@ -38,33 +54,52 @@ func (t *Transport) endTransaction() (err error) {
//
// NextFrontendMessage expects to be called only after a call to Handshake without an error response
// otherwise, an error is returned
func (t *Transport) NextFrontendMessage() (msg pgproto3.FrontendMessage, err error) {
if t.transaction != nil {
msg, err = t.transaction.NextFrontendMessage()
} else {
func (t *Transport) NextFrontendMessage() (msg pgproto3.FrontendMessage, ts TransactionState, err error) {
if t.transaction == nil {
// when not in transaction, client waits for ReadyForQuery before sending next message
err = t.Write(ReadyForQuery)
if err != nil {
return
}
msg, err = t.readFrontendMessage()
} else {
msg, err = t.transaction.NextFrontendMessage()
}
if err != nil {
return
}

ts, err = t.affectTransaction(msg)
return
}

func (t *Transport) affectTransaction(msg pgproto3.FrontendMessage) (ts TransactionState, err error) {
if t.transaction == nil {
switch msg.(type) {
case *pgproto3.Parse, *pgproto3.Bind, *pgproto3.Describe:
t.beginTransaction()
ts = InTransaction
default:
ts = NotInTransaction
}
} else {
if t.transaction.hasError() {
ts = TransactionFailed
}
switch msg.(type) {
case *pgproto3.Query, *pgproto3.Sync:
err = t.endTransaction()
if err != nil {
ts = TransactionFailed
} else if ts == TransactionUnknown {
ts = TransactionEnded
}
default:
if ts == TransactionUnknown {
ts = InTransaction
}
}
}

return
}

Expand Down
54 changes: 51 additions & 3 deletions protocol/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestTransport_Read(t *testing.T) {

msg := make(chan pgproto3.FrontendMessage)
go func() {
m, err := transport.NextFrontendMessage()
m, _, err := transport.NextFrontendMessage()
require.NoError(t, err)

msg <- m
Expand Down Expand Up @@ -76,7 +76,11 @@ func TestTransport_Read(t *testing.T) {

go func() {
for {
_, err := transport.NextFrontendMessage()
msg, ts, err := transport.NextFrontendMessage()
switch msg.(type) {
case *pgproto3.Parse, *pgproto3.Bind:
require.Equal(t, InTransaction, ts)
}
require.NoError(t, err)
}
}()
Expand All @@ -98,15 +102,19 @@ func TestTransport_Read(t *testing.T) {

go func() {
for {
m, err := transport.NextFrontendMessage()
m, ts, err := transport.NextFrontendMessage()
require.NoError(t, err)

err = nil
switch m.(type) {
case *pgproto3.Parse:
err = transport.Write(ParseComplete)
require.Equal(t, InTransaction, ts)
case *pgproto3.Bind:
err = transport.Write(BindComplete)
require.Equal(t, InTransaction, ts)
case *pgproto3.Sync:
require.Equal(t, TransactionEnded, ts)
}
require.NoError(t, err)
}
Expand All @@ -126,5 +134,45 @@ func TestTransport_Read(t *testing.T) {

require.Nil(t, transport.transaction, "expected protocol to end transaction")
})

t.Run("fails transaction", func(t *testing.T) {
f, b := net.Pipe()

transport := NewTransport(b)

go func() {
for {
m, ts, err := transport.NextFrontendMessage()
require.NoError(t, err)

err = nil
switch m.(type) {
case *pgproto3.Parse:
err = transport.Write(ParseComplete)
require.Equal(t, InTransaction, ts)
case *pgproto3.Bind:
require.Equal(t, InTransaction, ts)
err = transport.Write(ErrorResponse(fmt.Errorf("dosn't matter")))
case *pgproto3.Sync:
require.Equal(t, TransactionFailed, ts)
}
require.NoError(t, err)
}
}()

err := runStory(t, f, []pgstories.Step{
&pgstories.Response{BackendMessage: &pgproto3.ReadyForQuery{}},
&pgstories.Command{FrontendMessage: &pgproto3.Parse{}},
&pgstories.Command{FrontendMessage: &pgproto3.Bind{}},
&pgstories.Command{FrontendMessage: &pgproto3.Sync{}},
&pgstories.Response{BackendMessage: &pgproto3.ParseComplete{}},
&pgstories.Response{BackendMessage: &pgproto3.ErrorResponse{}},
&pgstories.Response{BackendMessage: &pgproto3.ReadyForQuery{}},
})

require.NoError(t, err)

require.Nil(t, transport.transaction, "expected protocol to end transaction")
})
})
}
11 changes: 10 additions & 1 deletion query.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@ func (q *query) Run(sess Session) error {
}

// determine if it's a query or command
switch stmt.(type) {
switch v := stmt.(type) {
case nodes.PrepareStmt:
s, ok := sess.(*session)
// only session implementation is capable of storing prepared stmts
if ok {
// we just store the statement and don't do anything
s.storePreparedStatement(&v)
} else {
return Unsupported("prepared statements")
}
case nodes.SelectStmt, nodes.VariableShowStmt:
err = q.Query(ctx, stmt)
default:
Expand Down
Loading

0 comments on commit c6338ea

Please sign in to comment.