From afad11c34037fee0c7a7c6baa4646d0b4f18ac63 Mon Sep 17 00:00:00 2001 From: Tobias Brandt <107411619+TobiasBrandt-Talos@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:01:28 +0800 Subject: [PATCH 01/35] Sync with upstream (#2) * Add upgrade request/response callbacks These callbacks are invoked during the handshake just before the upgrade request is sent and just after the response is received. * Add websocket rtt example * Add docs --- codec/websocket/definitions.go | 21 ++++++ codec/websocket/stream.go | 30 +++++++++ codec/websocket/stream_test.go | 21 ++++++ docs/architecture.md | 116 +++++++++++++++++++++++++++++++++ examples/ws-latency/main.go | 106 ++++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+) create mode 100644 docs/architecture.md create mode 100644 examples/ws-latency/main.go diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index 3ab11c57..209197a5 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -2,6 +2,7 @@ package websocket import ( "net" + "net/http" "time" "github.com/talostrading/sonic" @@ -107,6 +108,8 @@ func (s StreamState) String() string { type AsyncMessageHandler = func(err error, n int, mt MessageType) type AsyncFrameHandler = func(err error, f *Frame) type ControlCallback = func(mt MessageType, payload []byte) +type UpgradeRequestCallback = func(req *http.Request) +type UpgradeResponseCallback = func(res *http.Response) type Header struct { Key string @@ -357,6 +360,24 @@ type Stream interface { // ControlCallback returns the control callback set with SetControlCallback. ControlCallback() ControlCallback + // SetUpgradeRequestCallback sets a function that will be invoked during the handshake + // just before the upgrade request is sent. + // + // The caller must not perform any operations on the stream in the provided callback. + SetUpgradeRequestCallback(callback UpgradeRequestCallback) + + // UpgradeRequestCallback returns the callback set with SetUpgradeRequestCallback. + UpgradeRequestCallback() UpgradeRequestCallback + + // SetUpgradeResponseCallback sets a function that will be invoked during the handshake + // just after the upgrade response is received. + // + // The caller must not perform any operations on the stream in the provided callback. + SetUpgradeResponseCallback(callback UpgradeResponseCallback) + + // UpgradeResponseCallback returns the callback set with SetUpgradeResponseCallback. + UpgradeResponseCallback() UpgradeResponseCallback + // SetMaxMessageSize sets the maximum size of a message that can be read // from or written to a peer. // - If a message exceeds the limit while reading, the connection is diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 3dc2e2a1..8cbc8c8b 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -81,6 +81,12 @@ type WebsocketStream struct { // Optional callback invoked when a control frame is received. ccb ControlCallback + // Optional callback invoked when an upgrade request is sent. + upReqCb UpgradeRequestCallback + + // Optional callback invoked when an upgrade response is received. + upResCb UpgradeResponseCallback + // Used to establish a TCP connection to the peer with a timeout. dialer *net.Dialer @@ -788,6 +794,10 @@ func (s *WebsocketStream) upgrade( } } + if s.upReqCb != nil { + s.upReqCb(req) + } + err = req.Write(stream) if err != nil { return err @@ -820,6 +830,10 @@ func (s *WebsocketStream) upgrade( } s.hb = s.hb[:0] + if s.upResCb != nil { + s.upResCb(res) + } + if !IsUpgradeRes(res) { return ErrCannotUpgrade } @@ -867,6 +881,22 @@ func (s *WebsocketStream) ControlCallback() ControlCallback { return s.ccb } +func (s *WebsocketStream) SetUpgradeRequestCallback(upReqCb UpgradeRequestCallback) { + s.upReqCb = upReqCb +} + +func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback { + return s.upReqCb +} + +func (s *WebsocketStream) SetUpgradeResponseCallback(upResCb UpgradeResponseCallback) { + s.upResCb = upResCb +} + +func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback { + return s.upResCb +} + func (s *WebsocketStream) SetMaxMessageSize(bytes int) { // This is just for checking against the length returned in the frame // header. The sizes of the buffers in which we read or write the messages diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 239ff295..851db655 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "io" + "net/http" "testing" "time" @@ -174,6 +175,20 @@ func TestClientSuccessfulHandshake(t *testing.T) { t.Fatal(err) } + var upgReqCbCalled, upgResCbCalled bool + ws.SetUpgradeRequestCallback(func(req *http.Request) { + upgReqCbCalled = true + if val := req.Header.Get("Upgrade"); val != "websocket" { + t.Fatalf("invalid Upgrade header in request: given=%s expected=%s", val, "websocket") + } + }) + ws.SetUpgradeResponseCallback(func(res *http.Response) { + upgResCbCalled = true + if val := res.Header.Get("Upgrade"); val != "websocket" { + t.Fatalf("invalid Upgrade header in response: given=%s expected=%s", val, "websocket") + } + }) + assertState(t, ws, StateHandshake) ws.AsyncHandshake("ws://localhost:8080", func(err error) { @@ -181,6 +196,12 @@ func TestClientSuccessfulHandshake(t *testing.T) { assertState(t, ws, StateTerminated) } else { assertState(t, ws, StateActive) + if !upgReqCbCalled { + t.Fatal("upgrade request callback not invoked") + } + if !upgResCbCalled { + t.Fatal("upgrade response callback not invoked") + } } }) diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 00000000..d404e81d --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,116 @@ +Sonic follows the proactor model of asynchronous completion. For any asynchronous operation, such as a read/write from/to a socket, the user provides a callback which will be invoked at some point in the future, when the operation completes, either successfully, or through an error. In some cases, the completion handler will be invoked immediately as the operation can complete within the asynchronous call. + +For example: +```go +b := make([]byte, 4096) // buffer to read into +conn.AsyncRead(b, func(err error, n int) { + // if err == nil, then n is the number of bytes read + // if err != nil, then the read failed and n >= 0 +}) +``` + +Note that the user can invoke another asynchronous operation in the provided callback: +```go +var onRead func(error, int) +onRead = func(err error, n int) { + conn.AsyncRead(b, onRead) +} +conn.AsyncRead(b, onRead) // the first read +``` + +Since the first read can complete immediately, `onRead` can get called before the first `AsyncRead` returns. This also holds for the following `AsyncRead`s done in the callback. In other words, if multiple reads can complete immediately, we keep growing the call stack as we never get the change to return from an `AsyncRead`. The solution is to limit the number of consecutive immediate reads - if we hit the limit, we asynchronously schedule the next read, hence popping the call frames. For example, see `conn.go`. + +# The central IO component + +The `IO struct` (also called the IO context) is the main entity which provides users with the ability to schedule asynchronous operations. There must be only one `IO` per goroutine. Multiple `IO`s can exist within a program. + +Any construct which must provide asynchronous operations, such as `conn.go`, takes an instance of an IO context. + +Since the IO context is the central pillar of Sonic, every program's main function will end with a call to run the IO. For example: + +```go +ioc.Run() // runs until the program is terminated, yielding the CPU between async operations + +for { ioc.PollOne() } // runs until the program is terminated, polling for IO, thus never yielding the CPU between IO operations +``` + +Under the hood, the IO context uses a platform specific event notification system to get informed when an IO operation can be completed (`epoll` for Linux, `kqueue` for BSD/macOS). In short, this event notification systems allow us to register a socket for read/write events. Calling `epoll` after registering a socket might return an event on that socket, telling us whether it's ready to be read. + +# Networking constructs + +## Transports + +- TCP client/server: `conn.go`, `listen_conn.go` +- connection oriented UDP peer: `packet.go` and `listen_packet.go` +- UDP multicast: see the `multicast` package, which supports IPv4 and source IP filtering. + +## Codecs + +Codecs are meant to sit on top of stream based protocols. Currently, the only stream-based transport in Sonic is a TCP connection. + +Stream transports do not give any guarantee on the number of bytes read. A single socket read might end up with 0 or 100 bytes. This is in contrast with packet based transports, like UDP(/multicast), in which every read returns a complete packet. If your buffer is big enough, you'll read the whole packet. If your packet is smaller than a packet, you'll read as much as possible from the packet and the rest will be discarded. + +To help with writing application layer protocols on top of stream based transports, such as WebSocket, Sonic provides users with the notion of a `codec` and `codec stream`. + +A `codec` is simply an interface with two functions: +- `Encode(Item, *ByteBuffer) error`: encodes the given `Item` into the `ByteBuffer` +- `Decode(*ByteBuffer) (Item, error)`: decodes and returns the `Item` from the `ByteBuffer`, or, if there are not enough bytes, returns `nil, sonicerrors.ErrNeedMore`. + +A `codec stream` takes a user defined `codec` and implements read functions that return an `Item` and write functions that write an `Item`. A `codec stream` automatically handles short reads as notified by the `codec` through `sonicerrors.ErrNeedMore`. + +See `codec/frame.go` for a sample codec. + +## Buffers + +Sonic offers three buffer types: +- `ByteBuffer`: a FIFO like buffer on which `codec`s are based. Not extremely efficient as data is copied on some operations needed by `codec`s +- `MirroredBuffer`: a zero-copy `ByteBuffer` - this was recently added and needs a bit more code to fully replace the `ByteBuffer` +- `BipBuffer`: a zero-copy FIFO buffer best suited for packet based communication + +### BipBuffer in Packet Transports + +The purpose of the `BipBuffer` is to offer an efficient way to store packets in the event of loss, while the missing packets are replayed. + +Say each packet is 1 byte and carries a sequence number. Packet loss occurs when the next received sequence number does not follow the previous one. For example, if we receive `1 2 4` then we miss packet `3`. We must buffer `4` and everything that follows it until we replay packet `3` (through a TCP feed, for example). + +To store packets until the missing ones are replayed, we use a `BipBuffer`. Say +`| | | | |` is a bip buffer which can hold 4 bytes. Reading `1 2` goes as follows: +``` +|1| | | | +|2| | | | +``` + +Reading `4 5 6` then results in buffering each packet +``` +|4|5|6|x| --> the next read is done in x +``` + +If we somehow get 3 again, we can then iterate over the `BipBuffer` to process `4 5 6`. All subsequent reads will be done in the place of `x` as long as they're in sequence. If `x` is out of sequence, then it's stored there and we read the next packet in the first byte slot. + +If the `BipBuffer` is not full, then it acts like a circular buffer: if the next packet does not fit at the end of the buffer and there's enough space at the beginning, then it will be put there. If not, then an error will be returned - the user can then decide what to do in the event of a full buffer. + +Packet loss is the most common reliability issue in exchange communications. Packet re-ordering happens rarely, if ever. As such, we can use the zero-copy nature of a `BipBuffer` to efficiently store out of sequence packets until we replay the missing ones. + +#### Efficiently replaying packets + +Say we have the packets `1 3 5 7` - we miss `2 4 6`. In a program, we will process `1` and then queue `3 5 7` until `2` is received - when that happens, we are left with `5 7` at which point we wait for `4` etc. + +An efficient program will issue replay requests as gaps happen, and not when they're encountered. That means that in our example above, we issue a replay request for `2` after `3`, for `4` after `5` and for `6` after `7`, as they're read from the network. The innefficient alternative is to issue replay requests as the packets are read from the `BipBuffer`. + +This can be achieved by keeping track of two sequence numbers: +- `present`: the last valid sequence number read from the network. In the case of `1 3 5 7`, this will be `1`. +- `future`: the sequence number assuming all missing packets are delivered. In the case of `1 3 5 7`, this will be `7` + +The flow is as follows: +``` +read 1: present=1 future=1 +read 3: present=1 future=3, replay 2 +read 5: present=1 future=5, replay 4 +read 7: present=1 future=7, replay 6 + +2 is replayed: present=3 future=7 +4 is replayed: present=5 future=7 +6 is replayed: present=7 future=7 +``` + +We know we don't miss anything if `present == future`. We know we miss some packets if `present < future`. We can never have `present > future`. We know whether we need to replay something after reading a packet by comparing its sequence number with `future`. Say we read `seq`. We then request packets from `future + 1` to `seq - 1` if `seq > future`. diff --git a/examples/ws-latency/main.go b/examples/ws-latency/main.go new file mode 100644 index 00000000..89c7057b --- /dev/null +++ b/examples/ws-latency/main.go @@ -0,0 +1,106 @@ +package main + +import ( + "crypto/tls" + "encoding/binary" + "flag" + "log" + "runtime" + "runtime/debug" + "time" + + "github.com/talostrading/sonic" + "github.com/talostrading/sonic/codec/websocket" + "github.com/talostrading/sonic/util" +) + +// Sends a websocket ping with time t1 in the payload. Per the WebSocket protocol, the peer must reply with a pong +// immediately, with a payload equal to the ping's. On a pong, time t2 is taken again and time t1 is read from the payload. +// t2 - t1 is then printed to the console - this is the application layer RTT between the peers. + +// example: go run main.go -v=false -addr="wss://stream.binance.com:9443/ws" -s="{\"id\":1,\"method\":\"SUBSCRIBE\",\"params\":[\"bnbbtc@depth\"]}" -n 2 + +var ( + addr = flag.String("addr", "wss://stream.binance.com:9443/ws", "address") + n = flag.Int("n", 1, "number of websocket streams") + verbose = flag.Bool("v", false, "if true, print every inbound message") + subMsg = flag.String("s", "", "subscription message") // optional, if empty, we simply start reading directly +) + +func main() { + flag.Parse() + + runtime.LockOSThread() + debug.SetGCPercent(-1) + + ioc := sonic.MustIO() + defer ioc.Close() + + for i := 0; i < *n; i++ { + stream, err := websocket.NewWebsocketStream(ioc, &tls.Config{}, websocket.RoleClient) + if err != nil { + panic(err) + } + + log.Println("websocket", i, "connecting to", *addr) + stream.AsyncHandshake(*addr, func(err error) { + if err != nil { + panic(err) + } + + log.Println("websocket", i, "connected to", *addr) + + first := true + + b := make([]byte, 4096) + var onRead func(error, int, websocket.MessageType) + onRead = func(err error, n int, _ websocket.MessageType) { + if err != nil { + panic(err) + } + if *verbose || first { + if first { + log.Println("websocket", i, "read first message", string(b)) + first = false + } else { + log.Println("websocket", i, "read", string(b)) + } + } + stream.AsyncNextMessage(b, onRead) + } + + if *subMsg != "" { + stream.AsyncWrite([]byte(*subMsg), websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } + stream.AsyncNextMessage(b, onRead) + }) + } else { + stream.AsyncNextMessage(b, onRead) + } + + pingPayload := make([]byte, 48) + pingTimer, err := sonic.NewTimer(ioc) + pingTimer.ScheduleRepeating(5*time.Second, func() { + binary.LittleEndian.PutUint64(pingPayload, uint64(util.GetMonoTimeNanos())) + stream.AsyncWrite(pingPayload, websocket.TypePing, func(err error) { + if err != nil { + panic(err) + } + }) + }) + + stream.SetControlCallback(func(mt websocket.MessageType, b []byte) { + if mt == websocket.TypePong { + diff := util.GetMonoTimeNanos() - int64(binary.LittleEndian.Uint64(b)) + log.Println("websocket", i, "RTT:", time.Duration(diff)*time.Nanosecond) + } + }) + }) + } + + for { + ioc.PollOne() + } +} From 4e1ba280f5cecd791ed684ed942cb49f6e19b60c Mon Sep 17 00:00:00 2001 From: sergiu128 <57708198+sergiu128@users.noreply.github.com> Date: Tue, 24 Sep 2024 08:45:27 +0300 Subject: [PATCH 02/35] IO unset read/write (#3) * Update gosec build step * IO can unset reads/writes * Close() unsets reads+writes and deregisters the IO slot * AsyncAdapter exposes its IO Slot --- .github/workflows/gosec.yml | 4 ++-- async_adapter.go | 7 ++++++- file.go | 4 ++-- io.go | 25 +++++++++++++++++++++++++ io_test.go | 32 ++++++++++++++++++++++++++++++++ listen_conn.go | 3 ++- multicast/peer.go | 2 ++ packet.go | 2 ++ 8 files changed, 73 insertions(+), 6 deletions(-) diff --git a/.github/workflows/gosec.yml b/.github/workflows/gosec.yml index 866c315f..d9aaf365 100644 --- a/.github/workflows/gosec.yml +++ b/.github/workflows/gosec.yml @@ -15,10 +15,10 @@ jobs: - name: Checkout Source uses: actions/checkout@v3 - name: Run Gosec Security Scanner - uses: securego/gosec@v2.15.0 + uses: securego/gosec@v2.21.0 with: args: '-no-fail -fmt=sarif -out=results.sarif -exclude-dir=examples -exclude-dir=stress_test -exclude-dir=other -exclude-dir=docs -exclude-dir=tests -exclude-dir=benchmark ./...' - name: Upload SARIF file - uses: github/codeql-action/upload-sarif@v2 + uses: github/codeql-action/upload-sarif@v3 with: sarif_file: results.sarif diff --git a/async_adapter.go b/async_adapter.go index ca87ec41..e3439860 100644 --- a/async_adapter.go +++ b/async_adapter.go @@ -202,7 +202,8 @@ func (a *AsyncAdapter) Close() error { return io.EOF } - _ = a.ioc.poller.Del(&a.slot) + _ = a.ioc.UnsetReadWrite(&a.slot) + a.ioc.Deregister(&a.slot) return syscall.Close(a.slot.Fd) } @@ -245,3 +246,7 @@ func (a *AsyncAdapter) cancelWrites() { func (a *AsyncAdapter) RawFd() int { return a.slot.Fd } + +func (a *AsyncAdapter) Slot() *internal.Slot { + return &a.slot +} diff --git a/file.go b/file.go index 610454d1..36543050 100644 --- a/file.go +++ b/file.go @@ -232,10 +232,10 @@ func (f *file) Close() error { return io.EOF } - err := f.ioc.poller.Del(&f.slot) - if err != nil { + if err := f.ioc.UnsetReadWrite(&f.slot); err != nil { return err } + f.ioc.Deregister(&f.slot) return syscall.Close(f.slot.Fd) } diff --git a/io.go b/io.go index 3f33901f..fd6f0172 100644 --- a/io.go +++ b/io.go @@ -78,14 +78,39 @@ func (ioc *IO) Deregister(slot *internal.Slot) { } } +// SetRead tells the kernel to notify us when reads can be made on the provided IO slot. If successful, this call must +// be succeeded by Register(slot). +// +// It is safe to call this method multiple times. func (ioc *IO) SetRead(slot *internal.Slot) error { return ioc.poller.SetRead(slot) } +// UnsetRead tells the kernel to not notify us anymore when reads can be made on the provided IO slot. Since the +// underlying platform-specific poller already unsets a read before dispatching it, callers must only use this method +// they want to cancel a currently-scheduled read. For example, when an error occurs outside of an AsyncRead call and +// the underlying file descriptor must be closed. In that case, this call must be succeeded by Deregister(slot). +// +// It is safe to call this method multiple times. +func (ioc *IO) UnsetRead(slot *internal.Slot) error { + return ioc.poller.DelRead(slot) +} + +// Like SetRead but for writes. func (ioc *IO) SetWrite(slot *internal.Slot) error { return ioc.poller.SetWrite(slot) } +// Like UnsetRead but for writes. +func (ioc *IO) UnsetWrite(slot *internal.Slot) error { + return ioc.poller.DelWrite(slot) +} + +// UnsetRead and UnsetWrite in a single call. +func (ioc *IO) UnsetReadWrite(slot *internal.Slot) error { + return ioc.poller.Del(slot) +} + // Run runs the event processing loop. func (ioc *IO) Run() error { for { diff --git a/io_test.go b/io_test.go index 014b59a4..1d1f41fa 100644 --- a/io_test.go +++ b/io_test.go @@ -316,6 +316,38 @@ func TestIOPending(t *testing.T) { } } +func TestSetUnsetRead(t *testing.T) { + ioc := MustIO() + defer ioc.Close() + + pipe, err := internal.NewPipe() + if err != nil { + t.Fatal(err) + } + + if pipe.ReadFd() != pipe.Slot().Fd { + t.Fatal("pipe must be identified by its read end file descriptor") + } + + for i := 0; i < 10; i++ { + if err := ioc.SetRead(pipe.Slot()); err != nil { + t.Fatal(err) + } + } + + for i := 0; i < 100; i++ { + if err := ioc.UnsetRead(pipe.Slot()); err != nil { + t.Fatal(err) + } + } + + for i := 0; i < 100; i++ { + if err := ioc.UnsetReadWrite(pipe.Slot()); err != nil { + t.Fatal(err) + } + } +} + func BenchmarkPollOne(b *testing.B) { ioc := MustIO() defer ioc.Close() diff --git a/listen_conn.go b/listen_conn.go index 4c001234..0b1e39db 100644 --- a/listen_conn.go +++ b/listen_conn.go @@ -113,7 +113,8 @@ func (l *listener) accept() (Conn, error) { } func (l *listener) Close() error { - _ = l.ioc.poller.Del(&l.slot) + _ = l.ioc.UnsetReadWrite(&l.slot) + l.ioc.Deregister(&l.slot) return syscall.Close(l.slot.Fd) } diff --git a/multicast/peer.go b/multicast/peer.go index fa2276b5..1428f975 100644 --- a/multicast/peer.go +++ b/multicast/peer.go @@ -676,6 +676,8 @@ func (p *UDPPeer) LocalAddr() *net.UDPAddr { func (p *UDPPeer) Close() error { if !p.closed { p.closed = true + _ = p.ioc.UnsetReadWrite(&p.slot) + p.ioc.Deregister(&p.slot) return p.socket.Close() } return nil diff --git a/packet.go b/packet.go index 3494f8b4..8608cf9b 100644 --- a/packet.go +++ b/packet.go @@ -195,6 +195,8 @@ func (c *packetConn) getWriteHandler(b []byte, to net.Addr, cb AsyncWriteCallbac func (c *packetConn) Close() error { atomic.StoreUint32(&c.closed, 1) + _ = c.ioc.UnsetReadWrite(&c.slot) + c.ioc.Deregister(&c.slot) return syscall.Close(c.slot.Fd) } From c31fc01650d09532bafe6632db86482264700c2a Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 23 Sep 2024 12:10:44 +0300 Subject: [PATCH 03/35] Unified CodecConn No need for Blocking and Nonblocking variants anymore as all reads and writes are dispatched through the underlying ByteBuffer which in turns dispatches calls to a specific Stream interface. That Stream interface is ultimately responsible for dispatching reads/writes in a blocking or non-blocking way. --- codec.go | 155 ++++--------------------------- codec/frame/stream_test.go | 2 +- codec/websocket/stream.go | 13 +-- codec_test.go | 16 ++-- stress_test/codec/nonblocking.go | 2 +- 5 files changed, 30 insertions(+), 158 deletions(-) diff --git a/codec.go b/codec.go index 70f6bc19..1134a861 100644 --- a/codec.go +++ b/codec.go @@ -2,9 +2,6 @@ package sonic import ( "errors" - "fmt" - - "github.com/talostrading/sonic/internal" "github.com/talostrading/sonic/sonicerrors" ) @@ -40,26 +37,9 @@ type Codec[Enc, Dec any] interface { Decoder[Dec] } -type CodecConn[Enc, Dec any] interface { - AsyncReadNext(func(error, Dec)) - ReadNext() (Dec, error) - - AsyncWriteNext(Enc, AsyncCallback) - WriteNext(Enc) (int, error) - - NextLayer() Stream - - Close() error -} - -var ( - _ CodecConn[any, any] = &BlockingCodecConn[any, any]{} - _ CodecConn[any, any] = &NonblockingCodecConn[any, any]{} -) - -// BlockingCodecConn handles the decoding/encoding of bytes funneled through a +// CodecConn handles the decoding/encoding of bytes funneled through a // provided blocking file descriptor. -type BlockingCodecConn[Enc, Dec any] struct { +type CodecConn[Enc, Dec any] struct { stream Stream codec Codec[Enc, Dec] src *ByteBuffer @@ -69,14 +49,14 @@ type BlockingCodecConn[Enc, Dec any] struct { emptyDec Dec } -func NewBlockingCodecConn[Enc, Dec any]( +func NewCodecConn[Enc, Dec any]( stream Stream, codec Codec[Enc, Dec], src, dst *ByteBuffer, -) (*BlockingCodecConn[Enc, Dec], error) { +) (*CodecConn[Enc, Dec], error) { // Works on both blocking and nonblocking fds. - c := &BlockingCodecConn[Enc, Dec]{ + c := &CodecConn[Enc, Dec]{ stream: stream, codec: codec, src: src, @@ -85,26 +65,22 @@ func NewBlockingCodecConn[Enc, Dec any]( return c, nil } -func (c *BlockingCodecConn[Enc, Dec]) AsyncReadNext(cb func(error, Dec)) { +func (c *CodecConn[Enc, Dec]) AsyncReadNext(cb func(error, Dec)) { item, err := c.codec.Decode(c.src) if errors.Is(err, sonicerrors.ErrNeedMore) { - c.scheduleAsyncRead(cb) + c.src.AsyncReadFrom(c.stream, func(err error, _ int) { + if err != nil { + cb(err, c.emptyDec) + } else { + c.AsyncReadNext(cb) + } + }) } else { cb(err, item) } } -func (c *BlockingCodecConn[Enc, Dec]) scheduleAsyncRead(cb func(error, Dec)) { - c.src.AsyncReadFrom(c.stream, func(err error, _ int) { - if err != nil { - cb(err, c.emptyDec) - } else { - c.AsyncReadNext(cb) - } - }) -} - -func (c *BlockingCodecConn[Enc, Dec]) ReadNext() (Dec, error) { +func (c *CodecConn[Enc, Dec]) ReadNext() (Dec, error) { for { item, err := c.codec.Decode(c.src) if err == nil { @@ -122,7 +98,7 @@ func (c *BlockingCodecConn[Enc, Dec]) ReadNext() (Dec, error) { } } -func (c *BlockingCodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { +func (c *CodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { err = c.codec.Encode(item, c.dst) if err == nil { var nn int64 @@ -132,7 +108,7 @@ func (c *BlockingCodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { return } -func (c *BlockingCodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) { +func (c *CodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) { err := c.codec.Encode(item, c.dst) if err == nil { c.dst.AsyncWriteTo(c.stream, cb) @@ -141,105 +117,10 @@ func (c *BlockingCodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) } } -func (c *BlockingCodecConn[Enc, Dec]) NextLayer() Stream { - return c.stream -} - -func (c *BlockingCodecConn[Enc, Dec]) Close() error { - return c.stream.Close() -} - -type NonblockingCodecConn[Enc, Dec any] struct { - stream Stream - codec Codec[Enc, Dec] - src *ByteBuffer - dst *ByteBuffer - - dispatched int - - emptyEnc Enc - emptyDec Dec -} - -func NewNonblockingCodecConn[Enc, Dec any]( - stream Stream, - codec Codec[Enc, Dec], - src, dst *ByteBuffer, -) (*NonblockingCodecConn[Enc, Dec], error) { - nonblocking, err := internal.IsNonblocking(stream.RawFd()) - if err != nil { - return nil, err - } - if !nonblocking { - return nil, fmt.Errorf("the provided Stream is blocking") - } - - c := &NonblockingCodecConn[Enc, Dec]{ - stream: stream, - codec: codec, - src: src, - dst: dst, - } - return c, nil -} - -func (c *NonblockingCodecConn[Enc, Dec]) AsyncReadNext(cb func(error, Dec)) { - item, err := c.codec.Decode(c.src) - if errors.Is(err, sonicerrors.ErrNeedMore) { - c.src.AsyncReadFrom(c.stream, func(err error, _ int) { - if err != nil { - cb(err, c.emptyDec) - } else { - c.AsyncReadNext(cb) - } - }) - } else { - cb(err, item) - } -} - -func (c *NonblockingCodecConn[Enc, Dec]) ReadNext() (Dec, error) { - for { - item, err := c.codec.Decode(c.src) - if err == nil { - return item, nil - } - - if err != sonicerrors.ErrNeedMore { - return c.emptyDec, err - } - - _, err = c.src.ReadFrom(c.stream) - if err != nil { - return c.emptyDec, err - } - } -} - -func (c *NonblockingCodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) { - if err := c.codec.Encode(item, c.dst); err != nil { - cb(err, 0) - return - } - - // write everything into `dst` - c.dst.AsyncWriteTo(c.stream, cb) -} - -func (c *NonblockingCodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { - err = c.codec.Encode(item, c.dst) - if err == nil { - var nn int64 - nn, err = c.dst.WriteTo(c.stream) - n = int(nn) - } - return -} - -func (c *NonblockingCodecConn[Enc, Dec]) NextLayer() Stream { +func (c *CodecConn[Enc, Dec]) NextLayer() Stream { return c.stream } -func (c *NonblockingCodecConn[Enc, Dec]) Close() error { +func (c *CodecConn[Enc, Dec]) Close() error { return c.stream.Close() } diff --git a/codec/frame/stream_test.go b/codec/frame/stream_test.go index 0b216d33..b85fbf57 100644 --- a/codec/frame/stream_test.go +++ b/codec/frame/stream_test.go @@ -93,7 +93,7 @@ func runClient(port int, t *testing.T) { src := sonic.NewByteBuffer() dst := sonic.NewByteBuffer() codec := NewCodec(src) - codecConn, err := sonic.NewNonblockingCodecConn[[]byte, []byte]( + codecConn, err := sonic.NewCodecConn[[]byte, []byte]( conn, codec, src, dst, ) if err != nil { diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 8cbc8c8b..b9691f34 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -53,7 +53,7 @@ type WebsocketStream struct { conn net.Conn // Codec stream wrapping the underlying transport stream. - cs *sonic.BlockingCodecConn[*Frame, *Frame] + cs *sonic.CodecConn[*Frame, *Frame] // Websocket role: client or server. role Role @@ -129,8 +129,7 @@ func (s *WebsocketStream) init(stream sonic.Stream) (err error) { s.stream = stream codec := NewFrameCodec(s.src, s.dst) - s.cs, err = sonic.NewBlockingCodecConn[*Frame, *Frame]( - stream, codec, s.src, s.dst) + s.cs, err = sonic.NewCodecConn[*Frame, *Frame](stream, codec, s.src, s.dst) return } @@ -195,14 +194,6 @@ func (s *WebsocketStream) nextFrame() (f *Frame, err error) { } func (s *WebsocketStream) AsyncNextFrame(cb AsyncFrameHandler) { - // TODO for later: here we flush first since we might need to reply - // to ping/pong/close immediately, and only after that we try to - // async read. - // - // I think we can just flush asynchronously while reading asynchronously at - // the same time. I'm pretty sure this will work with a BlockingCodecConn. - // - // Not entirely sure about a NonblockingCodecStream. s.AsyncFlush(func(err error) { if errors.Is(err, ErrMessageTooBig) { s.AsyncClose(CloseGoingAway, "payload too big", func(err error) {}) diff --git a/codec_test.go b/codec_test.go index fd08b3ab..d82987a5 100644 --- a/codec_test.go +++ b/codec_test.go @@ -113,7 +113,7 @@ func setupCodecTestWriter() chan struct{} { return mark } -func TestNonblockingCodecConnAsyncReadNext(t *testing.T) { +func TestCodecConnAsyncReadNext(t *testing.T) { mark := setupCodecTestWriter() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -129,7 +129,7 @@ func TestNonblockingCodecConnAsyncReadNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) @@ -164,7 +164,7 @@ func TestNonblockingCodecConnAsyncReadNext(t *testing.T) { } } -func TestNonblockingCodecConnReadNext(t *testing.T) { +func TestCodecConnReadNext(t *testing.T) { mark := setupCodecTestWriter() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -180,7 +180,7 @@ func TestNonblockingCodecConnReadNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) @@ -283,7 +283,7 @@ func setupCodecTestReader() chan struct{} { return mark } -func TestNonblockingCodecConnAsyncWriteNext(t *testing.T) { +func TestCodecConnAsyncWriteNext(t *testing.T) { mark := setupCodecTestReader() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -298,7 +298,7 @@ func TestNonblockingCodecConnAsyncWriteNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) @@ -334,7 +334,7 @@ func TestNonblockingCodecConnAsyncWriteNext(t *testing.T) { } } -func TestNonblockingCodecConnWriteNext(t *testing.T) { +func TestCodecConnWriteNext(t *testing.T) { mark := setupCodecTestReader() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -350,7 +350,7 @@ func TestNonblockingCodecConnWriteNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) diff --git a/stress_test/codec/nonblocking.go b/stress_test/codec/nonblocking.go index 74f9a098..727c9e81 100644 --- a/stress_test/codec/nonblocking.go +++ b/stress_test/codec/nonblocking.go @@ -31,7 +31,7 @@ func main() { } defer conn.Close() log.Print("client connected, starting to read") - cc, err := sonic.NewNonblockingCodecConn[[]byte, []byte](conn, codec, src, + cc, err := sonic.NewCodecConn[[]byte, []byte](conn, codec, src, dst) if err != nil { panic(err) From f9370ef159909dac48562ae7760731883c1fa353 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 23 Sep 2024 12:56:18 +0300 Subject: [PATCH 04/35] Cleanup comments for constructs used by codec.go --- codec.go | 29 +++++------------- codec_test.go | 2 +- conn.go | 10 +++--- definitions.go | 83 ++++++++++++-------------------------------------- file.go | 25 +++++++-------- 5 files changed, 45 insertions(+), 104 deletions(-) diff --git a/codec.go b/codec.go index 1134a861..7f5c4420 100644 --- a/codec.go +++ b/codec.go @@ -6,39 +6,28 @@ import ( ) type Encoder[Item any] interface { - // Encode encodes the given item into the `dst` byte stream. + // Encode the given `Item` into the given buffer. // // Implementations should: - // - Commit the bytes into the read area of `dst`. - // - ensure `dst` is big enough to hold the serialized item by - // calling dst.Reserve(...) + // - Commit() the bytes into the given buffer if the encoding is successful. + // - Ensure the given buffer is big enough to hold the serialized `Item`s by calling `Reserve(...)`. Encode(item Item, dst *ByteBuffer) error } type Decoder[Item any] interface { - // Decode decodes the given stream into an `Item`. - // - // An implementation of Codec takes a byte stream that has already - // been buffered in `src` and decodes the data into a stream of - // `Item` objects. - // - // Implementations should return an empty Item and ErrNeedMore if - // there are not enough bytes to decode into an Item. + // Decode the next `Item`, if any, from the provided buffer. If there are not enough bytes to decode an `Item`, + // implementations should return an empty `Item` along with `ErrNeedMore`. `CodecConn` will then know to read more + // bytes before calling `Decode(...)` again. Decode(src *ByteBuffer) (Item, error) } -// Codec defines a generic interface through which one can encode/decode -// a raw stream of bytes. -// -// Implementations are optionally able to track their state which enables -// writing both stateful and stateless parsers. +// Codec groups together and Encoder and a Decoder for a CodecConn. type Codec[Enc, Dec any] interface { Encoder[Enc] Decoder[Dec] } -// CodecConn handles the decoding/encoding of bytes funneled through a -// provided blocking file descriptor. +// CodecConn reads/writes `Item`s through the provided `Codec`. For an example, see `codec/frame.go`. type CodecConn[Enc, Dec any] struct { stream Stream codec Codec[Enc, Dec] @@ -54,8 +43,6 @@ func NewCodecConn[Enc, Dec any]( codec Codec[Enc, Dec], src, dst *ByteBuffer, ) (*CodecConn[Enc, Dec], error) { - // Works on both blocking and nonblocking fds. - c := &CodecConn[Enc, Dec]{ stream: stream, codec: codec, diff --git a/codec_test.go b/codec_test.go index d82987a5..a4df042a 100644 --- a/codec_test.go +++ b/codec_test.go @@ -23,7 +23,7 @@ func (t *TestCodec) Encode(item TestItem, dst *ByteBuffer) error { n, err := dst.Write(item.V[:]) dst.Commit(n) if err != nil { - dst.Consume(n) // TODO not really happy about this (same for websocket) + dst.Consume(n) return err } return nil diff --git a/conn.go b/conn.go index 2290f4db..8749b669 100644 --- a/conn.go +++ b/conn.go @@ -1,7 +1,6 @@ package sonic import ( - "fmt" "net" "time" @@ -19,7 +18,7 @@ type conn struct { remoteAddr net.Addr } -// Dial establishes a stream based connection to the specified address. +// Dial establishes a stream based connection to the specified address. It is similar to `net.Dial`. // // Data can be sent or received only from the specified address for all networks: tcp, udp and unix domain sockets. func Dial( @@ -30,6 +29,7 @@ func Dial( return DialTimeout(ioc, network, addr, 10*time.Second, opts...) } +// `DialTimeout` is like `Dial` but with a timeout. func DialTimeout( ioc *IO, network, addr string, timeout time.Duration, @@ -64,13 +64,13 @@ func (c *conn) RemoteAddr() net.Addr { } func (c *conn) SetDeadline(t time.Time) error { - return fmt.Errorf("not supported") + panic("not implemented") } func (c *conn) SetReadDeadline(t time.Time) error { - return fmt.Errorf("not supported") + panic("not implemented") } func (c *conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("not supported") + panic("not implemented") } func (c *conn) RawFd() int { diff --git a/definitions.go b/definitions.go index f9e0f9c6..ca851318 100644 --- a/definitions.go +++ b/definitions.go @@ -6,60 +6,47 @@ import ( ) const ( - // MaxCallbackDispatch is the maximum number of callbacks which can be - // placed onto the stack for immediate invocation. + // MaxCallbackDispatch is the maximum number of callbacks that can exist on a stack-frame when asynchronous + // operations can be completed immediately. MaxCallbackDispatch int = 32 ) -// TODO this is quite a mess right now. Properly define what a Conn, Stream, -// PacketConn is and what should the async adapter return, as more than TCP -// streams can be "async-adapted". - type AsyncCallback func(error, int) type AcceptCallback func(error, Conn) type AcceptPacketCallback func(error, PacketConn) // AsyncReader is the interface that wraps the AsyncRead and AsyncReadAll methods. type AsyncReader interface { - // AsyncRead reads up to len(b) bytes into b asynchronously. - // - // This call should not block. The provided completion handler is called - // in the following cases: - // - a read of n bytes completes - // - an error occurs + // AsyncRead reads up to `len(b)` bytes into `b` asynchronously. // - // Callers should always process the n > 0 bytes returned before considering - // the error err. + // This call should not block. The provided completion handler is called in the following cases: + // - a read of up to `n` bytes completes + // - an error occurs // - // Implementation of AsyncRead are discouraged from invoking the handler - // with a zero byte count with a nil error, except when len(b) == 0. - // Callers should treat a return of 0 and nil as indicating that nothing - // happened; in particular, it does not indicate EOF. + // Callers should always process the returned bytes before considering the error err. Implementations of AsyncRead + // are discouraged from invoking the handler with 0 bytes and a nil error. // - // Implementations must not retain b. Ownership of b must be retained by the caller, - // which must guarantee that it remains valid until the handler is called. + // Ownership of the byte slice must be retained by callers, which must guarantee that it remains valid until the + // callback is invoked. AsyncRead(b []byte, cb AsyncCallback) - // AsyncReadAll reads len(b) bytes into b asynchronously. + // AsyncReadAll reads exactly `len(b)` bytes into b asynchronously. AsyncReadAll(b []byte, cb AsyncCallback) } -// AsyncWriter is the interface that wraps the AsyncRead and AsyncReadAll methods. +// AsyncWriter is the interface that wraps the AsyncWrite and AsyncWriteAll methods. type AsyncWriter interface { - // AsyncWrite writes up to len(b) bytes into the underlying data stream asynchronously. + // AsyncWrite writes up to `len(b)` bytes from `b` asynchronously. // - // This call should not block. The provided completion handler is called in the following cases: - // - a write of n bytes completes + // This call does not block. The provided completion handler is called in the following cases: + // - a write of up to `n` bytes completes // - an error occurs // - // AsyncWrite must provide a non-nil error if it writes n < len(b) bytes. - // - // Implementations must not retain b. Ownership of b must be retained by the caller, - // which must guarantee that it remains valid until the handler is called. - // AsyncWrite must not modify b, even temporarily. + // Ownership of the byte slice must be retained by callers, which must guarantee that it remains valid until the + // callback is invoked. This function does not modify the given byte slice, even temporarily. AsyncWrite(b []byte, cb AsyncCallback) - // AsyncWriteAll writes len(b) bytes into the underlying data stream asynchronously. + // AsyncWriteAll writes exactly `len(b)` bytes into the underlying data stream asynchronously. AsyncWriteAll(b []byte, cb AsyncCallback) } @@ -95,9 +82,8 @@ type AsyncCanceller interface { Cancel() } -// Stream represents a full-duplex connection between two processes, -// where data represented as bytes may be received reliably in the same order -// they were written. +// Stream represents a full-duplex connection between two processes, where data represented as bytes may be received +// reliably in the same order they were written. type Stream interface { RawFd() int @@ -180,32 +166,3 @@ type Listener interface { RawFd() int } - -// UDPMulticastClient defines a UDP multicast client that can read data from one or multiple multicast groups, -// optionally filtering packets on the source IP. -type UDPMulticastClient interface { - Join(multicastAddr *net.UDPAddr) error - JoinSource(multicastAddr, sourceAddr *net.UDPAddr) error - - Leave(multicastAddr *net.UDPAddr) error - LeaveSource(multicastAddr, sourceAddr *net.UDPAddr) error - - BlockSource(multicastAddr, sourceAddr *net.UDPAddr) error - UnblockSource(multicastAddr, sourceAddr *net.UDPAddr) error - - // ReadFrom and AsyncReadFrom read a partial or complete datagram into the provided buffer. When the datagram - // is smaller than the passed buffer, only that much data is returned; when it is bigger, the packet is truncated - // to fit into the buffer. The truncated bytes are lost. - // - // It is the responsibility of the caller to ensure the passed buffer can hold an entire datagram. A rule of thumb - // is to have the buffer size equal to the network interface's MTU. - ReadFrom([]byte) (n int, addr net.Addr, err error) - AsyncReadFrom([]byte, AsyncReadCallbackPacket) - - RawFd() int - Interface() *net.Interface - LocalAddr() *net.UDPAddr - - Close() error - Closed() bool -} diff --git a/file.go b/file.go index 36543050..87812c02 100644 --- a/file.go +++ b/file.go @@ -17,11 +17,14 @@ type file struct { slot internal.Slot closed uint32 - // dispatched tracks how callback are currently on the stack. - // If the fd has a lot of data to read/write and the caller nests - // read/write calls then we might overflow the stack. In order to not do that - // we limit the number of dispatched reads to MaxCallbackDispatch. - // If we hit that limit, we schedule an async read/write which results in clearing the stack. + // Tracks how many callbacks are on the current stack-frame. This prevents stack-overflows in cases where + // asynchronous operations can be completed immediately. + // + // For example, an asynchronous read might be completed immediately. In that case, the callback is invoked which in + // turn might call `AsyncRead` again. That asynchronous read might again be completed immediately and so on. In this + // case, all subsequent read callbacks are placed on the same stack-frame. We count these callbacks with + // `dispatched`. If we hit `MaxCallbackDispatch`, then the stack-frame is popped - asynchronous reads are scheduled + // to be completed on the next poll cycle, even if they can be completed immediately. dispatched int } @@ -178,21 +181,15 @@ func (f *file) asyncWriteNow(b []byte, writtenBytes int, writeAll bool, cb Async n, err := f.Write(b[writtenBytes:]) writtenBytes += n - // f is a nonblocking fd so if err == ErrWouldBlock - // then we need to schedule an async write. - if err == nil && !(writeAll && writtenBytes != len(b)) { - // If writeAll == true then wrote fully without errors. - // If writeAll == false then wrote some without errors. - // We are done. + // If writeAll == true then we wrote fully without errors. + // If writeAll == false then we wrote some without errors. cb(nil, writtenBytes) return } - // handles (writeAll == false) and (writeAll == true && writtenBytes != len(b)). + // Handles (writeAll == false) and (writeAll == true && writtenBytes != len(b)). if err == sonicerrors.ErrWouldBlock { - // If writeAll == true then wrote some without errors. - // We schedule an asynchronous write. f.scheduleWrite(b, writtenBytes, writeAll, cb) } else { cb(err, writtenBytes) From a6e1de665536c88b8306bcdb16ee4859da3b361e Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Sun, 27 Oct 2024 16:20:54 +0100 Subject: [PATCH 05/35] [websocket] Cleanup the WebSocket codec, no logic changes --- codec/websocket/frame.go | 209 ++++++++++--------------- codec/websocket/frame_codec.go | 4 +- codec/websocket/frame_codec_test.go | 10 +- codec/websocket/frame_test.go | 10 +- codec/websocket/rfc6455.go | 234 ++++++++++++---------------- codec/websocket/stream.go | 42 ++--- codec/websocket/stream_test.go | 40 ++--- codec/websocket/test_main.go | 4 +- examples/websocket/server.go | 5 - 9 files changed, 236 insertions(+), 322 deletions(-) delete mode 100644 examples/websocket/server.go diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 38bdc98c..3cb0f457 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -2,7 +2,6 @@ package websocket import ( "encoding/binary" - "fmt" "io" "sync" @@ -17,6 +16,11 @@ type Frame struct { payload []byte } +var ( + _ io.ReaderFrom = &Frame{} + _ io.WriterTo = &Frame{} +) + func NewFrame() *Frame { f := &Frame{ header: make([]byte, 10), @@ -42,17 +46,7 @@ func (f *Frame) ExtraHeaderLen() (n int) { return } -// PayloadLenType returns the payload length as indicated in the fixed -// size header. It is always less than 127 as per the WebSocket protocol. -// The actual payload size can be retrieved by calling PayloadLen. -func (f *Frame) PayloadLenType() int { - return int(f.header[1] & 127) -} - -// PayloadLen returns the actual payload length which can either be -// in the header if the length type is 125 or less, in the next 2 bytes if the -// length type is 126 or in the next 8 bytes if the length type is 127. -func (f *Frame) PayloadLen() int { +func (f *Frame) PayloadLength() int { length := uint64(f.header[1] & 127) switch length { case 126: @@ -63,7 +57,7 @@ func (f *Frame) PayloadLen() int { return int(length) } -func (f *Frame) SetPayloadLen() (bytes int) { +func (f *Frame) SetPayloadLength() (bytes int) { n := len(f.payload) switch { @@ -83,128 +77,49 @@ func (f *Frame) SetPayloadLen() (bytes int) { return } -func (f *Frame) ReadFrom(r io.Reader) (nt int64, err error) { - var n int - n, err = io.ReadFull(r, f.header[:2]) - nt += int64(n) - - if err == nil { - m := f.ExtraHeaderLen() - if m > 0 { - n, err = io.ReadFull(r, f.header[2:m+2]) - nt += int64(n) - } - - if err == nil && f.IsMasked() { - n, err = io.ReadFull(r, f.mask[:4]) - nt += int64(n) - } - - if err == nil { - if pn := f.PayloadLen(); pn > 0 { - if pn > MaxMessageSize { - err = ErrPayloadTooBig - } else { - f.payload = util.ExtendSlice(f.payload, pn) - n, err = io.ReadFull(r, f.payload[:pn]) - nt += int64(n) - } - } else if pn == 0 { - f.payload = f.payload[:0] - } - } - } - - return -} - -func (f *Frame) WriteTo(w io.Writer) (n int64, err error) { - var nn int - - nn, err = w.Write(f.header[:2+f.SetPayloadLen()]) - n += int64(nn) - - if err == nil { - if f.IsMasked() { - nn, err = w.Write(f.mask[:]) - n += int64(nn) - } - - if err == nil && f.PayloadLen() > 0 { - nn, err = w.Write(f.payload[:f.PayloadLen()]) - n += int64(nn) - } - } - - return -} - -func (f *Frame) IsFin() bool { - return f.header[0]&finBit != 0 +// An unfragmented message consists of a single frame with the FIN bit set and an opcode other than 0. +// +// A fragmented message consists of a single frame with the FIN bit clear and an opcode other than 0, followed by zero +// or more frames with the FIN bit clear and the opcode set to 0, and terminated by a single frame with the FIN bit set +// and an opcode of 0. +func (f *Frame) IsFIN() bool { + return f.header[0]&bitFIN != 0 } func (f *Frame) IsRSV1() bool { - return f.header[0]&rsv1Bit != 0 + return f.header[0]&bitRSV1 != 0 } func (f *Frame) IsRSV2() bool { - return f.header[0]&rsv2Bit != 0 + return f.header[0]&bitRSV2 != 0 } func (f *Frame) IsRSV3() bool { - return f.header[0]&rsv3Bit != 0 + return f.header[0]&bitRSV3 != 0 } func (f *Frame) Opcode() Opcode { return Opcode(f.header[0] & 15) } -func (f *Frame) IsContinuation() bool { - return f.Opcode() == OpcodeContinuation -} - -func (f *Frame) IsText() bool { - return f.Opcode() == OpcodeText -} - -func (f *Frame) IsBinary() bool { - return f.Opcode() == OpcodeBinary -} - -func (f *Frame) IsClose() bool { - return f.Opcode() == OpcodeClose -} - -func (f *Frame) IsPing() bool { - return f.Opcode() == OpcodePing -} - -func (f *Frame) IsPong() bool { - return f.Opcode() == OpcodePong -} - -func (f *Frame) IsControl() bool { - return f.IsClose() || f.IsPing() || f.IsPong() -} - func (f *Frame) IsMasked() bool { - return f.header[1]&maskBit != 0 + return f.header[1]&bitIsMasked != 0 } -func (f *Frame) SetFin() { - f.header[0] |= finBit +func (f *Frame) SetFIN() { + f.header[0] |= bitFIN } func (f *Frame) SetRSV1() { - f.header[0] |= rsv1Bit + f.header[0] |= bitRSV1 } func (f *Frame) SetRSV2() { - f.header[0] |= rsv2Bit + f.header[0] |= bitRSV2 } func (f *Frame) SetRSV3() { - f.header[0] |= rsv3Bit + f.header[0] |= bitRSV3 } func (f *Frame) SetOpcode(c Opcode) { @@ -250,7 +165,7 @@ func (f *Frame) Payload() []byte { } func (f *Frame) Mask() { - f.header[1] |= maskBit + f.header[1] |= bitIsMasked GenMask(f.mask[:]) if len(f.payload) > 0 { Mask(f.mask[:], f.payload) @@ -262,25 +177,63 @@ func (f *Frame) Unmask() { key := f.MaskKey() Mask(key, f.payload) } - f.header[1] ^= maskBit -} - -func (f *Frame) String() string { - return fmt.Sprintf(` -FIN: %v -RSV1: %v -RSV2: %v -RSV3: %v -OPCODE: %d -MASK: %v -LENGTH: %d -MASK-KEY: %v -PAYLOAD: %v`, - - f.IsFin(), f.IsRSV1(), f.IsRSV2(), f.IsRSV3(), - f.Opcode(), f.IsMasked(), f.PayloadLen(), - f.MaskKey(), f.Payload(), - ) + f.header[1] ^= bitIsMasked +} + +func (f *Frame) ReadFrom(r io.Reader) (nt int64, err error) { + var n int + n, err = io.ReadFull(r, f.header[:2]) + nt += int64(n) + + if err == nil { + m := f.ExtraHeaderLen() + if m > 0 { + n, err = io.ReadFull(r, f.header[2:m+2]) + nt += int64(n) + } + + if err == nil && f.IsMasked() { + n, err = io.ReadFull(r, f.mask[:4]) + nt += int64(n) + } + + if err == nil { + if pn := f.PayloadLength(); pn > 0 { + if pn > MaxMessageSize { + err = ErrPayloadTooBig + } else { + f.payload = util.ExtendSlice(f.payload, pn) + n, err = io.ReadFull(r, f.payload[:pn]) + nt += int64(n) + } + } else if pn == 0 { + f.payload = f.payload[:0] + } + } + } + + return +} + +func (f *Frame) WriteTo(w io.Writer) (n int64, err error) { + var nn int + + nn, err = w.Write(f.header[:2+f.SetPayloadLength()]) + n += int64(nn) + + if err == nil { + if f.IsMasked() { + nn, err = w.Write(f.mask[:]) + n += int64(nn) + } + + if err == nil && f.PayloadLength() > 0 { + nn, err = w.Write(f.payload[:f.PayloadLength()]) + n += int64(nn) + } + } + + return } var framePool = sync.Pool{ diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index af72bb41..d3c240bc 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -83,7 +83,7 @@ func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (*Frame, error) { } // check payload length - npayload := c.decodeFrame.PayloadLen() + npayload := c.decodeFrame.PayloadLength() if npayload > MaxMessageSize { return nil, ErrPayloadOverMaxSize } @@ -109,7 +109,7 @@ func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (*Frame, error) { func (c *FrameCodec) Encode(fr *Frame, dst *sonic.ByteBuffer) error { // Make sure there is enough space in the buffer to hold the serialized // frame. - dst.Reserve(fr.PayloadLen()) + dst.Reserve(fr.PayloadLength()) n, err := fr.WriteTo(dst) dst.Commit(int(n)) diff --git a/codec/websocket/frame_codec_test.go b/codec/websocket/frame_codec_test.go index e6773d1a..b5a59e66 100644 --- a/codec/websocket/frame_codec_test.go +++ b/codec/websocket/frame_codec_test.go @@ -44,7 +44,7 @@ func TestDecodeExactlyOneFrame(t *testing.T) { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && f.IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { + if !(f.IsFIN() && f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { t.Fatal("corrupt frame") } @@ -72,7 +72,7 @@ func TestDecodeOneAndShortFrame(t *testing.T) { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && f.IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { + if !(f.IsFIN() && f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { t.Fatal("corrupt frame") } @@ -106,7 +106,7 @@ func TestDecodeTwoFrames(t *testing.T) { if f == nil { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && f.IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { + if !(f.IsFIN() && f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { t.Fatal("corrupt frame") } if !codec.decodeReset { @@ -127,8 +127,8 @@ func TestDecodeTwoFrames(t *testing.T) { if f == nil { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && - f.IsText() && + if !(f.IsFIN() && + f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0x01, 0x02, 0x03, 0x04, 0x05})) { t.Fatal("corrupt frame") } diff --git a/codec/websocket/frame_test.go b/codec/websocket/frame_test.go index ac0cd315..b307a161 100644 --- a/codec/websocket/frame_test.go +++ b/codec/websocket/frame_test.go @@ -66,7 +66,7 @@ func TestWriteFrame(t *testing.T) { f := AcquireFrame() defer ReleaseFrame(f) - f.SetFin() + f.SetFIN() f.SetPayload(payload) f.SetText() @@ -111,7 +111,7 @@ func TestSameFrameWriteRead(t *testing.T) { if n != 7 { t.Fatalf("frame is corrupt") } - if !(f.IsFin() && f.IsText() && f.PayloadLen() == 5 && bytes.Equal(f.Payload(), payload)) { + if !(f.IsFIN() && f.Opcode().IsText() && f.PayloadLength() == 5 && bytes.Equal(f.Payload(), payload)) { t.Fatalf("invalid frame") } @@ -138,15 +138,15 @@ func TestSameFrameWriteRead(t *testing.T) { } func checkFrame(t *testing.T, f *Frame, c, fin bool, payload []byte) { - if c && !f.IsContinuation() { + if c && !f.Opcode().IsContinuation() { t.Fatal("expected continuation") } - if fin && !f.IsFin() { + if fin && !f.IsFIN() { t.Fatal("expected FIN") } - if given, expected := len(payload), f.PayloadLen(); given != expected { + if given, expected := len(payload), f.PayloadLength(); given != expected { t.Fatalf("invalid payload length; given=%d expected=%d", given, expected) } diff --git a/codec/websocket/rfc6455.go b/codec/websocket/rfc6455.go index 05546ac7..16898451 100644 --- a/codec/websocket/rfc6455.go +++ b/codec/websocket/rfc6455.go @@ -1,3 +1,4 @@ +// Based on https://datatracker.ietf.org/doc/html/rfc6455. package websocket import ( @@ -6,126 +7,157 @@ import ( "strings" ) -// NOTE: A fragmented message consists of a single frame with the FIN bit clear -// and an opcode other than 0, followed by zero or more frames with the FIN bit -// clear and the opcode set to 0, and terminated by a single frame with the FIN -// bit set and an opcode of 0. +// --------------------------------------------------- +// Framing ------------------------------------------- +// --------------------------------------------------- -// GUID is used when constructing the Sec-WebSocket-Accept key based on -// Sec-WebSocket-Key. +const ( + bitFIN = byte(1 << 7) + bitRSV1 = byte(1 << 6) + bitRSV2 = byte(1 << 5) + bitRSV3 = byte(1 << 4) + bitIsMasked = byte(1 << 7) +) + +const MaxControlFramePayloadLength = 125 + +type Opcode uint8 + +const ( + OpcodeContinuation Opcode = 0 + OpcodeText Opcode = 1 + OpcodeBinary Opcode = 2 + OpcodeClose Opcode = 8 + OpcodePing Opcode = 9 + OpcodePong Opcode = 10 +) + +func (c Opcode) IsContinuation() bool { return c == OpcodeContinuation } +func (c Opcode) IsText() bool { return c == OpcodeText } +func (c Opcode) IsBinary() bool { return c == OpcodeBinary } +func (c Opcode) IsClose() bool { return c == OpcodeClose } +func (c Opcode) IsPing() bool { return c == OpcodePing } +func (c Opcode) IsPong() bool { return c == OpcodePong } + +func (c Opcode) IsReserved() bool { + return c != OpcodeContinuation && + c != OpcodeText && + c != OpcodeBinary && + c != OpcodeClose && + c != OpcodePing && + c != OpcodePong +} + +func (c Opcode) IsControl() bool { + return c.IsPing() || c.IsPong() || c.IsClose() +} + +func (c Opcode) String() string { + switch c { + case OpcodeContinuation: + return "continuation" + case OpcodeText: + return "text" + case OpcodeBinary: + return "binary" + case OpcodeClose: + return "close" + case OpcodePing: + return "ping" + case OpcodePong: + return "pong" + default: + return "unknown" + } +} + +// --------------------------------------------------- +// Handshake ----------------------------------------- +// --------------------------------------------------- + +// Used when constructing the server's Sec-WebSocket-Accept key based on the client's Sec-WebSocket-Key. var GUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -// IsUpgradeReq returns true if the HTTP request is a valid WebSocket upgrade. func IsUpgradeReq(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") } -// IsUpgradeReq returns true if the HTTP response is a valid WebSocket upgrade. func IsUpgradeRes(res *http.Response) bool { return res.StatusCode == 101 && strings.EqualFold(res.Header.Get("Upgrade"), "websocket") } -const ( - finBit = byte(1 << 7) - rsv1Bit = byte(1 << 6) - rsv2Bit = byte(1 << 5) - rsv3Bit = byte(1 << 4) - maskBit = byte(1 << 7) -) - -// The max size of the ping/pong control frame payload. -const MaxControlFramePayloadSize = 125 +// --------------------------------------------------- +// Closing ------------------------------------------- +// --------------------------------------------------- -// The type representing the reason string in a close frame. -type ReasonString [123]byte - -// Close status codes. -// -// These codes accompany close frames. +// Close status codes that accompany close frames. type CloseCode uint16 const ( - // CloseNormal signifies normal closure; the connection successfully - // completed whatever purpose for which it was created. + // CloseNormal signifies normal closure; the connection successfully completed whatever purpose for which it was + // created. CloseNormal CloseCode = 1000 - // GoingaAway means endpoint is going away, either because of a - // server failure or because the browser is navigating away from - // the page that opened the connection. + // GoingAway means endpoint is going away, either because of a server failure or because the browser is navigating + // away from the page that opened the connection. CloseGoingAway CloseCode = 1001 - // CloseProtocolError means the endpoint is terminating the connection - // due to a protocol error. + // CloseProtocolError means the endpoint is terminating the connection due to a protocol error. CloseProtocolError CloseCode = 1002 - // CloseUnknownData means the connection is being terminated because - // the endpoint received data of a type it cannot accept (for example, - // a text-only endpoint received binary data). + // CloseUnknownData means the connection is being terminated because the endpoint received data of a type it cannot + // accept (for example, a text-only endpoint received binary data). CloseUnknownData CloseCode = 1003 - // CloseBadPayload means the endpoint is terminating the connection because - // a message was received that contained inconsistent data - // (e.g., non-UTF-8 data within a text message). + // CloseBadPayload means the endpoint is terminating the connection because a message was received that contained + // inconsistent data (e.g., non-UTF-8 data within a text message). CloseBadPayload CloseCode = 1007 - // ClosePolicyError means the endpoint is terminating the connection because - // it received a message that violates its policy. This is a generic status - // code, used when codes 1003 and 1009 are not suitable. + // ClosePolicyError means the endpoint is terminating the connection because it received a message that violates its + // policy. This is a generic status code, used when codes 1003 and 1009 are not suitable. ClosePolicyError CloseCode = 1008 - // CloseTooBig means the endpoint is terminating the connection because a - // data frame was received that is too large. + // CloseTooBig means the endpoint is terminating the connection because a data frame was received that is too large. CloseTooBig CloseCode = 1009 - // CloseNeedsExtension means the client is terminating the connection - // because it expected the server to negotiate one or more extensions, but - // the server didn't. + // CloseNeedsExtension means the client is terminating the connection because it expected the server to negotiate + // one or more extensions, but the server didn't. CloseNeedsExtension CloseCode = 1010 - // CloseInternalError means the server is terminating the connection because - // it encountered an unexpected condition that prevented it from fulfilling - // the request. + // CloseInternalError means the server is terminating the connection because it encountered an unexpected condition + // that prevented it from fulfilling the request. CloseInternalError CloseCode = 1011 - // CloseServiceRestart means the server is terminating the connection - // because it is restarting. + // CloseServiceRestart means the server is terminating the connection because it is restarting. CloseServiceRestart CloseCode = 1012 - // CloseTryAgainLater means the server is terminating the connection due to - // a temporary condition, e.g. it is overloaded and is casting off some of - // its clients. + // CloseTryAgainLater means the server is terminating the connection due to a temporary condition, e.g. it is + // overloaded and is casting off some of its clients. CloseTryAgainLater CloseCode = 1013 // ------------------------------------- // The following are illegal on the wire // ------------------------------------- - // CloseNone is used internally to mean "no error" - // This code is reserved and may not be sent. + // CloseNone is used internally to mean "no error" This code is reserved and may not be sent. CloseNone CloseCode = 0 - // CloseNoStatus means no status code was provided in the close frame sent - // by the peer, even though one was expected. - // This code is reserved for internal use and may not be sent in-between - // peers. + // CloseNoStatus means no status code was provided in the close frame sent by the peer, even though one was + // expected. This code is reserved for internal use and may not be sent in-between peers. CloseNoStatus CloseCode = 1005 - // CloseAbnormal means the connection was closed without receiving a close - // frame. - // This code is reserved and may not be sent. + // CloseAbnormal means the connection was closed without receiving a close frame. This code is reserved and may not + // be sent. CloseAbnormal CloseCode = 1006 - // CloseReserved1 is reserved for future use by the WebSocket standard. - // This code is reserved and may not be sent. + // CloseReserved1 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved1 CloseCode = 1004 - // CloseReserved2 is reserved for future use by the WebSocket standard. - // This code is reserved and may not be sent. + // CloseReserved2 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved2 CloseCode = 1014 - // CloseReserved3 is reserved for future use by the WebSocket standard. - // This code is reserved and may not be sent. + // CloseReserved3 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved3 CloseCode = 1015 ) @@ -138,69 +170,3 @@ func EncodeCloseCode(cc CloseCode) []byte { func DecodeCloseCode(b []byte) CloseCode { return CloseCode(binary.BigEndian.Uint16(b[:2])) } - -type Opcode uint8 - -// No `iota` here for clarity. -const ( - OpcodeContinuation Opcode = 0x00 - OpcodeText Opcode = 0x01 - OpcodeBinary Opcode = 0x02 - OpcodeRsv3 Opcode = 0x03 - OpcodeRsv4 Opcode = 0x04 - OpcodeRsv5 Opcode = 0x05 - OpcodeRsv6 Opcode = 0x06 - OpcodeRsv7 Opcode = 0x07 - OpcodeClose Opcode = 0x08 - OpcodePing Opcode = 0x09 - OpcodePong Opcode = 0x0A - OpcodeCrsvb Opcode = 0x0B - OpcodeCrsvc Opcode = 0x0C - OpcodeCrsvd Opcode = 0x0D - OpcodeCrsve Opcode = 0x0E - OpcodeCrsvf Opcode = 0x0F -) - -func IsReserved(op Opcode) bool { - return (op >= OpcodeRsv3 && op <= OpcodeRsv7) || - (op >= OpcodeCrsvb && op <= OpcodeCrsvf) -} - -func (c Opcode) String() string { - switch c { - case OpcodeContinuation: - return "continuation" - case OpcodeText: - return "text" - case OpcodeBinary: - return "binary" - case OpcodeRsv3: - return "rsv3" - case OpcodeRsv4: - return "rsv4" - case OpcodeRsv5: - return "rsv5" - case OpcodeRsv6: - return "rsv6" - case OpcodeRsv7: - return "rsv7" - case OpcodeClose: - return "close" - case OpcodePing: - return "ping" - case OpcodePong: - return "pong" - case OpcodeCrsvb: - return "crsvb" - case OpcodeCrsvc: - return "crsvc" - case OpcodeCrsvd: - return "crsvd" - case OpcodeCrsve: - return "crsve" - case OpcodeCrsvf: - return "crsvf" - default: - return "unknown" - } -} diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index b9691f34..4b789b55 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -241,7 +241,7 @@ func (s *WebsocketStream) NextMessage( return mt, readBytes, err } - if f.IsControl() { + if f.Opcode().IsControl() { if s.ccb != nil { s.ccb(MessageType(f.Opcode()), f.payload) } @@ -253,7 +253,7 @@ func (s *WebsocketStream) NextMessage( n := copy(b[readBytes:], f.Payload()) readBytes += n - if readBytes > MaxMessageSize || n != f.PayloadLen() { + if readBytes > MaxMessageSize || n != f.PayloadLength() { err = ErrMessageTooBig _ = s.Close(CloseGoingAway, "payload too big") break @@ -262,19 +262,19 @@ func (s *WebsocketStream) NextMessage( // verify continuation if !continuation { // this is the first frame of the series - continuation = !f.IsFin() //nolint:ineffassign - if f.IsContinuation() { + continuation = !f.IsFIN() //nolint:ineffassign + if f.Opcode().IsContinuation() { err = ErrUnexpectedContinuation } } else { // we are past the first frame of the series - continuation = !f.IsFin() //nolint:ineffassign - if !f.IsContinuation() { + continuation = !f.IsFIN() //nolint:ineffassign + if !f.Opcode().IsContinuation() { err = ErrExpectedContinuation } } - continuation = !f.IsFin() + continuation = !f.IsFIN() if err != nil || !continuation { break @@ -300,7 +300,7 @@ func (s *WebsocketStream) asyncNextMessage( if err != nil { cb(err, readBytes, mt) } else { - if f.IsControl() { + if f.Opcode().IsControl() { if s.ccb != nil { s.ccb(MessageType(f.Opcode()), f.payload) } @@ -314,7 +314,7 @@ func (s *WebsocketStream) asyncNextMessage( n := copy(b[readBytes:], f.Payload()) readBytes += n - if readBytes > MaxMessageSize || n != f.PayloadLen() { + if readBytes > MaxMessageSize || n != f.PayloadLength() { err = ErrMessageTooBig s.AsyncClose( CloseGoingAway, @@ -328,14 +328,14 @@ func (s *WebsocketStream) asyncNextMessage( // verify continuation if !continuation { // this is the first frame of the series - continuation = !f.IsFin() - if f.IsContinuation() { + continuation = !f.IsFIN() + if f.Opcode().IsContinuation() { err = ErrUnexpectedContinuation } } else { // we are past the first frame of the series - continuation = !f.IsFin() - if !f.IsContinuation() { + continuation = !f.IsFIN() + if !f.Opcode().IsContinuation() { err = ErrExpectedContinuation } } @@ -354,7 +354,7 @@ func (s *WebsocketStream) handleFrame(f *Frame) (err error) { err = s.verifyFrame(f) if err == nil { - if f.IsControl() { + if f.Opcode().IsControl() { err = s.handleControlFrame(f) } else { err = s.handleDataFrame(f) @@ -386,11 +386,11 @@ func (s *WebsocketStream) verifyFrame(f *Frame) error { } func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { - if !f.IsFin() { + if !f.IsFIN() { return ErrInvalidControlFrame } - if f.PayloadLenType() > MaxControlFramePayloadSize { + if f.PayloadLength() > MaxControlFramePayloadLength { return ErrControlFrameTooBig } @@ -398,7 +398,7 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { case OpcodePing: if s.state == StateActive { pongFrame := AcquireFrame() - pongFrame.SetFin() + pongFrame.SetFIN() pongFrame.SetPong() pongFrame.SetPayload(f.payload) if s.role == RoleClient { @@ -430,7 +430,7 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { } func (s *WebsocketStream) handleDataFrame(f *Frame) error { - if IsReserved(f.Opcode()) { + if f.Opcode().IsReserved() { return ErrReservedOpcode } return nil @@ -443,7 +443,7 @@ func (s *WebsocketStream) Write(b []byte, mt MessageType) error { if s.state == StateActive { f := AcquireFrame() - f.SetFin() + f.SetFIN() f.SetOpcode(Opcode(mt)) f.SetPayload(b) @@ -476,7 +476,7 @@ func (s *WebsocketStream) AsyncWrite( if s.state == StateActive { f := AcquireFrame() - f.SetFin() + f.SetFIN() f.SetOpcode(Opcode(mt)) f.SetPayload(b) @@ -544,7 +544,7 @@ func (s *WebsocketStream) Close(cc CloseCode, reason string) error { func (s *WebsocketStream) prepareClose(payload []byte) { closeFrame := AcquireFrame() - closeFrame.SetFin() + closeFrame.SetFIN() closeFrame.SetClose() closeFrame.SetPayload(payload) if s.role == RoleClient { diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 851db655..a9f364c3 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -551,7 +551,7 @@ func TestClientReadPingFrame(t *testing.T) { } reply := ws.pending[0] - if !(reply.IsPong() && reply.IsMasked()) { + if !(reply.Opcode().IsPong() && reply.IsMasked()) { t.Fatal("invalid pong reply") } @@ -604,7 +604,7 @@ func TestClientAsyncReadPingFrame(t *testing.T) { } reply := ws.pending[0] - if !(reply.IsPong() && reply.IsMasked()) { + if !(reply.Opcode().IsPong() && reply.IsMasked()) { t.Fatal("invalid pong reply") } @@ -876,7 +876,7 @@ func TestClientWriteFrame(t *testing.T) { f := AcquireFrame() defer ReleaseFrame(f) - f.SetFin() + f.SetFIN() f.SetText() f.SetPayload([]byte{1, 2, 3, 4, 5}) @@ -894,7 +894,7 @@ func TestClientWriteFrame(t *testing.T) { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } @@ -923,7 +923,7 @@ func TestClientAsyncWriteFrame(t *testing.T) { f := AcquireFrame() defer ReleaseFrame(f) - f.SetFin() + f.SetFIN() f.SetText() f.SetPayload([]byte{1, 2, 3, 4, 5}) @@ -945,7 +945,7 @@ func TestClientAsyncWriteFrame(t *testing.T) { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } @@ -991,7 +991,7 @@ func TestClientWrite(t *testing.T) { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } @@ -1032,7 +1032,7 @@ func TestClientAsyncWrite(t *testing.T) { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } @@ -1074,13 +1074,13 @@ func TestClientClose(t *testing.T) { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsClose()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsClose()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } f.Unmask() - if f.PayloadLen() != 5 { + if f.PayloadLength() != 5 { t.Fatal("wrong message in close frame") } @@ -1121,13 +1121,13 @@ func TestClientAsyncClose(t *testing.T) { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsClose()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsClose()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } f.Unmask() - if f.PayloadLen() != 5 { + if f.PayloadLength() != 5 { t.Fatal("wrong message in close frame") } @@ -1167,7 +1167,7 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { assertState(t, ws, StateClosedByUs) - serverReply.SetFin() + serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) _, err = serverReply.WriteTo(ws.src) @@ -1180,7 +1180,7 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { t.Fatal(err) } - if !(reply.IsFin() && reply.IsClose()) { + if !(reply.IsFIN() && reply.Opcode().IsClose()) { t.Fatal("wrong close reply") } @@ -1218,7 +1218,7 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { serverReply := AcquireFrame() defer ReleaseFrame(serverReply) - serverReply.SetFin() + serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) _, err = serverReply.WriteTo(ws.src) @@ -1231,7 +1231,7 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { t.Fatal(err) } - if !(reply.IsFin() && reply.IsClose()) { + if !(reply.IsFIN() && reply.Opcode().IsClose()) { t.Fatal("wrong close reply") } @@ -1264,7 +1264,7 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { serverClose := AcquireFrame() defer ReleaseFrame(serverClose) - serverClose.SetFin() + serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1280,7 +1280,7 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { t.Fatal(err) } - if !recv.IsClose() { + if !recv.Opcode().IsClose() { t.Fatal("should have received close") } @@ -1311,7 +1311,7 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { serverClose := AcquireFrame() defer ReleaseFrame(serverClose) - serverClose.SetFin() + serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1327,7 +1327,7 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { t.Fatal(err) } - if !recv.IsClose() { + if !recv.Opcode().IsClose() { t.Fatal("should have received close") } diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index f11ba6f9..eba0f8c9 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -78,7 +78,7 @@ func (s *MockServer) Write(b []byte) error { fr.SetText() fr.SetPayload(b) - fr.SetFin() + fr.SetFIN() _, err := fr.WriteTo(s.conn) return err @@ -96,7 +96,7 @@ func (s *MockServer) Read(b []byte) (n int, err error) { fr.Unmask() copy(b, fr.Payload()) - n = fr.PayloadLen() + n = fr.PayloadLength() } return n, err } diff --git a/examples/websocket/server.go b/examples/websocket/server.go deleted file mode 100644 index ea61f64b..00000000 --- a/examples/websocket/server.go +++ /dev/null @@ -1,5 +0,0 @@ -package main - -func main() { - panic("implement me") -} From 2396655f0a3e0b490c1680c4bbc5619466a92b29 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Sun, 27 Oct 2024 18:04:03 +0100 Subject: [PATCH 06/35] [websocket] Simplify the WebSocket codec The Frame is now a []byte slice and not a struct with 3 slices: header, mask, payload. SetFIN, SetOpcode etc employ the builder pattern to make a frame. --- codec/websocket/definitions.go | 6 +- codec/websocket/frame.go | 331 ++++++++++++++++------------ codec/websocket/frame_codec.go | 103 ++++----- codec/websocket/frame_codec_test.go | 10 +- codec/websocket/frame_test.go | 49 ++-- codec/websocket/rfc6455.go | 36 ++- codec/websocket/stream.go | 145 ++++++------ codec/websocket/stream_test.go | 85 +++---- codec/websocket/test_main.go | 26 +-- 9 files changed, 412 insertions(+), 379 deletions(-) diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index 209197a5..09123b2f 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -31,7 +31,7 @@ func (r Role) String() string { } } -type MessageType uint8 +type MessageType byte const ( TypeText = MessageType(OpcodeText) @@ -106,7 +106,7 @@ func (s StreamState) String() string { } type AsyncMessageHandler = func(err error, n int, mt MessageType) -type AsyncFrameHandler = func(err error, f *Frame) +type AsyncFrameHandler = func(err error, f Frame) type ControlCallback = func(mt MessageType, payload []byte) type UpgradeRequestCallback = func(req *http.Request) type UpgradeResponseCallback = func(res *http.Response) @@ -179,7 +179,7 @@ type Stream interface { // - an error occurs when reading/decoding the message bytes from the // underlying stream // - a frame is successfully read from the underlying stream - NextFrame() (*Frame, error) + NextFrame() (Frame, error) // AsyncNextMessage reads the payload of the next message into the supplied // buffer asynchronously. Message fragmentation is automatically handled by diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 3cb0f457..ae2f99d0 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -3,78 +3,72 @@ package websocket import ( "encoding/binary" "io" - "sync" "github.com/talostrading/sonic/util" ) -var zeroBytes = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +var zeroBytes [MaxFrameHeaderLengthInBytes]byte -type Frame struct { - header []byte - mask []byte - payload []byte +func init() { + for i := 0; i < len(zeroBytes); i++ { + zeroBytes[i] = 0 + } } +type Frame []byte + var ( _ io.ReaderFrom = &Frame{} _ io.WriterTo = &Frame{} ) -func NewFrame() *Frame { - f := &Frame{ - header: make([]byte, 10), - mask: make([]byte, 4), - payload: make([]byte, 0, 1024), - } - return f +func newFrame() Frame { + return make([]byte, MaxFrameHeaderLengthInBytes) } -func (f *Frame) Reset() { - copy(f.header, zeroBytes) - copy(f.mask, zeroBytes) - f.payload = f.payload[:0] +func (f Frame) Reset() { + copy(f, zeroBytes[:]) } -func (f *Frame) ExtraHeaderLen() (n int) { - switch f.header[1] & 127 { - case 127: - n = 8 - case 126: - n = 2 +func (f Frame) ExtendedPayloadLengthBytes() int { + if v := f[1] & bitmaskPayloadLength; v == 127 { + return 8 + } else if v == 126 { + return 2 } - return + return 0 } -func (f *Frame) PayloadLength() int { - length := uint64(f.header[1] & 127) - switch length { - case 126: - length = uint64(binary.BigEndian.Uint16(f.header[2:])) - case 127: - length = binary.BigEndian.Uint64(f.header[2:]) +func (f Frame) PayloadLength() int { + if length := f[1] & bitmaskPayloadLength; length == 127 { + return int(binary.BigEndian.Uint64(f[frameHeaderLength : frameHeaderLength+8])) + } else if length == 126 { + return int(binary.BigEndian.Uint16(f[frameHeaderLength : frameHeaderLength+2])) + } else { + return int(length) } - return int(length) } -func (f *Frame) SetPayloadLength() (bytes int) { - n := len(f.payload) +func (f Frame) clearPayloadLength() { + f[1] &= (1 << 7) +} + +func (f *Frame) SetPayloadLength(n int) *Frame { + f.clearPayloadLength() - switch { - case n > 65535: // more than two bytes needed for extra length - bytes = 8 //nolint:ineffassign - f.header[1] |= uint8(127) - binary.BigEndian.PutUint64(f.header[2:], uint64(n)) - case n > 125: - bytes = 2 //nolint:ineffassign - f.header[1] |= uint8(126) - binary.BigEndian.PutUint16(f.header[2:], uint16(n)) - return 2 - default: - bytes = 0 //nolint:ineffassign - f.header[1] |= uint8(n) + if n > (1<<16 - 1) { + // does not fit in 2 bytes, so take 8 as extended payload length + (*f)[1] |= 127 + binary.BigEndian.PutUint64((*f)[2:], uint64(n)) + } else if n > 125 { + // fits in 2 bytes as extended payload length + (*f)[1] |= 126 + binary.BigEndian.PutUint16((*f)[2:], uint16(n)) + } else { + // can be encoded in the 7 bits of the payload length, no extended payload length taken + (*f)[1] |= byte(n) } - return + return f } // An unfragmented message consists of a single frame with the FIN bit set and an opcode other than 0. @@ -82,171 +76,216 @@ func (f *Frame) SetPayloadLength() (bytes int) { // A fragmented message consists of a single frame with the FIN bit clear and an opcode other than 0, followed by zero // or more frames with the FIN bit clear and the opcode set to 0, and terminated by a single frame with the FIN bit set // and an opcode of 0. -func (f *Frame) IsFIN() bool { - return f.header[0]&bitFIN != 0 -} - -func (f *Frame) IsRSV1() bool { - return f.header[0]&bitRSV1 != 0 -} - -func (f *Frame) IsRSV2() bool { - return f.header[0]&bitRSV2 != 0 +func (f Frame) IsFIN() bool { + return f[0]&bitFIN != 0 } -func (f *Frame) IsRSV3() bool { - return f.header[0]&bitRSV3 != 0 +func (f Frame) Opcode() Opcode { + return Opcode(f[0] & bitmaskOpcode) } -func (f *Frame) Opcode() Opcode { - return Opcode(f.header[0] & 15) +func (f Frame) IsMasked() bool { + return f[1]&bitIsMasked != 0 } -func (f *Frame) IsMasked() bool { - return f.header[1]&bitIsMasked != 0 +func (f *Frame) SetIsMasked() *Frame { + (*f)[1] |= bitIsMasked + return f } -func (f *Frame) SetFIN() { - f.header[0] |= bitFIN +func (f *Frame) UnsetIsMasked() *Frame { + (*f)[1] ^= bitIsMasked + return f } -func (f *Frame) SetRSV1() { - f.header[0] |= bitRSV1 +func (f Frame) MaskBytes() int { + if f.IsMasked() { + return frameMaskLength + } + return 0 } -func (f *Frame) SetRSV2() { - f.header[0] |= bitRSV2 +func (f *Frame) SetFIN() *Frame { + (*f)[0] |= bitFIN + return f } -func (f *Frame) SetRSV3() { - f.header[0] |= bitRSV3 +func (f Frame) clearOpcode() { + f[0] &= bitmaskOpcode << 4 } -func (f *Frame) SetOpcode(c Opcode) { - c &= 15 - f.header[0] &= 15 << 4 - f.header[0] |= uint8(c) +func (f *Frame) SetOpcode(c Opcode) *Frame { + c &= Opcode(bitmaskOpcode) + f.clearOpcode() + (*f)[0] |= byte(c) + return f } -func (f *Frame) SetContinuation() { +func (f *Frame) SetContinuation() *Frame { f.SetOpcode(OpcodeContinuation) + return f } -func (f *Frame) SetText() { +func (f *Frame) SetText() *Frame { f.SetOpcode(OpcodeText) + return f } -func (f *Frame) SetBinary() { +func (f *Frame) SetBinary() *Frame { f.SetOpcode(OpcodeBinary) + return f } -func (f *Frame) SetClose() { +func (f *Frame) SetClose() *Frame { f.SetOpcode(OpcodeClose) + return f } -func (f *Frame) SetPing() { +func (f *Frame) SetPing() *Frame { f.SetOpcode(OpcodePing) + return f } -func (f *Frame) SetPong() { +func (f *Frame) SetPong() *Frame { f.SetOpcode(OpcodePong) + return f } -func (f *Frame) SetPayload(b []byte) { - f.payload = append(f.payload[:0], b...) +func (f Frame) extendedPayloadLengthStartIndex() int { + return frameHeaderLength } -func (f *Frame) MaskKey() []byte { - return f.mask[:] +func (f Frame) ExtendedPayloadLength() []byte { + if bytes := f.ExtendedPayloadLengthBytes(); bytes > 0 { + b := f[frameHeaderLength:] + return b[:bytes] + } + return nil } -func (f *Frame) Payload() []byte { - return f.payload +func (f Frame) Header() []byte { + return f[:frameHeaderLength] } -func (f *Frame) Mask() { - f.header[1] |= bitIsMasked - GenMask(f.mask[:]) - if len(f.payload) > 0 { - Mask(f.mask[:], f.payload) +func (f *Frame) maskStartIndex() int { + return frameHeaderLength + f.ExtendedPayloadLengthBytes() +} + +func (f Frame) MaskKey() []byte { + if f.IsMasked() { + mask := f[f.maskStartIndex():] + return mask[:frameMaskLength] } + return nil } -func (f *Frame) Unmask() { - if len(f.payload) > 0 { - key := f.MaskKey() - Mask(key, f.payload) +func (f Frame) payloadStartIndex() int { + return frameHeaderLength + f.ExtendedPayloadLengthBytes() + f.MaskBytes() +} + +func (f *Frame) fitPayload() ([]byte, error) { + length := f.PayloadLength() + if length <= 0 { + return nil, nil + } else if length > MaxMessageSize { + return nil, ErrPayloadTooBig } - f.header[1] ^= bitIsMasked + + *f = util.ExtendSlice(*f, f.payloadStartIndex()+length) + b := (*f)[f.payloadStartIndex():] + return b[:length], nil +} + +func (f *Frame) SetPayload(b []byte) *Frame { + *f = util.ExtendSlice(*f, f.payloadStartIndex()+len(b)) + payload := f.Payload() + copy(payload, b) + f.SetPayloadLength(len(payload)) + return f } -func (f *Frame) ReadFrom(r io.Reader) (nt int64, err error) { - var n int - n, err = io.ReadFull(r, f.header[:2]) - nt += int64(n) +func (f Frame) Payload() []byte { + return f[f.payloadStartIndex():] +} - if err == nil { - m := f.ExtraHeaderLen() - if m > 0 { - n, err = io.ReadFull(r, f.header[2:m+2]) - nt += int64(n) - } +func (f *Frame) Mask() { + f.SetIsMasked() - if err == nil && f.IsMasked() { - n, err = io.ReadFull(r, f.mask[:4]) - nt += int64(n) - } + var ( + mask = f.MaskKey() + payload = f.Payload() + ) - if err == nil { - if pn := f.PayloadLength(); pn > 0 { - if pn > MaxMessageSize { - err = ErrPayloadTooBig - } else { - f.payload = util.ExtendSlice(f.payload, pn) - n, err = io.ReadFull(r, f.payload[:pn]) - nt += int64(n) - } - } else if pn == 0 { - f.payload = f.payload[:0] - } - } + if len(payload) > 0 { + GenMask(mask) + Mask(mask, payload) } +} - return +func (f *Frame) Unmask() { + if f.IsMasked() { + var ( + mask = f.MaskKey() + payload = f.Payload() + ) + Mask(mask, payload) + // Does not unset the IsMasked bit in order to not mess up the offset at which the payload is found. + } } -func (f *Frame) WriteTo(w io.Writer) (n int64, err error) { +func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { var nn int - nn, err = w.Write(f.header[:2+f.SetPayloadLength()]) + // read the header + nn, err = io.ReadFull(r, f.Header()) n += int64(nn) + if err != nil { + return + } - if err == nil { - if f.IsMasked() { - nn, err = w.Write(f.mask[:]) - n += int64(nn) + // read the extended payload length, if any + if b := f.ExtendedPayloadLength(); b != nil { + nn, err = io.ReadFull(r, b) + n += int64(nn) + if err != nil { + return } + } - if err == nil && f.PayloadLength() > 0 { - nn, err = w.Write(f.payload[:f.PayloadLength()]) - n += int64(nn) + // read the mask, if any + if f.IsMasked() { + nn, err = io.ReadFull(r, f.MaskKey()) + n += int64(nn) + if err != nil { + return } } + // read the payload, if any + b, err := f.fitPayload() + if err != nil { + return + } + nn, err = io.ReadFull(r, b) + n += int64(nn) + if err != nil { + return + } + return } -var framePool = sync.Pool{ - New: func() interface{} { - return NewFrame() - }, -} +func (f Frame) WriteTo(w io.Writer) (int64, error) { + f.SetPayloadLength(len(f.Payload())) -func AcquireFrame() *Frame { - return framePool.Get().(*Frame) -} + written := 0 + for written < len(f) { + n, err := w.Write(f[written:]) + written += n + if err != nil { + return int64(n), err + } + } -func ReleaseFrame(f *Frame) { - f.Reset() - framePool.Put(f) + return int64(written), nil } diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index d3c240bc..2ab08e1f 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -6,26 +6,24 @@ import ( "github.com/talostrading/sonic" ) -var _ sonic.Codec[*Frame, *Frame] = &FrameCodec{} +var _ sonic.Codec[Frame, Frame] = &FrameCodec{} var ( ErrPartialPayload = errors.New("partial payload") ) -// FrameCodec is a stateful streaming parser handling the encoding -// and decoding of WebSocket frames. +// FrameCodec is a stateful streaming parser handling the encoding and decoding of WebSocket `Frame`s. type FrameCodec struct { src *sonic.ByteBuffer // buffer we decode from dst *sonic.ByteBuffer // buffer we encode to - decodeFrame *Frame // frame we decode into - decodeBytes int // number of bytes of the last successfully decoded frame - decodeReset bool // true if we must reset the state on the next decode + decodeFrame Frame // frame we decode into + decodeReset bool // true if we must reset the state on the next decode } func NewFrameCodec(src, dst *sonic.ByteBuffer) *FrameCodec { return &FrameCodec{ - decodeFrame: NewFrame(), + decodeFrame: newFrame(), src: src, dst: dst, } @@ -34,84 +32,75 @@ func NewFrameCodec(src, dst *sonic.ByteBuffer) *FrameCodec { func (c *FrameCodec) resetDecode() { if c.decodeReset { c.decodeReset = false - c.src.Consume(c.decodeBytes) - c.decodeBytes = 0 + c.src.Consume(len(c.decodeFrame)) + c.decodeFrame = nil } } -// Decode decodes the raw bytes from `src` into a frame. +// Decode decodes the raw bytes from `src` into a `Frame`. // -// Three things can happen while decoding a raw stream of bytes into a frame: -// 1. There are not enough bytes to construct a frame with. +// Two things can happen while decoding a raw stream of bytes into a frame: // -// In this case, a nil frame and ErrNeedMore are returned. The caller -// should perform another read into `src` later. +// 1. There are not enough bytes to construct a frame with: in this case, a `nil` `Frame` and `ErrNeedMore` are +// returned. The caller should perform another read into `src` later. // -// 2. `src` contains the bytes of one frame. -// -// In this case we try to decode the frame. An appropriate error is returned -// if the frame is corrupt. -// -// 3. `src` contains the bytes of more than one frame. -// -// In this case we try to decode the first frame. The rest of the bytes stay -// in `src`. An appropriate error is returned if the frame is corrupt. -func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (*Frame, error) { +// 2. `src` contains at least the bytes of one `Frame`: we decode the next `Frame` and leave the remainder bytes +// composing a partial `Frame` or a set of `Frame`s in the `src` buffer. +func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (Frame, error) { c.resetDecode() - n := 2 - err := src.PrepareRead(n) - if err != nil { + // read the mandatory header + readSoFar := frameHeaderLength + if err := src.PrepareRead(readSoFar); err != nil { + c.decodeFrame = nil return nil, err } - c.decodeFrame.header = src.Data()[:n] + c.decodeFrame = src.Data()[:readSoFar] - // read extra header length - n += c.decodeFrame.ExtraHeaderLen() - if err := src.PrepareRead(n); err != nil { + // read the extended payload length (0, 2 or 8 bytes) and check if within bounds + readSoFar += c.decodeFrame.ExtendedPayloadLengthBytes() + if err := src.PrepareRead(readSoFar); err != nil { + c.decodeFrame = nil return nil, err } - c.decodeFrame.header = src.Data()[:n] + c.decodeFrame = src.Data()[:readSoFar] + + payloadLength := c.decodeFrame.PayloadLength() + if payloadLength > MaxMessageSize { + c.decodeFrame = nil + return nil, ErrPayloadOverMaxSize + } // read mask if any if c.decodeFrame.IsMasked() { - n += 4 - if err := src.PrepareRead(n); err != nil { + readSoFar += frameMaskLength + if err := src.PrepareRead(readSoFar); err != nil { + c.decodeFrame = nil return nil, err } - c.decodeFrame.mask = src.Data()[n-4 : n] + c.decodeFrame = src.Data()[:readSoFar] } - // check payload length - npayload := c.decodeFrame.PayloadLength() - if npayload > MaxMessageSize { - return nil, ErrPayloadOverMaxSize - } - - // prepare to read the payload - n += npayload - if err := src.PrepareRead(n); err != nil { - // the payload might be too big for our buffer so we must allocate - // enough for the next Decode call to succeed - src.Reserve(npayload) + // read the payload; if that succeeds, we have a full frame in `src` - the decoding was successful and we can return + // the frame + readSoFar += payloadLength + if err := src.PrepareRead(readSoFar); err != nil { + src.Reserve(payloadLength) // payload is incomplete; reserve enough space for the remainder to fit in the buffer + c.decodeFrame = nil return nil, err } - - // at this point, we have a full frame in src - c.decodeFrame.payload = src.Data()[n-npayload : n] - c.decodeBytes = n + c.decodeFrame = src.Data()[:readSoFar] c.decodeReset = true return c.decodeFrame, nil } -// Encode encodes the frame and place the raw bytes into `dst`. -func (c *FrameCodec) Encode(fr *Frame, dst *sonic.ByteBuffer) error { - // Make sure there is enough space in the buffer to hold the serialized - // frame. - dst.Reserve(fr.PayloadLength()) +// Encode encodes the `Frame` into `dst`. +func (c *FrameCodec) Encode(frame Frame, dst *sonic.ByteBuffer) error { + // ensure the destination buffer can hold the serialized frame // TODO this can be improved + dst.Reserve(frame.PayloadLength() + MaxFrameHeaderLengthInBytes) - n, err := fr.WriteTo(dst) + n, err := frame.WriteTo(dst) dst.Commit(int(n)) if err != nil { dst.Consume(int(n)) diff --git a/codec/websocket/frame_codec_test.go b/codec/websocket/frame_codec_test.go index b5a59e66..8fa3f138 100644 --- a/codec/websocket/frame_codec_test.go +++ b/codec/websocket/frame_codec_test.go @@ -25,7 +25,7 @@ func TestDecodeShortFrame(t *testing.T) { if codec.decodeReset { t.Fatal("should not reset decoder state") } - if codec.decodeBytes != 0 { + if len(codec.decodeFrame) != 0 { t.Fatal("should have not decoded any bytes") } } @@ -51,7 +51,7 @@ func TestDecodeExactlyOneFrame(t *testing.T) { if !codec.decodeReset { t.Fatal("should reset decoder state") } - if codec.decodeBytes != 3 { + if len(codec.decodeFrame) != 3 { t.Fatal("should have decoded 3 bytes") } } @@ -80,7 +80,7 @@ func TestDecodeOneAndShortFrame(t *testing.T) { t.Fatal("should reset decoder state") } - if codec.decodeBytes != 3 { + if len(codec.decodeFrame) != 3 { t.Fatal("should have decoded 3 bytes") } } @@ -112,7 +112,7 @@ func TestDecodeTwoFrames(t *testing.T) { if !codec.decodeReset { t.Fatal("should have reset decoder state") } - if codec.decodeBytes != 3 { + if len(codec.decodeFrame) != 3 { t.Fatal("should have decoded 3 bytes") } if src.ReadLen() != 3 { @@ -135,7 +135,7 @@ func TestDecodeTwoFrames(t *testing.T) { if !codec.decodeReset { t.Fatal("should have reset decoder state") } - if codec.decodeBytes != 7 { + if len(codec.decodeFrame) != 7 { t.Fatal("should have decoded 3 bytes") } if src.ReadLen() != 7 { diff --git a/codec/websocket/frame_test.go b/codec/websocket/frame_test.go index b307a161..62b61ae1 100644 --- a/codec/websocket/frame_test.go +++ b/codec/websocket/frame_test.go @@ -9,13 +9,13 @@ import ( "github.com/talostrading/sonic" ) -func TestUnder125Frame(t *testing.T) { - raw := []byte{0x81, 5} // fin=1 opcode=1 (text) payload_len=5 +func TestUnder126Frame(t *testing.T) { + var ( + f = newFrame() + raw = []byte{0x81, 5} // fin=1 opcode=1 (text) payload_len=5 + ) raw = append(raw, genRandBytes(5)...) - f := AcquireFrame() - defer ReleaseFrame(f) - buf := bufio.NewReader(bytes.NewBuffer(raw)) _, err := f.ReadFrom(buf) @@ -27,12 +27,12 @@ func TestUnder125Frame(t *testing.T) { } func Test126Frame(t *testing.T) { - raw := []byte{0x81, 126, 0, 200} + var ( + f = newFrame() + raw = []byte{0x81, 126, 0, 200} + ) raw = append(raw, genRandBytes(200)...) - f := AcquireFrame() - defer ReleaseFrame(f) - buf := bufio.NewReader(bytes.NewBuffer(raw)) _, err := f.ReadFrom(buf) @@ -44,12 +44,12 @@ func Test126Frame(t *testing.T) { } func Test127Frame(t *testing.T) { - raw := []byte{0x81, 127, 0, 0, 0, 0, 0, 0x01, 0xFF, 0xFF} + var ( + f = newFrame() + raw = []byte{0x81, 127, 0, 0, 0, 0, 0, 0x01, 0xFF, 0xFF} + ) raw = append(raw, genRandBytes(131071)...) - f := AcquireFrame() - defer ReleaseFrame(f) - buf := bufio.NewReader(bytes.NewBuffer(raw)) _, err := f.ReadFrom(buf) @@ -61,10 +61,10 @@ func Test127Frame(t *testing.T) { } func TestWriteFrame(t *testing.T) { - payload := []byte("heloo") - - f := AcquireFrame() - defer ReleaseFrame(f) + var ( + f = newFrame() + payload = []byte("heloo") + ) f.SetFIN() f.SetPayload(payload) @@ -92,14 +92,13 @@ func TestWriteFrame(t *testing.T) { } func TestSameFrameWriteRead(t *testing.T) { - // deserialize - f := AcquireFrame() - defer ReleaseFrame(f) - - header := []byte{0x81, 5} - payload := genRandBytes(5) + var ( + header = []byte{0x81, 5} + payload = genRandBytes(5) + buf = sonic.NewByteBuffer() + f = newFrame() + ) - buf := sonic.NewByteBuffer() buf.Write(header) buf.Write(payload) buf.Commit(7) @@ -137,7 +136,7 @@ func TestSameFrameWriteRead(t *testing.T) { } } -func checkFrame(t *testing.T, f *Frame, c, fin bool, payload []byte) { +func checkFrame(t *testing.T, f Frame, c, fin bool, payload []byte) { if c && !f.Opcode().IsContinuation() { t.Fatal("expected continuation") } diff --git a/codec/websocket/rfc6455.go b/codec/websocket/rfc6455.go index 16898451..9f270c5b 100644 --- a/codec/websocket/rfc6455.go +++ b/codec/websocket/rfc6455.go @@ -1,4 +1,4 @@ -// Based on https://datatracker.ietf.org/doc/html/rfc6455. +// Based on https://datatracker.ietf.org/doc/html/rfc6455 package websocket import ( @@ -10,18 +10,38 @@ import ( // --------------------------------------------------- // Framing ------------------------------------------- // --------------------------------------------------- +// Based on https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 const ( - bitFIN = byte(1 << 7) - bitRSV1 = byte(1 << 6) - bitRSV2 = byte(1 << 5) - bitRSV3 = byte(1 << 4) - bitIsMasked = byte(1 << 7) + MaxControlFramePayloadLength = 125 + MaxFrameHeaderLengthInBytes = 14 // 14 bytes max for the header of a frame i.e. everything without the payload ) -const MaxControlFramePayloadLength = 125 +const ( + // Mandatory, 2 bytes: + // byte 1: |fin(1)|rsv1(1)|rsv2(1)|rsv3(1)|opcode(4)| + // byte 2: |is masked(1)|payload length(7)| + frameHeaderLength = 2 + + bitFIN = byte(1 << 7) + bitmaskOpcode = byte(1<<4 - 1) + + bitIsMasked = byte(1 << 7) + bitmaskPayloadLength = byte(1<<7 - 1) + + // Optional, max 8 bytes. If |payload length(7)| above is <= 125, then 0: the payload length is in + // |payload length(7)|. If 126, then the payload length is in the following 2 bytes. If 127, in the following 8 + // bytes. + + // Optional, max 4 bytes. If |is masked(1)| above is set, then the following 4 bytes are the mask. Otherwise, the + // frame is not masked and the mask is not included. + // + // All frames sent from the client to the server are masked by a 32-bit value. Frames sent from the server to the + // client are unmasked. + frameMaskLength = 4 +) -type Opcode uint8 +type Opcode byte const ( OpcodeContinuation Opcode = 0 diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 4b789b55..f5334304 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -28,6 +28,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "sync" "syscall" "time" @@ -53,7 +54,7 @@ type WebsocketStream struct { conn net.Conn // Codec stream wrapping the underlying transport stream. - cs *sonic.CodecConn[*Frame, *Frame] + cs *sonic.CodecConn[Frame, Frame] // Websocket role: client or server. role Role @@ -74,9 +75,8 @@ type WebsocketStream struct { // handshake is over. hb []byte - // Contains frames waiting to be sent to the peer. - // Is emptied by AsyncFlush or Flush. - pending []*Frame + // Contains frames waiting to be sent to the peer. Is emptied by AsyncFlush or Flush. + pendingFrames []*Frame // Optional callback invoked when a control frame is received. ccb ControlCallback @@ -92,6 +92,8 @@ type WebsocketStream struct { // The size of the currently read message. messageSize int + + framePool sync.Pool } func NewWebsocketStream( @@ -112,6 +114,12 @@ func NewWebsocketStream( dialer: &net.Dialer{ Timeout: DialTimeout, }, + framePool: sync.Pool{ + New: func() interface{} { + frame := newFrame() + return &frame + }, + }, } s.src.Reserve(4096) @@ -120,6 +128,21 @@ func NewWebsocketStream( return s, nil } +func (s *WebsocketStream) AcquireFrame() *Frame { + f := s.framePool.Get().(*Frame) + if s.role == RoleClient { + // This just reserves the 4 bytes needed for the mask in order to encode the payload correctly, since it follows + // the mask in byte order. The actual mask is set after `f.SetPayload` in `prepareWrite`. + f.SetIsMasked() + } + return f +} + +func (s *WebsocketStream) releaseFrame(f *Frame) { + f.Reset() + s.framePool.Put(f) +} + // init is run when we transition into StateActive which happens // after a successful handshake. func (s *WebsocketStream) init(stream sonic.Stream) (err error) { @@ -129,7 +152,7 @@ func (s *WebsocketStream) init(stream sonic.Stream) (err error) { s.stream = stream codec := NewFrameCodec(s.src, s.dst) - s.cs, err = sonic.NewCodecConn[*Frame, *Frame](stream, codec, s.src, s.dst) + s.cs, err = sonic.NewCodecConn[Frame, Frame](stream, codec, s.src, s.dst) return } @@ -162,7 +185,7 @@ func (s *WebsocketStream) canRead() bool { return s.state == StateActive || s.state == StateClosedByUs } -func (s *WebsocketStream) NextFrame() (f *Frame, err error) { +func (s *WebsocketStream) NextFrame() (f Frame, err error) { err = s.Flush() if errors.Is(err, ErrMessageTooBig) { @@ -185,7 +208,7 @@ func (s *WebsocketStream) NextFrame() (f *Frame, err error) { return } -func (s *WebsocketStream) nextFrame() (f *Frame, err error) { +func (s *WebsocketStream) nextFrame() (f Frame, err error) { f, err = s.cs.ReadNext() if err == nil { err = s.handleFrame(f) @@ -215,7 +238,7 @@ func (s *WebsocketStream) AsyncNextFrame(cb AsyncFrameHandler) { } func (s *WebsocketStream) asyncNextFrame(cb AsyncFrameHandler) { - s.cs.AsyncReadNext(func(err error, f *Frame) { + s.cs.AsyncReadNext(func(err error, f Frame) { if err == nil { err = s.handleFrame(f) } else if err == io.EOF { @@ -229,7 +252,7 @@ func (s *WebsocketStream) NextMessage( b []byte, ) (mt MessageType, readBytes int, err error) { var ( - f *Frame + f Frame continuation = false ) @@ -243,7 +266,7 @@ func (s *WebsocketStream) NextMessage( if f.Opcode().IsControl() { if s.ccb != nil { - s.ccb(MessageType(f.Opcode()), f.payload) + s.ccb(MessageType(f.Opcode()), f.Payload()) } } else { if mt == TypeNone { @@ -296,13 +319,13 @@ func (s *WebsocketStream) asyncNextMessage( mt MessageType, cb AsyncMessageHandler, ) { - s.AsyncNextFrame(func(err error, f *Frame) { + s.AsyncNextFrame(func(err error, f Frame) { if err != nil { cb(err, readBytes, mt) } else { if f.Opcode().IsControl() { if s.ccb != nil { - s.ccb(MessageType(f.Opcode()), f.payload) + s.ccb(MessageType(f.Opcode()), f.Payload()) } s.asyncNextMessage(b, readBytes, continuation, mt, cb) @@ -350,7 +373,7 @@ func (s *WebsocketStream) asyncNextMessage( }) } -func (s *WebsocketStream) handleFrame(f *Frame) (err error) { +func (s *WebsocketStream) handleFrame(f Frame) (err error) { err = s.verifyFrame(f) if err == nil { @@ -369,11 +392,7 @@ func (s *WebsocketStream) handleFrame(f *Frame) (err error) { return err } -func (s *WebsocketStream) verifyFrame(f *Frame) error { - if f.IsRSV1() || f.IsRSV2() || f.IsRSV3() { - return ErrNonZeroReservedBits - } - +func (s *WebsocketStream) verifyFrame(f Frame) error { if s.role == RoleClient && f.IsMasked() { return ErrMaskedFramesFromServer } @@ -385,7 +404,7 @@ func (s *WebsocketStream) verifyFrame(f *Frame) error { return nil } -func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { +func (s *WebsocketStream) handleControlFrame(f Frame) (err error) { if !f.IsFIN() { return ErrInvalidControlFrame } @@ -397,14 +416,11 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { switch f.Opcode() { case OpcodePing: if s.state == StateActive { - pongFrame := AcquireFrame() - pongFrame.SetFIN() - pongFrame.SetPong() - pongFrame.SetPayload(f.payload) - if s.role == RoleClient { - pongFrame.Mask() - } - s.pending = append(s.pending, pongFrame) + pongFrame := s.AcquireFrame(). + SetFIN(). + SetPong(). + SetPayload(f.Payload()) + s.prepareWrite(pongFrame) } case OpcodePong: case OpcodeClose: @@ -413,7 +429,7 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { panic("unreachable") case StateActive: s.state = StateClosedByPeer - s.prepareClose(f.payload) + s.prepareClose(f.Payload()) case StateClosedByPeer, StateCloseAcked: // ignore case StateClosedByUs: @@ -429,7 +445,7 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { return } -func (s *WebsocketStream) handleDataFrame(f *Frame) error { +func (s *WebsocketStream) handleDataFrame(f Frame) error { if f.Opcode().IsReserved() { return ErrReservedOpcode } @@ -442,11 +458,11 @@ func (s *WebsocketStream) Write(b []byte, mt MessageType) error { } if s.state == StateActive { - f := AcquireFrame() - f.SetFIN() - f.SetOpcode(Opcode(mt)) - f.SetPayload(b) - + // reserve space for mask if client + f := s.AcquireFrame(). + SetFIN(). + SetOpcode(Opcode(mt)). + SetPayload(b) s.prepareWrite(f) return s.Flush() } @@ -459,7 +475,7 @@ func (s *WebsocketStream) WriteFrame(f *Frame) error { s.prepareWrite(f) return s.Flush() } else { - ReleaseFrame(f) + s.releaseFrame(f) return sonicerrors.ErrCancelled } } @@ -475,11 +491,10 @@ func (s *WebsocketStream) AsyncWrite( } if s.state == StateActive { - f := AcquireFrame() - f.SetFIN() - f.SetOpcode(Opcode(mt)) - f.SetPayload(b) - + f := s.AcquireFrame(). + SetFIN(). + SetOpcode(Opcode(mt)). + SetPayload(b) s.prepareWrite(f) s.AsyncFlush(cb) } else { @@ -492,24 +507,16 @@ func (s *WebsocketStream) AsyncWriteFrame(f *Frame, cb func(err error)) { s.prepareWrite(f) s.AsyncFlush(cb) } else { - ReleaseFrame(f) + s.releaseFrame(f) cb(sonicerrors.ErrCancelled) } } func (s *WebsocketStream) prepareWrite(f *Frame) { - switch s.role { - case RoleClient: - if !f.IsMasked() { - f.Mask() - } - case RoleServer: - if f.IsMasked() { - f.Unmask() - } + if s.role == RoleClient { + f.Mask() } - - s.pending = append(s.pending, f) + s.pendingFrames = append(s.pendingFrames, f) } func (s *WebsocketStream) AsyncClose( @@ -543,41 +550,37 @@ func (s *WebsocketStream) Close(cc CloseCode, reason string) error { } func (s *WebsocketStream) prepareClose(payload []byte) { - closeFrame := AcquireFrame() - closeFrame.SetFIN() - closeFrame.SetClose() - closeFrame.SetPayload(payload) - if s.role == RoleClient { - closeFrame.Mask() - } - - s.pending = append(s.pending, closeFrame) + closeFrame := s.AcquireFrame(). + SetFIN(). + SetClose(). + SetPayload(payload) + s.prepareWrite(closeFrame) } func (s *WebsocketStream) Flush() (err error) { flushed := 0 - for i := 0; i < len(s.pending); i++ { - _, err = s.cs.WriteNext(s.pending[i]) + for i := 0; i < len(s.pendingFrames); i++ { + _, err = s.cs.WriteNext(*s.pendingFrames[i]) if err != nil { break } - ReleaseFrame(s.pending[i]) + s.releaseFrame(s.pendingFrames[i]) flushed++ } - s.pending = s.pending[flushed:] + s.pendingFrames = s.pendingFrames[flushed:] return } func (s *WebsocketStream) AsyncFlush(cb func(err error)) { - if len(s.pending) == 0 { + if len(s.pendingFrames) == 0 { cb(nil) } else { - sent := s.pending[0] - s.pending = s.pending[1:] + sent := s.pendingFrames[0] + s.pendingFrames = s.pendingFrames[1:] - s.cs.AsyncWriteNext(sent, func(err error, _ int) { - ReleaseFrame(sent) + s.cs.AsyncWriteNext(*sent, func(err error, _ int) { + s.releaseFrame(sent) if err != nil { cb(err) @@ -589,7 +592,7 @@ func (s *WebsocketStream) AsyncFlush(cb func(err error)) { } func (s *WebsocketStream) Pending() int { - return len(s.pending) + return len(s.pendingFrames) } func (s *WebsocketStream) State() StreamState { diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index a9f364c3..ac376085 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -463,10 +463,10 @@ func TestClientReadCorruptControlFrame(t *testing.T) { assertState(t, ws, StateClosedByUs) - closeFrame := ws.pending[0] + closeFrame := ws.pendingFrames[0] closeFrame.Unmask() - cc, _ := DecodeCloseFramePayload(ws.pending[0].payload) + cc, _ := DecodeCloseFramePayload(ws.pendingFrames[0].Payload()) if cc != CloseProtocolError { t.Fatal("should have closed with protocol error") } @@ -507,10 +507,10 @@ func TestClientAsyncReadCorruptControlFrame(t *testing.T) { assertState(t, ws, StateClosedByUs) - closeFrame := ws.pending[0] + closeFrame := ws.pendingFrames[0] closeFrame.Unmask() - cc, _ := DecodeCloseFramePayload(ws.pending[0].payload) + cc, _ := DecodeCloseFramePayload(ws.pendingFrames[0].Payload()) if cc != CloseProtocolError { t.Fatal("should have closed with protocol error") } @@ -550,7 +550,7 @@ func TestClientReadPingFrame(t *testing.T) { t.Fatal("should have a pending pong") } - reply := ws.pending[0] + reply := ws.pendingFrames[0] if !(reply.Opcode().IsPong() && reply.IsMasked()) { t.Fatal("invalid pong reply") } @@ -603,7 +603,7 @@ func TestClientAsyncReadPingFrame(t *testing.T) { t.Fatal("should have a pending pong") } - reply := ws.pending[0] + reply := ws.pendingFrames[0] if !(reply.Opcode().IsPong() && reply.IsMasked()) { t.Fatal("invalid pong reply") } @@ -764,13 +764,13 @@ func TestClientReadCloseFrame(t *testing.T) { t.Fatal("should have one pending operation") } - reply := ws.pending[0] + reply := ws.pendingFrames[0] if !reply.IsMasked() { t.Fatal("reply should be masked") } reply.Unmask() - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("invalid close frame reply") } @@ -781,7 +781,7 @@ func TestClientReadCloseFrame(t *testing.T) { b := make([]byte, 128) _, _, err = ws.NextMessage(b) - if len(ws.pending) > 0 { + if len(ws.pendingFrames) > 0 { t.Fatal("should have flushed") } @@ -827,13 +827,13 @@ func TestClientAsyncReadCloseFrame(t *testing.T) { t.Fatal("should have one pending operation") } - reply := ws.pending[0] + reply := ws.pendingFrames[0] if !reply.IsMasked() { t.Fatal("reply should be masked") } reply.Unmask() - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("invalid close frame reply") } @@ -874,20 +874,22 @@ func TestClientWriteFrame(t *testing.T) { ws.state = StateActive ws.init(mock) - f := AcquireFrame() - defer ReleaseFrame(f) + f := ws.AcquireFrame() f.SetFIN() f.SetText() f.SetPayload([]byte{1, 2, 3, 4, 5}) + if len(*f) != 2 /*mandatory header*/ +4 /*mask since it's written by a client*/ +5 /*payload length*/ { + t.Fatal("invalid frame length") + } + err = ws.WriteFrame(f) if err != nil { t.Fatal(err) } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) + f := Frame(make([]byte, 2+4+5)) _, err = f.ReadFrom(mock.b) if err != nil { @@ -921,8 +923,7 @@ func TestClientAsyncWriteFrame(t *testing.T) { ws.state = StateActive ws.init(mock) - f := AcquireFrame() - defer ReleaseFrame(f) + f := ws.AcquireFrame() f.SetFIN() f.SetText() f.SetPayload([]byte{1, 2, 3, 4, 5}) @@ -937,8 +938,7 @@ func TestClientAsyncWriteFrame(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) + f := newFrame() _, err = f.ReadFrom(mock.b) if err != nil { @@ -983,8 +983,7 @@ func TestClientWrite(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) + f := newFrame() _, err = f.ReadFrom(mock.b) if err != nil { @@ -1024,9 +1023,7 @@ func TestClientAsyncWrite(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) - + f := newFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) @@ -1066,9 +1063,7 @@ func TestClientClose(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) - + f := newFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) @@ -1084,7 +1079,7 @@ func TestClientClose(t *testing.T) { t.Fatal("wrong message in close frame") } - cc, reason := DecodeCloseFramePayload(f.payload) + cc, reason := DecodeCloseFramePayload(f.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload") } @@ -1113,9 +1108,7 @@ func TestClientAsyncClose(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) - + f := newFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) @@ -1131,7 +1124,7 @@ func TestClientAsyncClose(t *testing.T) { t.Fatal("wrong message in close frame") } - cc, reason := DecodeCloseFramePayload(f.payload) + cc, reason := DecodeCloseFramePayload(f.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload") } @@ -1160,13 +1153,11 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { if err != nil { t.Fatal(err) } else { - mock.b.Commit(mock.b.WriteLen()) - - serverReply := AcquireFrame() - defer ReleaseFrame(serverReply) - assertState(t, ws, StateClosedByUs) + mock.b.Commit(mock.b.WriteLen()) + + serverReply := newFrame() serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1184,7 +1175,7 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { t.Fatal("wrong close reply") } - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload reply") } @@ -1215,9 +1206,7 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - serverReply := AcquireFrame() - defer ReleaseFrame(serverReply) - + serverReply := newFrame() serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1235,7 +1224,7 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { t.Fatal("wrong close reply") } - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload reply") } @@ -1262,8 +1251,7 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { ws.state = StateActive ws.init(mock) - serverClose := AcquireFrame() - defer ReleaseFrame(serverClose) + serverClose := newFrame() serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1286,7 +1274,7 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { assertState(t, ws, StateClosedByPeer) - cc, reason := DecodeCloseFramePayload(recv.payload) + cc, reason := DecodeCloseFramePayload(recv.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("peer close frame payload is corrupt") } @@ -1309,8 +1297,7 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { ws.state = StateActive ws.init(mock) - serverClose := AcquireFrame() - defer ReleaseFrame(serverClose) + serverClose := newFrame() serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1322,7 +1309,7 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { ws.src.Commit(int(nn)) - ws.AsyncNextFrame(func(err error, recv *Frame) { + ws.AsyncNextFrame(func(err error, recv Frame) { if err != nil { t.Fatal(err) } @@ -1333,7 +1320,7 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { assertState(t, ws, StateClosedByPeer) - cc, reason := DecodeCloseFramePayload(recv.payload) + cc, reason := DecodeCloseFramePayload(recv.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("peer close frame payload is corrupt") } diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index eba0f8c9..4ddc5c32 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -73,30 +73,26 @@ func (s *MockServer) Accept(addr string) (err error) { } func (s *MockServer) Write(b []byte) error { - fr := AcquireFrame() - defer ReleaseFrame(fr) - - fr.SetText() - fr.SetPayload(b) - fr.SetFIN() - - _, err := fr.WriteTo(s.conn) + f := newFrame() + f.SetText() + f.SetPayload(b) + f.SetFIN() + _, err := f.WriteTo(s.conn) return err } func (s *MockServer) Read(b []byte) (n int, err error) { - fr := AcquireFrame() - defer ReleaseFrame(fr) + f := newFrame() - _, err = fr.ReadFrom(s.conn) + _, err = f.ReadFrom(s.conn) if err == nil { - if !fr.IsMasked() { + if !f.IsMasked() { return 0, fmt.Errorf("client frames should be masked") } - fr.Unmask() - copy(b, fr.Payload()) - n = fr.PayloadLength() + f.Unmask() + copy(b, f.Payload()) + n = f.PayloadLength() } return n, err } From 22a0332a7be68040927ce0c619775fe843f11903 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:05:56 +0100 Subject: [PATCH 07/35] [websocket] Get rid of Websocket Stream interface --- codec/websocket/definitions.go | 286 +------------------------- codec/websocket/frame.go | 3 +- codec/websocket/frame_codec.go | 2 +- codec/websocket/frame_test.go | 10 +- codec/websocket/stream.go | 358 ++++++++++++++++++++------------- codec/websocket/stream_test.go | 20 +- codec/websocket/test_main.go | 4 +- 7 files changed, 247 insertions(+), 436 deletions(-) diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index 09123b2f..be8a0080 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -1,16 +1,14 @@ package websocket import ( - "net" "net/http" "time" - - "github.com/talostrading/sonic" ) var ( MaxMessageSize = 1024 * 512 // the maximum size of a message CloseTimeout = 5 * time.Second + DialTimeout = 5 * time.Second ) type Role uint8 @@ -69,20 +67,16 @@ const ( // Intermediate state. Connection is active, can read/write/close. StateActive - // Intermediate state. We initiated the closing handshake and - // are waiting for a reply from the peer. + // Intermediate state. We initiated the closing handshake and are waiting for a reply from the peer. StateClosedByUs - // Terminal state. The peer initiated the closing handshake, - // we received a close frame and immediately replied. + // Terminal state. The peer initiated the closing handshake, we received a close frame and immediately replied. StateClosedByPeer - // Terminal state. The peer replied to our closing handshake. - // Can only end up here from StateClosedByUs. + // Terminal state. The peer replied to our closing handshake. Can only end up here from StateClosedByUs. StateCloseAcked - // Terminal state. The connection is closed or some error - // occurred which rendered the stream unusable. + // Terminal state. The connection is closed or some error occurred which rendered the stream unusable. StateTerminated ) @@ -124,273 +118,3 @@ func ExtraHeader(canonicalKey bool, key string, values ...string) Header { CanonicalKey: canonicalKey, } } - -// Stream is an interface for representing a stateful WebSocket connection -// on the server or client side. -// -// The interface uses the layered stream model. A WebSocket stream object -// contains another stream object, called the "next layer", which it uses to -// perform IO. -// -// Implementations handle the replies to control frames. Before closing the -// stream, it is important to call Flush or AsyncFlush in order to write any -// pending control frame replies to the underlying stream. -type Stream interface { - // NextLayer returns the underlying stream object. - // - // The returned object is constructed by the Stream and maintained - // throughout its entire lifetime. All reads and writes will go through the - // next layer. - NextLayer() sonic.Stream - - // SupportsDeflate returns true if Deflate compression is supported. - // - // https://datatracker.ietf.org/doc/html/rfc7692 - SupportsDeflate() bool - - // SupportsUTF8 returns true if UTF8 validity checks are supported. - // - // Implementations should not do UTF8 checking by default. Callers - // should be able to turn it on when instantiating the Stream. - SupportsUTF8() bool - - // NextMessage reads the payload of the next message into the supplied - // buffer. Message fragmentation is automatically handled by the - // implementation. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - the payload of the message is successfully read into the supplied - // buffer - NextMessage([]byte) (mt MessageType, n int, err error) - - // NextFrame reads and returns the next frame. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - a frame is successfully read from the underlying stream - NextFrame() (Frame, error) - - // AsyncNextMessage reads the payload of the next message into the supplied - // buffer asynchronously. Message fragmentation is automatically handled by - // the implementation. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - the payload of the message is successfully read into the supplied - // buffer - AsyncNextMessage([]byte, AsyncMessageHandler) - - // AsyncNextFrame reads and returns the next frame asynchronously. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - a frame is successfully read from the underlying stream - AsyncNextFrame(AsyncFrameHandler) - - // WriteFrame writes the supplied frame to the underlying stream. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the frame is successfully written to the underlying stream - WriteFrame(fr *Frame) error - - // AsyncWriteFrame writes the supplied frame to the underlying stream - // asynchronously. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the frame is successfully written to the underlying stream - AsyncWriteFrame(fr *Frame, cb func(err error)) - - // Write writes the supplied buffer as a single message with the given type - // to the underlying stream. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the message is successfully written to the underlying stream - // - // The message will be written as a single frame. Fragmentation should be - // handled by the caller through multiple calls to AsyncWriteFrame. - Write(b []byte, mt MessageType) error - - // AsyncWrite writes the supplied buffer as a single message with the given - // type to the underlying stream asynchronously. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the message is successfully written to the underlying stream - // - // The message will be written as a single frame. Fragmentation should be - // handled by the caller through multiple calls to AsyncWriteFrame. - AsyncWrite(b []byte, mt MessageType, cb func(err error)) - - // Flush writes any pending control frames to the underlying stream. - // - // This call blocks. - Flush() error - - // Flush writes any pending control frames to the underlying stream - // asynchronously. - // - // This call does not block. - AsyncFlush(cb func(err error)) - - // Pending returns the number of currently pending operations. - Pending() int - - // State returns the state of the WebSocket connection. - State() StreamState - - // Handshake performs the WebSocket handshake in the client role. - // - // The call blocks until one of the following conditions is true: - // - the request is sent and the response is received - // - an error occurs - Handshake(addr string, extraHeaders ...Header) error - - // AsyncHandshake performs the WebSocket handshake asynchronously in the - // client role. - // - // This call does not block. The provided completion handler is called when - // the request is sent and the response is received or when an error occurs. - // - // Regardless of whether the asynchronous operation completes immediately - // or not, the handler will not be invoked from within this function. - // Invocation of the handler will be performed in a manner equivalent to - // using sonic.Post(...). - AsyncHandshake(addr string, cb func(error), extraHeaders ...Header) - - // Accept performs the handshake in the server role. - // - // The call blocks until one of the following conditions is true: - // - the request is sent and the response is received - // - an error occurs - Accept() error - - // AsyncAccept performs the handshake asynchronously in the server role. - // - // This call does not block. The provided completion handler is called when - // the request is send and the response is received or when an error occurs. - // - // Regardless of whether the asynchronous operation completes immediately - // or not, the handler will not be invoked from within this function. - // Invocation of the handler will be performed in a manner equivalent to - // using sonic.Post(...). - AsyncAccept(func(error)) - - // AsyncClose sends a websocket close control frame asynchronously. - // - // This function is used to send a close frame which begins the WebSocket - // closing handshake. The session ends when both ends of the connection - // have sent and received a close frame. - // - // The handler is called if one of the following conditions is true: - // - the close frame is written - // - an error occurs - // - // After beginning the closing handshake, the program should not write - // further message data, pings, or pongs. Instead, the program should - // continue reading message data until an error occurs. - AsyncClose(cc CloseCode, reason string, cb func(err error)) - - // Close sends a websocket close control frame asynchronously. - // - // This function is used to send a close frame which begins the WebSocket - // closing handshake. The session ends when both ends of the connection - // have sent and received a close frame. - // - // The call blocks until one of the following conditions is true: - // - the close frame is written - // - an error occurs - // - // After beginning the closing handshake, the program should not write - // further message data, pings, or pongs. Instead, the program should - // continue reading message data until an error occurs. - Close(cc CloseCode, reason string) error - - // SetControlCallback sets a function that will be invoked when a - // Ping/Pong/Close is received while reading a message. This callback is - // not invoked when AsyncNextFrame or NextFrame are called. - // - // The caller must not perform any operations on the stream in the provided - // callback. - SetControlCallback(ControlCallback) - - // ControlCallback returns the control callback set with SetControlCallback. - ControlCallback() ControlCallback - - // SetUpgradeRequestCallback sets a function that will be invoked during the handshake - // just before the upgrade request is sent. - // - // The caller must not perform any operations on the stream in the provided callback. - SetUpgradeRequestCallback(callback UpgradeRequestCallback) - - // UpgradeRequestCallback returns the callback set with SetUpgradeRequestCallback. - UpgradeRequestCallback() UpgradeRequestCallback - - // SetUpgradeResponseCallback sets a function that will be invoked during the handshake - // just after the upgrade response is received. - // - // The caller must not perform any operations on the stream in the provided callback. - SetUpgradeResponseCallback(callback UpgradeResponseCallback) - - // UpgradeResponseCallback returns the callback set with SetUpgradeResponseCallback. - UpgradeResponseCallback() UpgradeResponseCallback - - // SetMaxMessageSize sets the maximum size of a message that can be read - // from or written to a peer. - // - If a message exceeds the limit while reading, the connection is - // closed abnormally. - // - If a message exceeds the limit while writing, the operation is - // cancelled. - SetMaxMessageSize(bytes int) - - RemoteAddr() net.Addr - - LocalAddr() net.Addr - - RawFd() int - - CloseNextLayer() error -} diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index ae2f99d0..6400e096 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -22,7 +22,8 @@ var ( _ io.WriterTo = &Frame{} ) -func newFrame() Frame { +// NOTE use stream.AcquireFrame() instead of NewFrame if you intend to write this frame onto a WebSocket stream. +func NewFrame() Frame { return make([]byte, MaxFrameHeaderLengthInBytes) } diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index 2ab08e1f..8eb2edbd 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -23,7 +23,7 @@ type FrameCodec struct { func NewFrameCodec(src, dst *sonic.ByteBuffer) *FrameCodec { return &FrameCodec{ - decodeFrame: newFrame(), + decodeFrame: NewFrame(), src: src, dst: dst, } diff --git a/codec/websocket/frame_test.go b/codec/websocket/frame_test.go index 62b61ae1..49480d0d 100644 --- a/codec/websocket/frame_test.go +++ b/codec/websocket/frame_test.go @@ -11,7 +11,7 @@ import ( func TestUnder126Frame(t *testing.T) { var ( - f = newFrame() + f = NewFrame() raw = []byte{0x81, 5} // fin=1 opcode=1 (text) payload_len=5 ) raw = append(raw, genRandBytes(5)...) @@ -28,7 +28,7 @@ func TestUnder126Frame(t *testing.T) { func Test126Frame(t *testing.T) { var ( - f = newFrame() + f = NewFrame() raw = []byte{0x81, 126, 0, 200} ) raw = append(raw, genRandBytes(200)...) @@ -45,7 +45,7 @@ func Test126Frame(t *testing.T) { func Test127Frame(t *testing.T) { var ( - f = newFrame() + f = NewFrame() raw = []byte{0x81, 127, 0, 0, 0, 0, 0, 0x01, 0xFF, 0xFF} ) raw = append(raw, genRandBytes(131071)...) @@ -62,7 +62,7 @@ func Test127Frame(t *testing.T) { func TestWriteFrame(t *testing.T) { var ( - f = newFrame() + f = NewFrame() payload = []byte("heloo") ) @@ -96,7 +96,7 @@ func TestSameFrameWriteRead(t *testing.T) { header = []byte{0x81, 5} payload = genRandBytes(5) buf = sonic.NewByteBuffer() - f = newFrame() + f = NewFrame() ) buf.Write(header) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index f5334304..342bae00 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -30,18 +30,12 @@ import ( "net/url" "sync" "syscall" - "time" "github.com/talostrading/sonic" "github.com/talostrading/sonic/sonicerrors" "github.com/talostrading/sonic/sonicopts" ) -var ( - _ Stream = &WebsocketStream{} - DialTimeout = 5 * time.Second -) - type WebsocketStream struct { // async operations executor. ioc *sonic.IO @@ -54,7 +48,7 @@ type WebsocketStream struct { conn net.Conn // Codec stream wrapping the underlying transport stream. - cs *sonic.CodecConn[Frame, Frame] + codecConn *sonic.CodecConn[Frame, Frame] // Websocket role: client or server. role Role @@ -73,19 +67,19 @@ type WebsocketStream struct { // Contains the handshake response. Is emptied after the // handshake is over. - hb []byte + handshakeBuffer []byte // Contains frames waiting to be sent to the peer. Is emptied by AsyncFlush or Flush. pendingFrames []*Frame // Optional callback invoked when a control frame is received. - ccb ControlCallback + controlCallback ControlCallback // Optional callback invoked when an upgrade request is sent. - upReqCb UpgradeRequestCallback + upgradeRequestCallback UpgradeRequestCallback // Optional callback invoked when an upgrade response is received. - upResCb UpgradeResponseCallback + upgradeResponseCallback UpgradeResponseCallback // Used to establish a TCP connection to the peer with a timeout. dialer *net.Dialer @@ -96,11 +90,7 @@ type WebsocketStream struct { framePool sync.Pool } -func NewWebsocketStream( - ioc *sonic.IO, - tls *tls.Config, - role Role, -) (s *WebsocketStream, err error) { +func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *WebsocketStream, err error) { s = &WebsocketStream{ ioc: ioc, tls: tls, @@ -109,14 +99,14 @@ func NewWebsocketStream( dst: sonic.NewByteBuffer(), state: StateHandshake, /* #nosec G401 */ - hasher: sha1.New(), - hb: make([]byte, 1024), + hasher: sha1.New(), + handshakeBuffer: make([]byte, 1024), dialer: &net.Dialer{ Timeout: DialTimeout, }, framePool: sync.Pool{ New: func() interface{} { - frame := newFrame() + frame := NewFrame() return &frame }, }, @@ -128,6 +118,8 @@ func NewWebsocketStream( return s, nil } +// The recommended way to create a `Frame`. This function takes care to allocate enough bytes to encode the WebSocket +// header and apply client side masking if the `Stream`'s role is `RoleClient`. func (s *WebsocketStream) AcquireFrame() *Frame { f := s.framePool.Get().(*Frame) if s.role == RoleClient { @@ -143,6 +135,12 @@ func (s *WebsocketStream) releaseFrame(f *Frame) { s.framePool.Put(f) } +// This stream is either a client or a server. The main differences are in how the opening/closing handshake is done and +// in the fact that payloads sent by the client to the server are masked. +func (s *WebsocketStream) Role() Role { + return s.role +} + // init is run when we transition into StateActive which happens // after a successful handshake. func (s *WebsocketStream) init(stream sonic.Stream) (err error) { @@ -152,12 +150,12 @@ func (s *WebsocketStream) init(stream sonic.Stream) (err error) { s.stream = stream codec := NewFrameCodec(s.src, s.dst) - s.cs, err = sonic.NewCodecConn[Frame, Frame](stream, codec, s.src, s.dst) + s.codecConn, err = sonic.NewCodecConn[Frame, Frame](stream, codec, s.src, s.dst) return } func (s *WebsocketStream) reset() { - s.hb = s.hb[:cap(s.hb)] + s.handshakeBuffer = s.handshakeBuffer[:cap(s.handshakeBuffer)] s.state = StateHandshake s.stream = nil s.conn = nil @@ -165,9 +163,10 @@ func (s *WebsocketStream) reset() { s.dst.Reset() } +// Returns the stream through which IO is done. func (s *WebsocketStream) NextLayer() sonic.Stream { - if s.cs != nil { - return s.cs.NextLayer() + if s.codecConn != nil { + return s.codecConn.NextLayer() } return nil } @@ -185,6 +184,14 @@ func (s *WebsocketStream) canRead() bool { return s.state == StateActive || s.state == StateClosedByUs } +// NextFrame reads and returns the next frame. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message bytes from the underlying stream +// - a frame is successfully read from the underlying stream func (s *WebsocketStream) NextFrame() (f Frame, err error) { err = s.Flush() @@ -209,18 +216,26 @@ func (s *WebsocketStream) NextFrame() (f Frame, err error) { } func (s *WebsocketStream) nextFrame() (f Frame, err error) { - f, err = s.cs.ReadNext() + f, err = s.codecConn.ReadNext() if err == nil { err = s.handleFrame(f) } return } -func (s *WebsocketStream) AsyncNextFrame(cb AsyncFrameHandler) { +// AsyncNextFrame reads and returns the next frame asynchronously. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message bytes from the underlying stream +// - a frame is successfully read from the underlying stream +func (s *WebsocketStream) AsyncNextFrame(callback AsyncFrameHandler) { s.AsyncFlush(func(err error) { if errors.Is(err, ErrMessageTooBig) { s.AsyncClose(CloseGoingAway, "payload too big", func(err error) {}) - cb(ErrMessageTooBig, nil) + callback(ErrMessageTooBig, nil) return } @@ -229,28 +244,35 @@ func (s *WebsocketStream) AsyncNextFrame(cb AsyncFrameHandler) { } if err == nil { - s.asyncNextFrame(cb) + s.asyncNextFrame(callback) } else { s.state = StateTerminated - cb(err, nil) + callback(err, nil) } }) } -func (s *WebsocketStream) asyncNextFrame(cb AsyncFrameHandler) { - s.cs.AsyncReadNext(func(err error, f Frame) { +func (s *WebsocketStream) asyncNextFrame(callback AsyncFrameHandler) { + s.codecConn.AsyncReadNext(func(err error, f Frame) { if err == nil { err = s.handleFrame(f) } else if err == io.EOF { s.state = StateTerminated } - cb(err, f) + callback(err, f) }) } -func (s *WebsocketStream) NextMessage( - b []byte, -) (mt MessageType, readBytes int, err error) { +// NextMessage reads the payload of the next message into the supplied buffer. Message fragmentation is automatically +// handled by the implementation. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message from the underlying stream +// - the payload of the message is successfully read into the supplied buffer, after all message fragments are read +func (s *WebsocketStream) NextMessage(b []byte) (mt MessageType, readBytes int, err error) { var ( f Frame continuation = false @@ -265,8 +287,8 @@ func (s *WebsocketStream) NextMessage( } if f.Opcode().IsControl() { - if s.ccb != nil { - s.ccb(MessageType(f.Opcode()), f.Payload()) + if s.controlCallback != nil { + s.controlCallback(MessageType(f.Opcode()), f.Payload()) } } else { if mt == TypeNone { @@ -308,30 +330,39 @@ func (s *WebsocketStream) NextMessage( return } -func (s *WebsocketStream) AsyncNextMessage(b []byte, cb AsyncMessageHandler) { - s.asyncNextMessage(b, 0, false, TypeNone, cb) +// AsyncNextMessage reads the payload of the next message into the supplied buffer asynchronously. Message fragmentation +// is automatically handled by the implementation. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message bytes from the underlying stream +// - the payload of the message is successfully read into the supplied buffer, after all message fragments are read +func (s *WebsocketStream) AsyncNextMessage(b []byte, callback AsyncMessageHandler) { + s.asyncNextMessage(b, 0, false, TypeNone, callback) } func (s *WebsocketStream) asyncNextMessage( b []byte, readBytes int, continuation bool, - mt MessageType, - cb AsyncMessageHandler, + messageType MessageType, + callback AsyncMessageHandler, ) { s.AsyncNextFrame(func(err error, f Frame) { if err != nil { - cb(err, readBytes, mt) + callback(err, readBytes, messageType) } else { if f.Opcode().IsControl() { - if s.ccb != nil { - s.ccb(MessageType(f.Opcode()), f.Payload()) + if s.controlCallback != nil { + s.controlCallback(MessageType(f.Opcode()), f.Payload()) } - s.asyncNextMessage(b, readBytes, continuation, mt, cb) + s.asyncNextMessage(b, readBytes, continuation, messageType, callback) } else { - if mt == TypeNone { - mt = MessageType(f.Opcode()) + if messageType == TypeNone { + messageType = MessageType(f.Opcode()) } n := copy(b[readBytes:], f.Payload()) @@ -344,7 +375,7 @@ func (s *WebsocketStream) asyncNextMessage( "payload too big", func(err error) {}, ) - cb(err, readBytes, mt) + callback(err, readBytes, messageType) return } @@ -364,9 +395,9 @@ func (s *WebsocketStream) asyncNextMessage( } if err != nil || !continuation { - cb(err, readBytes, mt) + callback(err, readBytes, messageType) } else { - s.asyncNextMessage(b, readBytes, continuation, mt, cb) + s.asyncNextMessage(b, readBytes, continuation, messageType, callback) } } } @@ -452,6 +483,14 @@ func (s *WebsocketStream) handleDataFrame(f Frame) error { return nil } +// Write writes the supplied buffer as a single message with the given type to the underlying stream. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the message is successfully written to the underlying stream func (s *WebsocketStream) Write(b []byte, mt MessageType) error { if len(b) > MaxMessageSize { return ErrMessageTooBig @@ -470,6 +509,14 @@ func (s *WebsocketStream) Write(b []byte, mt MessageType) error { return sonicerrors.ErrCancelled } +// WriteFrame writes the supplied frame to the underlying stream. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the frame is successfully written to the underlying stream func (s *WebsocketStream) WriteFrame(f *Frame) error { if s.state == StateActive { s.prepareWrite(f) @@ -480,35 +527,48 @@ func (s *WebsocketStream) WriteFrame(f *Frame) error { } } -func (s *WebsocketStream) AsyncWrite( - b []byte, - mt MessageType, - cb func(err error), -) { +// AsyncWrite writes the supplied buffer as a single message with the given type to the underlying stream +// asynchronously. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the message is successfully written to the underlying stream +func (s *WebsocketStream) AsyncWrite(b []byte, messageType MessageType, callback func(err error)) { if len(b) > MaxMessageSize { - cb(ErrMessageTooBig) + callback(ErrMessageTooBig) return } if s.state == StateActive { f := s.AcquireFrame(). SetFIN(). - SetOpcode(Opcode(mt)). + SetOpcode(Opcode(messageType)). SetPayload(b) s.prepareWrite(f) - s.AsyncFlush(cb) + s.AsyncFlush(callback) } else { - cb(sonicerrors.ErrCancelled) + callback(sonicerrors.ErrCancelled) } } -func (s *WebsocketStream) AsyncWriteFrame(f *Frame, cb func(err error)) { +// AsyncWriteFrame writes the supplied frame to the underlying stream asynchronously. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the frame is successfully written to the underlying stream +func (s *WebsocketStream) AsyncWriteFrame(f *Frame, callback func(err error)) { if s.state == StateActive { s.prepareWrite(f) - s.AsyncFlush(cb) + s.AsyncFlush(callback) } else { s.releaseFrame(f) - cb(sonicerrors.ErrCancelled) + callback(sonicerrors.ErrCancelled) } } @@ -519,23 +579,43 @@ func (s *WebsocketStream) prepareWrite(f *Frame) { s.pendingFrames = append(s.pendingFrames, f) } -func (s *WebsocketStream) AsyncClose( - cc CloseCode, - reason string, - cb func(err error), -) { +// AsyncClose sends a websocket close control frame asynchronously. +// +// This function is used to send a close frame which begins the WebSocket closing handshake. The session ends when both +// ends of the connection have sent and received a close frame. +// +// The callback is called if one of the following conditions is true: +// - the close frame is written +// - an error occurs +// +// After beginning the closing handshake, the program should not write further messages, pings, pongs or close +// frames. Instead, the program should continue reading messages until the closing handshake is complete or an error +// occurs. +func (s *WebsocketStream) AsyncClose(closeCode CloseCode, reason string, callback func(err error)) { switch s.state { case StateActive: s.state = StateClosedByUs - s.prepareClose(EncodeCloseFramePayload(cc, reason)) - s.AsyncFlush(cb) + s.prepareClose(EncodeCloseFramePayload(closeCode, reason)) + s.AsyncFlush(callback) case StateClosedByUs, StateHandshake: - cb(sonicerrors.ErrCancelled) + callback(sonicerrors.ErrCancelled) default: - cb(io.EOF) - } -} - + callback(io.EOF) + } +} + +// Close sends a websocket close control frame asynchronously. +// +// This function is used to send a close frame which begins the WebSocket closing handshake. The session ends when both +// ends of the connection have sent and received a close frame. +// +// The call blocks until one of the following conditions is true: +// - the close frame is written +// - an error occurs +// +// After beginning the closing handshake, the program should not write further messages, pings, pongs or close +// frames. Instead, the program should continue reading messages until the closing handshake is complete or an error +// occurs. func (s *WebsocketStream) Close(cc CloseCode, reason string) error { switch s.state { case StateActive: @@ -557,10 +637,13 @@ func (s *WebsocketStream) prepareClose(payload []byte) { s.prepareWrite(closeFrame) } +// Flush writes any pending control frames to the underlying stream. +// +// This call blocks. func (s *WebsocketStream) Flush() (err error) { flushed := 0 for i := 0; i < len(s.pendingFrames); i++ { - _, err = s.cs.WriteNext(*s.pendingFrames[i]) + _, err = s.codecConn.WriteNext(*s.pendingFrames[i]) if err != nil { break } @@ -572,25 +655,29 @@ func (s *WebsocketStream) Flush() (err error) { return } -func (s *WebsocketStream) AsyncFlush(cb func(err error)) { +// Flush writes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. +func (s *WebsocketStream) AsyncFlush(callback func(err error)) { if len(s.pendingFrames) == 0 { - cb(nil) + callback(nil) } else { sent := s.pendingFrames[0] s.pendingFrames = s.pendingFrames[1:] - s.cs.AsyncWriteNext(*sent, func(err error, _ int) { + s.codecConn.AsyncWriteNext(*sent, func(err error, _ int) { s.releaseFrame(sent) if err != nil { - cb(err) + callback(err) } else { - s.AsyncFlush(cb) + s.AsyncFlush(callback) } }) } } +// Pending returns the number of currently pending control frames waiting to be flushed. func (s *WebsocketStream) Pending() int { return len(s.pendingFrames) } @@ -599,10 +686,12 @@ func (s *WebsocketStream) State() StreamState { return s.state } -func (s *WebsocketStream) Handshake( - addr string, - extraHeaders ...Header, -) (err error) { +// Handshake performs the client handshake. This call blocks. +// +// The call blocks until one of the following conditions is true: +// - the HTTP1.1 request is sent and the response is received +// - an error occurs +func (s *WebsocketStream) Handshake(addr string, extraHeaders ...Header) (err error) { if s.role != RoleClient { return ErrWrongHandshakeRole } @@ -629,13 +718,13 @@ func (s *WebsocketStream) Handshake( return } -func (s *WebsocketStream) AsyncHandshake( - addr string, - cb func(error), - extraHeaders ...Header, -) { +// AsyncHandshake performs the WebSocket handshake asynchronously in the client role. +// +// This call does not block. The provided callback is called when the request is sent and the response is +// received or when an error occurs. +func (s *WebsocketStream) AsyncHandshake(addr string, callback func(error), extraHeaders ...Header) { if s.role != RoleClient { - cb(ErrWrongHandshakeRole) + callback(ErrWrongHandshakeRole) return } @@ -653,26 +742,22 @@ func (s *WebsocketStream) AsyncHandshake( s.state = StateActive err = s.init(stream) } - cb(err) + callback(err) }) }) }() } -func (s *WebsocketStream) handshake( - addr string, - headers []Header, - cb func(err error, stream sonic.Stream), -) { +func (s *WebsocketStream) handshake(addr string, headers []Header, callback func(err error, stream sonic.Stream)) { url, err := s.resolve(addr) if err != nil { - cb(err, nil) + callback(err, nil) } else { s.dial(url, func(err error, stream sonic.Stream) { if err == nil { err = s.upgrade(url, stream, headers) } - cb(err, stream) + callback(err, stream) }) } } @@ -693,10 +778,7 @@ func (s *WebsocketStream) resolve(addr string) (url *url.URL, err error) { return } -func (s *WebsocketStream) dial( - url *url.URL, - cb func(err error, stream sonic.Stream), -) { +func (s *WebsocketStream) dial(url *url.URL, callback func(err error, stream sonic.Stream)) { var ( err error sc syscall.Conn @@ -750,18 +832,14 @@ func (s *WebsocketStream) dial( // condition on the io context. sonic.NewAsyncAdapter( s.ioc, sc, s.conn, func(err error, stream *sonic.AsyncAdapter) { - cb(err, stream) + callback(err, stream) }, sonicopts.NoDelay(true)) } else { - cb(err, nil) + callback(err, nil) } } -func (s *WebsocketStream) upgrade( - uri *url.URL, - stream sonic.Stream, - headers []Header, -) error { +func (s *WebsocketStream) upgrade(uri *url.URL, stream sonic.Stream, headers []Header) error { req, err := http.NewRequest("GET", uri.String(), nil) if err != nil { return err @@ -788,8 +866,8 @@ func (s *WebsocketStream) upgrade( } } - if s.upReqCb != nil { - s.upReqCb(req) + if s.upgradeRequestCallback != nil { + s.upgradeRequestCallback(req) } err = req.Write(stream) @@ -797,13 +875,13 @@ func (s *WebsocketStream) upgrade( return err } - s.hb = s.hb[:cap(s.hb)] - n, err := stream.Read(s.hb) + s.handshakeBuffer = s.handshakeBuffer[:cap(s.handshakeBuffer)] + n, err := stream.Read(s.handshakeBuffer) if err != nil { return err } - s.hb = s.hb[:n] - rd := bytes.NewReader(s.hb) + s.handshakeBuffer = s.handshakeBuffer[:n] + rd := bytes.NewReader(s.handshakeBuffer) res, err := http.ReadResponse(bufio.NewReader(rd), req) if err != nil { return err @@ -815,17 +893,17 @@ func (s *WebsocketStream) upgrade( } resLen := len(rawRes) - extra := len(s.hb) - resLen + extra := len(s.handshakeBuffer) - resLen if extra > 0 { // we got some frames as well with the handshake so we can put // them in src for later decoding before clearing the handshake // buffer - _, _ = s.src.Write(s.hb[resLen:]) + _, _ = s.src.Write(s.handshakeBuffer[resLen:]) } - s.hb = s.hb[:0] + s.handshakeBuffer = s.handshakeBuffer[:0] - if s.upResCb != nil { - s.upResCb(res) + if s.upgradeResponseCallback != nil { + s.upgradeResponseCallback(res) } if !IsUpgradeRes(res) { @@ -859,38 +937,46 @@ func (s *WebsocketStream) makeHandshakeKey() (req, res string) { return } -func (s *WebsocketStream) Accept() error { - panic("implement me") -} - -func (s *WebsocketStream) AsyncAccept(func(error)) { - panic("implement me") -} - -func (s *WebsocketStream) SetControlCallback(ccb ControlCallback) { - s.ccb = ccb +// SetControlCallback sets a function that will be invoked when a Ping/Pong/Close is received while reading a +// message. This callback is only invoked when reading complete messages, not frames. +// +// The caller must not perform any operations on the stream in the provided callback. +func (s *WebsocketStream) SetControlCallback(controlCallback ControlCallback) { + s.controlCallback = controlCallback } func (s *WebsocketStream) ControlCallback() ControlCallback { - return s.ccb + return s.controlCallback } -func (s *WebsocketStream) SetUpgradeRequestCallback(upReqCb UpgradeRequestCallback) { - s.upReqCb = upReqCb +// SetUpgradeRequestCallback sets a function that will be invoked during the handshake just before the upgrade request +// is sent. +// +// The caller must not perform any operations on the stream in the provided callback. +func (s *WebsocketStream) SetUpgradeRequestCallback(upgradeRequestCallback UpgradeRequestCallback) { + s.upgradeRequestCallback = upgradeRequestCallback } func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback { - return s.upReqCb + return s.upgradeRequestCallback } -func (s *WebsocketStream) SetUpgradeResponseCallback(upResCb UpgradeResponseCallback) { - s.upResCb = upResCb +// SetUpgradeResponseCallback sets a function that will be invoked during the handshake just after the upgrade response +// is received. +// +// The caller must not perform any operations on the stream in the provided callback. +func (s *WebsocketStream) SetUpgradeResponseCallback(upgradeResponseCallback UpgradeResponseCallback) { + s.upgradeResponseCallback = upgradeResponseCallback } func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback { - return s.upResCb + return s.upgradeResponseCallback } +// SetMaxMessageSize sets the maximum size of a message that can be read from or written to a peer. +// +// - If a message exceeds the limit while reading, the connection is closed abnormally. +// - If a message exceeds the limit while writing, the operation is cancelled. func (s *WebsocketStream) SetMaxMessageSize(bytes int) { // This is just for checking against the length returned in the frame // header. The sizes of the buffers in which we read or write the messages diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index ac376085..ccd0d44c 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -12,7 +12,7 @@ import ( "github.com/talostrading/sonic" ) -func assertState(t *testing.T, ws Stream, expected StreamState) { +func assertState(t *testing.T, ws *WebsocketStream, expected StreamState) { if ws.State() != expected { t.Fatalf("wrong state: given=%s expected=%s ", ws.State(), expected) } @@ -938,7 +938,7 @@ func TestClientAsyncWriteFrame(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := newFrame() + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { @@ -983,7 +983,7 @@ func TestClientWrite(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := newFrame() + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { @@ -1023,7 +1023,7 @@ func TestClientAsyncWrite(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := newFrame() + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) @@ -1063,7 +1063,7 @@ func TestClientClose(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := newFrame() + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) @@ -1108,7 +1108,7 @@ func TestClientAsyncClose(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := newFrame() + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) @@ -1157,7 +1157,7 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { mock.b.Commit(mock.b.WriteLen()) - serverReply := newFrame() + serverReply := NewFrame() serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1206,7 +1206,7 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - serverReply := newFrame() + serverReply := NewFrame() serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1251,7 +1251,7 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { ws.state = StateActive ws.init(mock) - serverClose := newFrame() + serverClose := NewFrame() serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1297,7 +1297,7 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { ws.state = StateActive ws.init(mock) - serverClose := newFrame() + serverClose := NewFrame() serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index 4ddc5c32..961a85bd 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -73,7 +73,7 @@ func (s *MockServer) Accept(addr string) (err error) { } func (s *MockServer) Write(b []byte) error { - f := newFrame() + f := NewFrame() f.SetText() f.SetPayload(b) f.SetFIN() @@ -82,7 +82,7 @@ func (s *MockServer) Write(b []byte) error { } func (s *MockServer) Read(b []byte) (n int, err error) { - f := newFrame() + f := NewFrame() _, err = f.ReadFrom(s.conn) if err == nil { From baf4a8056e50c2a8123a4e2940421d2a5968a4d2 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:08:20 +0100 Subject: [PATCH 08/35] [websocket] Note on improving frame serialization --- codec/websocket/frame_codec.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index 8eb2edbd..45d11be9 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -97,7 +97,9 @@ func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (Frame, error) { // Encode encodes the `Frame` into `dst`. func (c *FrameCodec) Encode(frame Frame, dst *sonic.ByteBuffer) error { - // ensure the destination buffer can hold the serialized frame // TODO this can be improved + // TODO this can be improved: we can serialize directly in the buffer with zero-copy semantics + + // ensure the destination buffer can hold the serialized frame dst.Reserve(frame.PayloadLength() + MaxFrameHeaderLengthInBytes) n, err := frame.WriteTo(dst) From cf3af70e1dd64a993de079ab5cae66edc26bec2a Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:10:08 +0100 Subject: [PATCH 09/35] [websocket] Delete unused field from WebsocketStream struct --- codec/websocket/stream.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 342bae00..5f888d13 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -37,7 +37,6 @@ import ( ) type WebsocketStream struct { - // async operations executor. ioc *sonic.IO // User provided TLS config; nil if we don't use TLS @@ -84,9 +83,6 @@ type WebsocketStream struct { // Used to establish a TCP connection to the peer with a timeout. dialer *net.Dialer - // The size of the currently read message. - messageSize int - framePool sync.Pool } From d40d4ec67be157bdd9b3da5bae5e90ffe50d45b5 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:19:22 +0100 Subject: [PATCH 10/35] [websocket] MaxMessageSize is now set per stream And not globally, per all streams. --- codec/websocket/definitions.go | 8 ++++---- codec/websocket/frame.go | 2 -- codec/websocket/frame_codec.go | 16 +++++++++------- codec/websocket/frame_codec_test.go | 8 ++++---- codec/websocket/stream.go | 23 ++++++++++++++++------- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index be8a0080..b9564ac6 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -5,10 +5,10 @@ import ( "time" ) -var ( - MaxMessageSize = 1024 * 512 // the maximum size of a message - CloseTimeout = 5 * time.Second - DialTimeout = 5 * time.Second +const ( + DefaultMaxMessageSize = 1024 * 512 + CloseTimeout = 5 * time.Second + DialTimeout = 5 * time.Second ) type Role uint8 diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 6400e096..e9efaac0 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -188,8 +188,6 @@ func (f *Frame) fitPayload() ([]byte, error) { length := f.PayloadLength() if length <= 0 { return nil, nil - } else if length > MaxMessageSize { - return nil, ErrPayloadTooBig } *f = util.ExtendSlice(*f, f.payloadStartIndex()+length) diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index 45d11be9..9c5c7a11 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -14,18 +14,20 @@ var ( // FrameCodec is a stateful streaming parser handling the encoding and decoding of WebSocket `Frame`s. type FrameCodec struct { - src *sonic.ByteBuffer // buffer we decode from - dst *sonic.ByteBuffer // buffer we encode to + src *sonic.ByteBuffer // buffer we decode from + dst *sonic.ByteBuffer // buffer we encode to + maxMessageSize int decodeFrame Frame // frame we decode into decodeReset bool // true if we must reset the state on the next decode } -func NewFrameCodec(src, dst *sonic.ByteBuffer) *FrameCodec { +func NewFrameCodec(src, dst *sonic.ByteBuffer, maxMessageSize int) *FrameCodec { return &FrameCodec{ - decodeFrame: NewFrame(), - src: src, - dst: dst, + decodeFrame: NewFrame(), + src: src, + dst: dst, + maxMessageSize: maxMessageSize, } } @@ -66,7 +68,7 @@ func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (Frame, error) { c.decodeFrame = src.Data()[:readSoFar] payloadLength := c.decodeFrame.PayloadLength() - if payloadLength > MaxMessageSize { + if payloadLength > c.maxMessageSize { c.decodeFrame = nil return nil, ErrPayloadOverMaxSize } diff --git a/codec/websocket/frame_codec_test.go b/codec/websocket/frame_codec_test.go index 8fa3f138..3ee14605 100644 --- a/codec/websocket/frame_codec_test.go +++ b/codec/websocket/frame_codec_test.go @@ -13,7 +13,7 @@ func TestDecodeShortFrame(t *testing.T) { src := sonic.NewByteBuffer() src.Write([]byte{0x81, 1}) // fin=1 opcode=1 (text) payload_len=1 - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) f, err := codec.Decode(src) if !errors.Is(err, sonicerrors.ErrNeedMore) { @@ -34,7 +34,7 @@ func TestDecodeExactlyOneFrame(t *testing.T) { src := sonic.NewByteBuffer() src.Write([]byte{0x81, 1, 0xFF}) // fin=1 opcode=1 (text) payload_len=1 - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) f, err := codec.Decode(src) if err != nil { @@ -62,7 +62,7 @@ func TestDecodeOneAndShortFrame(t *testing.T) { // fin=1 opcode=1 (text) payload_len=1 src.Write([]byte{0x81, 1, 0xFF, 0xFF, 0xFF, 0xFF}) - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) f, err := codec.Decode(src) if err != nil { @@ -92,7 +92,7 @@ func TestDecodeTwoFrames(t *testing.T) { 0x81, 5, 0x01, 0x02, 0x03, 0x04, 0x05, // second complete frame 0x81, 10}) // third short frame - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) if src.WriteLen() != 12 { t.Fatal("should have 12 bytes in the write area") diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 5f888d13..96821cb8 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -47,6 +47,7 @@ type WebsocketStream struct { conn net.Conn // Codec stream wrapping the underlying transport stream. + codec *FrameCodec codecConn *sonic.CodecConn[Frame, Frame] // Websocket role: client or server. @@ -84,6 +85,8 @@ type WebsocketStream struct { dialer *net.Dialer framePool sync.Pool + + maxMessageSize int } func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *WebsocketStream, err error) { @@ -106,6 +109,7 @@ func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *Websocket return &frame }, }, + maxMessageSize: DefaultMaxMessageSize, } s.src.Reserve(4096) @@ -145,8 +149,8 @@ func (s *WebsocketStream) init(stream sonic.Stream) (err error) { } s.stream = stream - codec := NewFrameCodec(s.src, s.dst) - s.codecConn, err = sonic.NewCodecConn[Frame, Frame](stream, codec, s.src, s.dst) + s.codec = NewFrameCodec(s.src, s.dst, s.maxMessageSize) + s.codecConn, err = sonic.NewCodecConn[Frame, Frame](stream, s.codec, s.src, s.dst) return } @@ -294,7 +298,7 @@ func (s *WebsocketStream) NextMessage(b []byte) (mt MessageType, readBytes int, n := copy(b[readBytes:], f.Payload()) readBytes += n - if readBytes > MaxMessageSize || n != f.PayloadLength() { + if readBytes > s.maxMessageSize || n != f.PayloadLength() { err = ErrMessageTooBig _ = s.Close(CloseGoingAway, "payload too big") break @@ -364,7 +368,7 @@ func (s *WebsocketStream) asyncNextMessage( n := copy(b[readBytes:], f.Payload()) readBytes += n - if readBytes > MaxMessageSize || n != f.PayloadLength() { + if readBytes > s.maxMessageSize || n != f.PayloadLength() { err = ErrMessageTooBig s.AsyncClose( CloseGoingAway, @@ -488,7 +492,7 @@ func (s *WebsocketStream) handleDataFrame(f Frame) error { // - an error occurs during the write // - the message is successfully written to the underlying stream func (s *WebsocketStream) Write(b []byte, mt MessageType) error { - if len(b) > MaxMessageSize { + if len(b) > s.maxMessageSize { return ErrMessageTooBig } @@ -533,7 +537,7 @@ func (s *WebsocketStream) WriteFrame(f *Frame) error { // - an error occurs during the write // - the message is successfully written to the underlying stream func (s *WebsocketStream) AsyncWrite(b []byte, messageType MessageType, callback func(err error)) { - if len(b) > MaxMessageSize { + if len(b) > s.maxMessageSize { callback(ErrMessageTooBig) return } @@ -977,7 +981,12 @@ func (s *WebsocketStream) SetMaxMessageSize(bytes int) { // This is just for checking against the length returned in the frame // header. The sizes of the buffers in which we read or write the messages // are dynamically adjusted in frame_codec. - MaxMessageSize = bytes + s.maxMessageSize = bytes + s.codec.maxMessageSize = bytes +} + +func (s *WebsocketStream) MaxMessageSize() int { + return s.maxMessageSize } func (s *WebsocketStream) RemoteAddr() net.Addr { From 3600984edb0cc8fb6d7ff59ef2efb15e1be7634a Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:23:43 +0100 Subject: [PATCH 11/35] [websocket] Prefer "callback" to "handler" --- codec/websocket/definitions.go | 6 +++--- codec/websocket/stream.go | 35 +++++++++++++++++----------------- examples/binance/main.go | 2 +- examples/okex/main.go | 2 +- tests/autobahn/client.go | 2 +- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index b9564ac6..59ef9b5c 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -99,9 +99,9 @@ func (s StreamState) String() string { } } -type AsyncMessageHandler = func(err error, n int, mt MessageType) -type AsyncFrameHandler = func(err error, f Frame) -type ControlCallback = func(mt MessageType, payload []byte) +type AsyncMessageCallback = func(err error, n int, messageType MessageType) +type AsyncFrameCallback = func(err error, f Frame) +type ControlCallback = func(messageType MessageType, payload []byte) type UpgradeRequestCallback = func(req *http.Request) type UpgradeResponseCallback = func(res *http.Response) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 96821cb8..3ad7db38 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -65,8 +65,7 @@ type WebsocketStream struct { // Buffer for stream writes. dst *sonic.ByteBuffer - // Contains the handshake response. Is emptied after the - // handshake is over. + // Contains the handshake response. Is emptied after the handshake is over. handshakeBuffer []byte // Contains frames waiting to be sent to the peer. Is emptied by AsyncFlush or Flush. @@ -231,7 +230,7 @@ func (s *WebsocketStream) nextFrame() (f Frame, err error) { // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message bytes from the underlying stream // - a frame is successfully read from the underlying stream -func (s *WebsocketStream) AsyncNextFrame(callback AsyncFrameHandler) { +func (s *WebsocketStream) AsyncNextFrame(callback AsyncFrameCallback) { s.AsyncFlush(func(err error) { if errors.Is(err, ErrMessageTooBig) { s.AsyncClose(CloseGoingAway, "payload too big", func(err error) {}) @@ -252,7 +251,7 @@ func (s *WebsocketStream) AsyncNextFrame(callback AsyncFrameHandler) { }) } -func (s *WebsocketStream) asyncNextFrame(callback AsyncFrameHandler) { +func (s *WebsocketStream) asyncNextFrame(callback AsyncFrameCallback) { s.codecConn.AsyncReadNext(func(err error, f Frame) { if err == nil { err = s.handleFrame(f) @@ -272,18 +271,17 @@ func (s *WebsocketStream) asyncNextFrame(callback AsyncFrameHandler) { // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message from the underlying stream // - the payload of the message is successfully read into the supplied buffer, after all message fragments are read -func (s *WebsocketStream) NextMessage(b []byte) (mt MessageType, readBytes int, err error) { +func (s *WebsocketStream) NextMessage(b []byte) (messageType MessageType, readBytes int, err error) { var ( f Frame continuation = false ) - - mt = TypeNone + messageType = TypeNone for { f, err = s.NextFrame() if err != nil { - return mt, readBytes, err + return messageType, readBytes, err } if f.Opcode().IsControl() { @@ -291,8 +289,8 @@ func (s *WebsocketStream) NextMessage(b []byte) (mt MessageType, readBytes int, s.controlCallback(MessageType(f.Opcode()), f.Payload()) } } else { - if mt == TypeNone { - mt = MessageType(f.Opcode()) + if messageType == TypeNone { + messageType = MessageType(f.Opcode()) } n := copy(b[readBytes:], f.Payload()) @@ -339,7 +337,7 @@ func (s *WebsocketStream) NextMessage(b []byte) (mt MessageType, readBytes int, // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message bytes from the underlying stream // - the payload of the message is successfully read into the supplied buffer, after all message fragments are read -func (s *WebsocketStream) AsyncNextMessage(b []byte, callback AsyncMessageHandler) { +func (s *WebsocketStream) AsyncNextMessage(b []byte, callback AsyncMessageCallback) { s.asyncNextMessage(b, 0, false, TypeNone, callback) } @@ -348,7 +346,7 @@ func (s *WebsocketStream) asyncNextMessage( readBytes int, continuation bool, messageType MessageType, - callback AsyncMessageHandler, + callback AsyncMessageCallback, ) { s.AsyncNextFrame(func(err error, f Frame) { if err != nil { @@ -491,7 +489,7 @@ func (s *WebsocketStream) handleDataFrame(f Frame) error { // - an error occurs while flushing the pending control frames // - an error occurs during the write // - the message is successfully written to the underlying stream -func (s *WebsocketStream) Write(b []byte, mt MessageType) error { +func (s *WebsocketStream) Write(b []byte, messageType MessageType) error { if len(b) > s.maxMessageSize { return ErrMessageTooBig } @@ -500,7 +498,7 @@ func (s *WebsocketStream) Write(b []byte, mt MessageType) error { // reserve space for mask if client f := s.AcquireFrame(). SetFIN(). - SetOpcode(Opcode(mt)). + SetOpcode(Opcode(messageType)). SetPayload(b) s.prepareWrite(f) return s.Flush() @@ -691,6 +689,8 @@ func (s *WebsocketStream) State() StreamState { // The call blocks until one of the following conditions is true: // - the HTTP1.1 request is sent and the response is received // - an error occurs +// +// Extra headers should be generated by calling `ExtraHeader(...)`. func (s *WebsocketStream) Handshake(addr string, extraHeaders ...Header) (err error) { if s.role != RoleClient { return ErrWrongHandshakeRole @@ -722,6 +722,8 @@ func (s *WebsocketStream) Handshake(addr string, extraHeaders ...Header) (err er // // This call does not block. The provided callback is called when the request is sent and the response is // received or when an error occurs. +// +// Extra headers should be generated by calling `ExtraHeader(...)`. func (s *WebsocketStream) AsyncHandshake(addr string, callback func(error), extraHeaders ...Header) { if s.role != RoleClient { callback(ErrWrongHandshakeRole) @@ -780,9 +782,8 @@ func (s *WebsocketStream) resolve(addr string) (url *url.URL, err error) { func (s *WebsocketStream) dial(url *url.URL, callback func(err error, stream sonic.Stream)) { var ( - err error - sc syscall.Conn - + err error + sc syscall.Conn port = url.Port() ) diff --git a/examples/binance/main.go b/examples/binance/main.go index c0537f65..93e4343f 100644 --- a/examples/binance/main.go +++ b/examples/binance/main.go @@ -44,7 +44,7 @@ func onWrite(err error, stream websocket.Stream) { } func readLoop(stream websocket.Stream) { - var onRead websocket.AsyncMessageHandler + var onRead websocket.AsyncMessageCallback onRead = func(err error, n int, _ websocket.MessageType) { if err != nil { panic(err) diff --git a/examples/okex/main.go b/examples/okex/main.go index d9dc726f..f0a6b06d 100644 --- a/examples/okex/main.go +++ b/examples/okex/main.go @@ -48,7 +48,7 @@ func onWrite(err error, stream websocket.Stream) { } func readLoop(stream websocket.Stream) { - var onRead websocket.AsyncMessageHandler + var onRead websocket.AsyncMessageCallback onRead = func(err error, n int, _ websocket.MessageType) { if err != nil { panic(err) diff --git a/tests/autobahn/client.go b/tests/autobahn/client.go index 3c11b3b4..bdf1781d 100644 --- a/tests/autobahn/client.go +++ b/tests/autobahn/client.go @@ -92,7 +92,7 @@ func runTest(i int) { b := make([]byte, 1024*1024) - var onAsyncRead websocket.AsyncMessageHandler + var onAsyncRead websocket.AsyncMessageCallback onAsyncRead = func(err error, n int, mt websocket.MessageType) { if err != nil { From d2a784a1b36db0269e1da806c0679dc1679bb434 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:25:08 +0100 Subject: [PATCH 12/35] [websocket] Rename WebsocketStream to Stream Will be refered to as websocket.Stream. --- codec/websocket/stream.go | 108 ++++++++++++++++----------------- codec/websocket/stream_test.go | 2 +- 2 files changed, 55 insertions(+), 55 deletions(-) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 3ad7db38..f2593f7c 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -36,7 +36,7 @@ import ( "github.com/talostrading/sonic/sonicopts" ) -type WebsocketStream struct { +type Stream struct { ioc *sonic.IO // User provided TLS config; nil if we don't use TLS @@ -88,8 +88,8 @@ type WebsocketStream struct { maxMessageSize int } -func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *WebsocketStream, err error) { - s = &WebsocketStream{ +func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *Stream, err error) { + s = &Stream{ ioc: ioc, tls: tls, role: role, @@ -119,7 +119,7 @@ func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *Websocket // The recommended way to create a `Frame`. This function takes care to allocate enough bytes to encode the WebSocket // header and apply client side masking if the `Stream`'s role is `RoleClient`. -func (s *WebsocketStream) AcquireFrame() *Frame { +func (s *Stream) AcquireFrame() *Frame { f := s.framePool.Get().(*Frame) if s.role == RoleClient { // This just reserves the 4 bytes needed for the mask in order to encode the payload correctly, since it follows @@ -129,20 +129,20 @@ func (s *WebsocketStream) AcquireFrame() *Frame { return f } -func (s *WebsocketStream) releaseFrame(f *Frame) { +func (s *Stream) releaseFrame(f *Frame) { f.Reset() s.framePool.Put(f) } // This stream is either a client or a server. The main differences are in how the opening/closing handshake is done and // in the fact that payloads sent by the client to the server are masked. -func (s *WebsocketStream) Role() Role { +func (s *Stream) Role() Role { return s.role } // init is run when we transition into StateActive which happens // after a successful handshake. -func (s *WebsocketStream) init(stream sonic.Stream) (err error) { +func (s *Stream) init(stream sonic.Stream) (err error) { if s.state != StateActive { return fmt.Errorf("stream must be in StateActive") } @@ -153,7 +153,7 @@ func (s *WebsocketStream) init(stream sonic.Stream) (err error) { return } -func (s *WebsocketStream) reset() { +func (s *Stream) reset() { s.handshakeBuffer = s.handshakeBuffer[:cap(s.handshakeBuffer)] s.state = StateHandshake s.stream = nil @@ -163,22 +163,22 @@ func (s *WebsocketStream) reset() { } // Returns the stream through which IO is done. -func (s *WebsocketStream) NextLayer() sonic.Stream { +func (s *Stream) NextLayer() sonic.Stream { if s.codecConn != nil { return s.codecConn.NextLayer() } return nil } -func (s *WebsocketStream) SupportsUTF8() bool { +func (s *Stream) SupportsUTF8() bool { return false } -func (s *WebsocketStream) SupportsDeflate() bool { +func (s *Stream) SupportsDeflate() bool { return false } -func (s *WebsocketStream) canRead() bool { +func (s *Stream) canRead() bool { // we can only read from non-terminal states after a successful handshake return s.state == StateActive || s.state == StateClosedByUs } @@ -191,7 +191,7 @@ func (s *WebsocketStream) canRead() bool { // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message bytes from the underlying stream // - a frame is successfully read from the underlying stream -func (s *WebsocketStream) NextFrame() (f Frame, err error) { +func (s *Stream) NextFrame() (f Frame, err error) { err = s.Flush() if errors.Is(err, ErrMessageTooBig) { @@ -214,7 +214,7 @@ func (s *WebsocketStream) NextFrame() (f Frame, err error) { return } -func (s *WebsocketStream) nextFrame() (f Frame, err error) { +func (s *Stream) nextFrame() (f Frame, err error) { f, err = s.codecConn.ReadNext() if err == nil { err = s.handleFrame(f) @@ -230,7 +230,7 @@ func (s *WebsocketStream) nextFrame() (f Frame, err error) { // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message bytes from the underlying stream // - a frame is successfully read from the underlying stream -func (s *WebsocketStream) AsyncNextFrame(callback AsyncFrameCallback) { +func (s *Stream) AsyncNextFrame(callback AsyncFrameCallback) { s.AsyncFlush(func(err error) { if errors.Is(err, ErrMessageTooBig) { s.AsyncClose(CloseGoingAway, "payload too big", func(err error) {}) @@ -251,7 +251,7 @@ func (s *WebsocketStream) AsyncNextFrame(callback AsyncFrameCallback) { }) } -func (s *WebsocketStream) asyncNextFrame(callback AsyncFrameCallback) { +func (s *Stream) asyncNextFrame(callback AsyncFrameCallback) { s.codecConn.AsyncReadNext(func(err error, f Frame) { if err == nil { err = s.handleFrame(f) @@ -271,7 +271,7 @@ func (s *WebsocketStream) asyncNextFrame(callback AsyncFrameCallback) { // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message from the underlying stream // - the payload of the message is successfully read into the supplied buffer, after all message fragments are read -func (s *WebsocketStream) NextMessage(b []byte) (messageType MessageType, readBytes int, err error) { +func (s *Stream) NextMessage(b []byte) (messageType MessageType, readBytes int, err error) { var ( f Frame continuation = false @@ -337,11 +337,11 @@ func (s *WebsocketStream) NextMessage(b []byte) (messageType MessageType, readBy // - an error occurs while flushing the pending control frames // - an error occurs when reading/decoding the message bytes from the underlying stream // - the payload of the message is successfully read into the supplied buffer, after all message fragments are read -func (s *WebsocketStream) AsyncNextMessage(b []byte, callback AsyncMessageCallback) { +func (s *Stream) AsyncNextMessage(b []byte, callback AsyncMessageCallback) { s.asyncNextMessage(b, 0, false, TypeNone, callback) } -func (s *WebsocketStream) asyncNextMessage( +func (s *Stream) asyncNextMessage( b []byte, readBytes int, continuation bool, @@ -402,7 +402,7 @@ func (s *WebsocketStream) asyncNextMessage( }) } -func (s *WebsocketStream) handleFrame(f Frame) (err error) { +func (s *Stream) handleFrame(f Frame) (err error) { err = s.verifyFrame(f) if err == nil { @@ -421,7 +421,7 @@ func (s *WebsocketStream) handleFrame(f Frame) (err error) { return err } -func (s *WebsocketStream) verifyFrame(f Frame) error { +func (s *Stream) verifyFrame(f Frame) error { if s.role == RoleClient && f.IsMasked() { return ErrMaskedFramesFromServer } @@ -433,7 +433,7 @@ func (s *WebsocketStream) verifyFrame(f Frame) error { return nil } -func (s *WebsocketStream) handleControlFrame(f Frame) (err error) { +func (s *Stream) handleControlFrame(f Frame) (err error) { if !f.IsFIN() { return ErrInvalidControlFrame } @@ -474,7 +474,7 @@ func (s *WebsocketStream) handleControlFrame(f Frame) (err error) { return } -func (s *WebsocketStream) handleDataFrame(f Frame) error { +func (s *Stream) handleDataFrame(f Frame) error { if f.Opcode().IsReserved() { return ErrReservedOpcode } @@ -489,7 +489,7 @@ func (s *WebsocketStream) handleDataFrame(f Frame) error { // - an error occurs while flushing the pending control frames // - an error occurs during the write // - the message is successfully written to the underlying stream -func (s *WebsocketStream) Write(b []byte, messageType MessageType) error { +func (s *Stream) Write(b []byte, messageType MessageType) error { if len(b) > s.maxMessageSize { return ErrMessageTooBig } @@ -515,7 +515,7 @@ func (s *WebsocketStream) Write(b []byte, messageType MessageType) error { // - an error occurs while flushing the pending control frames // - an error occurs during the write // - the frame is successfully written to the underlying stream -func (s *WebsocketStream) WriteFrame(f *Frame) error { +func (s *Stream) WriteFrame(f *Frame) error { if s.state == StateActive { s.prepareWrite(f) return s.Flush() @@ -534,7 +534,7 @@ func (s *WebsocketStream) WriteFrame(f *Frame) error { // - an error occurs while flushing the pending control frames // - an error occurs during the write // - the message is successfully written to the underlying stream -func (s *WebsocketStream) AsyncWrite(b []byte, messageType MessageType, callback func(err error)) { +func (s *Stream) AsyncWrite(b []byte, messageType MessageType, callback func(err error)) { if len(b) > s.maxMessageSize { callback(ErrMessageTooBig) return @@ -560,7 +560,7 @@ func (s *WebsocketStream) AsyncWrite(b []byte, messageType MessageType, callback // - an error occurs while flushing the pending control frames // - an error occurs during the write // - the frame is successfully written to the underlying stream -func (s *WebsocketStream) AsyncWriteFrame(f *Frame, callback func(err error)) { +func (s *Stream) AsyncWriteFrame(f *Frame, callback func(err error)) { if s.state == StateActive { s.prepareWrite(f) s.AsyncFlush(callback) @@ -570,7 +570,7 @@ func (s *WebsocketStream) AsyncWriteFrame(f *Frame, callback func(err error)) { } } -func (s *WebsocketStream) prepareWrite(f *Frame) { +func (s *Stream) prepareWrite(f *Frame) { if s.role == RoleClient { f.Mask() } @@ -589,7 +589,7 @@ func (s *WebsocketStream) prepareWrite(f *Frame) { // After beginning the closing handshake, the program should not write further messages, pings, pongs or close // frames. Instead, the program should continue reading messages until the closing handshake is complete or an error // occurs. -func (s *WebsocketStream) AsyncClose(closeCode CloseCode, reason string, callback func(err error)) { +func (s *Stream) AsyncClose(closeCode CloseCode, reason string, callback func(err error)) { switch s.state { case StateActive: s.state = StateClosedByUs @@ -614,7 +614,7 @@ func (s *WebsocketStream) AsyncClose(closeCode CloseCode, reason string, callbac // After beginning the closing handshake, the program should not write further messages, pings, pongs or close // frames. Instead, the program should continue reading messages until the closing handshake is complete or an error // occurs. -func (s *WebsocketStream) Close(cc CloseCode, reason string) error { +func (s *Stream) Close(cc CloseCode, reason string) error { switch s.state { case StateActive: s.state = StateClosedByUs @@ -627,7 +627,7 @@ func (s *WebsocketStream) Close(cc CloseCode, reason string) error { } } -func (s *WebsocketStream) prepareClose(payload []byte) { +func (s *Stream) prepareClose(payload []byte) { closeFrame := s.AcquireFrame(). SetFIN(). SetClose(). @@ -638,7 +638,7 @@ func (s *WebsocketStream) prepareClose(payload []byte) { // Flush writes any pending control frames to the underlying stream. // // This call blocks. -func (s *WebsocketStream) Flush() (err error) { +func (s *Stream) Flush() (err error) { flushed := 0 for i := 0; i < len(s.pendingFrames); i++ { _, err = s.codecConn.WriteNext(*s.pendingFrames[i]) @@ -656,7 +656,7 @@ func (s *WebsocketStream) Flush() (err error) { // Flush writes any pending control frames to the underlying stream asynchronously. // // This call does not block. -func (s *WebsocketStream) AsyncFlush(callback func(err error)) { +func (s *Stream) AsyncFlush(callback func(err error)) { if len(s.pendingFrames) == 0 { callback(nil) } else { @@ -676,11 +676,11 @@ func (s *WebsocketStream) AsyncFlush(callback func(err error)) { } // Pending returns the number of currently pending control frames waiting to be flushed. -func (s *WebsocketStream) Pending() int { +func (s *Stream) Pending() int { return len(s.pendingFrames) } -func (s *WebsocketStream) State() StreamState { +func (s *Stream) State() StreamState { return s.state } @@ -691,7 +691,7 @@ func (s *WebsocketStream) State() StreamState { // - an error occurs // // Extra headers should be generated by calling `ExtraHeader(...)`. -func (s *WebsocketStream) Handshake(addr string, extraHeaders ...Header) (err error) { +func (s *Stream) Handshake(addr string, extraHeaders ...Header) (err error) { if s.role != RoleClient { return ErrWrongHandshakeRole } @@ -724,7 +724,7 @@ func (s *WebsocketStream) Handshake(addr string, extraHeaders ...Header) (err er // received or when an error occurs. // // Extra headers should be generated by calling `ExtraHeader(...)`. -func (s *WebsocketStream) AsyncHandshake(addr string, callback func(error), extraHeaders ...Header) { +func (s *Stream) AsyncHandshake(addr string, callback func(error), extraHeaders ...Header) { if s.role != RoleClient { callback(ErrWrongHandshakeRole) return @@ -750,7 +750,7 @@ func (s *WebsocketStream) AsyncHandshake(addr string, callback func(error), extr }() } -func (s *WebsocketStream) handshake(addr string, headers []Header, callback func(err error, stream sonic.Stream)) { +func (s *Stream) handshake(addr string, headers []Header, callback func(err error, stream sonic.Stream)) { url, err := s.resolve(addr) if err != nil { callback(err, nil) @@ -764,7 +764,7 @@ func (s *WebsocketStream) handshake(addr string, headers []Header, callback func } } -func (s *WebsocketStream) resolve(addr string) (url *url.URL, err error) { +func (s *Stream) resolve(addr string) (url *url.URL, err error) { url, err = url.Parse(addr) if err == nil { switch url.Scheme { @@ -780,7 +780,7 @@ func (s *WebsocketStream) resolve(addr string) (url *url.URL, err error) { return } -func (s *WebsocketStream) dial(url *url.URL, callback func(err error, stream sonic.Stream)) { +func (s *Stream) dial(url *url.URL, callback func(err error, stream sonic.Stream)) { var ( err error sc syscall.Conn @@ -840,7 +840,7 @@ func (s *WebsocketStream) dial(url *url.URL, callback func(err error, stream son } } -func (s *WebsocketStream) upgrade(uri *url.URL, stream sonic.Stream, headers []Header) error { +func (s *Stream) upgrade(uri *url.URL, stream sonic.Stream, headers []Header) error { req, err := http.NewRequest("GET", uri.String(), nil) if err != nil { return err @@ -920,7 +920,7 @@ func (s *WebsocketStream) upgrade(uri *url.URL, stream sonic.Stream, headers []H // makeHandshakeKey generates the key of Sec-WebSocket-Key header as well as the // expected response present in Sec-WebSocket-Accept header. -func (s *WebsocketStream) makeHandshakeKey() (req, res string) { +func (s *Stream) makeHandshakeKey() (req, res string) { // request b := make([]byte, 16) _, _ = rand.Read(b) @@ -942,11 +942,11 @@ func (s *WebsocketStream) makeHandshakeKey() (req, res string) { // message. This callback is only invoked when reading complete messages, not frames. // // The caller must not perform any operations on the stream in the provided callback. -func (s *WebsocketStream) SetControlCallback(controlCallback ControlCallback) { +func (s *Stream) SetControlCallback(controlCallback ControlCallback) { s.controlCallback = controlCallback } -func (s *WebsocketStream) ControlCallback() ControlCallback { +func (s *Stream) ControlCallback() ControlCallback { return s.controlCallback } @@ -954,11 +954,11 @@ func (s *WebsocketStream) ControlCallback() ControlCallback { // is sent. // // The caller must not perform any operations on the stream in the provided callback. -func (s *WebsocketStream) SetUpgradeRequestCallback(upgradeRequestCallback UpgradeRequestCallback) { +func (s *Stream) SetUpgradeRequestCallback(upgradeRequestCallback UpgradeRequestCallback) { s.upgradeRequestCallback = upgradeRequestCallback } -func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback { +func (s *Stream) UpgradeRequestCallback() UpgradeRequestCallback { return s.upgradeRequestCallback } @@ -966,11 +966,11 @@ func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback { // is received. // // The caller must not perform any operations on the stream in the provided callback. -func (s *WebsocketStream) SetUpgradeResponseCallback(upgradeResponseCallback UpgradeResponseCallback) { +func (s *Stream) SetUpgradeResponseCallback(upgradeResponseCallback UpgradeResponseCallback) { s.upgradeResponseCallback = upgradeResponseCallback } -func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback { +func (s *Stream) UpgradeResponseCallback() UpgradeResponseCallback { return s.upgradeResponseCallback } @@ -978,7 +978,7 @@ func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback { // // - If a message exceeds the limit while reading, the connection is closed abnormally. // - If a message exceeds the limit while writing, the operation is cancelled. -func (s *WebsocketStream) SetMaxMessageSize(bytes int) { +func (s *Stream) SetMaxMessageSize(bytes int) { // This is just for checking against the length returned in the frame // header. The sizes of the buffers in which we read or write the messages // are dynamically adjusted in frame_codec. @@ -986,26 +986,26 @@ func (s *WebsocketStream) SetMaxMessageSize(bytes int) { s.codec.maxMessageSize = bytes } -func (s *WebsocketStream) MaxMessageSize() int { +func (s *Stream) MaxMessageSize() int { return s.maxMessageSize } -func (s *WebsocketStream) RemoteAddr() net.Addr { +func (s *Stream) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *WebsocketStream) LocalAddr() net.Addr { +func (s *Stream) LocalAddr() net.Addr { return s.conn.LocalAddr() } -func (s *WebsocketStream) RawFd() int { +func (s *Stream) RawFd() int { if s.NextLayer() != nil { return s.NextLayer().(*sonic.AsyncAdapter).RawFd() } return -1 } -func (s *WebsocketStream) CloseNextLayer() (err error) { +func (s *Stream) CloseNextLayer() (err error) { if s.conn != nil { err = s.conn.Close() s.conn = nil diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index ccd0d44c..fc52bcd1 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -12,7 +12,7 @@ import ( "github.com/talostrading/sonic" ) -func assertState(t *testing.T, ws *WebsocketStream, expected StreamState) { +func assertState(t *testing.T, ws *Stream, expected StreamState) { if ws.State() != expected { t.Fatalf("wrong state: given=%s expected=%s ", ws.State(), expected) } From b8c7a0dcb8796e2c2b3bff81eda01203b70b5d23 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 12:33:31 +0100 Subject: [PATCH 13/35] [websocket] Fix examples --- examples/binance/main.go | 68 +++++++++++++----------------------- examples/okex/main.go | 68 +++++++++++++----------------------- examples/websocket/client.go | 27 ++++++-------- 3 files changed, 61 insertions(+), 102 deletions(-) diff --git a/examples/binance/main.go b/examples/binance/main.go index 93e4343f..7d2be8f2 100644 --- a/examples/binance/main.go +++ b/examples/binance/main.go @@ -17,48 +17,6 @@ var subscriptionMessage = []byte( } `) -var b = make([]byte, 512*1024) // contains websocket payloads - -func run(stream websocket.Stream) { - stream.AsyncHandshake("wss://stream.binance.com:9443/ws", func(err error) { - onHandshake(err, stream) - }) -} - -func onHandshake(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { - onWrite(err, stream) - }) - } -} - -func onWrite(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - readLoop(stream) - } -} - -func readLoop(stream websocket.Stream) { - var onRead websocket.AsyncMessageCallback - onRead = func(err error, n int, _ websocket.MessageType) { - if err != nil { - panic(err) - } else { - b = b[:n] - fmt.Println(string(b)) - b = b[:cap(b)] - - stream.AsyncNextMessage(b, onRead) - } - } - stream.AsyncNextMessage(b, onRead) -} - func main() { ioc := sonic.MustIO() defer ioc.Close() @@ -68,7 +26,31 @@ func main() { panic(err) } - run(stream) + stream.AsyncHandshake("wss://stream.binance.com:9443/ws", func(err error) { + if err != nil { + panic(err) + } + + stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } + + var ( + b [1024 * 512]byte + onRead websocket.AsyncMessageCallback + ) + onRead = func(err error, n int, _ websocket.MessageType) { + if err != nil { + panic(err) + } + + fmt.Println(string(b[:n])) + stream.AsyncNextMessage(b[:], onRead) + } + stream.AsyncNextMessage(b[:], onRead) + }) + }) ioc.Run() } diff --git a/examples/okex/main.go b/examples/okex/main.go index f0a6b06d..36dfb8c0 100644 --- a/examples/okex/main.go +++ b/examples/okex/main.go @@ -21,48 +21,6 @@ var subscriptionMessage = []byte( } `) -var b = make([]byte, 512*1024) // contains websocket payloads - -func run(stream websocket.Stream) { - stream.AsyncHandshake("wss://ws.okx.com:8443/ws/v5/public", func(err error) { - onHandshake(err, stream) - }) -} - -func onHandshake(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { - onWrite(err, stream) - }) - } -} - -func onWrite(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - readLoop(stream) - } -} - -func readLoop(stream websocket.Stream) { - var onRead websocket.AsyncMessageCallback - onRead = func(err error, n int, _ websocket.MessageType) { - if err != nil { - panic(err) - } else { - b = b[:n] - fmt.Println(string(b)) - b = b[:cap(b)] - - stream.AsyncNextMessage(b, onRead) - } - } - stream.AsyncNextMessage(b, onRead) -} - func main() { ioc := sonic.MustIO() defer ioc.Close() @@ -72,7 +30,31 @@ func main() { panic(err) } - run(stream) + stream.AsyncHandshake("wss://ws.okx.com:8443/ws/v5/public", func(err error) { + if err != nil { + panic(err) + } + + stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } + + var ( + b [1024 * 512]byte + onRead websocket.AsyncMessageCallback + ) + onRead = func(err error, n int, _ websocket.MessageType) { + if err != nil { + panic(err) + } + + fmt.Println(string(b[:n])) + stream.AsyncNextMessage(b[:], onRead) + } + stream.AsyncNextMessage(b[:], onRead) + }) + }) ioc.Run() } diff --git a/examples/websocket/client.go b/examples/websocket/client.go index db666ce2..83a413da 100644 --- a/examples/websocket/client.go +++ b/examples/websocket/client.go @@ -19,26 +19,21 @@ func main() { client.AsyncHandshake("ws://localhost:8080", func(err error) { if err != nil { panic(err) - } else { - client.AsyncWrite([]byte("hello"), websocket.TypeText, func(err error) { + } + client.AsyncWrite([]byte("hello"), websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } + + var b [128]byte + client.AsyncNextMessage(b[:], func(err error, n int, mt websocket.MessageType) { if err != nil { panic(err) - } else { - b := make([]byte, 128) - client.AsyncNextMessage(b, func(err error, n int, mt websocket.MessageType) { - if err != nil { - panic(err) - } else { - b = b[:n] - fmt.Println("read", n, "bytes", string(b), err) - } - }) } + fmt.Println("read", n, "bytes", string(b[:n]), err) }) - } + }) }) - for { - ioc.RunOneFor(0) // poll - } + ioc.Run() } From c37f7308e7f63a8ce132b5a5a942bb5f516af5e2 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 16:40:39 +0100 Subject: [PATCH 14/35] Add stretchr/testify dependency for assert.* --- go.mod | 4 ++++ go.sum | 3 +++ 2 files changed, 7 insertions(+) diff --git a/go.mod b/go.mod index 9a213554..dc64bd84 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,16 @@ go 1.20 require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 github.com/felixge/fgprof v0.9.3 + github.com/stretchr/testify v1.8.0 github.com/valyala/bytebufferpool v1.0.0 golang.org/x/sys v0.11.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/go-cmp v0.5.8 // indirect github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5db7805e..86601b83 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,9 @@ github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1: github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -72,6 +74,7 @@ gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From 30645fa642cb5b0ee00f7331e41402dd86870e79 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 16:48:25 +0100 Subject: [PATCH 15/35] [websocket] Set the payload length before copying the payload Simplifies stream.WriteTo --- codec/websocket/frame.go | 8 +-- codec/websocket/frame_codec.go | 1 - codec/websocket/stream_test.go | 109 +++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 5 deletions(-) diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index e9efaac0..b7502206 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -54,7 +54,7 @@ func (f Frame) clearPayloadLength() { f[1] &= (1 << 7) } -func (f *Frame) SetPayloadLength(n int) *Frame { +func (f *Frame) setPayloadLength(n int) *Frame { f.clearPayloadLength() if n > (1<<16 - 1) { @@ -196,10 +196,12 @@ func (f *Frame) fitPayload() ([]byte, error) { } func (f *Frame) SetPayload(b []byte) *Frame { + f.setPayloadLength(len(b)) // set the length as it's used by `payloadStartIndex`. + *f = util.ExtendSlice(*f, f.payloadStartIndex()+len(b)) payload := f.Payload() copy(payload, b) - f.SetPayloadLength(len(payload)) + return f } @@ -275,8 +277,6 @@ func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { } func (f Frame) WriteTo(w io.Writer) (int64, error) { - f.SetPayloadLength(len(f.Payload())) - written := 0 for written < len(f) { n, err := w.Write(f[written:]) diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index 9c5c7a11..e21b77f4 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -2,7 +2,6 @@ package websocket import ( "errors" - "github.com/talostrading/sonic" ) diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index fc52bcd1..911b609a 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/talostrading/sonic" ) @@ -18,6 +19,114 @@ func assertState(t *testing.T, ws *Stream, expected StreamState) { } } +func TestClientSendMessageWithPayload126(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + frame := NewFrame() + frame. + SetFIN(). + SetText(). + SetPayload(make([]byte, 126)) + assert.Equal(126, frame.PayloadLength()) + assert.Equal(2, frame.ExtendedPayloadLengthBytes()) + + frame.WriteTo(srv.conn) + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + if err != nil { + t.Fatal(err) + } + assert.True(f.IsFIN()) + assert.True(f.Opcode().IsText()) + assert.Equal(126, f.PayloadLength()) + assert.Equal(2, f.ExtendedPayloadLengthBytes()) + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + +func TestClientSendMessageWithPayload127(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + frame := NewFrame() + frame. + SetFIN(). + SetText(). + SetPayload(make([]byte, 1<<16+10 /* it won't fit in 2 bytes */)) + assert.Equal(1<<16+10, frame.PayloadLength()) + assert.Equal(8, frame.ExtendedPayloadLengthBytes()) + + frame.WriteTo(srv.conn) + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + if err != nil { + t.Fatal(err) + } + assert.True(f.IsFIN()) + assert.True(f.Opcode().IsText()) + assert.Equal(1<<16+10, f.PayloadLength()) + assert.Equal(8, f.ExtendedPayloadLengthBytes()) + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + func TestClientReconnectOnFailedRead(t *testing.T) { go func() { for i := 0; i < 10; i++ { From 5c737fbf6b4abf39ba82d09174ac929a8aae19a4 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 17:07:36 +0100 Subject: [PATCH 16/35] [websocket] Ensure invalid control frames close the connection With a ProtocolError. The close frame must be explicitly flushed by callers, for now. --- codec/websocket/stream.go | 1 + codec/websocket/stream_test.go | 70 ++++++++++++++++++++++++++++++++++ tests/autobahn/client.go | 1 + 3 files changed, 72 insertions(+) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index f2593f7c..ed9fa273 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -415,6 +415,7 @@ func (s *Stream) handleFrame(f Frame) (err error) { if err != nil { s.state = StateClosedByUs + // TODO consider flushing the close s.prepareClose(EncodeCloseFramePayload(CloseProtocolError, "")) } diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 911b609a..be15ec0e 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -19,6 +19,76 @@ func assertState(t *testing.T, ws *Stream, expected StreamState) { } } +func TestClientSendPingWithInvalidPayload(t *testing.T) { + // Per the protocol, pings cannot have payloads larger than 125. We send a ping with 125. The client should close + // the connection immediately with 1002/Protocol Error. + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + // This ping has an invalid payload size of 126, which should trigger a close with reason 1002. + { + frame := NewFrame() + frame. + SetFIN(). + SetPing(). + SetPayload(make([]byte, 126)) + assert.Equal(126, frame.PayloadLength()) + assert.Equal(2, frame.ExtendedPayloadLengthBytes()) + + frame.WriteTo(srv.conn) + } + + // Ensure we get the close. + { + frame := NewFrame() + frame.ReadFrom(srv.conn) + + assert.True(frame.Opcode().IsClose()) + assert.True(frame.IsMasked()) // client to server frames are masked + frame.Unmask() + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseProtocolError, closeCode) + assert.Empty(reason) + } + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + assert.NotNil(err) + assert.Equal(ErrControlFrameTooBig, err) + assert.Equal(1, ws.Pending()) + ws.Flush() + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + func TestClientSendMessageWithPayload126(t *testing.T) { assert := assert.New(t) diff --git a/tests/autobahn/client.go b/tests/autobahn/client.go index bdf1781d..d7bc5177 100644 --- a/tests/autobahn/client.go +++ b/tests/autobahn/client.go @@ -96,6 +96,7 @@ func runTest(i int) { onAsyncRead = func(err error, n int, mt websocket.MessageType) { if err != nil { + s.Flush() done = true } else { b = b[:n] From b7d1d6d0b68dd8de679247ca100442a452d000a5 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 17:21:49 +0100 Subject: [PATCH 17/35] [websocket] Ensure close codes are echoed --- codec/websocket/stream_test.go | 63 ++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index be15ec0e..f0f1ce48 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -19,6 +19,69 @@ func assertState(t *testing.T, ws *Stream, expected StreamState) { } } +func TestClientEchoCloseCode(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + { + frame := NewFrame() + frame. + SetFIN(). + SetClose(). + SetPayload(EncodeCloseFramePayload(CloseNormal, "something")) + + frame.WriteTo(srv.conn) + } + + { + frame := NewFrame() + frame.ReadFrom(srv.conn) + + assert.True(frame.Opcode().IsClose()) + assert.True(frame.IsMasked()) // client to server frames are masked + frame.Unmask() + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseNormal, closeCode) + assert.Equal(reason, "something") + } + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + assert.Nil(err) + assert.Equal(1, ws.Pending()) + ws.Flush() + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + func TestClientSendPingWithInvalidPayload(t *testing.T) { // Per the protocol, pings cannot have payloads larger than 125. We send a ping with 125. The client should close // the connection immediately with 1002/Protocol Error. From e14971f88015107469b3131860a50b519db73b04 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 18:19:22 +0100 Subject: [PATCH 18/35] [websocket] Validate CloseCodes before responding to a Close frame --- codec/websocket/rfc6455.go | 23 ++++++++++++ codec/websocket/stream.go | 15 +++++++- codec/websocket/stream_test.go | 67 ++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/codec/websocket/rfc6455.go b/codec/websocket/rfc6455.go index 9f270c5b..f525056e 100644 --- a/codec/websocket/rfc6455.go +++ b/codec/websocket/rfc6455.go @@ -179,8 +179,31 @@ const ( // CloseReserved3 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved3 CloseCode = 1015 + + // CloseReserved4 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. + CloseReserved4 CloseCode = 1016 + + CloseReservedForFuture CloseCode = 1004 ) +func ValidCloseCode(closeCode CloseCode) bool { + if closeCode == CloseNormal || + closeCode == CloseGoingAway || + closeCode == CloseProtocolError || + closeCode == CloseUnknownData || + closeCode == CloseBadPayload || + closeCode == ClosePolicyError || + closeCode == CloseTooBig || + closeCode == CloseNeedsExtension || + closeCode == CloseInternalError || + closeCode == CloseServiceRestart || + closeCode == CloseTryAgainLater || + (closeCode >= 3000 && closeCode <= 4999) { + return true + } + return false +} + func EncodeCloseCode(cc CloseCode) []byte { b := make([]byte, 2) binary.BigEndian.PutUint16(b, uint16(cc)) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index ed9fa273..d012cf2b 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -459,7 +459,20 @@ func (s *Stream) handleControlFrame(f Frame) (err error) { panic("unreachable") case StateActive: s.state = StateClosedByPeer - s.prepareClose(f.Payload()) + + payload := f.Payload() + if len(payload) >= 2 { + closeCode := DecodeCloseCode(payload) + if !ValidCloseCode(closeCode) { + s.prepareClose(EncodeCloseCode(CloseProtocolError)) + } else { + s.prepareClose(f.Payload()) + } + } else if len(payload) > 0 { + s.prepareClose(EncodeCloseCode(CloseProtocolError)) + } else { + s.prepareClose(EncodeCloseCode(CloseNormal)) + } case StateClosedByPeer, StateCloseAcked: // ignore case StateClosedByUs: diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index f0f1ce48..877a0925 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -19,6 +19,73 @@ func assertState(t *testing.T, ws *Stream, expected StreamState) { } } +func TestClientServerSendsInvalidCloseCode(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + { + frame := NewFrame() + + closeCode := CloseReserved1 + assert.False(ValidCloseCode(closeCode)) + + frame. + SetFIN(). + SetClose(). + SetPayload(EncodeCloseFramePayload(closeCode, "something")) + + frame.WriteTo(srv.conn) + } + + { + frame := NewFrame() + frame.ReadFrom(srv.conn) + + assert.True(frame.Opcode().IsClose()) + assert.True(frame.IsMasked()) // client to server frames are masked + frame.Unmask() + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseProtocolError, closeCode) + assert.Equal(reason, "") + } + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + assert.Nil(err) + assert.Equal(1, ws.Pending()) + ws.Flush() + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + func TestClientEchoCloseCode(t *testing.T) { assert := assert.New(t) From 7a0a6b43c17ec0070f9c4be238c7d5e17ea80483 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 18:25:45 +0100 Subject: [PATCH 19/35] [websocket] Handle reserved bits The connection should fail with 1002/Protocol error if any of the RSV1/2/3 bits are set. --- codec/websocket/frame.go | 27 +++++++++++++++++++++++++++ codec/websocket/rfc6455.go | 3 +++ codec/websocket/stream.go | 4 ++++ 3 files changed, 34 insertions(+) diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index b7502206..451eee19 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -81,6 +81,18 @@ func (f Frame) IsFIN() bool { return f[0]&bitFIN != 0 } +func (f Frame) IsRSV1() bool { + return f[0]&bitRSV1 != 0 +} + +func (f Frame) IsRSV2() bool { + return f[0]&bitRSV2 != 0 +} + +func (f Frame) IsRSV3() bool { + return f[0]&bitRSV3 != 0 +} + func (f Frame) Opcode() Opcode { return Opcode(f[0] & bitmaskOpcode) } @@ -111,6 +123,21 @@ func (f *Frame) SetFIN() *Frame { return f } +func (f *Frame) SetRSV1() *Frame { + (*f)[0] |= bitRSV1 + return f +} + +func (f *Frame) SetRSV2() *Frame { + (*f)[0] |= bitRSV2 + return f +} + +func (f *Frame) SetRSV3() *Frame { + (*f)[0] |= bitRSV3 + return f +} + func (f Frame) clearOpcode() { f[0] &= bitmaskOpcode << 4 } diff --git a/codec/websocket/rfc6455.go b/codec/websocket/rfc6455.go index f525056e..0c887bff 100644 --- a/codec/websocket/rfc6455.go +++ b/codec/websocket/rfc6455.go @@ -24,6 +24,9 @@ const ( frameHeaderLength = 2 bitFIN = byte(1 << 7) + bitRSV1 = byte(1 << 6) + bitRSV2 = byte(1 << 5) + bitRSV3 = byte(1 << 4) bitmaskOpcode = byte(1<<4 - 1) bitIsMasked = byte(1 << 7) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index d012cf2b..4ca0f5d1 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -423,6 +423,10 @@ func (s *Stream) handleFrame(f Frame) (err error) { } func (s *Stream) verifyFrame(f Frame) error { + if f.IsRSV1() || f.IsRSV2() || f.IsRSV3() { + return ErrNonZeroReservedBits + } + if s.role == RoleClient && f.IsMasked() { return ErrMaskedFramesFromServer } From ba8514a76c24abd9290c7554df621ea8dee7bfcb Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 18:32:40 +0100 Subject: [PATCH 20/35] [websocket] Ensure the close reason, if any, is UTF-8 encoded --- codec/websocket/stream.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 4ca0f5d1..65bd0e31 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -30,6 +30,7 @@ import ( "net/url" "sync" "syscall" + "unicode/utf8" "github.com/talostrading/sonic" "github.com/talostrading/sonic/sonicerrors" @@ -465,12 +466,17 @@ func (s *Stream) handleControlFrame(f Frame) (err error) { s.state = StateClosedByPeer payload := f.Payload() + // The first 2 bytes are the close code. The rest, if any, is the reason - this must be UTF-8. if len(payload) >= 2 { - closeCode := DecodeCloseCode(payload) - if !ValidCloseCode(closeCode) { + if !utf8.Valid(payload[2:]) { s.prepareClose(EncodeCloseCode(CloseProtocolError)) } else { - s.prepareClose(f.Payload()) + closeCode := DecodeCloseCode(payload) + if !ValidCloseCode(closeCode) { + s.prepareClose(EncodeCloseCode(CloseProtocolError)) + } else { + s.prepareClose(f.Payload()) + } } } else if len(payload) > 0 { s.prepareClose(EncodeCloseCode(CloseProtocolError)) From ccf404310c1fe2b44504838f9d76efeaac795eb4 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Mon, 28 Oct 2024 19:22:41 +0100 Subject: [PATCH 21/35] [websocket] Optionally validate UTF8 --- codec/websocket/errors.go | 2 ++ codec/websocket/stream.go | 26 +++++++++++++++++++++++++- tests/autobahn/client.go | 11 +++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/codec/websocket/errors.go b/codec/websocket/errors.go index 8733c1c8..bd20f962 100644 --- a/codec/websocket/errors.go +++ b/codec/websocket/errors.go @@ -38,4 +38,6 @@ var ( ErrExpectedContinuation = errors.New("expected continue frame") ErrInvalidAddress = errors.New("invalid address") + + ErrInvalidUTF8 = errors.New("Invalid UTF-8 encoding") ) diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 65bd0e31..d1ccb820 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -87,6 +87,8 @@ type Stream struct { framePool sync.Pool maxMessageSize int + + validateUTF8 bool } func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *Stream, err error) { @@ -110,6 +112,7 @@ func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *Stream, e }, }, maxMessageSize: DefaultMaxMessageSize, + validateUTF8: false, } s.src.Reserve(4096) @@ -171,8 +174,22 @@ func (s *Stream) NextLayer() sonic.Stream { return nil } +// SupportsUTF8 indicates that the stream can optionally perform UTF-8 validation on the payloads of Text frames. +// Validation is disabled by default and can be toggled with `ValidateUTF8(bool)`. func (s *Stream) SupportsUTF8() bool { - return false + return true +} + +// ValidatesUTF8 indicates if UTF-8 validation is performed on the payloads of Text frames. Validation is disabled by +// default and can be toggled with `ValidateUTF8(bool)`. +func (s *Stream) ValidatesUTF8() bool { + return s.validateUTF8 +} + +// ValidateUTF8 toggles UTF8 validation done on the payloads of Text frames. +func (s *Stream) ValidateUTF8(v bool) *Stream { + s.validateUTF8 = v + return s } func (s *Stream) SupportsDeflate() bool { @@ -502,6 +519,13 @@ func (s *Stream) handleDataFrame(f Frame) error { if f.Opcode().IsReserved() { return ErrReservedOpcode } + + if f.Opcode().IsText() && s.validateUTF8 { + if !utf8.Valid(f.Payload()) { + return ErrInvalidUTF8 + } + } + return nil } diff --git a/tests/autobahn/client.go b/tests/autobahn/client.go index d7bc5177..68934cf7 100644 --- a/tests/autobahn/client.go +++ b/tests/autobahn/client.go @@ -12,6 +12,7 @@ import ( var ( addr = flag.String("addr", "ws://localhost:9001", "server address") testCase = flag.Int("case", -1, "autobahn test case to run") + utf8 = flag.Bool("utf8", false, "if true, payloads of text frames are utf8 validated") ) func main() { @@ -83,6 +84,16 @@ func runTest(i int) { if err != nil { panic(err) } + if *utf8 { + s.ValidateUTF8(true) + if !s.ValidatesUTF8() { + panic("UTF8 should be validated") + } + } else { + if s.ValidatesUTF8() { + panic("UTF8 should NOT be validated") + } + } done := false s.AsyncHandshake(fmt.Sprintf("%s/runCase?case=%d&agent=sonic", *addr, i), func(err error) { From 90317538902b2a0e1e4e1615ded6ccbaf34a0c9c Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Tue, 29 Oct 2024 09:41:53 +0100 Subject: [PATCH 22/35] [websocket] Cleanup Frame --- codec/websocket/frame.go | 100 +++++++++++++++++---------------- codec/websocket/frame_codec.go | 2 +- codec/websocket/rfc6455.go | 2 +- codec/websocket/stream.go | 2 +- codec/websocket/stream_test.go | 30 +++++----- codec/websocket/test_main.go | 2 +- 6 files changed, 70 insertions(+), 68 deletions(-) diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 451eee19..3dbf207a 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -7,7 +7,7 @@ import ( "github.com/talostrading/sonic/util" ) -var zeroBytes [MaxFrameHeaderLengthInBytes]byte +var zeroBytes [frameMaxHeaderLength]byte func init() { for i := 0; i < len(zeroBytes); i++ { @@ -24,7 +24,9 @@ var ( // NOTE use stream.AcquireFrame() instead of NewFrame if you intend to write this frame onto a WebSocket stream. func NewFrame() Frame { - return make([]byte, MaxFrameHeaderLengthInBytes) + b := make([]byte, frameMaxHeaderLength) + copy(b, zeroBytes[:]) + return b } func (f Frame) Reset() { @@ -50,28 +52,6 @@ func (f Frame) PayloadLength() int { } } -func (f Frame) clearPayloadLength() { - f[1] &= (1 << 7) -} - -func (f *Frame) setPayloadLength(n int) *Frame { - f.clearPayloadLength() - - if n > (1<<16 - 1) { - // does not fit in 2 bytes, so take 8 as extended payload length - (*f)[1] |= 127 - binary.BigEndian.PutUint64((*f)[2:], uint64(n)) - } else if n > 125 { - // fits in 2 bytes as extended payload length - (*f)[1] |= 126 - binary.BigEndian.PutUint16((*f)[2:], uint16(n)) - } else { - // can be encoded in the 7 bits of the payload length, no extended payload length taken - (*f)[1] |= byte(n) - } - return f -} - // An unfragmented message consists of a single frame with the FIN bit set and an opcode other than 0. // // A fragmented message consists of a single frame with the FIN bit clear and an opcode other than 0, followed by zero @@ -179,7 +159,7 @@ func (f *Frame) SetPong() *Frame { return f } -func (f Frame) extendedPayloadLengthStartIndex() int { +func (f Frame) extendedPayloadLengthOffset() int { return frameHeaderLength } @@ -195,37 +175,44 @@ func (f Frame) Header() []byte { return f[:frameHeaderLength] } -func (f *Frame) maskStartIndex() int { +func (f *Frame) maskOffset() int { return frameHeaderLength + f.ExtendedPayloadLengthBytes() } -func (f Frame) MaskKey() []byte { +func (f Frame) Mask() []byte { if f.IsMasked() { - mask := f[f.maskStartIndex():] + mask := f[f.maskOffset():] return mask[:frameMaskLength] } return nil } -func (f Frame) payloadStartIndex() int { +func (f Frame) payloadOffset() int { return frameHeaderLength + f.ExtendedPayloadLengthBytes() + f.MaskBytes() } -func (f *Frame) fitPayload() ([]byte, error) { - length := f.PayloadLength() - if length <= 0 { - return nil, nil - } +func (f *Frame) setPayloadLength(n int) *Frame { + (*f)[1] &= (1 << 7) - *f = util.ExtendSlice(*f, f.payloadStartIndex()+length) - b := (*f)[f.payloadStartIndex():] - return b[:length], nil + if n > (1<<16 - 1) { + // does not fit in 2 bytes, so take 8 as extended payload length + (*f)[1] |= 127 + binary.BigEndian.PutUint64((*f)[2:], uint64(n)) + } else if n > 125 { + // fits in 2 bytes as extended payload length + (*f)[1] |= 126 + binary.BigEndian.PutUint16((*f)[2:], uint16(n)) + } else { + // can be encoded in the 7 bits of the payload length, no extended payload length taken + (*f)[1] |= byte(n) + } + return f } func (f *Frame) SetPayload(b []byte) *Frame { - f.setPayloadLength(len(b)) // set the length as it's used by `payloadStartIndex`. + f.setPayloadLength(len(b)) // set the length as it's used by `payloadOffset`. - *f = util.ExtendSlice(*f, f.payloadStartIndex()+len(b)) + *f = util.ExtendSlice(*f, f.payloadOffset()+len(b)) payload := f.Payload() copy(payload, b) @@ -233,14 +220,14 @@ func (f *Frame) SetPayload(b []byte) *Frame { } func (f Frame) Payload() []byte { - return f[f.payloadStartIndex():] + return f[f.payloadOffset():] } -func (f *Frame) Mask() { +func (f *Frame) MaskPayload() { f.SetIsMasked() var ( - mask = f.MaskKey() + mask = f.Mask() payload = f.Payload() ) @@ -250,10 +237,10 @@ func (f *Frame) Mask() { } } -func (f *Frame) Unmask() { +func (f *Frame) UnmaskPayload() { if f.IsMasked() { var ( - mask = f.MaskKey() + mask = f.Mask() payload = f.Payload() ) Mask(mask, payload) @@ -261,6 +248,17 @@ func (f *Frame) Unmask() { } } +func (f *Frame) fitPayload() ([]byte, error) { + length := f.PayloadLength() + if length <= 0 { + return nil, nil + } + + *f = util.ExtendSlice(*f, f.payloadOffset()+length) + b := (*f)[f.payloadOffset():] + return b[:length], nil +} + func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { var nn int @@ -282,7 +280,7 @@ func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { // read the mask, if any if f.IsMasked() { - nn, err = io.ReadFull(r, f.MaskKey()) + nn, err = io.ReadFull(r, f.Mask()) n += int64(nn) if err != nil { return @@ -294,16 +292,20 @@ func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { if err != nil { return } - nn, err = io.ReadFull(r, b) - n += int64(nn) - if err != nil { - return + if b != nil { + nn, err = io.ReadFull(r, b) + n += int64(nn) + if err != nil { + return + } } return } func (f Frame) WriteTo(w io.Writer) (int64, error) { + // TODO test partial write + written := 0 for written < len(f) { n, err := w.Write(f[written:]) diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index e21b77f4..5c2a5262 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -101,7 +101,7 @@ func (c *FrameCodec) Encode(frame Frame, dst *sonic.ByteBuffer) error { // TODO this can be improved: we can serialize directly in the buffer with zero-copy semantics // ensure the destination buffer can hold the serialized frame - dst.Reserve(frame.PayloadLength() + MaxFrameHeaderLengthInBytes) + dst.Reserve(frame.PayloadLength() + frameMaxHeaderLength) n, err := frame.WriteTo(dst) dst.Commit(int(n)) diff --git a/codec/websocket/rfc6455.go b/codec/websocket/rfc6455.go index 0c887bff..93670a47 100644 --- a/codec/websocket/rfc6455.go +++ b/codec/websocket/rfc6455.go @@ -14,7 +14,7 @@ import ( const ( MaxControlFramePayloadLength = 125 - MaxFrameHeaderLengthInBytes = 14 // 14 bytes max for the header of a frame i.e. everything without the payload + frameMaxHeaderLength = 14 // 14 bytes max for the header of a frame i.e. everything without the payload ) const ( diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index d1ccb820..dbcb02fe 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -620,7 +620,7 @@ func (s *Stream) AsyncWriteFrame(f *Frame, callback func(err error)) { func (s *Stream) prepareWrite(f *Frame) { if s.role == RoleClient { - f.Mask() + f.MaskPayload() } s.pendingFrames = append(s.pendingFrames, f) } diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 877a0925..b5e20ac8 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -51,7 +51,7 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) { assert.True(frame.Opcode().IsClose()) assert.True(frame.IsMasked()) // client to server frames are masked - frame.Unmask() + frame.UnmaskPayload() closeCode, reason := DecodeCloseFramePayload(frame.Payload()) assert.Equal(CloseProtocolError, closeCode) @@ -114,7 +114,7 @@ func TestClientEchoCloseCode(t *testing.T) { assert.True(frame.Opcode().IsClose()) assert.True(frame.IsMasked()) // client to server frames are masked - frame.Unmask() + frame.UnmaskPayload() closeCode, reason := DecodeCloseFramePayload(frame.Payload()) assert.Equal(CloseNormal, closeCode) @@ -183,7 +183,7 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) { assert.True(frame.Opcode().IsClose()) assert.True(frame.IsMasked()) // client to server frames are masked - frame.Unmask() + frame.UnmaskPayload() closeCode, reason := DecodeCloseFramePayload(frame.Payload()) assert.Equal(CloseProtocolError, closeCode) @@ -773,7 +773,7 @@ func TestClientReadCorruptControlFrame(t *testing.T) { assertState(t, ws, StateClosedByUs) closeFrame := ws.pendingFrames[0] - closeFrame.Unmask() + closeFrame.UnmaskPayload() cc, _ := DecodeCloseFramePayload(ws.pendingFrames[0].Payload()) if cc != CloseProtocolError { @@ -817,7 +817,7 @@ func TestClientAsyncReadCorruptControlFrame(t *testing.T) { assertState(t, ws, StateClosedByUs) closeFrame := ws.pendingFrames[0] - closeFrame.Unmask() + closeFrame.UnmaskPayload() cc, _ := DecodeCloseFramePayload(ws.pendingFrames[0].Payload()) if cc != CloseProtocolError { @@ -864,7 +864,7 @@ func TestClientReadPingFrame(t *testing.T) { t.Fatal("invalid pong reply") } - reply.Unmask() + reply.UnmaskPayload() if !bytes.Equal(reply.Payload(), []byte{0x01, 0x02}) { t.Fatal("invalid pong reply") } @@ -917,7 +917,7 @@ func TestClientAsyncReadPingFrame(t *testing.T) { t.Fatal("invalid pong reply") } - reply.Unmask() + reply.UnmaskPayload() if !bytes.Equal(reply.Payload(), []byte{0x01, 0x02}) { t.Fatal("invalid pong reply") } @@ -1077,7 +1077,7 @@ func TestClientReadCloseFrame(t *testing.T) { if !reply.IsMasked() { t.Fatal("reply should be masked") } - reply.Unmask() + reply.UnmaskPayload() cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { @@ -1140,7 +1140,7 @@ func TestClientAsyncReadCloseFrame(t *testing.T) { if !reply.IsMasked() { t.Fatal("reply should be masked") } - reply.Unmask() + reply.UnmaskPayload() cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { @@ -1209,7 +1209,7 @@ func TestClientWriteFrame(t *testing.T) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -1258,7 +1258,7 @@ func TestClientAsyncWriteFrame(t *testing.T) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -1303,7 +1303,7 @@ func TestClientWrite(t *testing.T) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -1342,7 +1342,7 @@ func TestClientAsyncWrite(t *testing.T) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -1382,7 +1382,7 @@ func TestClientClose(t *testing.T) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if f.PayloadLength() != 5 { t.Fatal("wrong message in close frame") @@ -1427,7 +1427,7 @@ func TestClientAsyncClose(t *testing.T) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if f.PayloadLength() != 5 { t.Fatal("wrong message in close frame") diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index 961a85bd..da91cf44 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -90,7 +90,7 @@ func (s *MockServer) Read(b []byte) (n int, err error) { return 0, fmt.Errorf("client frames should be masked") } - f.Unmask() + f.UnmaskPayload() copy(b, f.Payload()) n = f.PayloadLength() } From c2bc04884bfa09d5c4407d628913e2e927db3dcc Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Tue, 29 Oct 2024 09:49:04 +0100 Subject: [PATCH 23/35] [websocket] Ensure Frame can serialize in chunks --- codec/websocket/frame.go | 2 -- codec/websocket/frame_test.go | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 3dbf207a..14d3610b 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -304,8 +304,6 @@ func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { } func (f Frame) WriteTo(w io.Writer) (int64, error) { - // TODO test partial write - written := 0 for written < len(f) { n, err := w.Write(f[written:]) diff --git a/codec/websocket/frame_test.go b/codec/websocket/frame_test.go index 49480d0d..4290bea0 100644 --- a/codec/websocket/frame_test.go +++ b/codec/websocket/frame_test.go @@ -6,9 +6,58 @@ import ( "crypto/rand" "testing" + "github.com/stretchr/testify/assert" "github.com/talostrading/sonic" ) +type partialFrameWriter struct { + b [1024]byte + invoked int +} + +func (w *partialFrameWriter) Write(b []byte) (n int, err error) { + copy(w.b[w.invoked:w.invoked+1], b) + w.invoked++ + return 1, nil +} + +func (w *partialFrameWriter) Bytes() []byte { + return w.b[:w.invoked] +} + +func TestFramePartialWrites(t *testing.T) { + assert := assert.New(t) + + payload := make([]byte, 126) + for i := 0; i < len(payload); i++ { + payload[i] = 0 + } + copy(payload, []byte("something")) + + f := NewFrame() + f.SetFIN().SetText().SetPayload(payload) + assert.Equal(2, f.ExtendedPayloadLengthBytes()) + assert.Equal(126, f.PayloadLength()) + + w := &partialFrameWriter{} + f.WriteTo(w) + assert.Equal(2 /* mandatory flags */ +2 /* 2 bytes for the extended payload length */ +126 /* payload */, w.invoked) + assert.Equal(130, len(w.Bytes())) + + // deserialize to make sure the frames are identical + { + f := Frame(w.Bytes()) + assert.True(f.IsFIN()) + assert.True(f.Opcode().IsText()) + assert.Equal(126, f.PayloadLength()) + assert.Equal(2, f.ExtendedPayloadLengthBytes()) + assert.Equal("something", string(f.Payload()[:len("something")])) + for i := len("something"); i < len(f.Payload()); i++ { + assert.Equal(0, int(f.Payload()[i])) + } + } +} + func TestUnder126Frame(t *testing.T) { var ( f = NewFrame() From a0f5e904209d818b3d78825f7e002a0b66548747 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Tue, 29 Oct 2024 09:59:58 +0100 Subject: [PATCH 24/35] [websocket] Make Frame.ExtendedPayloadLength private --- codec/websocket/frame.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 14d3610b..63ee9584 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -163,7 +163,7 @@ func (f Frame) extendedPayloadLengthOffset() int { return frameHeaderLength } -func (f Frame) ExtendedPayloadLength() []byte { +func (f Frame) extendedPayloadLength() []byte { if bytes := f.ExtendedPayloadLengthBytes(); bytes > 0 { b := f[frameHeaderLength:] return b[:bytes] @@ -270,7 +270,7 @@ func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { } // read the extended payload length, if any - if b := f.ExtendedPayloadLength(); b != nil { + if b := f.extendedPayloadLength(); b != nil { nn, err = io.ReadFull(r, b) n += int64(nn) if err != nil { From a6ed38a83c882e046f3470bd140d4514c748035f Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Tue, 29 Oct 2024 10:11:41 +0100 Subject: [PATCH 25/35] [websocket] Track IO latency in binance example --- examples/binance/main.go | 49 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/examples/binance/main.go b/examples/binance/main.go index 7d2be8f2..b5dfe15e 100644 --- a/examples/binance/main.go +++ b/examples/binance/main.go @@ -2,25 +2,35 @@ package main import ( "crypto/tls" + "flag" "fmt" + "time" "github.com/talostrading/sonic" "github.com/talostrading/sonic/codec/websocket" + "github.com/talostrading/sonic/util" ) -var subscriptionMessage = []byte( - ` +var ( + verbose = flag.Bool("v", false, "if true, websocket messages are printed") + + subscriptionMessage = []byte(` { "id": 1, "method": "SUBSCRIBE", "params": [ "bnbbtc@depth" ] } `) +) func main() { + flag.Parse() + ioc := sonic.MustIO() defer ioc.Close() + ioLatency := util.NewOnlineStats() + stream, err := websocket.NewWebsocketStream(ioc, &tls.Config{}, websocket.RoleClient) if err != nil { panic(err) @@ -45,12 +55,43 @@ func main() { panic(err) } - fmt.Println(string(b[:n])) + if *verbose { + fmt.Println(string(b[:n])) + } stream.AsyncNextMessage(b[:], onRead) } stream.AsyncNextMessage(b[:], onRead) }) }) - ioc.Run() + eventsReceived := 0 + + ioLatencyTimer, err := sonic.NewTimer(ioc) + if err != nil { + panic(err) + } + ioLatencyTimer.ScheduleRepeating(time.Second, func() { + if eventsReceived > 0 { + result := ioLatency.Result() + fmt.Printf( + "min/avg/max/stddev = %.2f/%.2f/%.2f/%.2f us from %d events\n", + result.Min, + result.Avg, + result.Max, + result.StdDev, + eventsReceived, + ) + ioLatency.Reset() + eventsReceived = 0 + } + }) + + for { + start := time.Now() + n, _ := ioc.PollOne() + if n > 0 { + eventsReceived += n + ioLatency.Add(float64(time.Now().Sub(start).Microseconds())) + } + } } From 36da1515b871ddb04045f7a177ed2eb61cb44c2e Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 10:43:13 +0100 Subject: [PATCH 26/35] [file] No need to store the fd in Conn Use the one from the slot, as we do in packet.go. --- conn.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/conn.go b/conn.go index 8749b669..c5a32c13 100644 --- a/conn.go +++ b/conn.go @@ -13,7 +13,6 @@ var _ Conn = &conn{} type conn struct { *file - fd int localAddr net.Addr remoteAddr net.Addr } @@ -50,7 +49,6 @@ func newConn( ) *conn { return &conn{ file: &file{ioc: ioc, slot: internal.Slot{Fd: fd}}, - fd: fd, localAddr: localAddr, remoteAddr: remoteAddr, } @@ -74,5 +72,5 @@ func (c *conn) SetWriteDeadline(t time.Time) error { } func (c *conn) RawFd() int { - return c.fd + return c.file.slot.Fd } From 8d9b78694406385e610507feaf25d0c3694c1d0f Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 10:53:20 +0100 Subject: [PATCH 27/35] Cleanup IO --- io.go | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/io.go b/io.go index fd6f0172..fccf7a1c 100644 --- a/io.go +++ b/io.go @@ -20,9 +20,9 @@ type IO struct { poller internal.Poller // The below structures keep a pointer to a Slot struct usually owned by an object capable of asynchronous - // operations (essentially any object taking an IO* on construction). Keeping a Slot pointer keeps the owning - // object in the GC's object graph while an asynchronous operation is in progress. This ensures Slot references - // valid memory when an asynchronous operation completes and the object is already out of scope. + // operations (essentially any object taking an IO* on construction). Keeping a Slot pointer keeps the owning object + // in the GC's object graph while an asynchronous operation is in progress. This ensures Slot references valid + // memory when an asynchronous operation completes and the object is already out of scope. pending struct { // The kernel allocates a process' descriptors from a fixed range that is [0, 1024) by default. Unprivileged // users can bump this range to [0, 4096). The below array should cover 99% of the cases and makes for cheap @@ -32,8 +32,8 @@ type IO struct { // file descriptors are bound by a fixed range whose upper limit is controlled through RLIMIT_NOFILE. static [4096]*internal.Slot - // This map covers the 1%, the degenerate case. Any Slot whose file descriptor is greater than or equal to - // 4096 goes here. + // This map covers the 1%, the degenerate case. Any Slot whose file descriptor is greater than or equal to 4096 + // goes here. This is lazily initialized. dynamic map[*internal.Slot]struct{} } pendingTimers map[*Timer]struct{} // XXX: should be embedded into the above pending struct @@ -170,7 +170,7 @@ const ( // RunWarm runs the event loop in a combined busy-wait and yielding mode, meaning that if the current cycle does not // process anything, the event-loop will busy-wait for at most `busyCycles` which we call the warm-state. After // `busyCycles` of not processing anything, the event-loop is out of the warm-state and falls back to yielding with the -// provided timeout. If at any moment an event occurs and something is processed, the event-loop transitions to its +// provided timeout. If at any moment an event occurs and something is processed, the event-loop transitions to its // warm-state. func (ioc *IO) RunWarm(busyCycles int, timeout time.Duration) (err error) { if busyCycles <= 0 { @@ -200,9 +200,9 @@ func (ioc *IO) RunWarm(busyCycles int, timeout time.Duration) (err error) { // We processed something in this cycle, be it inside or outside the warm-period. We restart the warm-period i = 0 } else { - // We did not process anything in this cycle. If we are still in the warm period i.e. `i < busyCycles`, - // we are going to poll in the next cycle. If we are outside the warm period i.e. `i >= busyCycles`, - // we are going to yield in the next cycle. + // We did not process anything in this cycle. If we are still in the warm period i.e. `i < busyCycles`, we + // are going to poll in the next cycle. If we are outside the warm period i.e. `i >= busyCycles`, we are + // going to yield in the next cycle. i++ } } @@ -231,8 +231,6 @@ func (ioc *IO) poll(timeoutMs int) (int, error) { if err != nil { if err == syscall.EINTR { - // TODO not sure about this one, and whether returning timeout here is ok. - // need to look into syscall.EINTR again if timeoutMs >= 0 { return 0, sonicerrors.ErrTimeout } @@ -251,8 +249,7 @@ func (ioc *IO) poll(timeoutMs int) (int, error) { return n, nil } -// Post schedules the provided handler to be run immediately by the event -// processing loop in its own thread. +// Post schedules the provided handler to be run immediately by the event processing loop in its own thread. // // It is safe to call Post concurrently. func (ioc *IO) Post(handler func()) error { @@ -266,6 +263,7 @@ func (ioc *IO) Posted() int { return ioc.poller.Posted() } +// Returns the current number of pending asynchronous operations. func (ioc *IO) Pending() int64 { return ioc.poller.Pending() } From 4bf9c1fc78500cceccefc0c0ce419c63543f814a Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 11:31:28 +0100 Subject: [PATCH 28/35] The callback dispatch counter is shared between all async objects This ensures we also prevent stack overflows when two objects build up each other's stackframes. See TestDispatchLimit for such a case. --- conn_test.go | 119 ++++++++++++++++++++++++++++++++++++++++++++++ definitions.go | 6 --- file.go | 22 +++------ io.go | 19 ++++++++ listen_conn.go | 8 ++-- multicast/peer.go | 17 ++++--- packet.go | 14 +++--- 7 files changed, 161 insertions(+), 44 deletions(-) diff --git a/conn_test.go b/conn_test.go index 9ffd439c..69ff96df 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/talostrading/sonic/sonicopts" ) @@ -349,3 +350,121 @@ func TestConnWriteHandlesError(t *testing.T) { t.Fatal("test did not run to completion") } } + +func TestDispatchLimit(t *testing.T) { + // Since the `ioc.Dispatched` counter is shared amongst, we test if it updated correctly when two connections build + // up each other's stack-frames. As in, when connection 1 has an immediate async read, it invokes connection 2's + // async read which is immediate and when done that invokes connection 1's async read which is immediate and so + // on. We ensure we reach the `MaxCallbackDispatch` limit with this sequence of reads. We then assert that + // `ioc.Dispatched` ends up 0 at the end of all operations. + + var ( + writtenBytes = MaxCallbackDispatch * 2 + readBytes = 0 + ) + + assert := assert.New(t) + + ioc := MustIO() + defer ioc.Close() + + assert.Equal(0, ioc.Dispatched) + + // setup up the server which will write to both reading connections + ln, err := net.Listen("tcp", "localhost:0") + assert.Nil(err) + addr := ln.Addr().String() + + marker := make(chan struct{}, 10) + go func() { + marker <- struct{}{} + + for { + conn, err := ln.Accept() + assert.Nil(err) + + go func() { + // These bytes will get cached by the TCP layer. We thus ensure that each connection has at least one + // series of `MaxCallbackDispatch` immediate asynchronous reads. + for i := 0; i < writtenBytes; i++ { + n, err := conn.Write([]byte("1")) + assert.Nil(err) + assert.Equal(1, n) + } + + <-marker + conn.Close() + }() + } + }() + <-marker + + limitHit := 0 // counts how many times each connection has hit the `MaxCallbackDispatch` limit + + var b [1]byte // shared between the two reading connections + + rd1, err := Dial(ioc, "tcp", addr) + assert.Nil(err) + defer rd1.Close() + + rd2, err := Dial(ioc, "tcp", addr) + assert.Nil(err) + defer rd2.Close() + + var ( + onRead1, onRead2 AsyncCallback + ) + onRead1 = func(err error, n int) { + if err != nil && err != io.EOF { + t.Fatal(err) + } + if err == io.EOF { + return + } + + readBytes += n + + assert.True(ioc.Dispatched <= MaxCallbackDispatch) + if ioc.Dispatched == MaxCallbackDispatch { + limitHit++ + } + + // sleeping ensures the writer is faster than the reader so we build up stack frames by having each async read + // complete immediately + time.Sleep(time.Millisecond) + + // invoke the other reading connection's async read to ensure `ioc.Dispatched` is correctly updated when shared + // between asynchronous objects + rd2.AsyncRead(b[:], onRead2) + } + + onRead2 = func(err error, n int) { + if err != nil && err != io.EOF { + t.Fatal(err) + } + if err == io.EOF { + return + } + + readBytes += n + + assert.True(ioc.Dispatched <= MaxCallbackDispatch) + if ioc.Dispatched == MaxCallbackDispatch { + limitHit++ + } + + time.Sleep(time.Millisecond) + rd1.AsyncRead(b[:], onRead1) // the other connection + } + + rd1.AsyncRead(b[:], onRead1) // starting point + + for readBytes < writtenBytes*2 /* two connections */ { + ioc.PollOne() + } + + assert.True(limitHit > 0) + assert.Equal(0, ioc.Dispatched) + + marker <- struct{}{} // to close the write end +} diff --git a/definitions.go b/definitions.go index ca851318..289b847a 100644 --- a/definitions.go +++ b/definitions.go @@ -5,12 +5,6 @@ import ( "net" ) -const ( - // MaxCallbackDispatch is the maximum number of callbacks that can exist on a stack-frame when asynchronous - // operations can be completed immediately. - MaxCallbackDispatch int = 32 -) - type AsyncCallback func(error, int) type AcceptCallback func(error, Conn) type AcceptPacketCallback func(error, PacketConn) diff --git a/file.go b/file.go index 87812c02..2b0f97dc 100644 --- a/file.go +++ b/file.go @@ -16,16 +16,6 @@ type file struct { ioc *IO slot internal.Slot closed uint32 - - // Tracks how many callbacks are on the current stack-frame. This prevents stack-overflows in cases where - // asynchronous operations can be completed immediately. - // - // For example, an asynchronous read might be completed immediately. In that case, the callback is invoked which in - // turn might call `AsyncRead` again. That asynchronous read might again be completed immediately and so on. In this - // case, all subsequent read callbacks are placed on the same stack-frame. We count these callbacks with - // `dispatched`. If we hit `MaxCallbackDispatch`, then the stack-frame is popped - asynchronous reads are scheduled - // to be completed on the next poll cycle, even if they can be completed immediately. - dispatched int } func Open(ioc *IO, path string, flags int, mode os.FileMode) (File, error) { @@ -94,11 +84,11 @@ func (f *file) AsyncReadAll(b []byte, cb AsyncCallback) { } func (f *file) asyncRead(b []byte, readAll bool, cb AsyncCallback) { - if f.dispatched < MaxCallbackDispatch { + if f.ioc.Dispatched < MaxCallbackDispatch { f.asyncReadNow(b, 0, readAll, func(err error, n int) { - f.dispatched++ + f.ioc.Dispatched++ cb(err, n) - f.dispatched-- + f.ioc.Dispatched-- }) } else { f.scheduleRead(b, 0, readAll, cb) @@ -166,11 +156,11 @@ func (f *file) AsyncWriteAll(b []byte, cb AsyncCallback) { } func (f *file) asyncWrite(b []byte, writeAll bool, cb AsyncCallback) { - if f.dispatched < MaxCallbackDispatch { + if f.ioc.Dispatched < MaxCallbackDispatch { f.asyncWriteNow(b, 0, writeAll, func(err error, n int) { - f.dispatched++ + f.ioc.Dispatched++ cb(err, n) - f.dispatched-- + f.ioc.Dispatched-- }) } else { f.scheduleWrite(b, 0, writeAll, cb) diff --git a/io.go b/io.go index fccf7a1c..7cd14549 100644 --- a/io.go +++ b/io.go @@ -11,6 +11,12 @@ import ( "github.com/talostrading/sonic/sonicerrors" ) +// MaxCallbackDispatch is the maximum number of callbacks that can exist on a stack-frame when asynchronous operations +// can be completed immediately. +// +// This is the limit to the `IO.dispatched` counter. +const MaxCallbackDispatch int = 32 + // IO is the executor of all asynchronous operations and the way any object can schedule them. It runs fully in the // calling goroutine. // @@ -37,6 +43,18 @@ type IO struct { dynamic map[*internal.Slot]struct{} } pendingTimers map[*Timer]struct{} // XXX: should be embedded into the above pending struct + + // Tracks how many callbacks are on the current stack-frame. This prevents stack-overflows in cases where + // asynchronous operations can be completed immediately. + // + // For example, an asynchronous read might be completed immediately. In that case, the callback is invoked which in + // turn might call `AsyncRead` again. That asynchronous read might again be completed immediately and so on. In this + // case, all subsequent read callbacks are placed on the same stack-frame. We count these callbacks with + // `Dispatched`. If we hit `MaxCallbackDispatch`, then the stack-frame is popped - asynchronous reads are scheduled + // to be completed on the next poll cycle, even if they can be completed immediately. + // + // This counter is shared amongst all asynchronous objects - they are responsible for updating it. + Dispatched int } func NewIO() (*IO, error) { @@ -48,6 +66,7 @@ func NewIO() (*IO, error) { return &IO{ poller: poller, pendingTimers: make(map[*Timer]struct{}), + Dispatched: 0, }, nil } diff --git a/listen_conn.go b/listen_conn.go index 0b1e39db..ba738334 100644 --- a/listen_conn.go +++ b/listen_conn.go @@ -16,8 +16,6 @@ type listener struct { ioc *IO slot internal.Slot addr net.Addr - - dispatched int } // Listen creates a Listener that listens for new connections on the local address. @@ -53,16 +51,16 @@ func (l *listener) Accept() (Conn, error) { } func (l *listener) AsyncAccept(cb AcceptCallback) { - if l.dispatched >= MaxCallbackDispatch { + if l.ioc.Dispatched >= MaxCallbackDispatch { l.asyncAccept(cb) } else { conn, err := l.accept() if err != nil && (err == sonicerrors.ErrWouldBlock) { l.asyncAccept(cb) } else { - l.dispatched++ + l.ioc.Dispatched++ cb(err, conn) - l.dispatched-- + l.ioc.Dispatched-- } } } diff --git a/multicast/peer.go b/multicast/peer.go index 1428f975..f16098b6 100644 --- a/multicast/peer.go +++ b/multicast/peer.go @@ -32,9 +32,8 @@ type UDPPeer struct { slot internal.Slot - sockAddr syscall.Sockaddr - closed bool - dispatched int + sockAddr syscall.Sockaddr + closed bool } // NewUDPPeer creates a new UDPPeer capable of reading/writing multicast packets @@ -567,11 +566,11 @@ func (p *UDPPeer) AsyncRead(b []byte, fn func(error, int, netip.AddrPort)) { p.read.b = b p.read.fn = fn - if p.dispatched < sonic.MaxCallbackDispatch { + if p.ioc.Dispatched < sonic.MaxCallbackDispatch { p.asyncReadNow(b, func(err error, n int, addr netip.AddrPort) { - p.dispatched++ + p.ioc.Dispatched++ fn(err, n, addr) - p.dispatched-- + p.ioc.Dispatched-- }) } else { p.scheduleRead(fn) @@ -622,11 +621,11 @@ func (p *UDPPeer) AsyncWrite( p.write.addr = addr p.write.fn = fn - if p.dispatched < sonic.MaxCallbackDispatch { + if p.ioc.Dispatched < sonic.MaxCallbackDispatch { p.asyncWriteNow(b, addr, func(err error, n int) { - p.dispatched++ + p.ioc.Dispatched++ fn(err, n) - p.dispatched-- + p.ioc.Dispatched-- }) } else { p.scheduleWrite(fn) diff --git a/packet.go b/packet.go index 8608cf9b..4a3cdfd9 100644 --- a/packet.go +++ b/packet.go @@ -20,8 +20,6 @@ type packetConn struct { localAddr net.Addr remoteAddr net.Addr closed uint32 - - dispatched int } // NewPacketConn establishes a packet based stream-less connection which is optionally bound to the specified addr. @@ -81,11 +79,11 @@ func (c *packetConn) AsyncReadAllFrom(b []byte, cb AsyncReadCallbackPacket) { } func (c *packetConn) asyncReadFrom(b []byte, readAll bool, cb AsyncReadCallbackPacket) { - if c.dispatched < MaxCallbackDispatch { + if c.ioc.Dispatched < MaxCallbackDispatch { c.asyncReadNow(b, 0, readAll, func(err error, n int, addr net.Addr) { - c.dispatched++ + c.ioc.Dispatched++ cb(err, n, addr) - c.dispatched-- + c.ioc.Dispatched-- }) } else { c.scheduleRead(b, 0, readAll, cb) @@ -145,11 +143,11 @@ func (c *packetConn) WriteTo(b []byte, to net.Addr) error { } func (c *packetConn) AsyncWriteTo(b []byte, to net.Addr, cb AsyncWriteCallbackPacket) { - if c.dispatched < MaxCallbackDispatch { + if c.ioc.Dispatched < MaxCallbackDispatch { c.asyncWriteToNow(b, to, func(err error) { - c.dispatched++ + c.ioc.Dispatched++ cb(err) - c.dispatched-- + c.ioc.Dispatched-- }) } else { c.scheduleWrite(b, to, cb) From 2783b590f3fa1213e664879b746d8b6552dd3e3d Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 11:35:51 +0100 Subject: [PATCH 29/35] Use a constructor to make a file In file.go and conn.go --- conn.go | 2 +- file.go | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index c5a32c13..d0395404 100644 --- a/conn.go +++ b/conn.go @@ -48,7 +48,7 @@ func newConn( localAddr, remoteAddr net.Addr, ) *conn { return &conn{ - file: &file{ioc: ioc, slot: internal.Slot{Fd: fd}}, + file: newFile(ioc, fd), localAddr: localAddr, remoteAddr: remoteAddr, } diff --git a/file.go b/file.go index 2b0f97dc..d1069018 100644 --- a/file.go +++ b/file.go @@ -18,17 +18,22 @@ type file struct { closed uint32 } +func newFile(ioc *IO, fd int) *file { + f := &file{ + ioc: ioc, + slot: internal.Slot{Fd: fd}, + } + atomic.StoreUint32(&f.closed, 0) + return f +} + func Open(ioc *IO, path string, flags int, mode os.FileMode) (File, error) { fd, err := syscall.Open(path, flags, uint32(mode)) if err != nil { return nil, err } - f := &file{ - ioc: ioc, - slot: internal.Slot{Fd: fd}, - } - return f, nil + return newFile(ioc, fd), nil } func (f *file) Read(b []byte) (int, error) { From 3b05493dcc2127aa5a30d60ef9d5f4b8c71ba15f Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 11:52:21 +0100 Subject: [PATCH 30/35] readBytes -> readSoFar in file.go --- file.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/file.go b/file.go index d1069018..77b1af39 100644 --- a/file.go +++ b/file.go @@ -100,54 +100,54 @@ func (f *file) asyncRead(b []byte, readAll bool, cb AsyncCallback) { } } -func (f *file) asyncReadNow(b []byte, readBytes int, readAll bool, cb AsyncCallback) { - n, err := f.Read(b[readBytes:]) - readBytes += n +func (f *file) asyncReadNow(b []byte, readSoFar int, readAll bool, cb AsyncCallback) { + n, err := f.Read(b[readSoFar:]) + readSoFar += n // f is a nonblocking fd so if err == ErrWouldBlock // then we need to schedule an async read. - if err == nil && !(readAll && readBytes != len(b)) { + if err == nil && !(readAll && readSoFar != len(b)) { // If readAll == true then read fully without errors. // If readAll == false then read some without errors. // We are done. - cb(nil, readBytes) + cb(nil, readSoFar) return } - // handles (readAll == false) and (readAll == true && readBytes != len(b)). + // handles (readAll == false) and (readAll == true && readSoFar != len(b)). if err == sonicerrors.ErrWouldBlock { // If readAll == true then read some without errors. // We schedule an asynchronous read. - f.scheduleRead(b, readBytes, readAll, cb) + f.scheduleRead(b, readSoFar, readAll, cb) } else { - cb(err, readBytes) + cb(err, readSoFar) } } -func (f *file) scheduleRead(b []byte, readBytes int, readAll bool, cb AsyncCallback) { +func (f *file) scheduleRead(b []byte, readSoFar int, readAll bool, cb AsyncCallback) { if f.Closed() { cb(io.EOF, 0) return } - handler := f.getReadHandler(b, readBytes, readAll, cb) + handler := f.getReadHandler(b, readSoFar, readAll, cb) f.slot.Set(internal.ReadEvent, handler) if err := f.ioc.SetRead(&f.slot); err != nil { - cb(err, readBytes) + cb(err, readSoFar) } else { f.ioc.Register(&f.slot) } } -func (f *file) getReadHandler(b []byte, readBytes int, readAll bool, cb AsyncCallback) internal.Handler { +func (f *file) getReadHandler(b []byte, readSoFar int, readAll bool, cb AsyncCallback) internal.Handler { return func(err error) { f.ioc.Deregister(&f.slot) if err != nil { - cb(err, readBytes) + cb(err, readSoFar) } else { - f.asyncReadNow(b, readBytes, readAll, cb) + f.asyncReadNow(b, readSoFar, readAll, cb) } } } From 65b37b26503587140961ef46bbc0f715c4c6370c Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 11:52:48 +0100 Subject: [PATCH 31/35] writtenBytes -> wroteSoFar in file.go --- file.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/file.go b/file.go index 77b1af39..4a0b8e52 100644 --- a/file.go +++ b/file.go @@ -172,49 +172,49 @@ func (f *file) asyncWrite(b []byte, writeAll bool, cb AsyncCallback) { } } -func (f *file) asyncWriteNow(b []byte, writtenBytes int, writeAll bool, cb AsyncCallback) { - n, err := f.Write(b[writtenBytes:]) - writtenBytes += n +func (f *file) asyncWriteNow(b []byte, wroteSoFar int, writeAll bool, cb AsyncCallback) { + n, err := f.Write(b[wroteSoFar:]) + wroteSoFar += n - if err == nil && !(writeAll && writtenBytes != len(b)) { + if err == nil && !(writeAll && wroteSoFar != len(b)) { // If writeAll == true then we wrote fully without errors. // If writeAll == false then we wrote some without errors. - cb(nil, writtenBytes) + cb(nil, wroteSoFar) return } - // Handles (writeAll == false) and (writeAll == true && writtenBytes != len(b)). + // Handles (writeAll == false) and (writeAll == true && wroteSoFar != len(b)). if err == sonicerrors.ErrWouldBlock { - f.scheduleWrite(b, writtenBytes, writeAll, cb) + f.scheduleWrite(b, wroteSoFar, writeAll, cb) } else { - cb(err, writtenBytes) + cb(err, wroteSoFar) } } -func (f *file) scheduleWrite(b []byte, writtenBytes int, writeAll bool, cb AsyncCallback) { +func (f *file) scheduleWrite(b []byte, wroteSoFar int, writeAll bool, cb AsyncCallback) { if f.Closed() { cb(io.EOF, 0) return } - handler := f.getWriteHandler(b, writtenBytes, writeAll, cb) + handler := f.getWriteHandler(b, wroteSoFar, writeAll, cb) f.slot.Set(internal.WriteEvent, handler) if err := f.ioc.SetWrite(&f.slot); err != nil { - cb(err, writtenBytes) + cb(err, wroteSoFar) } else { f.ioc.Register(&f.slot) } } -func (f *file) getWriteHandler(b []byte, writtenBytes int, writeAll bool, cb AsyncCallback) internal.Handler { +func (f *file) getWriteHandler(b []byte, wroteSoFar int, writeAll bool, cb AsyncCallback) internal.Handler { return func(err error) { f.ioc.Deregister(&f.slot) if err != nil { - cb(err, writtenBytes) + cb(err, wroteSoFar) } else { - f.asyncWriteNow(b, writtenBytes, writeAll, cb) + f.asyncWriteNow(b, wroteSoFar, writeAll, cb) } } } From b45b4bd3a4e0acbbab1972cd466416d6f90768cc Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 12:00:44 +0100 Subject: [PATCH 32/35] [file] Ensure we do not heap-allocate on each async read Instead of creating the asynchronous callback with getReadHandler we now store the async state in fileReadReactor which is allocated once. This is identical to how multicast/peer.go does async reads, so the mechanism is well-tested. --- file.go | 60 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/file.go b/file.go index 4a0b8e52..44aab7c2 100644 --- a/file.go +++ b/file.go @@ -13,9 +13,36 @@ import ( var _ File = &file{} type file struct { - ioc *IO - slot internal.Slot - closed uint32 + ioc *IO + slot internal.Slot + closed uint32 + readReactor fileReadReactor +} + +type fileReadReactor struct { + file *file + + b []byte + readAll bool + cb AsyncCallback + readSoFar int +} + +func (r *fileReadReactor) init(b []byte, readAll bool, cb AsyncCallback) { + r.b = b + r.readAll = readAll + r.cb = cb + + r.readSoFar = 0 +} + +func (r *fileReadReactor) onRead(err error) { + r.file.ioc.Deregister(&r.file.slot) + if err != nil { + r.cb(err, r.readSoFar) + } else { + r.file.asyncReadNow(r.b, r.readSoFar, r.readAll, r.cb) + } } func newFile(ioc *IO, fd int) *file { @@ -24,6 +51,10 @@ func newFile(ioc *IO, fd int) *file { slot: internal.Slot{Fd: fd}, } atomic.StoreUint32(&f.closed, 0) + + f.readReactor = fileReadReactor{file: f} + f.readReactor.init(nil, false, nil) + return f } @@ -89,6 +120,8 @@ func (f *file) AsyncReadAll(b []byte, cb AsyncCallback) { } func (f *file) asyncRead(b []byte, readAll bool, cb AsyncCallback) { + f.readReactor.init(b, readAll, cb) + if f.ioc.Dispatched < MaxCallbackDispatch { f.asyncReadNow(b, 0, readAll, func(err error, n int) { f.ioc.Dispatched++ @@ -96,7 +129,7 @@ func (f *file) asyncRead(b []byte, readAll bool, cb AsyncCallback) { f.ioc.Dispatched-- }) } else { - f.scheduleRead(b, 0, readAll, cb) + f.scheduleRead(0 /* this is the starting point, we did not read anything yet */, cb) } } @@ -119,20 +152,20 @@ func (f *file) asyncReadNow(b []byte, readSoFar int, readAll bool, cb AsyncCallb if err == sonicerrors.ErrWouldBlock { // If readAll == true then read some without errors. // We schedule an asynchronous read. - f.scheduleRead(b, readSoFar, readAll, cb) + f.scheduleRead(readSoFar, cb) } else { cb(err, readSoFar) } } -func (f *file) scheduleRead(b []byte, readSoFar int, readAll bool, cb AsyncCallback) { +func (f *file) scheduleRead(readSoFar int, cb AsyncCallback) { if f.Closed() { cb(io.EOF, 0) return } - handler := f.getReadHandler(b, readSoFar, readAll, cb) - f.slot.Set(internal.ReadEvent, handler) + f.readReactor.readSoFar = readSoFar + f.slot.Set(internal.ReadEvent, f.readReactor.onRead) if err := f.ioc.SetRead(&f.slot); err != nil { cb(err, readSoFar) @@ -141,17 +174,6 @@ func (f *file) scheduleRead(b []byte, readSoFar int, readAll bool, cb AsyncCallb } } -func (f *file) getReadHandler(b []byte, readSoFar int, readAll bool, cb AsyncCallback) internal.Handler { - return func(err error) { - f.ioc.Deregister(&f.slot) - if err != nil { - cb(err, readSoFar) - } else { - f.asyncReadNow(b, readSoFar, readAll, cb) - } - } -} - func (f *file) AsyncWrite(b []byte, cb AsyncCallback) { f.asyncWrite(b, false, cb) } From 391e1adb4f05edeaaad0f40801d6b784d7bb2ba4 Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 12:11:15 +0100 Subject: [PATCH 33/35] [file] Ensure we do not heap-allocate on each async write Same reasoning as the previous commit, but for writes. --- file.go | 62 ++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/file.go b/file.go index 44aab7c2..889712f8 100644 --- a/file.go +++ b/file.go @@ -13,10 +13,11 @@ import ( var _ File = &file{} type file struct { - ioc *IO - slot internal.Slot - closed uint32 - readReactor fileReadReactor + ioc *IO + slot internal.Slot + closed uint32 + readReactor fileReadReactor + writeReactor fileWriteReactor } type fileReadReactor struct { @@ -45,6 +46,32 @@ func (r *fileReadReactor) onRead(err error) { } } +type fileWriteReactor struct { + file *file + + b []byte + writeAll bool + cb AsyncCallback + wroteSoFar int +} + +func (r *fileWriteReactor) init(b []byte, writeAll bool, cb AsyncCallback) { + r.b = b + r.writeAll = writeAll + r.cb = cb + + r.wroteSoFar = 0 +} + +func (r *fileWriteReactor) onWrite(err error) { + r.file.ioc.Deregister(&r.file.slot) + if err != nil { + r.cb(err, r.wroteSoFar) + } else { + r.file.asyncWriteNow(r.b, r.wroteSoFar, r.writeAll, r.cb) + } +} + func newFile(ioc *IO, fd int) *file { f := &file{ ioc: ioc, @@ -55,6 +82,9 @@ func newFile(ioc *IO, fd int) *file { f.readReactor = fileReadReactor{file: f} f.readReactor.init(nil, false, nil) + f.writeReactor = fileWriteReactor{file: f} + f.writeReactor.init(nil, false, nil) + return f } @@ -183,6 +213,8 @@ func (f *file) AsyncWriteAll(b []byte, cb AsyncCallback) { } func (f *file) asyncWrite(b []byte, writeAll bool, cb AsyncCallback) { + f.writeReactor.init(b, writeAll, cb) + if f.ioc.Dispatched < MaxCallbackDispatch { f.asyncWriteNow(b, 0, writeAll, func(err error, n int) { f.ioc.Dispatched++ @@ -190,7 +222,7 @@ func (f *file) asyncWrite(b []byte, writeAll bool, cb AsyncCallback) { f.ioc.Dispatched-- }) } else { - f.scheduleWrite(b, 0, writeAll, cb) + f.scheduleWrite(0 /* this is the starting point, we did not write anything yet */, cb) } } @@ -207,20 +239,20 @@ func (f *file) asyncWriteNow(b []byte, wroteSoFar int, writeAll bool, cb AsyncCa // Handles (writeAll == false) and (writeAll == true && wroteSoFar != len(b)). if err == sonicerrors.ErrWouldBlock { - f.scheduleWrite(b, wroteSoFar, writeAll, cb) + f.scheduleWrite(wroteSoFar, cb) } else { cb(err, wroteSoFar) } } -func (f *file) scheduleWrite(b []byte, wroteSoFar int, writeAll bool, cb AsyncCallback) { +func (f *file) scheduleWrite(wroteSoFar int, cb AsyncCallback) { if f.Closed() { cb(io.EOF, 0) return } - handler := f.getWriteHandler(b, wroteSoFar, writeAll, cb) - f.slot.Set(internal.WriteEvent, handler) + f.writeReactor.wroteSoFar = wroteSoFar + f.slot.Set(internal.WriteEvent, f.writeReactor.onWrite) if err := f.ioc.SetWrite(&f.slot); err != nil { cb(err, wroteSoFar) @@ -229,18 +261,6 @@ func (f *file) scheduleWrite(b []byte, wroteSoFar int, writeAll bool, cb AsyncCa } } -func (f *file) getWriteHandler(b []byte, wroteSoFar int, writeAll bool, cb AsyncCallback) internal.Handler { - return func(err error) { - f.ioc.Deregister(&f.slot) - - if err != nil { - cb(err, wroteSoFar) - } else { - f.asyncWriteNow(b, wroteSoFar, writeAll, cb) - } - } -} - func (f *file) Close() error { if !atomic.CompareAndSwapUint32(&f.closed, 0, 1) { return io.EOF From 0f1452e901fbb50f1710cfba78a6903122170a0d Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Thu, 14 Nov 2024 20:21:23 +0100 Subject: [PATCH 34/35] Remove duplicate response callback in ws/stream.go I probably messed up a diff somewhere. The public repo is now in sync with the private mirror. --- codec/websocket/definitions.go | 1 - codec/websocket/stream.go | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index ab3fa7c8..59ef9b5c 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -118,4 +118,3 @@ func ExtraHeader(canonicalKey bool, key string, values ...string) Header { CanonicalKey: canonicalKey, } } - diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index c53752c4..dbcb02fe 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -81,12 +81,6 @@ type Stream struct { // Optional callback invoked when an upgrade response is received. upgradeResponseCallback UpgradeResponseCallback - // Optional callback invoked when an upgrade request is sent. - upReqCb UpgradeRequestCallback - - // Optional callback invoked when an upgrade response is received. - upResCb UpgradeResponseCallback - // Used to establish a TCP connection to the peer with a timeout. dialer *net.Dialer @@ -961,10 +955,6 @@ func (s *Stream) upgrade(uri *url.URL, stream sonic.Stream, headers []Header) er s.upgradeResponseCallback(res) } - if s.upResCb != nil { - s.upResCb(res) - } - if !IsUpgradeRes(res) { return ErrCannotUpgrade } From 582266aa2f3cd3835da039d418f8111841d787eb Mon Sep 17 00:00:00 2001 From: sergiu128 Date: Tue, 26 Nov 2024 19:36:56 +0100 Subject: [PATCH 35/35] Do not ignore SO_ERROR value on connect Otherwise we will mark the connection as successful and a subsequent read/write will return EOF. In reality, the connect syscall failed. Checking the return value of Getsockoptint(... SO_ERROR) will tell us whether the connect syscall succeeded. A return value of 0 means success - anything else can be interpreted by syscall.Errno and means failure. I was mislead into thinking the returned `err` from `Getsockoptint` is the actual socket error - in reality, it's whether the `getsockopt` syscall succeeded or not. - Golang src: https://github.com/golang/go/blob/04879acdebbb08bdca00356f043d769c4b4375ce/src/syscall/syscall_unix.go#L312 - Sanity check that uninit values are set to 0: https://go.dev/play/p/vCLjDd6WwL- - syscall.Errno is safe for any input value: https://cs.opensource.google/go/go/+/refs/tags/go1.23.3:src/syscall/syscall_unix.go;l=110 --- internal/socket_unix.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/socket_unix.go b/internal/socket_unix.go index 62291622..a817a2e9 100644 --- a/internal/socket_unix.go +++ b/internal/socket_unix.go @@ -151,10 +151,13 @@ func connect(fd int, remoteAddr net.Addr, timeout time.Duration, opts ...sonicop return sonicerrors.ErrTimeout } - _, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR) + socketErr, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR) if err != nil { return os.NewSyscallError("getsockopt", err) } + if socketErr != 0 { + return syscall.Errno(socketErr) + } } return nil