From 0fd26691fdad149d4e4c57e2b7850de312b6ca12 Mon Sep 17 00:00:00 2001 From: Tobias Brandt Date: Mon, 10 Jun 2024 15:46:15 +0800 Subject: [PATCH] Add upgrade request/response callbacks These callbacks are invoked during the handshake just before the upgrade request is sent and just after the response is received. --- codec/websocket/definitions.go | 21 +++++++++++++++++++++ codec/websocket/stream.go | 30 ++++++++++++++++++++++++++++++ codec/websocket/stream_test.go | 21 +++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index 3ab11c57..209197a5 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -2,6 +2,7 @@ package websocket import ( "net" + "net/http" "time" "github.com/talostrading/sonic" @@ -107,6 +108,8 @@ func (s StreamState) String() string { type AsyncMessageHandler = func(err error, n int, mt MessageType) type AsyncFrameHandler = func(err error, f *Frame) type ControlCallback = func(mt MessageType, payload []byte) +type UpgradeRequestCallback = func(req *http.Request) +type UpgradeResponseCallback = func(res *http.Response) type Header struct { Key string @@ -357,6 +360,24 @@ type Stream interface { // ControlCallback returns the control callback set with SetControlCallback. ControlCallback() ControlCallback + // SetUpgradeRequestCallback sets a function that will be invoked during the handshake + // just before the upgrade request is sent. + // + // The caller must not perform any operations on the stream in the provided callback. + SetUpgradeRequestCallback(callback UpgradeRequestCallback) + + // UpgradeRequestCallback returns the callback set with SetUpgradeRequestCallback. + UpgradeRequestCallback() UpgradeRequestCallback + + // SetUpgradeResponseCallback sets a function that will be invoked during the handshake + // just after the upgrade response is received. + // + // The caller must not perform any operations on the stream in the provided callback. + SetUpgradeResponseCallback(callback UpgradeResponseCallback) + + // UpgradeResponseCallback returns the callback set with SetUpgradeResponseCallback. + UpgradeResponseCallback() UpgradeResponseCallback + // SetMaxMessageSize sets the maximum size of a message that can be read // from or written to a peer. // - If a message exceeds the limit while reading, the connection is diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 3dc2e2a1..8cbc8c8b 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -81,6 +81,12 @@ type WebsocketStream struct { // Optional callback invoked when a control frame is received. ccb ControlCallback + // Optional callback invoked when an upgrade request is sent. + upReqCb UpgradeRequestCallback + + // Optional callback invoked when an upgrade response is received. + upResCb UpgradeResponseCallback + // Used to establish a TCP connection to the peer with a timeout. dialer *net.Dialer @@ -788,6 +794,10 @@ func (s *WebsocketStream) upgrade( } } + if s.upReqCb != nil { + s.upReqCb(req) + } + err = req.Write(stream) if err != nil { return err @@ -820,6 +830,10 @@ func (s *WebsocketStream) upgrade( } s.hb = s.hb[:0] + if s.upResCb != nil { + s.upResCb(res) + } + if !IsUpgradeRes(res) { return ErrCannotUpgrade } @@ -867,6 +881,22 @@ func (s *WebsocketStream) ControlCallback() ControlCallback { return s.ccb } +func (s *WebsocketStream) SetUpgradeRequestCallback(upReqCb UpgradeRequestCallback) { + s.upReqCb = upReqCb +} + +func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback { + return s.upReqCb +} + +func (s *WebsocketStream) SetUpgradeResponseCallback(upResCb UpgradeResponseCallback) { + s.upResCb = upResCb +} + +func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback { + return s.upResCb +} + func (s *WebsocketStream) SetMaxMessageSize(bytes int) { // This is just for checking against the length returned in the frame // header. The sizes of the buffers in which we read or write the messages diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 239ff295..851db655 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "io" + "net/http" "testing" "time" @@ -174,6 +175,20 @@ func TestClientSuccessfulHandshake(t *testing.T) { t.Fatal(err) } + var upgReqCbCalled, upgResCbCalled bool + ws.SetUpgradeRequestCallback(func(req *http.Request) { + upgReqCbCalled = true + if val := req.Header.Get("Upgrade"); val != "websocket" { + t.Fatalf("invalid Upgrade header in request: given=%s expected=%s", val, "websocket") + } + }) + ws.SetUpgradeResponseCallback(func(res *http.Response) { + upgResCbCalled = true + if val := res.Header.Get("Upgrade"); val != "websocket" { + t.Fatalf("invalid Upgrade header in response: given=%s expected=%s", val, "websocket") + } + }) + assertState(t, ws, StateHandshake) ws.AsyncHandshake("ws://localhost:8080", func(err error) { @@ -181,6 +196,12 @@ func TestClientSuccessfulHandshake(t *testing.T) { assertState(t, ws, StateTerminated) } else { assertState(t, ws, StateActive) + if !upgReqCbCalled { + t.Fatal("upgrade request callback not invoked") + } + if !upgResCbCalled { + t.Fatal("upgrade response callback not invoked") + } } })