Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upgrade request/response callbacks #132

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions codec/websocket/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package websocket

import (
"net"
"net/http"
"time"

"github.com/talostrading/sonic"
Expand Down Expand Up @@ -107,6 +108,8 @@ func (s StreamState) String() string {
type AsyncMessageHandler = func(err error, n int, mt MessageType)
type AsyncFrameHandler = func(err error, f *Frame)
type ControlCallback = func(mt MessageType, payload []byte)
type UpgradeRequestCallback = func(req *http.Request)
type UpgradeResponseCallback = func(res *http.Response)

type Header struct {
Key string
Expand Down Expand Up @@ -357,6 +360,24 @@ type Stream interface {
// ControlCallback returns the control callback set with SetControlCallback.
ControlCallback() ControlCallback

// SetUpgradeRequestCallback sets a function that will be invoked during the handshake
// just before the upgrade request is sent.
//
// The caller must not perform any operations on the stream in the provided callback.
SetUpgradeRequestCallback(callback UpgradeRequestCallback)
Copy link
Collaborator

@sergiu128 sergiu128 Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this is here in order to provide the http.Request to the caller for potential modification... which for the handshake means adding more headers. If so, we already have this mechanism of adding extra headers to the handshake, see here (this is used in bullish iirc)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for logging the request (at least for my use case).


// UpgradeRequestCallback returns the callback set with SetUpgradeRequestCallback.
UpgradeRequestCallback() UpgradeRequestCallback

// SetUpgradeResponseCallback sets a function that will be invoked during the handshake
// just after the upgrade response is received.
//
// The caller must not perform any operations on the stream in the provided callback.
SetUpgradeResponseCallback(callback UpgradeResponseCallback)

// UpgradeResponseCallback returns the callback set with SetUpgradeResponseCallback.
UpgradeResponseCallback() UpgradeResponseCallback

// SetMaxMessageSize sets the maximum size of a message that can be read
// from or written to a peer.
// - If a message exceeds the limit while reading, the connection is
Expand Down
30 changes: 30 additions & 0 deletions codec/websocket/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ type WebsocketStream struct {
// Optional callback invoked when a control frame is received.
ccb ControlCallback

// Optional callback invoked when an upgrade request is sent.
upReqCb UpgradeRequestCallback

// Optional callback invoked when an upgrade response is received.
upResCb UpgradeResponseCallback

// Used to establish a TCP connection to the peer with a timeout.
dialer *net.Dialer

Expand Down Expand Up @@ -788,6 +794,10 @@ func (s *WebsocketStream) upgrade(
}
}

if s.upReqCb != nil {
s.upReqCb(req)
}

err = req.Write(stream)
if err != nil {
return err
Expand Down Expand Up @@ -820,6 +830,10 @@ func (s *WebsocketStream) upgrade(
}
s.hb = s.hb[:0]

if s.upResCb != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be called after we assert that the handshake response is valid below? Or is there a use-case to snoop into a potentially invalid response?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter. The main motivation for this is to be able to log the request/response so we can debug failed handshakes.

s.upResCb(res)
}

if !IsUpgradeRes(res) {
return ErrCannotUpgrade
}
Expand Down Expand Up @@ -867,6 +881,22 @@ func (s *WebsocketStream) ControlCallback() ControlCallback {
return s.ccb
}

func (s *WebsocketStream) SetUpgradeRequestCallback(upReqCb UpgradeRequestCallback) {
s.upReqCb = upReqCb
}

func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback {
return s.upReqCb
}

func (s *WebsocketStream) SetUpgradeResponseCallback(upResCb UpgradeResponseCallback) {
s.upResCb = upResCb
}

func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback {
return s.upResCb
}

func (s *WebsocketStream) SetMaxMessageSize(bytes int) {
// This is just for checking against the length returned in the frame
// header. The sizes of the buffers in which we read or write the messages
Expand Down
21 changes: 21 additions & 0 deletions codec/websocket/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"io"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -174,13 +175,33 @@ func TestClientSuccessfulHandshake(t *testing.T) {
t.Fatal(err)
}

var upgReqCbCalled, upgResCbCalled bool
ws.SetUpgradeRequestCallback(func(req *http.Request) {
upgReqCbCalled = true
if val := req.Header.Get("Upgrade"); val != "websocket" {
t.Fatalf("invalid Upgrade header in request: given=%s expected=%s", val, "websocket")
}
})
ws.SetUpgradeResponseCallback(func(res *http.Response) {
upgResCbCalled = true
if val := res.Header.Get("Upgrade"); val != "websocket" {
t.Fatalf("invalid Upgrade header in response: given=%s expected=%s", val, "websocket")
}
})

assertState(t, ws, StateHandshake)

ws.AsyncHandshake("ws://localhost:8080", func(err error) {
if err != nil {
assertState(t, ws, StateTerminated)
} else {
assertState(t, ws, StateActive)
if !upgReqCbCalled {
t.Fatal("upgrade request callback not invoked")
}
if !upgResCbCalled {
t.Fatal("upgrade response callback not invoked")
}
}
})

Expand Down
Loading