diff --git a/protocol/handshake.go b/protocol/handshake.go index fb8e203..2ecc7ed 100644 --- a/protocol/handshake.go +++ b/protocol/handshake.go @@ -43,7 +43,7 @@ func (h *Handshake) Init() (res Message, err error) { } // read the initial connection startup message - res, err = h.readRawMessage() + res, err = h.Read() if err != nil { return nil, err } @@ -60,7 +60,7 @@ func (h *Handshake) Init() (res Message, err error) { return nil, err } - res, err = h.readRawMessage() + res, err = h.Read() if err != nil { return nil, err } @@ -77,7 +77,6 @@ func (h *Handshake) Init() (res Message, err error) { h.passed = true return res, nil - } func (h *Handshake) readTypedMessage() (Message, error) { diff --git a/protocol/handshake_test.go b/protocol/handshake_test.go index e673c19..c03b357 100644 --- a/protocol/handshake_test.go +++ b/protocol/handshake_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -func TestHandshake_Do(t *testing.T) { +func TestHandshake_Init(t *testing.T) { t.Run("supported protocol version", func(t *testing.T) { buf := bytes.Buffer{} comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf)) @@ -45,7 +45,7 @@ func TestHandshake_Do(t *testing.T) { require.Error(t, err, "expected error of unsupported version. got none") }) - t.Run("do twice returns an error", func(t *testing.T) { + t.Run("call init twice returns an error", func(t *testing.T) { buf := bytes.Buffer{} comm := bufio.NewReadWriter(bufio.NewReader(&buf), bufio.NewWriter(&buf)) handshake := NewHandshake(comm) diff --git a/protocol/transport.go b/protocol/transport.go index 2a23a37..6a2cb18 100644 --- a/protocol/transport.go +++ b/protocol/transport.go @@ -22,7 +22,7 @@ type Transport struct { } func (t *Transport) beginTransaction() { - t.transaction = &transaction{transport: t, in: []pgproto3.FrontendMessage{}, out: []Message{}} + t.transaction = &transaction{transport: t} } func (t *Transport) endTransaction() (err error) { diff --git a/session_test.go b/session_test.go index 4e1dcca..b6e54f3 100644 --- a/session_test.go +++ b/session_test.go @@ -271,6 +271,5 @@ func TestSession_startUp(t *testing.T) { _ = s.startUp() require.NoError(t, err) require.Equal(t, true, canceled) - }) }