Skip to content

Commit

Permalink
Add upgrade request/response callbacks
Browse files Browse the repository at this point in the history
These callbacks are invoked during the handshake just before the
upgrade request is sent and just after the response is received.
  • Loading branch information
TobiasBrandt-Talos authored and ethanf committed Jun 18, 2024
1 parent 019d280 commit 0fd2669
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
21 changes: 21 additions & 0 deletions codec/websocket/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package websocket

import (
"net"
"net/http"
"time"

"github.com/talostrading/sonic"
Expand Down Expand Up @@ -107,6 +108,8 @@ func (s StreamState) String() string {
type AsyncMessageHandler = func(err error, n int, mt MessageType)
type AsyncFrameHandler = func(err error, f *Frame)
type ControlCallback = func(mt MessageType, payload []byte)
type UpgradeRequestCallback = func(req *http.Request)
type UpgradeResponseCallback = func(res *http.Response)

type Header struct {
Key string
Expand Down Expand Up @@ -357,6 +360,24 @@ type Stream interface {
// ControlCallback returns the control callback set with SetControlCallback.
ControlCallback() ControlCallback

// SetUpgradeRequestCallback sets a function that will be invoked during the handshake
// just before the upgrade request is sent.
//
// The caller must not perform any operations on the stream in the provided callback.
SetUpgradeRequestCallback(callback UpgradeRequestCallback)

// UpgradeRequestCallback returns the callback set with SetUpgradeRequestCallback.
UpgradeRequestCallback() UpgradeRequestCallback

// SetUpgradeResponseCallback sets a function that will be invoked during the handshake
// just after the upgrade response is received.
//
// The caller must not perform any operations on the stream in the provided callback.
SetUpgradeResponseCallback(callback UpgradeResponseCallback)

// UpgradeResponseCallback returns the callback set with SetUpgradeResponseCallback.
UpgradeResponseCallback() UpgradeResponseCallback

// SetMaxMessageSize sets the maximum size of a message that can be read
// from or written to a peer.
// - If a message exceeds the limit while reading, the connection is
Expand Down
30 changes: 30 additions & 0 deletions codec/websocket/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ type WebsocketStream struct {
// Optional callback invoked when a control frame is received.
ccb ControlCallback

// Optional callback invoked when an upgrade request is sent.
upReqCb UpgradeRequestCallback

// Optional callback invoked when an upgrade response is received.
upResCb UpgradeResponseCallback

// Used to establish a TCP connection to the peer with a timeout.
dialer *net.Dialer

Expand Down Expand Up @@ -788,6 +794,10 @@ func (s *WebsocketStream) upgrade(
}
}

if s.upReqCb != nil {
s.upReqCb(req)
}

err = req.Write(stream)
if err != nil {
return err
Expand Down Expand Up @@ -820,6 +830,10 @@ func (s *WebsocketStream) upgrade(
}
s.hb = s.hb[:0]

if s.upResCb != nil {
s.upResCb(res)
}

if !IsUpgradeRes(res) {
return ErrCannotUpgrade
}
Expand Down Expand Up @@ -867,6 +881,22 @@ func (s *WebsocketStream) ControlCallback() ControlCallback {
return s.ccb
}

func (s *WebsocketStream) SetUpgradeRequestCallback(upReqCb UpgradeRequestCallback) {
s.upReqCb = upReqCb
}

func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback {
return s.upReqCb
}

func (s *WebsocketStream) SetUpgradeResponseCallback(upResCb UpgradeResponseCallback) {
s.upResCb = upResCb
}

func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback {
return s.upResCb
}

func (s *WebsocketStream) SetMaxMessageSize(bytes int) {
// This is just for checking against the length returned in the frame
// header. The sizes of the buffers in which we read or write the messages
Expand Down
21 changes: 21 additions & 0 deletions codec/websocket/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"io"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -174,13 +175,33 @@ func TestClientSuccessfulHandshake(t *testing.T) {
t.Fatal(err)
}

var upgReqCbCalled, upgResCbCalled bool
ws.SetUpgradeRequestCallback(func(req *http.Request) {
upgReqCbCalled = true
if val := req.Header.Get("Upgrade"); val != "websocket" {
t.Fatalf("invalid Upgrade header in request: given=%s expected=%s", val, "websocket")
}
})
ws.SetUpgradeResponseCallback(func(res *http.Response) {
upgResCbCalled = true
if val := res.Header.Get("Upgrade"); val != "websocket" {
t.Fatalf("invalid Upgrade header in response: given=%s expected=%s", val, "websocket")
}
})

assertState(t, ws, StateHandshake)

ws.AsyncHandshake("ws://localhost:8080", func(err error) {
if err != nil {
assertState(t, ws, StateTerminated)
} else {
assertState(t, ws, StateActive)
if !upgReqCbCalled {
t.Fatal("upgrade request callback not invoked")
}
if !upgResCbCalled {
t.Fatal("upgrade response callback not invoked")
}
}
})

Expand Down

0 comments on commit 0fd2669

Please sign in to comment.