diff --git a/.github/workflows/gosec.yml b/.github/workflows/gosec.yml index 866c315f..d9aaf365 100644 --- a/.github/workflows/gosec.yml +++ b/.github/workflows/gosec.yml @@ -15,10 +15,10 @@ jobs: - name: Checkout Source uses: actions/checkout@v3 - name: Run Gosec Security Scanner - uses: securego/gosec@v2.15.0 + uses: securego/gosec@v2.21.0 with: args: '-no-fail -fmt=sarif -out=results.sarif -exclude-dir=examples -exclude-dir=stress_test -exclude-dir=other -exclude-dir=docs -exclude-dir=tests -exclude-dir=benchmark ./...' - name: Upload SARIF file - uses: github/codeql-action/upload-sarif@v2 + uses: github/codeql-action/upload-sarif@v3 with: sarif_file: results.sarif diff --git a/async_adapter.go b/async_adapter.go index ca87ec41..e3439860 100644 --- a/async_adapter.go +++ b/async_adapter.go @@ -202,7 +202,8 @@ func (a *AsyncAdapter) Close() error { return io.EOF } - _ = a.ioc.poller.Del(&a.slot) + _ = a.ioc.UnsetReadWrite(&a.slot) + a.ioc.Deregister(&a.slot) return syscall.Close(a.slot.Fd) } @@ -245,3 +246,7 @@ func (a *AsyncAdapter) cancelWrites() { func (a *AsyncAdapter) RawFd() int { return a.slot.Fd } + +func (a *AsyncAdapter) Slot() *internal.Slot { + return &a.slot +} diff --git a/codec.go b/codec.go index 70f6bc19..7f5c4420 100644 --- a/codec.go +++ b/codec.go @@ -2,64 +2,33 @@ package sonic import ( "errors" - "fmt" - - "github.com/talostrading/sonic/internal" "github.com/talostrading/sonic/sonicerrors" ) type Encoder[Item any] interface { - // Encode encodes the given item into the `dst` byte stream. + // Encode the given `Item` into the given buffer. // // Implementations should: - // - Commit the bytes into the read area of `dst`. - // - ensure `dst` is big enough to hold the serialized item by - // calling dst.Reserve(...) + // - Commit() the bytes into the given buffer if the encoding is successful. + // - Ensure the given buffer is big enough to hold the serialized `Item`s by calling `Reserve(...)`. Encode(item Item, dst *ByteBuffer) error } type Decoder[Item any] interface { - // Decode decodes the given stream into an `Item`. - // - // An implementation of Codec takes a byte stream that has already - // been buffered in `src` and decodes the data into a stream of - // `Item` objects. - // - // Implementations should return an empty Item and ErrNeedMore if - // there are not enough bytes to decode into an Item. + // Decode the next `Item`, if any, from the provided buffer. If there are not enough bytes to decode an `Item`, + // implementations should return an empty `Item` along with `ErrNeedMore`. `CodecConn` will then know to read more + // bytes before calling `Decode(...)` again. Decode(src *ByteBuffer) (Item, error) } -// Codec defines a generic interface through which one can encode/decode -// a raw stream of bytes. -// -// Implementations are optionally able to track their state which enables -// writing both stateful and stateless parsers. +// Codec groups together and Encoder and a Decoder for a CodecConn. type Codec[Enc, Dec any] interface { Encoder[Enc] Decoder[Dec] } -type CodecConn[Enc, Dec any] interface { - AsyncReadNext(func(error, Dec)) - ReadNext() (Dec, error) - - AsyncWriteNext(Enc, AsyncCallback) - WriteNext(Enc) (int, error) - - NextLayer() Stream - - Close() error -} - -var ( - _ CodecConn[any, any] = &BlockingCodecConn[any, any]{} - _ CodecConn[any, any] = &NonblockingCodecConn[any, any]{} -) - -// BlockingCodecConn handles the decoding/encoding of bytes funneled through a -// provided blocking file descriptor. -type BlockingCodecConn[Enc, Dec any] struct { +// CodecConn reads/writes `Item`s through the provided `Codec`. For an example, see `codec/frame.go`. +type CodecConn[Enc, Dec any] struct { stream Stream codec Codec[Enc, Dec] src *ByteBuffer @@ -69,14 +38,12 @@ type BlockingCodecConn[Enc, Dec any] struct { emptyDec Dec } -func NewBlockingCodecConn[Enc, Dec any]( +func NewCodecConn[Enc, Dec any]( stream Stream, codec Codec[Enc, Dec], src, dst *ByteBuffer, -) (*BlockingCodecConn[Enc, Dec], error) { - // Works on both blocking and nonblocking fds. - - c := &BlockingCodecConn[Enc, Dec]{ +) (*CodecConn[Enc, Dec], error) { + c := &CodecConn[Enc, Dec]{ stream: stream, codec: codec, src: src, @@ -85,26 +52,22 @@ func NewBlockingCodecConn[Enc, Dec any]( return c, nil } -func (c *BlockingCodecConn[Enc, Dec]) AsyncReadNext(cb func(error, Dec)) { +func (c *CodecConn[Enc, Dec]) AsyncReadNext(cb func(error, Dec)) { item, err := c.codec.Decode(c.src) if errors.Is(err, sonicerrors.ErrNeedMore) { - c.scheduleAsyncRead(cb) + c.src.AsyncReadFrom(c.stream, func(err error, _ int) { + if err != nil { + cb(err, c.emptyDec) + } else { + c.AsyncReadNext(cb) + } + }) } else { cb(err, item) } } -func (c *BlockingCodecConn[Enc, Dec]) scheduleAsyncRead(cb func(error, Dec)) { - c.src.AsyncReadFrom(c.stream, func(err error, _ int) { - if err != nil { - cb(err, c.emptyDec) - } else { - c.AsyncReadNext(cb) - } - }) -} - -func (c *BlockingCodecConn[Enc, Dec]) ReadNext() (Dec, error) { +func (c *CodecConn[Enc, Dec]) ReadNext() (Dec, error) { for { item, err := c.codec.Decode(c.src) if err == nil { @@ -122,7 +85,7 @@ func (c *BlockingCodecConn[Enc, Dec]) ReadNext() (Dec, error) { } } -func (c *BlockingCodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { +func (c *CodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { err = c.codec.Encode(item, c.dst) if err == nil { var nn int64 @@ -132,7 +95,7 @@ func (c *BlockingCodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { return } -func (c *BlockingCodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) { +func (c *CodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) { err := c.codec.Encode(item, c.dst) if err == nil { c.dst.AsyncWriteTo(c.stream, cb) @@ -141,105 +104,10 @@ func (c *BlockingCodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) } } -func (c *BlockingCodecConn[Enc, Dec]) NextLayer() Stream { - return c.stream -} - -func (c *BlockingCodecConn[Enc, Dec]) Close() error { - return c.stream.Close() -} - -type NonblockingCodecConn[Enc, Dec any] struct { - stream Stream - codec Codec[Enc, Dec] - src *ByteBuffer - dst *ByteBuffer - - dispatched int - - emptyEnc Enc - emptyDec Dec -} - -func NewNonblockingCodecConn[Enc, Dec any]( - stream Stream, - codec Codec[Enc, Dec], - src, dst *ByteBuffer, -) (*NonblockingCodecConn[Enc, Dec], error) { - nonblocking, err := internal.IsNonblocking(stream.RawFd()) - if err != nil { - return nil, err - } - if !nonblocking { - return nil, fmt.Errorf("the provided Stream is blocking") - } - - c := &NonblockingCodecConn[Enc, Dec]{ - stream: stream, - codec: codec, - src: src, - dst: dst, - } - return c, nil -} - -func (c *NonblockingCodecConn[Enc, Dec]) AsyncReadNext(cb func(error, Dec)) { - item, err := c.codec.Decode(c.src) - if errors.Is(err, sonicerrors.ErrNeedMore) { - c.src.AsyncReadFrom(c.stream, func(err error, _ int) { - if err != nil { - cb(err, c.emptyDec) - } else { - c.AsyncReadNext(cb) - } - }) - } else { - cb(err, item) - } -} - -func (c *NonblockingCodecConn[Enc, Dec]) ReadNext() (Dec, error) { - for { - item, err := c.codec.Decode(c.src) - if err == nil { - return item, nil - } - - if err != sonicerrors.ErrNeedMore { - return c.emptyDec, err - } - - _, err = c.src.ReadFrom(c.stream) - if err != nil { - return c.emptyDec, err - } - } -} - -func (c *NonblockingCodecConn[Enc, Dec]) AsyncWriteNext(item Enc, cb AsyncCallback) { - if err := c.codec.Encode(item, c.dst); err != nil { - cb(err, 0) - return - } - - // write everything into `dst` - c.dst.AsyncWriteTo(c.stream, cb) -} - -func (c *NonblockingCodecConn[Enc, Dec]) WriteNext(item Enc) (n int, err error) { - err = c.codec.Encode(item, c.dst) - if err == nil { - var nn int64 - nn, err = c.dst.WriteTo(c.stream) - n = int(nn) - } - return -} - -func (c *NonblockingCodecConn[Enc, Dec]) NextLayer() Stream { +func (c *CodecConn[Enc, Dec]) NextLayer() Stream { return c.stream } -func (c *NonblockingCodecConn[Enc, Dec]) Close() error { +func (c *CodecConn[Enc, Dec]) Close() error { return c.stream.Close() } diff --git a/codec/frame/stream_test.go b/codec/frame/stream_test.go index 0b216d33..b85fbf57 100644 --- a/codec/frame/stream_test.go +++ b/codec/frame/stream_test.go @@ -93,7 +93,7 @@ func runClient(port int, t *testing.T) { src := sonic.NewByteBuffer() dst := sonic.NewByteBuffer() codec := NewCodec(src) - codecConn, err := sonic.NewNonblockingCodecConn[[]byte, []byte]( + codecConn, err := sonic.NewCodecConn[[]byte, []byte]( conn, codec, src, dst, ) if err != nil { diff --git a/codec/websocket/definitions.go b/codec/websocket/definitions.go index 209197a5..59ef9b5c 100644 --- a/codec/websocket/definitions.go +++ b/codec/websocket/definitions.go @@ -1,16 +1,14 @@ package websocket import ( - "net" "net/http" "time" - - "github.com/talostrading/sonic" ) -var ( - MaxMessageSize = 1024 * 512 // the maximum size of a message - CloseTimeout = 5 * time.Second +const ( + DefaultMaxMessageSize = 1024 * 512 + CloseTimeout = 5 * time.Second + DialTimeout = 5 * time.Second ) type Role uint8 @@ -31,7 +29,7 @@ func (r Role) String() string { } } -type MessageType uint8 +type MessageType byte const ( TypeText = MessageType(OpcodeText) @@ -69,20 +67,16 @@ const ( // Intermediate state. Connection is active, can read/write/close. StateActive - // Intermediate state. We initiated the closing handshake and - // are waiting for a reply from the peer. + // Intermediate state. We initiated the closing handshake and are waiting for a reply from the peer. StateClosedByUs - // Terminal state. The peer initiated the closing handshake, - // we received a close frame and immediately replied. + // Terminal state. The peer initiated the closing handshake, we received a close frame and immediately replied. StateClosedByPeer - // Terminal state. The peer replied to our closing handshake. - // Can only end up here from StateClosedByUs. + // Terminal state. The peer replied to our closing handshake. Can only end up here from StateClosedByUs. StateCloseAcked - // Terminal state. The connection is closed or some error - // occurred which rendered the stream unusable. + // Terminal state. The connection is closed or some error occurred which rendered the stream unusable. StateTerminated ) @@ -105,9 +99,9 @@ func (s StreamState) String() string { } } -type AsyncMessageHandler = func(err error, n int, mt MessageType) -type AsyncFrameHandler = func(err error, f *Frame) -type ControlCallback = func(mt MessageType, payload []byte) +type AsyncMessageCallback = func(err error, n int, messageType MessageType) +type AsyncFrameCallback = func(err error, f Frame) +type ControlCallback = func(messageType MessageType, payload []byte) type UpgradeRequestCallback = func(req *http.Request) type UpgradeResponseCallback = func(res *http.Response) @@ -124,273 +118,3 @@ func ExtraHeader(canonicalKey bool, key string, values ...string) Header { CanonicalKey: canonicalKey, } } - -// Stream is an interface for representing a stateful WebSocket connection -// on the server or client side. -// -// The interface uses the layered stream model. A WebSocket stream object -// contains another stream object, called the "next layer", which it uses to -// perform IO. -// -// Implementations handle the replies to control frames. Before closing the -// stream, it is important to call Flush or AsyncFlush in order to write any -// pending control frame replies to the underlying stream. -type Stream interface { - // NextLayer returns the underlying stream object. - // - // The returned object is constructed by the Stream and maintained - // throughout its entire lifetime. All reads and writes will go through the - // next layer. - NextLayer() sonic.Stream - - // SupportsDeflate returns true if Deflate compression is supported. - // - // https://datatracker.ietf.org/doc/html/rfc7692 - SupportsDeflate() bool - - // SupportsUTF8 returns true if UTF8 validity checks are supported. - // - // Implementations should not do UTF8 checking by default. Callers - // should be able to turn it on when instantiating the Stream. - SupportsUTF8() bool - - // NextMessage reads the payload of the next message into the supplied - // buffer. Message fragmentation is automatically handled by the - // implementation. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - the payload of the message is successfully read into the supplied - // buffer - NextMessage([]byte) (mt MessageType, n int, err error) - - // NextFrame reads and returns the next frame. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - a frame is successfully read from the underlying stream - NextFrame() (*Frame, error) - - // AsyncNextMessage reads the payload of the next message into the supplied - // buffer asynchronously. Message fragmentation is automatically handled by - // the implementation. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - the payload of the message is successfully read into the supplied - // buffer - AsyncNextMessage([]byte, AsyncMessageHandler) - - // AsyncNextFrame reads and returns the next frame asynchronously. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs when reading/decoding the message bytes from the - // underlying stream - // - a frame is successfully read from the underlying stream - AsyncNextFrame(AsyncFrameHandler) - - // WriteFrame writes the supplied frame to the underlying stream. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the frame is successfully written to the underlying stream - WriteFrame(fr *Frame) error - - // AsyncWriteFrame writes the supplied frame to the underlying stream - // asynchronously. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the frame is successfully written to the underlying stream - AsyncWriteFrame(fr *Frame, cb func(err error)) - - // Write writes the supplied buffer as a single message with the given type - // to the underlying stream. - // - // This call first flushes any pending control frames to the underlying - // stream. - // - // This call blocks until one of the following conditions is true: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the message is successfully written to the underlying stream - // - // The message will be written as a single frame. Fragmentation should be - // handled by the caller through multiple calls to AsyncWriteFrame. - Write(b []byte, mt MessageType) error - - // AsyncWrite writes the supplied buffer as a single message with the given - // type to the underlying stream asynchronously. - // - // This call first flushes any pending control frames to the underlying - // stream asynchronously. - // - // This call does not block. The provided completion handler is invoked when - // one of the following happens: - // - an error occurs while flushing the pending operations - // - an error occurs during the write - // - the message is successfully written to the underlying stream - // - // The message will be written as a single frame. Fragmentation should be - // handled by the caller through multiple calls to AsyncWriteFrame. - AsyncWrite(b []byte, mt MessageType, cb func(err error)) - - // Flush writes any pending control frames to the underlying stream. - // - // This call blocks. - Flush() error - - // Flush writes any pending control frames to the underlying stream - // asynchronously. - // - // This call does not block. - AsyncFlush(cb func(err error)) - - // Pending returns the number of currently pending operations. - Pending() int - - // State returns the state of the WebSocket connection. - State() StreamState - - // Handshake performs the WebSocket handshake in the client role. - // - // The call blocks until one of the following conditions is true: - // - the request is sent and the response is received - // - an error occurs - Handshake(addr string, extraHeaders ...Header) error - - // AsyncHandshake performs the WebSocket handshake asynchronously in the - // client role. - // - // This call does not block. The provided completion handler is called when - // the request is sent and the response is received or when an error occurs. - // - // Regardless of whether the asynchronous operation completes immediately - // or not, the handler will not be invoked from within this function. - // Invocation of the handler will be performed in a manner equivalent to - // using sonic.Post(...). - AsyncHandshake(addr string, cb func(error), extraHeaders ...Header) - - // Accept performs the handshake in the server role. - // - // The call blocks until one of the following conditions is true: - // - the request is sent and the response is received - // - an error occurs - Accept() error - - // AsyncAccept performs the handshake asynchronously in the server role. - // - // This call does not block. The provided completion handler is called when - // the request is send and the response is received or when an error occurs. - // - // Regardless of whether the asynchronous operation completes immediately - // or not, the handler will not be invoked from within this function. - // Invocation of the handler will be performed in a manner equivalent to - // using sonic.Post(...). - AsyncAccept(func(error)) - - // AsyncClose sends a websocket close control frame asynchronously. - // - // This function is used to send a close frame which begins the WebSocket - // closing handshake. The session ends when both ends of the connection - // have sent and received a close frame. - // - // The handler is called if one of the following conditions is true: - // - the close frame is written - // - an error occurs - // - // After beginning the closing handshake, the program should not write - // further message data, pings, or pongs. Instead, the program should - // continue reading message data until an error occurs. - AsyncClose(cc CloseCode, reason string, cb func(err error)) - - // Close sends a websocket close control frame asynchronously. - // - // This function is used to send a close frame which begins the WebSocket - // closing handshake. The session ends when both ends of the connection - // have sent and received a close frame. - // - // The call blocks until one of the following conditions is true: - // - the close frame is written - // - an error occurs - // - // After beginning the closing handshake, the program should not write - // further message data, pings, or pongs. Instead, the program should - // continue reading message data until an error occurs. - Close(cc CloseCode, reason string) error - - // SetControlCallback sets a function that will be invoked when a - // Ping/Pong/Close is received while reading a message. This callback is - // not invoked when AsyncNextFrame or NextFrame are called. - // - // The caller must not perform any operations on the stream in the provided - // callback. - SetControlCallback(ControlCallback) - - // ControlCallback returns the control callback set with SetControlCallback. - ControlCallback() ControlCallback - - // SetUpgradeRequestCallback sets a function that will be invoked during the handshake - // just before the upgrade request is sent. - // - // The caller must not perform any operations on the stream in the provided callback. - SetUpgradeRequestCallback(callback UpgradeRequestCallback) - - // UpgradeRequestCallback returns the callback set with SetUpgradeRequestCallback. - UpgradeRequestCallback() UpgradeRequestCallback - - // SetUpgradeResponseCallback sets a function that will be invoked during the handshake - // just after the upgrade response is received. - // - // The caller must not perform any operations on the stream in the provided callback. - SetUpgradeResponseCallback(callback UpgradeResponseCallback) - - // UpgradeResponseCallback returns the callback set with SetUpgradeResponseCallback. - UpgradeResponseCallback() UpgradeResponseCallback - - // SetMaxMessageSize sets the maximum size of a message that can be read - // from or written to a peer. - // - If a message exceeds the limit while reading, the connection is - // closed abnormally. - // - If a message exceeds the limit while writing, the operation is - // cancelled. - SetMaxMessageSize(bytes int) - - RemoteAddr() net.Addr - - LocalAddr() net.Addr - - RawFd() int - - CloseNextLayer() error -} diff --git a/codec/websocket/errors.go b/codec/websocket/errors.go index 8733c1c8..bd20f962 100644 --- a/codec/websocket/errors.go +++ b/codec/websocket/errors.go @@ -38,4 +38,6 @@ var ( ErrExpectedContinuation = errors.New("expected continue frame") ErrInvalidAddress = errors.New("invalid address") + + ErrInvalidUTF8 = errors.New("Invalid UTF-8 encoding") ) diff --git a/codec/websocket/frame.go b/codec/websocket/frame.go index 38bdc98c..63ee9584 100644 --- a/codec/websocket/frame.go +++ b/codec/websocket/frame.go @@ -2,298 +2,316 @@ package websocket import ( "encoding/binary" - "fmt" "io" - "sync" "github.com/talostrading/sonic/util" ) -var zeroBytes = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +var zeroBytes [frameMaxHeaderLength]byte -type Frame struct { - header []byte - mask []byte - payload []byte -} - -func NewFrame() *Frame { - f := &Frame{ - header: make([]byte, 10), - mask: make([]byte, 4), - payload: make([]byte, 0, 1024), +func init() { + for i := 0; i < len(zeroBytes); i++ { + zeroBytes[i] = 0 } - return f } -func (f *Frame) Reset() { - copy(f.header, zeroBytes) - copy(f.mask, zeroBytes) - f.payload = f.payload[:0] -} +type Frame []byte -func (f *Frame) ExtraHeaderLen() (n int) { - switch f.header[1] & 127 { - case 127: - n = 8 - case 126: - n = 2 - } - return -} +var ( + _ io.ReaderFrom = &Frame{} + _ io.WriterTo = &Frame{} +) -// PayloadLenType returns the payload length as indicated in the fixed -// size header. It is always less than 127 as per the WebSocket protocol. -// The actual payload size can be retrieved by calling PayloadLen. -func (f *Frame) PayloadLenType() int { - return int(f.header[1] & 127) +// NOTE use stream.AcquireFrame() instead of NewFrame if you intend to write this frame onto a WebSocket stream. +func NewFrame() Frame { + b := make([]byte, frameMaxHeaderLength) + copy(b, zeroBytes[:]) + return b } -// PayloadLen returns the actual payload length which can either be -// in the header if the length type is 125 or less, in the next 2 bytes if the -// length type is 126 or in the next 8 bytes if the length type is 127. -func (f *Frame) PayloadLen() int { - length := uint64(f.header[1] & 127) - switch length { - case 126: - length = uint64(binary.BigEndian.Uint16(f.header[2:])) - case 127: - length = binary.BigEndian.Uint64(f.header[2:]) - } - return int(length) +func (f Frame) Reset() { + copy(f, zeroBytes[:]) } -func (f *Frame) SetPayloadLen() (bytes int) { - n := len(f.payload) - - switch { - case n > 65535: // more than two bytes needed for extra length - bytes = 8 //nolint:ineffassign - f.header[1] |= uint8(127) - binary.BigEndian.PutUint64(f.header[2:], uint64(n)) - case n > 125: - bytes = 2 //nolint:ineffassign - f.header[1] |= uint8(126) - binary.BigEndian.PutUint16(f.header[2:], uint16(n)) +func (f Frame) ExtendedPayloadLengthBytes() int { + if v := f[1] & bitmaskPayloadLength; v == 127 { + return 8 + } else if v == 126 { return 2 - default: - bytes = 0 //nolint:ineffassign - f.header[1] |= uint8(n) } - return + return 0 } -func (f *Frame) ReadFrom(r io.Reader) (nt int64, err error) { - var n int - n, err = io.ReadFull(r, f.header[:2]) - nt += int64(n) - - if err == nil { - m := f.ExtraHeaderLen() - if m > 0 { - n, err = io.ReadFull(r, f.header[2:m+2]) - nt += int64(n) - } - - if err == nil && f.IsMasked() { - n, err = io.ReadFull(r, f.mask[:4]) - nt += int64(n) - } - - if err == nil { - if pn := f.PayloadLen(); pn > 0 { - if pn > MaxMessageSize { - err = ErrPayloadTooBig - } else { - f.payload = util.ExtendSlice(f.payload, pn) - n, err = io.ReadFull(r, f.payload[:pn]) - nt += int64(n) - } - } else if pn == 0 { - f.payload = f.payload[:0] - } - } +func (f Frame) PayloadLength() int { + if length := f[1] & bitmaskPayloadLength; length == 127 { + return int(binary.BigEndian.Uint64(f[frameHeaderLength : frameHeaderLength+8])) + } else if length == 126 { + return int(binary.BigEndian.Uint16(f[frameHeaderLength : frameHeaderLength+2])) + } else { + return int(length) } - - return } -func (f *Frame) WriteTo(w io.Writer) (n int64, err error) { - var nn int - - nn, err = w.Write(f.header[:2+f.SetPayloadLen()]) - n += int64(nn) - - if err == nil { - if f.IsMasked() { - nn, err = w.Write(f.mask[:]) - n += int64(nn) - } +// An unfragmented message consists of a single frame with the FIN bit set and an opcode other than 0. +// +// A fragmented message consists of a single frame with the FIN bit clear and an opcode other than 0, followed by zero +// or more frames with the FIN bit clear and the opcode set to 0, and terminated by a single frame with the FIN bit set +// and an opcode of 0. +func (f Frame) IsFIN() bool { + return f[0]&bitFIN != 0 +} - if err == nil && f.PayloadLen() > 0 { - nn, err = w.Write(f.payload[:f.PayloadLen()]) - n += int64(nn) - } - } +func (f Frame) IsRSV1() bool { + return f[0]&bitRSV1 != 0 +} - return +func (f Frame) IsRSV2() bool { + return f[0]&bitRSV2 != 0 } -func (f *Frame) IsFin() bool { - return f.header[0]&finBit != 0 +func (f Frame) IsRSV3() bool { + return f[0]&bitRSV3 != 0 } -func (f *Frame) IsRSV1() bool { - return f.header[0]&rsv1Bit != 0 +func (f Frame) Opcode() Opcode { + return Opcode(f[0] & bitmaskOpcode) } -func (f *Frame) IsRSV2() bool { - return f.header[0]&rsv2Bit != 0 +func (f Frame) IsMasked() bool { + return f[1]&bitIsMasked != 0 } -func (f *Frame) IsRSV3() bool { - return f.header[0]&rsv3Bit != 0 +func (f *Frame) SetIsMasked() *Frame { + (*f)[1] |= bitIsMasked + return f } -func (f *Frame) Opcode() Opcode { - return Opcode(f.header[0] & 15) +func (f *Frame) UnsetIsMasked() *Frame { + (*f)[1] ^= bitIsMasked + return f } -func (f *Frame) IsContinuation() bool { - return f.Opcode() == OpcodeContinuation +func (f Frame) MaskBytes() int { + if f.IsMasked() { + return frameMaskLength + } + return 0 } -func (f *Frame) IsText() bool { - return f.Opcode() == OpcodeText +func (f *Frame) SetFIN() *Frame { + (*f)[0] |= bitFIN + return f } -func (f *Frame) IsBinary() bool { - return f.Opcode() == OpcodeBinary +func (f *Frame) SetRSV1() *Frame { + (*f)[0] |= bitRSV1 + return f } -func (f *Frame) IsClose() bool { - return f.Opcode() == OpcodeClose +func (f *Frame) SetRSV2() *Frame { + (*f)[0] |= bitRSV2 + return f } -func (f *Frame) IsPing() bool { - return f.Opcode() == OpcodePing +func (f *Frame) SetRSV3() *Frame { + (*f)[0] |= bitRSV3 + return f } -func (f *Frame) IsPong() bool { - return f.Opcode() == OpcodePong +func (f Frame) clearOpcode() { + f[0] &= bitmaskOpcode << 4 } -func (f *Frame) IsControl() bool { - return f.IsClose() || f.IsPing() || f.IsPong() +func (f *Frame) SetOpcode(c Opcode) *Frame { + c &= Opcode(bitmaskOpcode) + f.clearOpcode() + (*f)[0] |= byte(c) + return f } -func (f *Frame) IsMasked() bool { - return f.header[1]&maskBit != 0 +func (f *Frame) SetContinuation() *Frame { + f.SetOpcode(OpcodeContinuation) + return f } -func (f *Frame) SetFin() { - f.header[0] |= finBit +func (f *Frame) SetText() *Frame { + f.SetOpcode(OpcodeText) + return f } -func (f *Frame) SetRSV1() { - f.header[0] |= rsv1Bit +func (f *Frame) SetBinary() *Frame { + f.SetOpcode(OpcodeBinary) + return f } -func (f *Frame) SetRSV2() { - f.header[0] |= rsv2Bit +func (f *Frame) SetClose() *Frame { + f.SetOpcode(OpcodeClose) + return f } -func (f *Frame) SetRSV3() { - f.header[0] |= rsv3Bit +func (f *Frame) SetPing() *Frame { + f.SetOpcode(OpcodePing) + return f } -func (f *Frame) SetOpcode(c Opcode) { - c &= 15 - f.header[0] &= 15 << 4 - f.header[0] |= uint8(c) +func (f *Frame) SetPong() *Frame { + f.SetOpcode(OpcodePong) + return f } -func (f *Frame) SetContinuation() { - f.SetOpcode(OpcodeContinuation) +func (f Frame) extendedPayloadLengthOffset() int { + return frameHeaderLength } -func (f *Frame) SetText() { - f.SetOpcode(OpcodeText) +func (f Frame) extendedPayloadLength() []byte { + if bytes := f.ExtendedPayloadLengthBytes(); bytes > 0 { + b := f[frameHeaderLength:] + return b[:bytes] + } + return nil } -func (f *Frame) SetBinary() { - f.SetOpcode(OpcodeBinary) +func (f Frame) Header() []byte { + return f[:frameHeaderLength] } -func (f *Frame) SetClose() { - f.SetOpcode(OpcodeClose) +func (f *Frame) maskOffset() int { + return frameHeaderLength + f.ExtendedPayloadLengthBytes() } -func (f *Frame) SetPing() { - f.SetOpcode(OpcodePing) +func (f Frame) Mask() []byte { + if f.IsMasked() { + mask := f[f.maskOffset():] + return mask[:frameMaskLength] + } + return nil } -func (f *Frame) SetPong() { - f.SetOpcode(OpcodePong) +func (f Frame) payloadOffset() int { + return frameHeaderLength + f.ExtendedPayloadLengthBytes() + f.MaskBytes() } -func (f *Frame) SetPayload(b []byte) { - f.payload = append(f.payload[:0], b...) +func (f *Frame) setPayloadLength(n int) *Frame { + (*f)[1] &= (1 << 7) + + if n > (1<<16 - 1) { + // does not fit in 2 bytes, so take 8 as extended payload length + (*f)[1] |= 127 + binary.BigEndian.PutUint64((*f)[2:], uint64(n)) + } else if n > 125 { + // fits in 2 bytes as extended payload length + (*f)[1] |= 126 + binary.BigEndian.PutUint16((*f)[2:], uint16(n)) + } else { + // can be encoded in the 7 bits of the payload length, no extended payload length taken + (*f)[1] |= byte(n) + } + return f } -func (f *Frame) MaskKey() []byte { - return f.mask[:] +func (f *Frame) SetPayload(b []byte) *Frame { + f.setPayloadLength(len(b)) // set the length as it's used by `payloadOffset`. + + *f = util.ExtendSlice(*f, f.payloadOffset()+len(b)) + payload := f.Payload() + copy(payload, b) + + return f } -func (f *Frame) Payload() []byte { - return f.payload +func (f Frame) Payload() []byte { + return f[f.payloadOffset():] } -func (f *Frame) Mask() { - f.header[1] |= maskBit - GenMask(f.mask[:]) - if len(f.payload) > 0 { - Mask(f.mask[:], f.payload) +func (f *Frame) MaskPayload() { + f.SetIsMasked() + + var ( + mask = f.Mask() + payload = f.Payload() + ) + + if len(payload) > 0 { + GenMask(mask) + Mask(mask, payload) } } -func (f *Frame) Unmask() { - if len(f.payload) > 0 { - key := f.MaskKey() - Mask(key, f.payload) +func (f *Frame) UnmaskPayload() { + if f.IsMasked() { + var ( + mask = f.Mask() + payload = f.Payload() + ) + Mask(mask, payload) + // Does not unset the IsMasked bit in order to not mess up the offset at which the payload is found. } - f.header[1] ^= maskBit -} - -func (f *Frame) String() string { - return fmt.Sprintf(` -FIN: %v -RSV1: %v -RSV2: %v -RSV3: %v -OPCODE: %d -MASK: %v -LENGTH: %d -MASK-KEY: %v -PAYLOAD: %v`, - - f.IsFin(), f.IsRSV1(), f.IsRSV2(), f.IsRSV3(), - f.Opcode(), f.IsMasked(), f.PayloadLen(), - f.MaskKey(), f.Payload(), - ) } -var framePool = sync.Pool{ - New: func() interface{} { - return NewFrame() - }, +func (f *Frame) fitPayload() ([]byte, error) { + length := f.PayloadLength() + if length <= 0 { + return nil, nil + } + + *f = util.ExtendSlice(*f, f.payloadOffset()+length) + b := (*f)[f.payloadOffset():] + return b[:length], nil } -func AcquireFrame() *Frame { - return framePool.Get().(*Frame) +func (f *Frame) ReadFrom(r io.Reader) (n int64, err error) { + var nn int + + // read the header + nn, err = io.ReadFull(r, f.Header()) + n += int64(nn) + if err != nil { + return + } + + // read the extended payload length, if any + if b := f.extendedPayloadLength(); b != nil { + nn, err = io.ReadFull(r, b) + n += int64(nn) + if err != nil { + return + } + } + + // read the mask, if any + if f.IsMasked() { + nn, err = io.ReadFull(r, f.Mask()) + n += int64(nn) + if err != nil { + return + } + } + + // read the payload, if any + b, err := f.fitPayload() + if err != nil { + return + } + if b != nil { + nn, err = io.ReadFull(r, b) + n += int64(nn) + if err != nil { + return + } + } + + return } -func ReleaseFrame(f *Frame) { - f.Reset() - framePool.Put(f) +func (f Frame) WriteTo(w io.Writer) (int64, error) { + written := 0 + for written < len(f) { + n, err := w.Write(f[written:]) + written += n + if err != nil { + return int64(n), err + } + } + + return int64(written), nil } diff --git a/codec/websocket/frame_codec.go b/codec/websocket/frame_codec.go index af72bb41..5c2a5262 100644 --- a/codec/websocket/frame_codec.go +++ b/codec/websocket/frame_codec.go @@ -2,116 +2,108 @@ package websocket import ( "errors" - "github.com/talostrading/sonic" ) -var _ sonic.Codec[*Frame, *Frame] = &FrameCodec{} +var _ sonic.Codec[Frame, Frame] = &FrameCodec{} var ( ErrPartialPayload = errors.New("partial payload") ) -// FrameCodec is a stateful streaming parser handling the encoding -// and decoding of WebSocket frames. +// FrameCodec is a stateful streaming parser handling the encoding and decoding of WebSocket `Frame`s. type FrameCodec struct { - src *sonic.ByteBuffer // buffer we decode from - dst *sonic.ByteBuffer // buffer we encode to + src *sonic.ByteBuffer // buffer we decode from + dst *sonic.ByteBuffer // buffer we encode to + maxMessageSize int - decodeFrame *Frame // frame we decode into - decodeBytes int // number of bytes of the last successfully decoded frame - decodeReset bool // true if we must reset the state on the next decode + decodeFrame Frame // frame we decode into + decodeReset bool // true if we must reset the state on the next decode } -func NewFrameCodec(src, dst *sonic.ByteBuffer) *FrameCodec { +func NewFrameCodec(src, dst *sonic.ByteBuffer, maxMessageSize int) *FrameCodec { return &FrameCodec{ - decodeFrame: NewFrame(), - src: src, - dst: dst, + decodeFrame: NewFrame(), + src: src, + dst: dst, + maxMessageSize: maxMessageSize, } } func (c *FrameCodec) resetDecode() { if c.decodeReset { c.decodeReset = false - c.src.Consume(c.decodeBytes) - c.decodeBytes = 0 + c.src.Consume(len(c.decodeFrame)) + c.decodeFrame = nil } } -// Decode decodes the raw bytes from `src` into a frame. -// -// Three things can happen while decoding a raw stream of bytes into a frame: -// 1. There are not enough bytes to construct a frame with. -// -// In this case, a nil frame and ErrNeedMore are returned. The caller -// should perform another read into `src` later. -// -// 2. `src` contains the bytes of one frame. +// Decode decodes the raw bytes from `src` into a `Frame`. // -// In this case we try to decode the frame. An appropriate error is returned -// if the frame is corrupt. +// Two things can happen while decoding a raw stream of bytes into a frame: // -// 3. `src` contains the bytes of more than one frame. +// 1. There are not enough bytes to construct a frame with: in this case, a `nil` `Frame` and `ErrNeedMore` are +// returned. The caller should perform another read into `src` later. // -// In this case we try to decode the first frame. The rest of the bytes stay -// in `src`. An appropriate error is returned if the frame is corrupt. -func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (*Frame, error) { +// 2. `src` contains at least the bytes of one `Frame`: we decode the next `Frame` and leave the remainder bytes +// composing a partial `Frame` or a set of `Frame`s in the `src` buffer. +func (c *FrameCodec) Decode(src *sonic.ByteBuffer) (Frame, error) { c.resetDecode() - n := 2 - err := src.PrepareRead(n) - if err != nil { + // read the mandatory header + readSoFar := frameHeaderLength + if err := src.PrepareRead(readSoFar); err != nil { + c.decodeFrame = nil return nil, err } - c.decodeFrame.header = src.Data()[:n] + c.decodeFrame = src.Data()[:readSoFar] - // read extra header length - n += c.decodeFrame.ExtraHeaderLen() - if err := src.PrepareRead(n); err != nil { + // read the extended payload length (0, 2 or 8 bytes) and check if within bounds + readSoFar += c.decodeFrame.ExtendedPayloadLengthBytes() + if err := src.PrepareRead(readSoFar); err != nil { + c.decodeFrame = nil return nil, err } - c.decodeFrame.header = src.Data()[:n] + c.decodeFrame = src.Data()[:readSoFar] + + payloadLength := c.decodeFrame.PayloadLength() + if payloadLength > c.maxMessageSize { + c.decodeFrame = nil + return nil, ErrPayloadOverMaxSize + } // read mask if any if c.decodeFrame.IsMasked() { - n += 4 - if err := src.PrepareRead(n); err != nil { + readSoFar += frameMaskLength + if err := src.PrepareRead(readSoFar); err != nil { + c.decodeFrame = nil return nil, err } - c.decodeFrame.mask = src.Data()[n-4 : n] - } - - // check payload length - npayload := c.decodeFrame.PayloadLen() - if npayload > MaxMessageSize { - return nil, ErrPayloadOverMaxSize + c.decodeFrame = src.Data()[:readSoFar] } - // prepare to read the payload - n += npayload - if err := src.PrepareRead(n); err != nil { - // the payload might be too big for our buffer so we must allocate - // enough for the next Decode call to succeed - src.Reserve(npayload) + // read the payload; if that succeeds, we have a full frame in `src` - the decoding was successful and we can return + // the frame + readSoFar += payloadLength + if err := src.PrepareRead(readSoFar); err != nil { + src.Reserve(payloadLength) // payload is incomplete; reserve enough space for the remainder to fit in the buffer + c.decodeFrame = nil return nil, err } - - // at this point, we have a full frame in src - c.decodeFrame.payload = src.Data()[n-npayload : n] - c.decodeBytes = n + c.decodeFrame = src.Data()[:readSoFar] c.decodeReset = true return c.decodeFrame, nil } -// Encode encodes the frame and place the raw bytes into `dst`. -func (c *FrameCodec) Encode(fr *Frame, dst *sonic.ByteBuffer) error { - // Make sure there is enough space in the buffer to hold the serialized - // frame. - dst.Reserve(fr.PayloadLen()) +// Encode encodes the `Frame` into `dst`. +func (c *FrameCodec) Encode(frame Frame, dst *sonic.ByteBuffer) error { + // TODO this can be improved: we can serialize directly in the buffer with zero-copy semantics + + // ensure the destination buffer can hold the serialized frame + dst.Reserve(frame.PayloadLength() + frameMaxHeaderLength) - n, err := fr.WriteTo(dst) + n, err := frame.WriteTo(dst) dst.Commit(int(n)) if err != nil { dst.Consume(int(n)) diff --git a/codec/websocket/frame_codec_test.go b/codec/websocket/frame_codec_test.go index e6773d1a..3ee14605 100644 --- a/codec/websocket/frame_codec_test.go +++ b/codec/websocket/frame_codec_test.go @@ -13,7 +13,7 @@ func TestDecodeShortFrame(t *testing.T) { src := sonic.NewByteBuffer() src.Write([]byte{0x81, 1}) // fin=1 opcode=1 (text) payload_len=1 - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) f, err := codec.Decode(src) if !errors.Is(err, sonicerrors.ErrNeedMore) { @@ -25,7 +25,7 @@ func TestDecodeShortFrame(t *testing.T) { if codec.decodeReset { t.Fatal("should not reset decoder state") } - if codec.decodeBytes != 0 { + if len(codec.decodeFrame) != 0 { t.Fatal("should have not decoded any bytes") } } @@ -34,7 +34,7 @@ func TestDecodeExactlyOneFrame(t *testing.T) { src := sonic.NewByteBuffer() src.Write([]byte{0x81, 1, 0xFF}) // fin=1 opcode=1 (text) payload_len=1 - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) f, err := codec.Decode(src) if err != nil { @@ -44,14 +44,14 @@ func TestDecodeExactlyOneFrame(t *testing.T) { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && f.IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { + if !(f.IsFIN() && f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { t.Fatal("corrupt frame") } if !codec.decodeReset { t.Fatal("should reset decoder state") } - if codec.decodeBytes != 3 { + if len(codec.decodeFrame) != 3 { t.Fatal("should have decoded 3 bytes") } } @@ -62,7 +62,7 @@ func TestDecodeOneAndShortFrame(t *testing.T) { // fin=1 opcode=1 (text) payload_len=1 src.Write([]byte{0x81, 1, 0xFF, 0xFF, 0xFF, 0xFF}) - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) f, err := codec.Decode(src) if err != nil { @@ -72,7 +72,7 @@ func TestDecodeOneAndShortFrame(t *testing.T) { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && f.IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { + if !(f.IsFIN() && f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { t.Fatal("corrupt frame") } @@ -80,7 +80,7 @@ func TestDecodeOneAndShortFrame(t *testing.T) { t.Fatal("should reset decoder state") } - if codec.decodeBytes != 3 { + if len(codec.decodeFrame) != 3 { t.Fatal("should have decoded 3 bytes") } } @@ -92,7 +92,7 @@ func TestDecodeTwoFrames(t *testing.T) { 0x81, 5, 0x01, 0x02, 0x03, 0x04, 0x05, // second complete frame 0x81, 10}) // third short frame - codec := NewFrameCodec(src, nil) + codec := NewFrameCodec(src, nil, DefaultMaxMessageSize) if src.WriteLen() != 12 { t.Fatal("should have 12 bytes in the write area") @@ -106,13 +106,13 @@ func TestDecodeTwoFrames(t *testing.T) { if f == nil { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && f.IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { + if !(f.IsFIN() && f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0xFF})) { t.Fatal("corrupt frame") } if !codec.decodeReset { t.Fatal("should have reset decoder state") } - if codec.decodeBytes != 3 { + if len(codec.decodeFrame) != 3 { t.Fatal("should have decoded 3 bytes") } if src.ReadLen() != 3 { @@ -127,15 +127,15 @@ func TestDecodeTwoFrames(t *testing.T) { if f == nil { t.Fatal("should have gotten a frame") } - if !(f.IsFin() && - f.IsText() && + if !(f.IsFIN() && + f.Opcode().IsText() && bytes.Equal(f.Payload(), []byte{0x01, 0x02, 0x03, 0x04, 0x05})) { t.Fatal("corrupt frame") } if !codec.decodeReset { t.Fatal("should have reset decoder state") } - if codec.decodeBytes != 7 { + if len(codec.decodeFrame) != 7 { t.Fatal("should have decoded 3 bytes") } if src.ReadLen() != 7 { diff --git a/codec/websocket/frame_test.go b/codec/websocket/frame_test.go index ac0cd315..4290bea0 100644 --- a/codec/websocket/frame_test.go +++ b/codec/websocket/frame_test.go @@ -6,15 +6,64 @@ import ( "crypto/rand" "testing" + "github.com/stretchr/testify/assert" "github.com/talostrading/sonic" ) -func TestUnder125Frame(t *testing.T) { - raw := []byte{0x81, 5} // fin=1 opcode=1 (text) payload_len=5 - raw = append(raw, genRandBytes(5)...) +type partialFrameWriter struct { + b [1024]byte + invoked int +} + +func (w *partialFrameWriter) Write(b []byte) (n int, err error) { + copy(w.b[w.invoked:w.invoked+1], b) + w.invoked++ + return 1, nil +} + +func (w *partialFrameWriter) Bytes() []byte { + return w.b[:w.invoked] +} + +func TestFramePartialWrites(t *testing.T) { + assert := assert.New(t) + + payload := make([]byte, 126) + for i := 0; i < len(payload); i++ { + payload[i] = 0 + } + copy(payload, []byte("something")) + + f := NewFrame() + f.SetFIN().SetText().SetPayload(payload) + assert.Equal(2, f.ExtendedPayloadLengthBytes()) + assert.Equal(126, f.PayloadLength()) + + w := &partialFrameWriter{} + f.WriteTo(w) + assert.Equal(2 /* mandatory flags */ +2 /* 2 bytes for the extended payload length */ +126 /* payload */, w.invoked) + assert.Equal(130, len(w.Bytes())) + + // deserialize to make sure the frames are identical + { + f := Frame(w.Bytes()) + assert.True(f.IsFIN()) + assert.True(f.Opcode().IsText()) + assert.Equal(126, f.PayloadLength()) + assert.Equal(2, f.ExtendedPayloadLengthBytes()) + assert.Equal("something", string(f.Payload()[:len("something")])) + for i := len("something"); i < len(f.Payload()); i++ { + assert.Equal(0, int(f.Payload()[i])) + } + } +} - f := AcquireFrame() - defer ReleaseFrame(f) +func TestUnder126Frame(t *testing.T) { + var ( + f = NewFrame() + raw = []byte{0x81, 5} // fin=1 opcode=1 (text) payload_len=5 + ) + raw = append(raw, genRandBytes(5)...) buf := bufio.NewReader(bytes.NewBuffer(raw)) @@ -27,12 +76,12 @@ func TestUnder125Frame(t *testing.T) { } func Test126Frame(t *testing.T) { - raw := []byte{0x81, 126, 0, 200} + var ( + f = NewFrame() + raw = []byte{0x81, 126, 0, 200} + ) raw = append(raw, genRandBytes(200)...) - f := AcquireFrame() - defer ReleaseFrame(f) - buf := bufio.NewReader(bytes.NewBuffer(raw)) _, err := f.ReadFrom(buf) @@ -44,12 +93,12 @@ func Test126Frame(t *testing.T) { } func Test127Frame(t *testing.T) { - raw := []byte{0x81, 127, 0, 0, 0, 0, 0, 0x01, 0xFF, 0xFF} + var ( + f = NewFrame() + raw = []byte{0x81, 127, 0, 0, 0, 0, 0, 0x01, 0xFF, 0xFF} + ) raw = append(raw, genRandBytes(131071)...) - f := AcquireFrame() - defer ReleaseFrame(f) - buf := bufio.NewReader(bytes.NewBuffer(raw)) _, err := f.ReadFrom(buf) @@ -61,12 +110,12 @@ func Test127Frame(t *testing.T) { } func TestWriteFrame(t *testing.T) { - payload := []byte("heloo") - - f := AcquireFrame() - defer ReleaseFrame(f) + var ( + f = NewFrame() + payload = []byte("heloo") + ) - f.SetFin() + f.SetFIN() f.SetPayload(payload) f.SetText() @@ -92,14 +141,13 @@ func TestWriteFrame(t *testing.T) { } func TestSameFrameWriteRead(t *testing.T) { - // deserialize - f := AcquireFrame() - defer ReleaseFrame(f) - - header := []byte{0x81, 5} - payload := genRandBytes(5) + var ( + header = []byte{0x81, 5} + payload = genRandBytes(5) + buf = sonic.NewByteBuffer() + f = NewFrame() + ) - buf := sonic.NewByteBuffer() buf.Write(header) buf.Write(payload) buf.Commit(7) @@ -111,7 +159,7 @@ func TestSameFrameWriteRead(t *testing.T) { if n != 7 { t.Fatalf("frame is corrupt") } - if !(f.IsFin() && f.IsText() && f.PayloadLen() == 5 && bytes.Equal(f.Payload(), payload)) { + if !(f.IsFIN() && f.Opcode().IsText() && f.PayloadLength() == 5 && bytes.Equal(f.Payload(), payload)) { t.Fatalf("invalid frame") } @@ -137,16 +185,16 @@ func TestSameFrameWriteRead(t *testing.T) { } } -func checkFrame(t *testing.T, f *Frame, c, fin bool, payload []byte) { - if c && !f.IsContinuation() { +func checkFrame(t *testing.T, f Frame, c, fin bool, payload []byte) { + if c && !f.Opcode().IsContinuation() { t.Fatal("expected continuation") } - if fin && !f.IsFin() { + if fin && !f.IsFIN() { t.Fatal("expected FIN") } - if given, expected := len(payload), f.PayloadLen(); given != expected { + if given, expected := len(payload), f.PayloadLength(); given != expected { t.Fatalf("invalid payload length; given=%d expected=%d", given, expected) } diff --git a/codec/websocket/rfc6455.go b/codec/websocket/rfc6455.go index 05546ac7..93670a47 100644 --- a/codec/websocket/rfc6455.go +++ b/codec/websocket/rfc6455.go @@ -1,3 +1,4 @@ +// Based on https://datatracker.ietf.org/doc/html/rfc6455 package websocket import ( @@ -6,129 +7,206 @@ import ( "strings" ) -// NOTE: A fragmented message consists of a single frame with the FIN bit clear -// and an opcode other than 0, followed by zero or more frames with the FIN bit -// clear and the opcode set to 0, and terminated by a single frame with the FIN -// bit set and an opcode of 0. +// --------------------------------------------------- +// Framing ------------------------------------------- +// --------------------------------------------------- +// Based on https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 -// GUID is used when constructing the Sec-WebSocket-Accept key based on -// Sec-WebSocket-Key. +const ( + MaxControlFramePayloadLength = 125 + frameMaxHeaderLength = 14 // 14 bytes max for the header of a frame i.e. everything without the payload +) + +const ( + // Mandatory, 2 bytes: + // byte 1: |fin(1)|rsv1(1)|rsv2(1)|rsv3(1)|opcode(4)| + // byte 2: |is masked(1)|payload length(7)| + frameHeaderLength = 2 + + bitFIN = byte(1 << 7) + bitRSV1 = byte(1 << 6) + bitRSV2 = byte(1 << 5) + bitRSV3 = byte(1 << 4) + bitmaskOpcode = byte(1<<4 - 1) + + bitIsMasked = byte(1 << 7) + bitmaskPayloadLength = byte(1<<7 - 1) + + // Optional, max 8 bytes. If |payload length(7)| above is <= 125, then 0: the payload length is in + // |payload length(7)|. If 126, then the payload length is in the following 2 bytes. If 127, in the following 8 + // bytes. + + // Optional, max 4 bytes. If |is masked(1)| above is set, then the following 4 bytes are the mask. Otherwise, the + // frame is not masked and the mask is not included. + // + // All frames sent from the client to the server are masked by a 32-bit value. Frames sent from the server to the + // client are unmasked. + frameMaskLength = 4 +) + +type Opcode byte + +const ( + OpcodeContinuation Opcode = 0 + OpcodeText Opcode = 1 + OpcodeBinary Opcode = 2 + OpcodeClose Opcode = 8 + OpcodePing Opcode = 9 + OpcodePong Opcode = 10 +) + +func (c Opcode) IsContinuation() bool { return c == OpcodeContinuation } +func (c Opcode) IsText() bool { return c == OpcodeText } +func (c Opcode) IsBinary() bool { return c == OpcodeBinary } +func (c Opcode) IsClose() bool { return c == OpcodeClose } +func (c Opcode) IsPing() bool { return c == OpcodePing } +func (c Opcode) IsPong() bool { return c == OpcodePong } + +func (c Opcode) IsReserved() bool { + return c != OpcodeContinuation && + c != OpcodeText && + c != OpcodeBinary && + c != OpcodeClose && + c != OpcodePing && + c != OpcodePong +} + +func (c Opcode) IsControl() bool { + return c.IsPing() || c.IsPong() || c.IsClose() +} + +func (c Opcode) String() string { + switch c { + case OpcodeContinuation: + return "continuation" + case OpcodeText: + return "text" + case OpcodeBinary: + return "binary" + case OpcodeClose: + return "close" + case OpcodePing: + return "ping" + case OpcodePong: + return "pong" + default: + return "unknown" + } +} + +// --------------------------------------------------- +// Handshake ----------------------------------------- +// --------------------------------------------------- + +// Used when constructing the server's Sec-WebSocket-Accept key based on the client's Sec-WebSocket-Key. var GUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -// IsUpgradeReq returns true if the HTTP request is a valid WebSocket upgrade. func IsUpgradeReq(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") } -// IsUpgradeReq returns true if the HTTP response is a valid WebSocket upgrade. func IsUpgradeRes(res *http.Response) bool { return res.StatusCode == 101 && strings.EqualFold(res.Header.Get("Upgrade"), "websocket") } -const ( - finBit = byte(1 << 7) - rsv1Bit = byte(1 << 6) - rsv2Bit = byte(1 << 5) - rsv3Bit = byte(1 << 4) - maskBit = byte(1 << 7) -) - -// The max size of the ping/pong control frame payload. -const MaxControlFramePayloadSize = 125 +// --------------------------------------------------- +// Closing ------------------------------------------- +// --------------------------------------------------- -// The type representing the reason string in a close frame. -type ReasonString [123]byte - -// Close status codes. -// -// These codes accompany close frames. +// Close status codes that accompany close frames. type CloseCode uint16 const ( - // CloseNormal signifies normal closure; the connection successfully - // completed whatever purpose for which it was created. + // CloseNormal signifies normal closure; the connection successfully completed whatever purpose for which it was + // created. CloseNormal CloseCode = 1000 - // GoingaAway means endpoint is going away, either because of a - // server failure or because the browser is navigating away from - // the page that opened the connection. + // GoingAway means endpoint is going away, either because of a server failure or because the browser is navigating + // away from the page that opened the connection. CloseGoingAway CloseCode = 1001 - // CloseProtocolError means the endpoint is terminating the connection - // due to a protocol error. + // CloseProtocolError means the endpoint is terminating the connection due to a protocol error. CloseProtocolError CloseCode = 1002 - // CloseUnknownData means the connection is being terminated because - // the endpoint received data of a type it cannot accept (for example, - // a text-only endpoint received binary data). + // CloseUnknownData means the connection is being terminated because the endpoint received data of a type it cannot + // accept (for example, a text-only endpoint received binary data). CloseUnknownData CloseCode = 1003 - // CloseBadPayload means the endpoint is terminating the connection because - // a message was received that contained inconsistent data - // (e.g., non-UTF-8 data within a text message). + // CloseBadPayload means the endpoint is terminating the connection because a message was received that contained + // inconsistent data (e.g., non-UTF-8 data within a text message). CloseBadPayload CloseCode = 1007 - // ClosePolicyError means the endpoint is terminating the connection because - // it received a message that violates its policy. This is a generic status - // code, used when codes 1003 and 1009 are not suitable. + // ClosePolicyError means the endpoint is terminating the connection because it received a message that violates its + // policy. This is a generic status code, used when codes 1003 and 1009 are not suitable. ClosePolicyError CloseCode = 1008 - // CloseTooBig means the endpoint is terminating the connection because a - // data frame was received that is too large. + // CloseTooBig means the endpoint is terminating the connection because a data frame was received that is too large. CloseTooBig CloseCode = 1009 - // CloseNeedsExtension means the client is terminating the connection - // because it expected the server to negotiate one or more extensions, but - // the server didn't. + // CloseNeedsExtension means the client is terminating the connection because it expected the server to negotiate + // one or more extensions, but the server didn't. CloseNeedsExtension CloseCode = 1010 - // CloseInternalError means the server is terminating the connection because - // it encountered an unexpected condition that prevented it from fulfilling - // the request. + // CloseInternalError means the server is terminating the connection because it encountered an unexpected condition + // that prevented it from fulfilling the request. CloseInternalError CloseCode = 1011 - // CloseServiceRestart means the server is terminating the connection - // because it is restarting. + // CloseServiceRestart means the server is terminating the connection because it is restarting. CloseServiceRestart CloseCode = 1012 - // CloseTryAgainLater means the server is terminating the connection due to - // a temporary condition, e.g. it is overloaded and is casting off some of - // its clients. + // CloseTryAgainLater means the server is terminating the connection due to a temporary condition, e.g. it is + // overloaded and is casting off some of its clients. CloseTryAgainLater CloseCode = 1013 // ------------------------------------- // The following are illegal on the wire // ------------------------------------- - // CloseNone is used internally to mean "no error" - // This code is reserved and may not be sent. + // CloseNone is used internally to mean "no error" This code is reserved and may not be sent. CloseNone CloseCode = 0 - // CloseNoStatus means no status code was provided in the close frame sent - // by the peer, even though one was expected. - // This code is reserved for internal use and may not be sent in-between - // peers. + // CloseNoStatus means no status code was provided in the close frame sent by the peer, even though one was + // expected. This code is reserved for internal use and may not be sent in-between peers. CloseNoStatus CloseCode = 1005 - // CloseAbnormal means the connection was closed without receiving a close - // frame. - // This code is reserved and may not be sent. + // CloseAbnormal means the connection was closed without receiving a close frame. This code is reserved and may not + // be sent. CloseAbnormal CloseCode = 1006 - // CloseReserved1 is reserved for future use by the WebSocket standard. - // This code is reserved and may not be sent. + // CloseReserved1 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved1 CloseCode = 1004 - // CloseReserved2 is reserved for future use by the WebSocket standard. - // This code is reserved and may not be sent. + // CloseReserved2 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved2 CloseCode = 1014 - // CloseReserved3 is reserved for future use by the WebSocket standard. - // This code is reserved and may not be sent. + // CloseReserved3 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. CloseReserved3 CloseCode = 1015 + + // CloseReserved4 is reserved for future use by the WebSocket standard. This code is reserved and may not be sent. + CloseReserved4 CloseCode = 1016 + + CloseReservedForFuture CloseCode = 1004 ) +func ValidCloseCode(closeCode CloseCode) bool { + if closeCode == CloseNormal || + closeCode == CloseGoingAway || + closeCode == CloseProtocolError || + closeCode == CloseUnknownData || + closeCode == CloseBadPayload || + closeCode == ClosePolicyError || + closeCode == CloseTooBig || + closeCode == CloseNeedsExtension || + closeCode == CloseInternalError || + closeCode == CloseServiceRestart || + closeCode == CloseTryAgainLater || + (closeCode >= 3000 && closeCode <= 4999) { + return true + } + return false +} + func EncodeCloseCode(cc CloseCode) []byte { b := make([]byte, 2) binary.BigEndian.PutUint16(b, uint16(cc)) @@ -138,69 +216,3 @@ func EncodeCloseCode(cc CloseCode) []byte { func DecodeCloseCode(b []byte) CloseCode { return CloseCode(binary.BigEndian.Uint16(b[:2])) } - -type Opcode uint8 - -// No `iota` here for clarity. -const ( - OpcodeContinuation Opcode = 0x00 - OpcodeText Opcode = 0x01 - OpcodeBinary Opcode = 0x02 - OpcodeRsv3 Opcode = 0x03 - OpcodeRsv4 Opcode = 0x04 - OpcodeRsv5 Opcode = 0x05 - OpcodeRsv6 Opcode = 0x06 - OpcodeRsv7 Opcode = 0x07 - OpcodeClose Opcode = 0x08 - OpcodePing Opcode = 0x09 - OpcodePong Opcode = 0x0A - OpcodeCrsvb Opcode = 0x0B - OpcodeCrsvc Opcode = 0x0C - OpcodeCrsvd Opcode = 0x0D - OpcodeCrsve Opcode = 0x0E - OpcodeCrsvf Opcode = 0x0F -) - -func IsReserved(op Opcode) bool { - return (op >= OpcodeRsv3 && op <= OpcodeRsv7) || - (op >= OpcodeCrsvb && op <= OpcodeCrsvf) -} - -func (c Opcode) String() string { - switch c { - case OpcodeContinuation: - return "continuation" - case OpcodeText: - return "text" - case OpcodeBinary: - return "binary" - case OpcodeRsv3: - return "rsv3" - case OpcodeRsv4: - return "rsv4" - case OpcodeRsv5: - return "rsv5" - case OpcodeRsv6: - return "rsv6" - case OpcodeRsv7: - return "rsv7" - case OpcodeClose: - return "close" - case OpcodePing: - return "ping" - case OpcodePong: - return "pong" - case OpcodeCrsvb: - return "crsvb" - case OpcodeCrsvc: - return "crsvc" - case OpcodeCrsvd: - return "crsvd" - case OpcodeCrsve: - return "crsve" - case OpcodeCrsvf: - return "crsvf" - default: - return "unknown" - } -} diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index 8cbc8c8b..dbcb02fe 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -28,21 +28,16 @@ import ( "net/http" "net/http/httputil" "net/url" + "sync" "syscall" - "time" + "unicode/utf8" "github.com/talostrading/sonic" "github.com/talostrading/sonic/sonicerrors" "github.com/talostrading/sonic/sonicopts" ) -var ( - _ Stream = &WebsocketStream{} - DialTimeout = 5 * time.Second -) - -type WebsocketStream struct { - // async operations executor. +type Stream struct { ioc *sonic.IO // User provided TLS config; nil if we don't use TLS @@ -53,7 +48,8 @@ type WebsocketStream struct { conn net.Conn // Codec stream wrapping the underlying transport stream. - cs *sonic.BlockingCodecConn[*Frame, *Frame] + codec *FrameCodec + codecConn *sonic.CodecConn[Frame, Frame] // Websocket role: client or server. role Role @@ -70,36 +66,33 @@ type WebsocketStream struct { // Buffer for stream writes. dst *sonic.ByteBuffer - // Contains the handshake response. Is emptied after the - // handshake is over. - hb []byte + // Contains the handshake response. Is emptied after the handshake is over. + handshakeBuffer []byte - // Contains frames waiting to be sent to the peer. - // Is emptied by AsyncFlush or Flush. - pending []*Frame + // Contains frames waiting to be sent to the peer. Is emptied by AsyncFlush or Flush. + pendingFrames []*Frame // Optional callback invoked when a control frame is received. - ccb ControlCallback + controlCallback ControlCallback // Optional callback invoked when an upgrade request is sent. - upReqCb UpgradeRequestCallback + upgradeRequestCallback UpgradeRequestCallback // Optional callback invoked when an upgrade response is received. - upResCb UpgradeResponseCallback + upgradeResponseCallback UpgradeResponseCallback // Used to establish a TCP connection to the peer with a timeout. dialer *net.Dialer - // The size of the currently read message. - messageSize int + framePool sync.Pool + + maxMessageSize int + + validateUTF8 bool } -func NewWebsocketStream( - ioc *sonic.IO, - tls *tls.Config, - role Role, -) (s *WebsocketStream, err error) { - s = &WebsocketStream{ +func NewWebsocketStream(ioc *sonic.IO, tls *tls.Config, role Role) (s *Stream, err error) { + s = &Stream{ ioc: ioc, tls: tls, role: role, @@ -107,11 +100,19 @@ func NewWebsocketStream( dst: sonic.NewByteBuffer(), state: StateHandshake, /* #nosec G401 */ - hasher: sha1.New(), - hb: make([]byte, 1024), + hasher: sha1.New(), + handshakeBuffer: make([]byte, 1024), dialer: &net.Dialer{ Timeout: DialTimeout, }, + framePool: sync.Pool{ + New: func() interface{} { + frame := NewFrame() + return &frame + }, + }, + maxMessageSize: DefaultMaxMessageSize, + validateUTF8: false, } s.src.Reserve(4096) @@ -120,22 +121,44 @@ func NewWebsocketStream( return s, nil } +// The recommended way to create a `Frame`. This function takes care to allocate enough bytes to encode the WebSocket +// header and apply client side masking if the `Stream`'s role is `RoleClient`. +func (s *Stream) AcquireFrame() *Frame { + f := s.framePool.Get().(*Frame) + if s.role == RoleClient { + // This just reserves the 4 bytes needed for the mask in order to encode the payload correctly, since it follows + // the mask in byte order. The actual mask is set after `f.SetPayload` in `prepareWrite`. + f.SetIsMasked() + } + return f +} + +func (s *Stream) releaseFrame(f *Frame) { + f.Reset() + s.framePool.Put(f) +} + +// This stream is either a client or a server. The main differences are in how the opening/closing handshake is done and +// in the fact that payloads sent by the client to the server are masked. +func (s *Stream) Role() Role { + return s.role +} + // init is run when we transition into StateActive which happens // after a successful handshake. -func (s *WebsocketStream) init(stream sonic.Stream) (err error) { +func (s *Stream) init(stream sonic.Stream) (err error) { if s.state != StateActive { return fmt.Errorf("stream must be in StateActive") } s.stream = stream - codec := NewFrameCodec(s.src, s.dst) - s.cs, err = sonic.NewBlockingCodecConn[*Frame, *Frame]( - stream, codec, s.src, s.dst) + s.codec = NewFrameCodec(s.src, s.dst, s.maxMessageSize) + s.codecConn, err = sonic.NewCodecConn[Frame, Frame](stream, s.codec, s.src, s.dst) return } -func (s *WebsocketStream) reset() { - s.hb = s.hb[:cap(s.hb)] +func (s *Stream) reset() { + s.handshakeBuffer = s.handshakeBuffer[:cap(s.handshakeBuffer)] s.state = StateHandshake s.stream = nil s.conn = nil @@ -143,27 +166,50 @@ func (s *WebsocketStream) reset() { s.dst.Reset() } -func (s *WebsocketStream) NextLayer() sonic.Stream { - if s.cs != nil { - return s.cs.NextLayer() +// Returns the stream through which IO is done. +func (s *Stream) NextLayer() sonic.Stream { + if s.codecConn != nil { + return s.codecConn.NextLayer() } return nil } -func (s *WebsocketStream) SupportsUTF8() bool { - return false +// SupportsUTF8 indicates that the stream can optionally perform UTF-8 validation on the payloads of Text frames. +// Validation is disabled by default and can be toggled with `ValidateUTF8(bool)`. +func (s *Stream) SupportsUTF8() bool { + return true } -func (s *WebsocketStream) SupportsDeflate() bool { +// ValidatesUTF8 indicates if UTF-8 validation is performed on the payloads of Text frames. Validation is disabled by +// default and can be toggled with `ValidateUTF8(bool)`. +func (s *Stream) ValidatesUTF8() bool { + return s.validateUTF8 +} + +// ValidateUTF8 toggles UTF8 validation done on the payloads of Text frames. +func (s *Stream) ValidateUTF8(v bool) *Stream { + s.validateUTF8 = v + return s +} + +func (s *Stream) SupportsDeflate() bool { return false } -func (s *WebsocketStream) canRead() bool { +func (s *Stream) canRead() bool { // we can only read from non-terminal states after a successful handshake return s.state == StateActive || s.state == StateClosedByUs } -func (s *WebsocketStream) NextFrame() (f *Frame, err error) { +// NextFrame reads and returns the next frame. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message bytes from the underlying stream +// - a frame is successfully read from the underlying stream +func (s *Stream) NextFrame() (f Frame, err error) { err = s.Flush() if errors.Is(err, ErrMessageTooBig) { @@ -186,27 +232,27 @@ func (s *WebsocketStream) NextFrame() (f *Frame, err error) { return } -func (s *WebsocketStream) nextFrame() (f *Frame, err error) { - f, err = s.cs.ReadNext() +func (s *Stream) nextFrame() (f Frame, err error) { + f, err = s.codecConn.ReadNext() if err == nil { err = s.handleFrame(f) } return } -func (s *WebsocketStream) AsyncNextFrame(cb AsyncFrameHandler) { - // TODO for later: here we flush first since we might need to reply - // to ping/pong/close immediately, and only after that we try to - // async read. - // - // I think we can just flush asynchronously while reading asynchronously at - // the same time. I'm pretty sure this will work with a BlockingCodecConn. - // - // Not entirely sure about a NonblockingCodecStream. +// AsyncNextFrame reads and returns the next frame asynchronously. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message bytes from the underlying stream +// - a frame is successfully read from the underlying stream +func (s *Stream) AsyncNextFrame(callback AsyncFrameCallback) { s.AsyncFlush(func(err error) { if errors.Is(err, ErrMessageTooBig) { s.AsyncClose(CloseGoingAway, "payload too big", func(err error) {}) - cb(ErrMessageTooBig, nil) + callback(ErrMessageTooBig, nil) return } @@ -215,54 +261,60 @@ func (s *WebsocketStream) AsyncNextFrame(cb AsyncFrameHandler) { } if err == nil { - s.asyncNextFrame(cb) + s.asyncNextFrame(callback) } else { s.state = StateTerminated - cb(err, nil) + callback(err, nil) } }) } -func (s *WebsocketStream) asyncNextFrame(cb AsyncFrameHandler) { - s.cs.AsyncReadNext(func(err error, f *Frame) { +func (s *Stream) asyncNextFrame(callback AsyncFrameCallback) { + s.codecConn.AsyncReadNext(func(err error, f Frame) { if err == nil { err = s.handleFrame(f) } else if err == io.EOF { s.state = StateTerminated } - cb(err, f) + callback(err, f) }) } -func (s *WebsocketStream) NextMessage( - b []byte, -) (mt MessageType, readBytes int, err error) { +// NextMessage reads the payload of the next message into the supplied buffer. Message fragmentation is automatically +// handled by the implementation. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message from the underlying stream +// - the payload of the message is successfully read into the supplied buffer, after all message fragments are read +func (s *Stream) NextMessage(b []byte) (messageType MessageType, readBytes int, err error) { var ( - f *Frame + f Frame continuation = false ) - - mt = TypeNone + messageType = TypeNone for { f, err = s.NextFrame() if err != nil { - return mt, readBytes, err + return messageType, readBytes, err } - if f.IsControl() { - if s.ccb != nil { - s.ccb(MessageType(f.Opcode()), f.payload) + if f.Opcode().IsControl() { + if s.controlCallback != nil { + s.controlCallback(MessageType(f.Opcode()), f.Payload()) } } else { - if mt == TypeNone { - mt = MessageType(f.Opcode()) + if messageType == TypeNone { + messageType = MessageType(f.Opcode()) } n := copy(b[readBytes:], f.Payload()) readBytes += n - if readBytes > MaxMessageSize || n != f.PayloadLen() { + if readBytes > s.maxMessageSize || n != f.PayloadLength() { err = ErrMessageTooBig _ = s.Close(CloseGoingAway, "payload too big") break @@ -271,19 +323,19 @@ func (s *WebsocketStream) NextMessage( // verify continuation if !continuation { // this is the first frame of the series - continuation = !f.IsFin() //nolint:ineffassign - if f.IsContinuation() { + continuation = !f.IsFIN() //nolint:ineffassign + if f.Opcode().IsContinuation() { err = ErrUnexpectedContinuation } } else { // we are past the first frame of the series - continuation = !f.IsFin() //nolint:ineffassign - if !f.IsContinuation() { + continuation = !f.IsFIN() //nolint:ineffassign + if !f.Opcode().IsContinuation() { err = ErrExpectedContinuation } } - continuation = !f.IsFin() + continuation = !f.IsFIN() if err != nil || !continuation { break @@ -294,76 +346,85 @@ func (s *WebsocketStream) NextMessage( return } -func (s *WebsocketStream) AsyncNextMessage(b []byte, cb AsyncMessageHandler) { - s.asyncNextMessage(b, 0, false, TypeNone, cb) +// AsyncNextMessage reads the payload of the next message into the supplied buffer asynchronously. Message fragmentation +// is automatically handled by the implementation. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs when reading/decoding the message bytes from the underlying stream +// - the payload of the message is successfully read into the supplied buffer, after all message fragments are read +func (s *Stream) AsyncNextMessage(b []byte, callback AsyncMessageCallback) { + s.asyncNextMessage(b, 0, false, TypeNone, callback) } -func (s *WebsocketStream) asyncNextMessage( +func (s *Stream) asyncNextMessage( b []byte, readBytes int, continuation bool, - mt MessageType, - cb AsyncMessageHandler, + messageType MessageType, + callback AsyncMessageCallback, ) { - s.AsyncNextFrame(func(err error, f *Frame) { + s.AsyncNextFrame(func(err error, f Frame) { if err != nil { - cb(err, readBytes, mt) + callback(err, readBytes, messageType) } else { - if f.IsControl() { - if s.ccb != nil { - s.ccb(MessageType(f.Opcode()), f.payload) + if f.Opcode().IsControl() { + if s.controlCallback != nil { + s.controlCallback(MessageType(f.Opcode()), f.Payload()) } - s.asyncNextMessage(b, readBytes, continuation, mt, cb) + s.asyncNextMessage(b, readBytes, continuation, messageType, callback) } else { - if mt == TypeNone { - mt = MessageType(f.Opcode()) + if messageType == TypeNone { + messageType = MessageType(f.Opcode()) } n := copy(b[readBytes:], f.Payload()) readBytes += n - if readBytes > MaxMessageSize || n != f.PayloadLen() { + if readBytes > s.maxMessageSize || n != f.PayloadLength() { err = ErrMessageTooBig s.AsyncClose( CloseGoingAway, "payload too big", func(err error) {}, ) - cb(err, readBytes, mt) + callback(err, readBytes, messageType) return } // verify continuation if !continuation { // this is the first frame of the series - continuation = !f.IsFin() - if f.IsContinuation() { + continuation = !f.IsFIN() + if f.Opcode().IsContinuation() { err = ErrUnexpectedContinuation } } else { // we are past the first frame of the series - continuation = !f.IsFin() - if !f.IsContinuation() { + continuation = !f.IsFIN() + if !f.Opcode().IsContinuation() { err = ErrExpectedContinuation } } if err != nil || !continuation { - cb(err, readBytes, mt) + callback(err, readBytes, messageType) } else { - s.asyncNextMessage(b, readBytes, continuation, mt, cb) + s.asyncNextMessage(b, readBytes, continuation, messageType, callback) } } } }) } -func (s *WebsocketStream) handleFrame(f *Frame) (err error) { +func (s *Stream) handleFrame(f Frame) (err error) { err = s.verifyFrame(f) if err == nil { - if f.IsControl() { + if f.Opcode().IsControl() { err = s.handleControlFrame(f) } else { err = s.handleDataFrame(f) @@ -372,13 +433,14 @@ func (s *WebsocketStream) handleFrame(f *Frame) (err error) { if err != nil { s.state = StateClosedByUs + // TODO consider flushing the close s.prepareClose(EncodeCloseFramePayload(CloseProtocolError, "")) } return err } -func (s *WebsocketStream) verifyFrame(f *Frame) error { +func (s *Stream) verifyFrame(f Frame) error { if f.IsRSV1() || f.IsRSV2() || f.IsRSV3() { return ErrNonZeroReservedBits } @@ -394,26 +456,23 @@ func (s *WebsocketStream) verifyFrame(f *Frame) error { return nil } -func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { - if !f.IsFin() { +func (s *Stream) handleControlFrame(f Frame) (err error) { + if !f.IsFIN() { return ErrInvalidControlFrame } - if f.PayloadLenType() > MaxControlFramePayloadSize { + if f.PayloadLength() > MaxControlFramePayloadLength { return ErrControlFrameTooBig } switch f.Opcode() { case OpcodePing: if s.state == StateActive { - pongFrame := AcquireFrame() - pongFrame.SetFin() - pongFrame.SetPong() - pongFrame.SetPayload(f.payload) - if s.role == RoleClient { - pongFrame.Mask() - } - s.pending = append(s.pending, pongFrame) + pongFrame := s.AcquireFrame(). + SetFIN(). + SetPong(). + SetPayload(f.Payload()) + s.prepareWrite(pongFrame) } case OpcodePong: case OpcodeClose: @@ -422,7 +481,25 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { panic("unreachable") case StateActive: s.state = StateClosedByPeer - s.prepareClose(f.payload) + + payload := f.Payload() + // The first 2 bytes are the close code. The rest, if any, is the reason - this must be UTF-8. + if len(payload) >= 2 { + if !utf8.Valid(payload[2:]) { + s.prepareClose(EncodeCloseCode(CloseProtocolError)) + } else { + closeCode := DecodeCloseCode(payload) + if !ValidCloseCode(closeCode) { + s.prepareClose(EncodeCloseCode(CloseProtocolError)) + } else { + s.prepareClose(f.Payload()) + } + } + } else if len(payload) > 0 { + s.prepareClose(EncodeCloseCode(CloseProtocolError)) + } else { + s.prepareClose(EncodeCloseCode(CloseNormal)) + } case StateClosedByPeer, StateCloseAcked: // ignore case StateClosedByUs: @@ -438,24 +515,39 @@ func (s *WebsocketStream) handleControlFrame(f *Frame) (err error) { return } -func (s *WebsocketStream) handleDataFrame(f *Frame) error { - if IsReserved(f.Opcode()) { +func (s *Stream) handleDataFrame(f Frame) error { + if f.Opcode().IsReserved() { return ErrReservedOpcode } + + if f.Opcode().IsText() && s.validateUTF8 { + if !utf8.Valid(f.Payload()) { + return ErrInvalidUTF8 + } + } + return nil } -func (s *WebsocketStream) Write(b []byte, mt MessageType) error { - if len(b) > MaxMessageSize { +// Write writes the supplied buffer as a single message with the given type to the underlying stream. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the message is successfully written to the underlying stream +func (s *Stream) Write(b []byte, messageType MessageType) error { + if len(b) > s.maxMessageSize { return ErrMessageTooBig } if s.state == StateActive { - f := AcquireFrame() - f.SetFin() - f.SetOpcode(Opcode(mt)) - f.SetPayload(b) - + // reserve space for mask if client + f := s.AcquireFrame(). + SetFIN(). + SetOpcode(Opcode(messageType)). + SetPayload(b) s.prepareWrite(f) return s.Flush() } @@ -463,82 +555,114 @@ func (s *WebsocketStream) Write(b []byte, mt MessageType) error { return sonicerrors.ErrCancelled } -func (s *WebsocketStream) WriteFrame(f *Frame) error { +// WriteFrame writes the supplied frame to the underlying stream. +// +// This call first flushes any pending control frames to the underlying stream. +// +// This call blocks until one of the following conditions is true: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the frame is successfully written to the underlying stream +func (s *Stream) WriteFrame(f *Frame) error { if s.state == StateActive { s.prepareWrite(f) return s.Flush() } else { - ReleaseFrame(f) + s.releaseFrame(f) return sonicerrors.ErrCancelled } } -func (s *WebsocketStream) AsyncWrite( - b []byte, - mt MessageType, - cb func(err error), -) { - if len(b) > MaxMessageSize { - cb(ErrMessageTooBig) +// AsyncWrite writes the supplied buffer as a single message with the given type to the underlying stream +// asynchronously. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the message is successfully written to the underlying stream +func (s *Stream) AsyncWrite(b []byte, messageType MessageType, callback func(err error)) { + if len(b) > s.maxMessageSize { + callback(ErrMessageTooBig) return } if s.state == StateActive { - f := AcquireFrame() - f.SetFin() - f.SetOpcode(Opcode(mt)) - f.SetPayload(b) - + f := s.AcquireFrame(). + SetFIN(). + SetOpcode(Opcode(messageType)). + SetPayload(b) s.prepareWrite(f) - s.AsyncFlush(cb) + s.AsyncFlush(callback) } else { - cb(sonicerrors.ErrCancelled) + callback(sonicerrors.ErrCancelled) } } -func (s *WebsocketStream) AsyncWriteFrame(f *Frame, cb func(err error)) { +// AsyncWriteFrame writes the supplied frame to the underlying stream asynchronously. +// +// This call first flushes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. The provided callback is invoked when one of the following happens: +// - an error occurs while flushing the pending control frames +// - an error occurs during the write +// - the frame is successfully written to the underlying stream +func (s *Stream) AsyncWriteFrame(f *Frame, callback func(err error)) { if s.state == StateActive { s.prepareWrite(f) - s.AsyncFlush(cb) + s.AsyncFlush(callback) } else { - ReleaseFrame(f) - cb(sonicerrors.ErrCancelled) + s.releaseFrame(f) + callback(sonicerrors.ErrCancelled) } } -func (s *WebsocketStream) prepareWrite(f *Frame) { - switch s.role { - case RoleClient: - if !f.IsMasked() { - f.Mask() - } - case RoleServer: - if f.IsMasked() { - f.Unmask() - } +func (s *Stream) prepareWrite(f *Frame) { + if s.role == RoleClient { + f.MaskPayload() } - - s.pending = append(s.pending, f) + s.pendingFrames = append(s.pendingFrames, f) } -func (s *WebsocketStream) AsyncClose( - cc CloseCode, - reason string, - cb func(err error), -) { +// AsyncClose sends a websocket close control frame asynchronously. +// +// This function is used to send a close frame which begins the WebSocket closing handshake. The session ends when both +// ends of the connection have sent and received a close frame. +// +// The callback is called if one of the following conditions is true: +// - the close frame is written +// - an error occurs +// +// After beginning the closing handshake, the program should not write further messages, pings, pongs or close +// frames. Instead, the program should continue reading messages until the closing handshake is complete or an error +// occurs. +func (s *Stream) AsyncClose(closeCode CloseCode, reason string, callback func(err error)) { switch s.state { case StateActive: s.state = StateClosedByUs - s.prepareClose(EncodeCloseFramePayload(cc, reason)) - s.AsyncFlush(cb) + s.prepareClose(EncodeCloseFramePayload(closeCode, reason)) + s.AsyncFlush(callback) case StateClosedByUs, StateHandshake: - cb(sonicerrors.ErrCancelled) + callback(sonicerrors.ErrCancelled) default: - cb(io.EOF) + callback(io.EOF) } } -func (s *WebsocketStream) Close(cc CloseCode, reason string) error { +// Close sends a websocket close control frame asynchronously. +// +// This function is used to send a close frame which begins the WebSocket closing handshake. The session ends when both +// ends of the connection have sent and received a close frame. +// +// The call blocks until one of the following conditions is true: +// - the close frame is written +// - an error occurs +// +// After beginning the closing handshake, the program should not write further messages, pings, pongs or close +// frames. Instead, the program should continue reading messages until the closing handshake is complete or an error +// occurs. +func (s *Stream) Close(cc CloseCode, reason string) error { switch s.state { case StateActive: s.state = StateClosedByUs @@ -551,64 +675,71 @@ func (s *WebsocketStream) Close(cc CloseCode, reason string) error { } } -func (s *WebsocketStream) prepareClose(payload []byte) { - closeFrame := AcquireFrame() - closeFrame.SetFin() - closeFrame.SetClose() - closeFrame.SetPayload(payload) - if s.role == RoleClient { - closeFrame.Mask() - } - - s.pending = append(s.pending, closeFrame) +func (s *Stream) prepareClose(payload []byte) { + closeFrame := s.AcquireFrame(). + SetFIN(). + SetClose(). + SetPayload(payload) + s.prepareWrite(closeFrame) } -func (s *WebsocketStream) Flush() (err error) { +// Flush writes any pending control frames to the underlying stream. +// +// This call blocks. +func (s *Stream) Flush() (err error) { flushed := 0 - for i := 0; i < len(s.pending); i++ { - _, err = s.cs.WriteNext(s.pending[i]) + for i := 0; i < len(s.pendingFrames); i++ { + _, err = s.codecConn.WriteNext(*s.pendingFrames[i]) if err != nil { break } - ReleaseFrame(s.pending[i]) + s.releaseFrame(s.pendingFrames[i]) flushed++ } - s.pending = s.pending[flushed:] + s.pendingFrames = s.pendingFrames[flushed:] return } -func (s *WebsocketStream) AsyncFlush(cb func(err error)) { - if len(s.pending) == 0 { - cb(nil) +// Flush writes any pending control frames to the underlying stream asynchronously. +// +// This call does not block. +func (s *Stream) AsyncFlush(callback func(err error)) { + if len(s.pendingFrames) == 0 { + callback(nil) } else { - sent := s.pending[0] - s.pending = s.pending[1:] + sent := s.pendingFrames[0] + s.pendingFrames = s.pendingFrames[1:] - s.cs.AsyncWriteNext(sent, func(err error, _ int) { - ReleaseFrame(sent) + s.codecConn.AsyncWriteNext(*sent, func(err error, _ int) { + s.releaseFrame(sent) if err != nil { - cb(err) + callback(err) } else { - s.AsyncFlush(cb) + s.AsyncFlush(callback) } }) } } -func (s *WebsocketStream) Pending() int { - return len(s.pending) +// Pending returns the number of currently pending control frames waiting to be flushed. +func (s *Stream) Pending() int { + return len(s.pendingFrames) } -func (s *WebsocketStream) State() StreamState { +func (s *Stream) State() StreamState { return s.state } -func (s *WebsocketStream) Handshake( - addr string, - extraHeaders ...Header, -) (err error) { +// Handshake performs the client handshake. This call blocks. +// +// The call blocks until one of the following conditions is true: +// - the HTTP1.1 request is sent and the response is received +// - an error occurs +// +// Extra headers should be generated by calling `ExtraHeader(...)`. +func (s *Stream) Handshake(addr string, extraHeaders ...Header) (err error) { if s.role != RoleClient { return ErrWrongHandshakeRole } @@ -635,13 +766,15 @@ func (s *WebsocketStream) Handshake( return } -func (s *WebsocketStream) AsyncHandshake( - addr string, - cb func(error), - extraHeaders ...Header, -) { +// AsyncHandshake performs the WebSocket handshake asynchronously in the client role. +// +// This call does not block. The provided callback is called when the request is sent and the response is +// received or when an error occurs. +// +// Extra headers should be generated by calling `ExtraHeader(...)`. +func (s *Stream) AsyncHandshake(addr string, callback func(error), extraHeaders ...Header) { if s.role != RoleClient { - cb(ErrWrongHandshakeRole) + callback(ErrWrongHandshakeRole) return } @@ -659,31 +792,27 @@ func (s *WebsocketStream) AsyncHandshake( s.state = StateActive err = s.init(stream) } - cb(err) + callback(err) }) }) }() } -func (s *WebsocketStream) handshake( - addr string, - headers []Header, - cb func(err error, stream sonic.Stream), -) { +func (s *Stream) handshake(addr string, headers []Header, callback func(err error, stream sonic.Stream)) { url, err := s.resolve(addr) if err != nil { - cb(err, nil) + callback(err, nil) } else { s.dial(url, func(err error, stream sonic.Stream) { if err == nil { err = s.upgrade(url, stream, headers) } - cb(err, stream) + callback(err, stream) }) } } -func (s *WebsocketStream) resolve(addr string) (url *url.URL, err error) { +func (s *Stream) resolve(addr string) (url *url.URL, err error) { url, err = url.Parse(addr) if err == nil { switch url.Scheme { @@ -699,14 +828,10 @@ func (s *WebsocketStream) resolve(addr string) (url *url.URL, err error) { return } -func (s *WebsocketStream) dial( - url *url.URL, - cb func(err error, stream sonic.Stream), -) { +func (s *Stream) dial(url *url.URL, callback func(err error, stream sonic.Stream)) { var ( - err error - sc syscall.Conn - + err error + sc syscall.Conn port = url.Port() ) @@ -756,18 +881,14 @@ func (s *WebsocketStream) dial( // condition on the io context. sonic.NewAsyncAdapter( s.ioc, sc, s.conn, func(err error, stream *sonic.AsyncAdapter) { - cb(err, stream) + callback(err, stream) }, sonicopts.NoDelay(true)) } else { - cb(err, nil) + callback(err, nil) } } -func (s *WebsocketStream) upgrade( - uri *url.URL, - stream sonic.Stream, - headers []Header, -) error { +func (s *Stream) upgrade(uri *url.URL, stream sonic.Stream, headers []Header) error { req, err := http.NewRequest("GET", uri.String(), nil) if err != nil { return err @@ -794,8 +915,8 @@ func (s *WebsocketStream) upgrade( } } - if s.upReqCb != nil { - s.upReqCb(req) + if s.upgradeRequestCallback != nil { + s.upgradeRequestCallback(req) } err = req.Write(stream) @@ -803,13 +924,13 @@ func (s *WebsocketStream) upgrade( return err } - s.hb = s.hb[:cap(s.hb)] - n, err := stream.Read(s.hb) + s.handshakeBuffer = s.handshakeBuffer[:cap(s.handshakeBuffer)] + n, err := stream.Read(s.handshakeBuffer) if err != nil { return err } - s.hb = s.hb[:n] - rd := bytes.NewReader(s.hb) + s.handshakeBuffer = s.handshakeBuffer[:n] + rd := bytes.NewReader(s.handshakeBuffer) res, err := http.ReadResponse(bufio.NewReader(rd), req) if err != nil { return err @@ -821,17 +942,17 @@ func (s *WebsocketStream) upgrade( } resLen := len(rawRes) - extra := len(s.hb) - resLen + extra := len(s.handshakeBuffer) - resLen if extra > 0 { // we got some frames as well with the handshake so we can put // them in src for later decoding before clearing the handshake // buffer - _, _ = s.src.Write(s.hb[resLen:]) + _, _ = s.src.Write(s.handshakeBuffer[resLen:]) } - s.hb = s.hb[:0] + s.handshakeBuffer = s.handshakeBuffer[:0] - if s.upResCb != nil { - s.upResCb(res) + if s.upgradeResponseCallback != nil { + s.upgradeResponseCallback(res) } if !IsUpgradeRes(res) { @@ -847,7 +968,7 @@ func (s *WebsocketStream) upgrade( // makeHandshakeKey generates the key of Sec-WebSocket-Key header as well as the // expected response present in Sec-WebSocket-Accept header. -func (s *WebsocketStream) makeHandshakeKey() (req, res string) { +func (s *Stream) makeHandshakeKey() (req, res string) { // request b := make([]byte, 16) _, _ = rand.Read(b) @@ -865,61 +986,74 @@ func (s *WebsocketStream) makeHandshakeKey() (req, res string) { return } -func (s *WebsocketStream) Accept() error { - panic("implement me") +// SetControlCallback sets a function that will be invoked when a Ping/Pong/Close is received while reading a +// message. This callback is only invoked when reading complete messages, not frames. +// +// The caller must not perform any operations on the stream in the provided callback. +func (s *Stream) SetControlCallback(controlCallback ControlCallback) { + s.controlCallback = controlCallback } -func (s *WebsocketStream) AsyncAccept(func(error)) { - panic("implement me") +func (s *Stream) ControlCallback() ControlCallback { + return s.controlCallback } -func (s *WebsocketStream) SetControlCallback(ccb ControlCallback) { - s.ccb = ccb +// SetUpgradeRequestCallback sets a function that will be invoked during the handshake just before the upgrade request +// is sent. +// +// The caller must not perform any operations on the stream in the provided callback. +func (s *Stream) SetUpgradeRequestCallback(upgradeRequestCallback UpgradeRequestCallback) { + s.upgradeRequestCallback = upgradeRequestCallback } -func (s *WebsocketStream) ControlCallback() ControlCallback { - return s.ccb +func (s *Stream) UpgradeRequestCallback() UpgradeRequestCallback { + return s.upgradeRequestCallback } -func (s *WebsocketStream) SetUpgradeRequestCallback(upReqCb UpgradeRequestCallback) { - s.upReqCb = upReqCb +// SetUpgradeResponseCallback sets a function that will be invoked during the handshake just after the upgrade response +// is received. +// +// The caller must not perform any operations on the stream in the provided callback. +func (s *Stream) SetUpgradeResponseCallback(upgradeResponseCallback UpgradeResponseCallback) { + s.upgradeResponseCallback = upgradeResponseCallback } -func (s *WebsocketStream) UpgradeRequestCallback() UpgradeRequestCallback { - return s.upReqCb +func (s *Stream) UpgradeResponseCallback() UpgradeResponseCallback { + return s.upgradeResponseCallback } -func (s *WebsocketStream) SetUpgradeResponseCallback(upResCb UpgradeResponseCallback) { - s.upResCb = upResCb -} - -func (s *WebsocketStream) UpgradeResponseCallback() UpgradeResponseCallback { - return s.upResCb -} - -func (s *WebsocketStream) SetMaxMessageSize(bytes int) { +// SetMaxMessageSize sets the maximum size of a message that can be read from or written to a peer. +// +// - If a message exceeds the limit while reading, the connection is closed abnormally. +// - If a message exceeds the limit while writing, the operation is cancelled. +func (s *Stream) SetMaxMessageSize(bytes int) { // This is just for checking against the length returned in the frame // header. The sizes of the buffers in which we read or write the messages // are dynamically adjusted in frame_codec. - MaxMessageSize = bytes + s.maxMessageSize = bytes + s.codec.maxMessageSize = bytes +} + +func (s *Stream) MaxMessageSize() int { + return s.maxMessageSize } -func (s *WebsocketStream) RemoteAddr() net.Addr { +func (s *Stream) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *WebsocketStream) LocalAddr() net.Addr { +func (s *Stream) LocalAddr() net.Addr { return s.conn.LocalAddr() } -func (s *WebsocketStream) RawFd() int { +func (s *Stream) RawFd() int { if s.NextLayer() != nil { return s.NextLayer().(*sonic.AsyncAdapter).RawFd() } return -1 } -func (s *WebsocketStream) CloseNextLayer() (err error) { +func (s *Stream) CloseNextLayer() (err error) { if s.conn != nil { err = s.conn.Close() s.conn = nil diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 851db655..b5e20ac8 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -9,15 +9,324 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/talostrading/sonic" ) -func assertState(t *testing.T, ws Stream, expected StreamState) { +func assertState(t *testing.T, ws *Stream, expected StreamState) { if ws.State() != expected { t.Fatalf("wrong state: given=%s expected=%s ", ws.State(), expected) } } +func TestClientServerSendsInvalidCloseCode(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + { + frame := NewFrame() + + closeCode := CloseReserved1 + assert.False(ValidCloseCode(closeCode)) + + frame. + SetFIN(). + SetClose(). + SetPayload(EncodeCloseFramePayload(closeCode, "something")) + + frame.WriteTo(srv.conn) + } + + { + frame := NewFrame() + frame.ReadFrom(srv.conn) + + assert.True(frame.Opcode().IsClose()) + assert.True(frame.IsMasked()) // client to server frames are masked + frame.UnmaskPayload() + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseProtocolError, closeCode) + assert.Equal(reason, "") + } + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + assert.Nil(err) + assert.Equal(1, ws.Pending()) + ws.Flush() + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + +func TestClientEchoCloseCode(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + { + frame := NewFrame() + frame. + SetFIN(). + SetClose(). + SetPayload(EncodeCloseFramePayload(CloseNormal, "something")) + + frame.WriteTo(srv.conn) + } + + { + frame := NewFrame() + frame.ReadFrom(srv.conn) + + assert.True(frame.Opcode().IsClose()) + assert.True(frame.IsMasked()) // client to server frames are masked + frame.UnmaskPayload() + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseNormal, closeCode) + assert.Equal(reason, "something") + } + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + assert.Nil(err) + assert.Equal(1, ws.Pending()) + ws.Flush() + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + +func TestClientSendPingWithInvalidPayload(t *testing.T) { + // Per the protocol, pings cannot have payloads larger than 125. We send a ping with 125. The client should close + // the connection immediately with 1002/Protocol Error. + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + // This ping has an invalid payload size of 126, which should trigger a close with reason 1002. + { + frame := NewFrame() + frame. + SetFIN(). + SetPing(). + SetPayload(make([]byte, 126)) + assert.Equal(126, frame.PayloadLength()) + assert.Equal(2, frame.ExtendedPayloadLengthBytes()) + + frame.WriteTo(srv.conn) + } + + // Ensure we get the close. + { + frame := NewFrame() + frame.ReadFrom(srv.conn) + + assert.True(frame.Opcode().IsClose()) + assert.True(frame.IsMasked()) // client to server frames are masked + frame.UnmaskPayload() + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseProtocolError, closeCode) + assert.Empty(reason) + } + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + assert.NotNil(err) + assert.Equal(ErrControlFrameTooBig, err) + assert.Equal(1, ws.Pending()) + ws.Flush() + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + +func TestClientSendMessageWithPayload126(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + frame := NewFrame() + frame. + SetFIN(). + SetText(). + SetPayload(make([]byte, 126)) + assert.Equal(126, frame.PayloadLength()) + assert.Equal(2, frame.ExtendedPayloadLengthBytes()) + + frame.WriteTo(srv.conn) + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + if err != nil { + t.Fatal(err) + } + assert.True(f.IsFIN()) + assert.True(f.Opcode().IsText()) + assert.Equal(126, f.PayloadLength()) + assert.Equal(2, f.ExtendedPayloadLengthBytes()) + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + +func TestClientSendMessageWithPayload127(t *testing.T) { + assert := assert.New(t) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:8080") + if err != nil { + panic(err) + } + + frame := NewFrame() + frame. + SetFIN(). + SetText(). + SetPayload(make([]byte, 1<<16+10 /* it won't fit in 2 bytes */)) + assert.Equal(1<<16+10, frame.PayloadLength()) + assert.Equal(8, frame.ExtendedPayloadLengthBytes()) + + frame.WriteTo(srv.conn) + }() + time.Sleep(10 * time.Millisecond) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + if err != nil { + t.Fatal(err) + } + + done := false + ws.AsyncHandshake("ws://localhost:8080", func(err error) { + if err != nil { + t.Fatal(err) + } + ws.AsyncNextFrame(func(err error, f Frame) { + if err != nil { + t.Fatal(err) + } + assert.True(f.IsFIN()) + assert.True(f.Opcode().IsText()) + assert.Equal(1<<16+10, f.PayloadLength()) + assert.Equal(8, f.ExtendedPayloadLengthBytes()) + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} + func TestClientReconnectOnFailedRead(t *testing.T) { go func() { for i := 0; i < 10; i++ { @@ -463,10 +772,10 @@ func TestClientReadCorruptControlFrame(t *testing.T) { assertState(t, ws, StateClosedByUs) - closeFrame := ws.pending[0] - closeFrame.Unmask() + closeFrame := ws.pendingFrames[0] + closeFrame.UnmaskPayload() - cc, _ := DecodeCloseFramePayload(ws.pending[0].payload) + cc, _ := DecodeCloseFramePayload(ws.pendingFrames[0].Payload()) if cc != CloseProtocolError { t.Fatal("should have closed with protocol error") } @@ -507,10 +816,10 @@ func TestClientAsyncReadCorruptControlFrame(t *testing.T) { assertState(t, ws, StateClosedByUs) - closeFrame := ws.pending[0] - closeFrame.Unmask() + closeFrame := ws.pendingFrames[0] + closeFrame.UnmaskPayload() - cc, _ := DecodeCloseFramePayload(ws.pending[0].payload) + cc, _ := DecodeCloseFramePayload(ws.pendingFrames[0].Payload()) if cc != CloseProtocolError { t.Fatal("should have closed with protocol error") } @@ -550,12 +859,12 @@ func TestClientReadPingFrame(t *testing.T) { t.Fatal("should have a pending pong") } - reply := ws.pending[0] - if !(reply.IsPong() && reply.IsMasked()) { + reply := ws.pendingFrames[0] + if !(reply.Opcode().IsPong() && reply.IsMasked()) { t.Fatal("invalid pong reply") } - reply.Unmask() + reply.UnmaskPayload() if !bytes.Equal(reply.Payload(), []byte{0x01, 0x02}) { t.Fatal("invalid pong reply") } @@ -603,12 +912,12 @@ func TestClientAsyncReadPingFrame(t *testing.T) { t.Fatal("should have a pending pong") } - reply := ws.pending[0] - if !(reply.IsPong() && reply.IsMasked()) { + reply := ws.pendingFrames[0] + if !(reply.Opcode().IsPong() && reply.IsMasked()) { t.Fatal("invalid pong reply") } - reply.Unmask() + reply.UnmaskPayload() if !bytes.Equal(reply.Payload(), []byte{0x01, 0x02}) { t.Fatal("invalid pong reply") } @@ -764,13 +1073,13 @@ func TestClientReadCloseFrame(t *testing.T) { t.Fatal("should have one pending operation") } - reply := ws.pending[0] + reply := ws.pendingFrames[0] if !reply.IsMasked() { t.Fatal("reply should be masked") } - reply.Unmask() + reply.UnmaskPayload() - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("invalid close frame reply") } @@ -781,7 +1090,7 @@ func TestClientReadCloseFrame(t *testing.T) { b := make([]byte, 128) _, _, err = ws.NextMessage(b) - if len(ws.pending) > 0 { + if len(ws.pendingFrames) > 0 { t.Fatal("should have flushed") } @@ -827,13 +1136,13 @@ func TestClientAsyncReadCloseFrame(t *testing.T) { t.Fatal("should have one pending operation") } - reply := ws.pending[0] + reply := ws.pendingFrames[0] if !reply.IsMasked() { t.Fatal("reply should be masked") } - reply.Unmask() + reply.UnmaskPayload() - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("invalid close frame reply") } @@ -874,31 +1183,33 @@ func TestClientWriteFrame(t *testing.T) { ws.state = StateActive ws.init(mock) - f := AcquireFrame() - defer ReleaseFrame(f) - f.SetFin() + f := ws.AcquireFrame() + f.SetFIN() f.SetText() f.SetPayload([]byte{1, 2, 3, 4, 5}) + if len(*f) != 2 /*mandatory header*/ +4 /*mask since it's written by a client*/ +5 /*payload length*/ { + t.Fatal("invalid frame length") + } + err = ws.WriteFrame(f) if err != nil { t.Fatal(err) } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) + f := Frame(make([]byte, 2+4+5)) _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -921,9 +1232,8 @@ func TestClientAsyncWriteFrame(t *testing.T) { ws.state = StateActive ws.init(mock) - f := AcquireFrame() - defer ReleaseFrame(f) - f.SetFin() + f := ws.AcquireFrame() + f.SetFIN() f.SetText() f.SetPayload([]byte{1, 2, 3, 4, 5}) @@ -937,19 +1247,18 @@ func TestClientAsyncWriteFrame(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -983,19 +1292,18 @@ func TestClientWrite(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -1024,19 +1332,17 @@ func TestClientAsyncWrite(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) - + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsText()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsText()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() if !bytes.Equal(f.Payload(), []byte{1, 2, 3, 4, 5}) { t.Fatal("frame payload is corrupt, something went wrong with the encoder") @@ -1066,25 +1372,23 @@ func TestClientClose(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) - + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsClose()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsClose()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() - if f.PayloadLen() != 5 { + if f.PayloadLength() != 5 { t.Fatal("wrong message in close frame") } - cc, reason := DecodeCloseFramePayload(f.payload) + cc, reason := DecodeCloseFramePayload(f.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload") } @@ -1113,25 +1417,23 @@ func TestClientAsyncClose(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - f := AcquireFrame() - defer ReleaseFrame(f) - + f := NewFrame() _, err = f.ReadFrom(mock.b) if err != nil { t.Fatal(err) } - if !(f.IsFin() && f.IsMasked() && f.IsClose()) { + if !(f.IsFIN() && f.IsMasked() && f.Opcode().IsClose()) { t.Fatal("frame is corrupt, something went wrong with the encoder") } - f.Unmask() + f.UnmaskPayload() - if f.PayloadLen() != 5 { + if f.PayloadLength() != 5 { t.Fatal("wrong message in close frame") } - cc, reason := DecodeCloseFramePayload(f.payload) + cc, reason := DecodeCloseFramePayload(f.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload") } @@ -1160,14 +1462,12 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { if err != nil { t.Fatal(err) } else { - mock.b.Commit(mock.b.WriteLen()) - - serverReply := AcquireFrame() - defer ReleaseFrame(serverReply) - assertState(t, ws, StateClosedByUs) - serverReply.SetFin() + mock.b.Commit(mock.b.WriteLen()) + + serverReply := NewFrame() + serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) _, err = serverReply.WriteTo(ws.src) @@ -1180,11 +1480,11 @@ func TestClientCloseHandshakeWeStart(t *testing.T) { t.Fatal(err) } - if !(reply.IsFin() && reply.IsClose()) { + if !(reply.IsFIN() && reply.Opcode().IsClose()) { t.Fatal("wrong close reply") } - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload reply") } @@ -1215,10 +1515,8 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { } else { mock.b.Commit(mock.b.WriteLen()) - serverReply := AcquireFrame() - defer ReleaseFrame(serverReply) - - serverReply.SetFin() + serverReply := NewFrame() + serverReply.SetFIN() serverReply.SetClose() serverReply.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) _, err = serverReply.WriteTo(ws.src) @@ -1231,11 +1529,11 @@ func TestClientAsyncCloseHandshakeWeStart(t *testing.T) { t.Fatal(err) } - if !(reply.IsFin() && reply.IsClose()) { + if !(reply.IsFIN() && reply.Opcode().IsClose()) { t.Fatal("wrong close reply") } - cc, reason := DecodeCloseFramePayload(reply.payload) + cc, reason := DecodeCloseFramePayload(reply.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("wrong close frame payload reply") } @@ -1262,9 +1560,8 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { ws.state = StateActive ws.init(mock) - serverClose := AcquireFrame() - defer ReleaseFrame(serverClose) - serverClose.SetFin() + serverClose := NewFrame() + serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1280,13 +1577,13 @@ func TestClientCloseHandshakePeerStarts(t *testing.T) { t.Fatal(err) } - if !recv.IsClose() { + if !recv.Opcode().IsClose() { t.Fatal("should have received close") } assertState(t, ws, StateClosedByPeer) - cc, reason := DecodeCloseFramePayload(recv.payload) + cc, reason := DecodeCloseFramePayload(recv.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("peer close frame payload is corrupt") } @@ -1309,9 +1606,8 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { ws.state = StateActive ws.init(mock) - serverClose := AcquireFrame() - defer ReleaseFrame(serverClose) - serverClose.SetFin() + serverClose := NewFrame() + serverClose.SetFIN() serverClose.SetClose() serverClose.SetPayload(EncodeCloseFramePayload(CloseNormal, "bye")) @@ -1322,18 +1618,18 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { ws.src.Commit(int(nn)) - ws.AsyncNextFrame(func(err error, recv *Frame) { + ws.AsyncNextFrame(func(err error, recv Frame) { if err != nil { t.Fatal(err) } - if !recv.IsClose() { + if !recv.Opcode().IsClose() { t.Fatal("should have received close") } assertState(t, ws, StateClosedByPeer) - cc, reason := DecodeCloseFramePayload(recv.payload) + cc, reason := DecodeCloseFramePayload(recv.Payload()) if !(cc == CloseNormal && reason == "bye") { t.Fatal("peer close frame payload is corrupt") } diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index f11ba6f9..da91cf44 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -73,30 +73,26 @@ func (s *MockServer) Accept(addr string) (err error) { } func (s *MockServer) Write(b []byte) error { - fr := AcquireFrame() - defer ReleaseFrame(fr) - - fr.SetText() - fr.SetPayload(b) - fr.SetFin() - - _, err := fr.WriteTo(s.conn) + f := NewFrame() + f.SetText() + f.SetPayload(b) + f.SetFIN() + _, err := f.WriteTo(s.conn) return err } func (s *MockServer) Read(b []byte) (n int, err error) { - fr := AcquireFrame() - defer ReleaseFrame(fr) + f := NewFrame() - _, err = fr.ReadFrom(s.conn) + _, err = f.ReadFrom(s.conn) if err == nil { - if !fr.IsMasked() { + if !f.IsMasked() { return 0, fmt.Errorf("client frames should be masked") } - fr.Unmask() - copy(b, fr.Payload()) - n = fr.PayloadLen() + f.UnmaskPayload() + copy(b, f.Payload()) + n = f.PayloadLength() } return n, err } diff --git a/codec_test.go b/codec_test.go index fd08b3ab..a4df042a 100644 --- a/codec_test.go +++ b/codec_test.go @@ -23,7 +23,7 @@ func (t *TestCodec) Encode(item TestItem, dst *ByteBuffer) error { n, err := dst.Write(item.V[:]) dst.Commit(n) if err != nil { - dst.Consume(n) // TODO not really happy about this (same for websocket) + dst.Consume(n) return err } return nil @@ -113,7 +113,7 @@ func setupCodecTestWriter() chan struct{} { return mark } -func TestNonblockingCodecConnAsyncReadNext(t *testing.T) { +func TestCodecConnAsyncReadNext(t *testing.T) { mark := setupCodecTestWriter() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -129,7 +129,7 @@ func TestNonblockingCodecConnAsyncReadNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) @@ -164,7 +164,7 @@ func TestNonblockingCodecConnAsyncReadNext(t *testing.T) { } } -func TestNonblockingCodecConnReadNext(t *testing.T) { +func TestCodecConnReadNext(t *testing.T) { mark := setupCodecTestWriter() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -180,7 +180,7 @@ func TestNonblockingCodecConnReadNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) @@ -283,7 +283,7 @@ func setupCodecTestReader() chan struct{} { return mark } -func TestNonblockingCodecConnAsyncWriteNext(t *testing.T) { +func TestCodecConnAsyncWriteNext(t *testing.T) { mark := setupCodecTestReader() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -298,7 +298,7 @@ func TestNonblockingCodecConnAsyncWriteNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) @@ -334,7 +334,7 @@ func TestNonblockingCodecConnAsyncWriteNext(t *testing.T) { } } -func TestNonblockingCodecConnWriteNext(t *testing.T) { +func TestCodecConnWriteNext(t *testing.T) { mark := setupCodecTestReader() defer func() { <-mark /* wait for the listener to close*/ }() <-mark // wait for the listener to open @@ -350,7 +350,7 @@ func TestNonblockingCodecConnWriteNext(t *testing.T) { src := NewByteBuffer() dst := NewByteBuffer() - codecConn, err := NewNonblockingCodecConn[TestItem, TestItem]( + codecConn, err := NewCodecConn[TestItem, TestItem]( conn, &TestCodec{}, src, dst) if err != nil { t.Fatal(err) diff --git a/conn.go b/conn.go index 2290f4db..d0395404 100644 --- a/conn.go +++ b/conn.go @@ -1,7 +1,6 @@ package sonic import ( - "fmt" "net" "time" @@ -14,12 +13,11 @@ var _ Conn = &conn{} type conn struct { *file - fd int localAddr net.Addr remoteAddr net.Addr } -// Dial establishes a stream based connection to the specified address. +// Dial establishes a stream based connection to the specified address. It is similar to `net.Dial`. // // Data can be sent or received only from the specified address for all networks: tcp, udp and unix domain sockets. func Dial( @@ -30,6 +28,7 @@ func Dial( return DialTimeout(ioc, network, addr, 10*time.Second, opts...) } +// `DialTimeout` is like `Dial` but with a timeout. func DialTimeout( ioc *IO, network, addr string, timeout time.Duration, @@ -49,8 +48,7 @@ func newConn( localAddr, remoteAddr net.Addr, ) *conn { return &conn{ - file: &file{ioc: ioc, slot: internal.Slot{Fd: fd}}, - fd: fd, + file: newFile(ioc, fd), localAddr: localAddr, remoteAddr: remoteAddr, } @@ -64,15 +62,15 @@ func (c *conn) RemoteAddr() net.Addr { } func (c *conn) SetDeadline(t time.Time) error { - return fmt.Errorf("not supported") + panic("not implemented") } func (c *conn) SetReadDeadline(t time.Time) error { - return fmt.Errorf("not supported") + panic("not implemented") } func (c *conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("not supported") + panic("not implemented") } func (c *conn) RawFd() int { - return c.fd + return c.file.slot.Fd } diff --git a/conn_test.go b/conn_test.go index 9ffd439c..69ff96df 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/talostrading/sonic/sonicopts" ) @@ -349,3 +350,121 @@ func TestConnWriteHandlesError(t *testing.T) { t.Fatal("test did not run to completion") } } + +func TestDispatchLimit(t *testing.T) { + // Since the `ioc.Dispatched` counter is shared amongst, we test if it updated correctly when two connections build + // up each other's stack-frames. As in, when connection 1 has an immediate async read, it invokes connection 2's + // async read which is immediate and when done that invokes connection 1's async read which is immediate and so + // on. We ensure we reach the `MaxCallbackDispatch` limit with this sequence of reads. We then assert that + // `ioc.Dispatched` ends up 0 at the end of all operations. + + var ( + writtenBytes = MaxCallbackDispatch * 2 + readBytes = 0 + ) + + assert := assert.New(t) + + ioc := MustIO() + defer ioc.Close() + + assert.Equal(0, ioc.Dispatched) + + // setup up the server which will write to both reading connections + ln, err := net.Listen("tcp", "localhost:0") + assert.Nil(err) + addr := ln.Addr().String() + + marker := make(chan struct{}, 10) + go func() { + marker <- struct{}{} + + for { + conn, err := ln.Accept() + assert.Nil(err) + + go func() { + // These bytes will get cached by the TCP layer. We thus ensure that each connection has at least one + // series of `MaxCallbackDispatch` immediate asynchronous reads. + for i := 0; i < writtenBytes; i++ { + n, err := conn.Write([]byte("1")) + assert.Nil(err) + assert.Equal(1, n) + } + + <-marker + conn.Close() + }() + } + }() + <-marker + + limitHit := 0 // counts how many times each connection has hit the `MaxCallbackDispatch` limit + + var b [1]byte // shared between the two reading connections + + rd1, err := Dial(ioc, "tcp", addr) + assert.Nil(err) + defer rd1.Close() + + rd2, err := Dial(ioc, "tcp", addr) + assert.Nil(err) + defer rd2.Close() + + var ( + onRead1, onRead2 AsyncCallback + ) + onRead1 = func(err error, n int) { + if err != nil && err != io.EOF { + t.Fatal(err) + } + if err == io.EOF { + return + } + + readBytes += n + + assert.True(ioc.Dispatched <= MaxCallbackDispatch) + if ioc.Dispatched == MaxCallbackDispatch { + limitHit++ + } + + // sleeping ensures the writer is faster than the reader so we build up stack frames by having each async read + // complete immediately + time.Sleep(time.Millisecond) + + // invoke the other reading connection's async read to ensure `ioc.Dispatched` is correctly updated when shared + // between asynchronous objects + rd2.AsyncRead(b[:], onRead2) + } + + onRead2 = func(err error, n int) { + if err != nil && err != io.EOF { + t.Fatal(err) + } + if err == io.EOF { + return + } + + readBytes += n + + assert.True(ioc.Dispatched <= MaxCallbackDispatch) + if ioc.Dispatched == MaxCallbackDispatch { + limitHit++ + } + + time.Sleep(time.Millisecond) + rd1.AsyncRead(b[:], onRead1) // the other connection + } + + rd1.AsyncRead(b[:], onRead1) // starting point + + for readBytes < writtenBytes*2 /* two connections */ { + ioc.PollOne() + } + + assert.True(limitHit > 0) + assert.Equal(0, ioc.Dispatched) + + marker <- struct{}{} // to close the write end +} diff --git a/definitions.go b/definitions.go index f9e0f9c6..289b847a 100644 --- a/definitions.go +++ b/definitions.go @@ -5,61 +5,42 @@ import ( "net" ) -const ( - // MaxCallbackDispatch is the maximum number of callbacks which can be - // placed onto the stack for immediate invocation. - MaxCallbackDispatch int = 32 -) - -// TODO this is quite a mess right now. Properly define what a Conn, Stream, -// PacketConn is and what should the async adapter return, as more than TCP -// streams can be "async-adapted". - type AsyncCallback func(error, int) type AcceptCallback func(error, Conn) type AcceptPacketCallback func(error, PacketConn) // AsyncReader is the interface that wraps the AsyncRead and AsyncReadAll methods. type AsyncReader interface { - // AsyncRead reads up to len(b) bytes into b asynchronously. + // AsyncRead reads up to `len(b)` bytes into `b` asynchronously. // - // This call should not block. The provided completion handler is called - // in the following cases: - // - a read of n bytes completes - // - an error occurs - // - // Callers should always process the n > 0 bytes returned before considering - // the error err. + // This call should not block. The provided completion handler is called in the following cases: + // - a read of up to `n` bytes completes + // - an error occurs // - // Implementation of AsyncRead are discouraged from invoking the handler - // with a zero byte count with a nil error, except when len(b) == 0. - // Callers should treat a return of 0 and nil as indicating that nothing - // happened; in particular, it does not indicate EOF. + // Callers should always process the returned bytes before considering the error err. Implementations of AsyncRead + // are discouraged from invoking the handler with 0 bytes and a nil error. // - // Implementations must not retain b. Ownership of b must be retained by the caller, - // which must guarantee that it remains valid until the handler is called. + // Ownership of the byte slice must be retained by callers, which must guarantee that it remains valid until the + // callback is invoked. AsyncRead(b []byte, cb AsyncCallback) - // AsyncReadAll reads len(b) bytes into b asynchronously. + // AsyncReadAll reads exactly `len(b)` bytes into b asynchronously. AsyncReadAll(b []byte, cb AsyncCallback) } -// AsyncWriter is the interface that wraps the AsyncRead and AsyncReadAll methods. +// AsyncWriter is the interface that wraps the AsyncWrite and AsyncWriteAll methods. type AsyncWriter interface { - // AsyncWrite writes up to len(b) bytes into the underlying data stream asynchronously. + // AsyncWrite writes up to `len(b)` bytes from `b` asynchronously. // - // This call should not block. The provided completion handler is called in the following cases: - // - a write of n bytes completes + // This call does not block. The provided completion handler is called in the following cases: + // - a write of up to `n` bytes completes // - an error occurs // - // AsyncWrite must provide a non-nil error if it writes n < len(b) bytes. - // - // Implementations must not retain b. Ownership of b must be retained by the caller, - // which must guarantee that it remains valid until the handler is called. - // AsyncWrite must not modify b, even temporarily. + // Ownership of the byte slice must be retained by callers, which must guarantee that it remains valid until the + // callback is invoked. This function does not modify the given byte slice, even temporarily. AsyncWrite(b []byte, cb AsyncCallback) - // AsyncWriteAll writes len(b) bytes into the underlying data stream asynchronously. + // AsyncWriteAll writes exactly `len(b)` bytes into the underlying data stream asynchronously. AsyncWriteAll(b []byte, cb AsyncCallback) } @@ -95,9 +76,8 @@ type AsyncCanceller interface { Cancel() } -// Stream represents a full-duplex connection between two processes, -// where data represented as bytes may be received reliably in the same order -// they were written. +// Stream represents a full-duplex connection between two processes, where data represented as bytes may be received +// reliably in the same order they were written. type Stream interface { RawFd() int @@ -180,32 +160,3 @@ type Listener interface { RawFd() int } - -// UDPMulticastClient defines a UDP multicast client that can read data from one or multiple multicast groups, -// optionally filtering packets on the source IP. -type UDPMulticastClient interface { - Join(multicastAddr *net.UDPAddr) error - JoinSource(multicastAddr, sourceAddr *net.UDPAddr) error - - Leave(multicastAddr *net.UDPAddr) error - LeaveSource(multicastAddr, sourceAddr *net.UDPAddr) error - - BlockSource(multicastAddr, sourceAddr *net.UDPAddr) error - UnblockSource(multicastAddr, sourceAddr *net.UDPAddr) error - - // ReadFrom and AsyncReadFrom read a partial or complete datagram into the provided buffer. When the datagram - // is smaller than the passed buffer, only that much data is returned; when it is bigger, the packet is truncated - // to fit into the buffer. The truncated bytes are lost. - // - // It is the responsibility of the caller to ensure the passed buffer can hold an entire datagram. A rule of thumb - // is to have the buffer size equal to the network interface's MTU. - ReadFrom([]byte) (n int, addr net.Addr, err error) - AsyncReadFrom([]byte, AsyncReadCallbackPacket) - - RawFd() int - Interface() *net.Interface - LocalAddr() *net.UDPAddr - - Close() error - Closed() bool -} diff --git a/examples/binance/main.go b/examples/binance/main.go index c0537f65..b5dfe15e 100644 --- a/examples/binance/main.go +++ b/examples/binance/main.go @@ -2,73 +2,96 @@ package main import ( "crypto/tls" + "flag" "fmt" + "time" "github.com/talostrading/sonic" "github.com/talostrading/sonic/codec/websocket" + "github.com/talostrading/sonic/util" ) -var subscriptionMessage = []byte( - ` +var ( + verbose = flag.Bool("v", false, "if true, websocket messages are printed") + + subscriptionMessage = []byte(` { "id": 1, "method": "SUBSCRIBE", "params": [ "bnbbtc@depth" ] } `) +) -var b = make([]byte, 512*1024) // contains websocket payloads +func main() { + flag.Parse() -func run(stream websocket.Stream) { - stream.AsyncHandshake("wss://stream.binance.com:9443/ws", func(err error) { - onHandshake(err, stream) - }) -} + ioc := sonic.MustIO() + defer ioc.Close() -func onHandshake(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { - onWrite(err, stream) - }) - } -} + ioLatency := util.NewOnlineStats() -func onWrite(err error, stream websocket.Stream) { + stream, err := websocket.NewWebsocketStream(ioc, &tls.Config{}, websocket.RoleClient) if err != nil { panic(err) - } else { - readLoop(stream) } -} -func readLoop(stream websocket.Stream) { - var onRead websocket.AsyncMessageHandler - onRead = func(err error, n int, _ websocket.MessageType) { + stream.AsyncHandshake("wss://stream.binance.com:9443/ws", func(err error) { if err != nil { panic(err) - } else { - b = b[:n] - fmt.Println(string(b)) - b = b[:cap(b)] - - stream.AsyncNextMessage(b, onRead) } - } - stream.AsyncNextMessage(b, onRead) -} -func main() { - ioc := sonic.MustIO() - defer ioc.Close() + stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } - stream, err := websocket.NewWebsocketStream(ioc, &tls.Config{}, websocket.RoleClient) + var ( + b [1024 * 512]byte + onRead websocket.AsyncMessageCallback + ) + onRead = func(err error, n int, _ websocket.MessageType) { + if err != nil { + panic(err) + } + + if *verbose { + fmt.Println(string(b[:n])) + } + stream.AsyncNextMessage(b[:], onRead) + } + stream.AsyncNextMessage(b[:], onRead) + }) + }) + + eventsReceived := 0 + + ioLatencyTimer, err := sonic.NewTimer(ioc) if err != nil { panic(err) } + ioLatencyTimer.ScheduleRepeating(time.Second, func() { + if eventsReceived > 0 { + result := ioLatency.Result() + fmt.Printf( + "min/avg/max/stddev = %.2f/%.2f/%.2f/%.2f us from %d events\n", + result.Min, + result.Avg, + result.Max, + result.StdDev, + eventsReceived, + ) + ioLatency.Reset() + eventsReceived = 0 + } + }) - run(stream) - - ioc.Run() + for { + start := time.Now() + n, _ := ioc.PollOne() + if n > 0 { + eventsReceived += n + ioLatency.Add(float64(time.Now().Sub(start).Microseconds())) + } + } } diff --git a/examples/okex/main.go b/examples/okex/main.go index d9dc726f..36dfb8c0 100644 --- a/examples/okex/main.go +++ b/examples/okex/main.go @@ -21,48 +21,6 @@ var subscriptionMessage = []byte( } `) -var b = make([]byte, 512*1024) // contains websocket payloads - -func run(stream websocket.Stream) { - stream.AsyncHandshake("wss://ws.okx.com:8443/ws/v5/public", func(err error) { - onHandshake(err, stream) - }) -} - -func onHandshake(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { - onWrite(err, stream) - }) - } -} - -func onWrite(err error, stream websocket.Stream) { - if err != nil { - panic(err) - } else { - readLoop(stream) - } -} - -func readLoop(stream websocket.Stream) { - var onRead websocket.AsyncMessageHandler - onRead = func(err error, n int, _ websocket.MessageType) { - if err != nil { - panic(err) - } else { - b = b[:n] - fmt.Println(string(b)) - b = b[:cap(b)] - - stream.AsyncNextMessage(b, onRead) - } - } - stream.AsyncNextMessage(b, onRead) -} - func main() { ioc := sonic.MustIO() defer ioc.Close() @@ -72,7 +30,31 @@ func main() { panic(err) } - run(stream) + stream.AsyncHandshake("wss://ws.okx.com:8443/ws/v5/public", func(err error) { + if err != nil { + panic(err) + } + + stream.AsyncWrite(subscriptionMessage, websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } + + var ( + b [1024 * 512]byte + onRead websocket.AsyncMessageCallback + ) + onRead = func(err error, n int, _ websocket.MessageType) { + if err != nil { + panic(err) + } + + fmt.Println(string(b[:n])) + stream.AsyncNextMessage(b[:], onRead) + } + stream.AsyncNextMessage(b[:], onRead) + }) + }) ioc.Run() } diff --git a/examples/websocket/client.go b/examples/websocket/client.go index db666ce2..83a413da 100644 --- a/examples/websocket/client.go +++ b/examples/websocket/client.go @@ -19,26 +19,21 @@ func main() { client.AsyncHandshake("ws://localhost:8080", func(err error) { if err != nil { panic(err) - } else { - client.AsyncWrite([]byte("hello"), websocket.TypeText, func(err error) { + } + client.AsyncWrite([]byte("hello"), websocket.TypeText, func(err error) { + if err != nil { + panic(err) + } + + var b [128]byte + client.AsyncNextMessage(b[:], func(err error, n int, mt websocket.MessageType) { if err != nil { panic(err) - } else { - b := make([]byte, 128) - client.AsyncNextMessage(b, func(err error, n int, mt websocket.MessageType) { - if err != nil { - panic(err) - } else { - b = b[:n] - fmt.Println("read", n, "bytes", string(b), err) - } - }) } + fmt.Println("read", n, "bytes", string(b[:n]), err) }) - } + }) }) - for { - ioc.RunOneFor(0) // poll - } + ioc.Run() } diff --git a/examples/websocket/server.go b/examples/websocket/server.go deleted file mode 100644 index ea61f64b..00000000 --- a/examples/websocket/server.go +++ /dev/null @@ -1,5 +0,0 @@ -package main - -func main() { - panic("implement me") -} diff --git a/file.go b/file.go index 610454d1..889712f8 100644 --- a/file.go +++ b/file.go @@ -13,29 +13,88 @@ import ( var _ File = &file{} type file struct { - ioc *IO - slot internal.Slot - closed uint32 - - // dispatched tracks how callback are currently on the stack. - // If the fd has a lot of data to read/write and the caller nests - // read/write calls then we might overflow the stack. In order to not do that - // we limit the number of dispatched reads to MaxCallbackDispatch. - // If we hit that limit, we schedule an async read/write which results in clearing the stack. - dispatched int + ioc *IO + slot internal.Slot + closed uint32 + readReactor fileReadReactor + writeReactor fileWriteReactor } -func Open(ioc *IO, path string, flags int, mode os.FileMode) (File, error) { - fd, err := syscall.Open(path, flags, uint32(mode)) +type fileReadReactor struct { + file *file + + b []byte + readAll bool + cb AsyncCallback + readSoFar int +} + +func (r *fileReadReactor) init(b []byte, readAll bool, cb AsyncCallback) { + r.b = b + r.readAll = readAll + r.cb = cb + + r.readSoFar = 0 +} + +func (r *fileReadReactor) onRead(err error) { + r.file.ioc.Deregister(&r.file.slot) if err != nil { - return nil, err + r.cb(err, r.readSoFar) + } else { + r.file.asyncReadNow(r.b, r.readSoFar, r.readAll, r.cb) + } +} + +type fileWriteReactor struct { + file *file + + b []byte + writeAll bool + cb AsyncCallback + wroteSoFar int +} + +func (r *fileWriteReactor) init(b []byte, writeAll bool, cb AsyncCallback) { + r.b = b + r.writeAll = writeAll + r.cb = cb + + r.wroteSoFar = 0 +} + +func (r *fileWriteReactor) onWrite(err error) { + r.file.ioc.Deregister(&r.file.slot) + if err != nil { + r.cb(err, r.wroteSoFar) + } else { + r.file.asyncWriteNow(r.b, r.wroteSoFar, r.writeAll, r.cb) } +} +func newFile(ioc *IO, fd int) *file { f := &file{ ioc: ioc, slot: internal.Slot{Fd: fd}, } - return f, nil + atomic.StoreUint32(&f.closed, 0) + + f.readReactor = fileReadReactor{file: f} + f.readReactor.init(nil, false, nil) + + f.writeReactor = fileWriteReactor{file: f} + f.writeReactor.init(nil, false, nil) + + return f +} + +func Open(ioc *IO, path string, flags int, mode os.FileMode) (File, error) { + fd, err := syscall.Open(path, flags, uint32(mode)) + if err != nil { + return nil, err + } + + return newFile(ioc, fd), nil } func (f *file) Read(b []byte) (int, error) { @@ -91,69 +150,60 @@ func (f *file) AsyncReadAll(b []byte, cb AsyncCallback) { } func (f *file) asyncRead(b []byte, readAll bool, cb AsyncCallback) { - if f.dispatched < MaxCallbackDispatch { + f.readReactor.init(b, readAll, cb) + + if f.ioc.Dispatched < MaxCallbackDispatch { f.asyncReadNow(b, 0, readAll, func(err error, n int) { - f.dispatched++ + f.ioc.Dispatched++ cb(err, n) - f.dispatched-- + f.ioc.Dispatched-- }) } else { - f.scheduleRead(b, 0, readAll, cb) + f.scheduleRead(0 /* this is the starting point, we did not read anything yet */, cb) } } -func (f *file) asyncReadNow(b []byte, readBytes int, readAll bool, cb AsyncCallback) { - n, err := f.Read(b[readBytes:]) - readBytes += n +func (f *file) asyncReadNow(b []byte, readSoFar int, readAll bool, cb AsyncCallback) { + n, err := f.Read(b[readSoFar:]) + readSoFar += n // f is a nonblocking fd so if err == ErrWouldBlock // then we need to schedule an async read. - if err == nil && !(readAll && readBytes != len(b)) { + if err == nil && !(readAll && readSoFar != len(b)) { // If readAll == true then read fully without errors. // If readAll == false then read some without errors. // We are done. - cb(nil, readBytes) + cb(nil, readSoFar) return } - // handles (readAll == false) and (readAll == true && readBytes != len(b)). + // handles (readAll == false) and (readAll == true && readSoFar != len(b)). if err == sonicerrors.ErrWouldBlock { // If readAll == true then read some without errors. // We schedule an asynchronous read. - f.scheduleRead(b, readBytes, readAll, cb) + f.scheduleRead(readSoFar, cb) } else { - cb(err, readBytes) + cb(err, readSoFar) } } -func (f *file) scheduleRead(b []byte, readBytes int, readAll bool, cb AsyncCallback) { +func (f *file) scheduleRead(readSoFar int, cb AsyncCallback) { if f.Closed() { cb(io.EOF, 0) return } - handler := f.getReadHandler(b, readBytes, readAll, cb) - f.slot.Set(internal.ReadEvent, handler) + f.readReactor.readSoFar = readSoFar + f.slot.Set(internal.ReadEvent, f.readReactor.onRead) if err := f.ioc.SetRead(&f.slot); err != nil { - cb(err, readBytes) + cb(err, readSoFar) } else { f.ioc.Register(&f.slot) } } -func (f *file) getReadHandler(b []byte, readBytes int, readAll bool, cb AsyncCallback) internal.Handler { - return func(err error) { - f.ioc.Deregister(&f.slot) - if err != nil { - cb(err, readBytes) - } else { - f.asyncReadNow(b, readBytes, readAll, cb) - } - } -} - func (f *file) AsyncWrite(b []byte, cb AsyncCallback) { f.asyncWrite(b, false, cb) } @@ -163,79 +213,63 @@ func (f *file) AsyncWriteAll(b []byte, cb AsyncCallback) { } func (f *file) asyncWrite(b []byte, writeAll bool, cb AsyncCallback) { - if f.dispatched < MaxCallbackDispatch { + f.writeReactor.init(b, writeAll, cb) + + if f.ioc.Dispatched < MaxCallbackDispatch { f.asyncWriteNow(b, 0, writeAll, func(err error, n int) { - f.dispatched++ + f.ioc.Dispatched++ cb(err, n) - f.dispatched-- + f.ioc.Dispatched-- }) } else { - f.scheduleWrite(b, 0, writeAll, cb) + f.scheduleWrite(0 /* this is the starting point, we did not write anything yet */, cb) } } -func (f *file) asyncWriteNow(b []byte, writtenBytes int, writeAll bool, cb AsyncCallback) { - n, err := f.Write(b[writtenBytes:]) - writtenBytes += n +func (f *file) asyncWriteNow(b []byte, wroteSoFar int, writeAll bool, cb AsyncCallback) { + n, err := f.Write(b[wroteSoFar:]) + wroteSoFar += n - // f is a nonblocking fd so if err == ErrWouldBlock - // then we need to schedule an async write. - - if err == nil && !(writeAll && writtenBytes != len(b)) { - // If writeAll == true then wrote fully without errors. - // If writeAll == false then wrote some without errors. - // We are done. - cb(nil, writtenBytes) + if err == nil && !(writeAll && wroteSoFar != len(b)) { + // If writeAll == true then we wrote fully without errors. + // If writeAll == false then we wrote some without errors. + cb(nil, wroteSoFar) return } - // handles (writeAll == false) and (writeAll == true && writtenBytes != len(b)). + // Handles (writeAll == false) and (writeAll == true && wroteSoFar != len(b)). if err == sonicerrors.ErrWouldBlock { - // If writeAll == true then wrote some without errors. - // We schedule an asynchronous write. - f.scheduleWrite(b, writtenBytes, writeAll, cb) + f.scheduleWrite(wroteSoFar, cb) } else { - cb(err, writtenBytes) + cb(err, wroteSoFar) } } -func (f *file) scheduleWrite(b []byte, writtenBytes int, writeAll bool, cb AsyncCallback) { +func (f *file) scheduleWrite(wroteSoFar int, cb AsyncCallback) { if f.Closed() { cb(io.EOF, 0) return } - handler := f.getWriteHandler(b, writtenBytes, writeAll, cb) - f.slot.Set(internal.WriteEvent, handler) + f.writeReactor.wroteSoFar = wroteSoFar + f.slot.Set(internal.WriteEvent, f.writeReactor.onWrite) if err := f.ioc.SetWrite(&f.slot); err != nil { - cb(err, writtenBytes) + cb(err, wroteSoFar) } else { f.ioc.Register(&f.slot) } } -func (f *file) getWriteHandler(b []byte, writtenBytes int, writeAll bool, cb AsyncCallback) internal.Handler { - return func(err error) { - f.ioc.Deregister(&f.slot) - - if err != nil { - cb(err, writtenBytes) - } else { - f.asyncWriteNow(b, writtenBytes, writeAll, cb) - } - } -} - func (f *file) Close() error { if !atomic.CompareAndSwapUint32(&f.closed, 0, 1) { return io.EOF } - err := f.ioc.poller.Del(&f.slot) - if err != nil { + if err := f.ioc.UnsetReadWrite(&f.slot); err != nil { return err } + f.ioc.Deregister(&f.slot) return syscall.Close(f.slot.Fd) } diff --git a/go.mod b/go.mod index 9a213554..dc64bd84 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,16 @@ go 1.20 require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 github.com/felixge/fgprof v0.9.3 + github.com/stretchr/testify v1.8.0 github.com/valyala/bytebufferpool v1.0.0 golang.org/x/sys v0.11.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/go-cmp v0.5.8 // indirect github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5db7805e..86601b83 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,9 @@ github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1: github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -72,6 +74,7 @@ gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/socket_unix.go b/internal/socket_unix.go index 62291622..a817a2e9 100644 --- a/internal/socket_unix.go +++ b/internal/socket_unix.go @@ -151,10 +151,13 @@ func connect(fd int, remoteAddr net.Addr, timeout time.Duration, opts ...sonicop return sonicerrors.ErrTimeout } - _, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR) + socketErr, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR) if err != nil { return os.NewSyscallError("getsockopt", err) } + if socketErr != 0 { + return syscall.Errno(socketErr) + } } return nil diff --git a/io.go b/io.go index 3f33901f..7cd14549 100644 --- a/io.go +++ b/io.go @@ -11,6 +11,12 @@ import ( "github.com/talostrading/sonic/sonicerrors" ) +// MaxCallbackDispatch is the maximum number of callbacks that can exist on a stack-frame when asynchronous operations +// can be completed immediately. +// +// This is the limit to the `IO.dispatched` counter. +const MaxCallbackDispatch int = 32 + // IO is the executor of all asynchronous operations and the way any object can schedule them. It runs fully in the // calling goroutine. // @@ -20,9 +26,9 @@ type IO struct { poller internal.Poller // The below structures keep a pointer to a Slot struct usually owned by an object capable of asynchronous - // operations (essentially any object taking an IO* on construction). Keeping a Slot pointer keeps the owning - // object in the GC's object graph while an asynchronous operation is in progress. This ensures Slot references - // valid memory when an asynchronous operation completes and the object is already out of scope. + // operations (essentially any object taking an IO* on construction). Keeping a Slot pointer keeps the owning object + // in the GC's object graph while an asynchronous operation is in progress. This ensures Slot references valid + // memory when an asynchronous operation completes and the object is already out of scope. pending struct { // The kernel allocates a process' descriptors from a fixed range that is [0, 1024) by default. Unprivileged // users can bump this range to [0, 4096). The below array should cover 99% of the cases and makes for cheap @@ -32,11 +38,23 @@ type IO struct { // file descriptors are bound by a fixed range whose upper limit is controlled through RLIMIT_NOFILE. static [4096]*internal.Slot - // This map covers the 1%, the degenerate case. Any Slot whose file descriptor is greater than or equal to - // 4096 goes here. + // This map covers the 1%, the degenerate case. Any Slot whose file descriptor is greater than or equal to 4096 + // goes here. This is lazily initialized. dynamic map[*internal.Slot]struct{} } pendingTimers map[*Timer]struct{} // XXX: should be embedded into the above pending struct + + // Tracks how many callbacks are on the current stack-frame. This prevents stack-overflows in cases where + // asynchronous operations can be completed immediately. + // + // For example, an asynchronous read might be completed immediately. In that case, the callback is invoked which in + // turn might call `AsyncRead` again. That asynchronous read might again be completed immediately and so on. In this + // case, all subsequent read callbacks are placed on the same stack-frame. We count these callbacks with + // `Dispatched`. If we hit `MaxCallbackDispatch`, then the stack-frame is popped - asynchronous reads are scheduled + // to be completed on the next poll cycle, even if they can be completed immediately. + // + // This counter is shared amongst all asynchronous objects - they are responsible for updating it. + Dispatched int } func NewIO() (*IO, error) { @@ -48,6 +66,7 @@ func NewIO() (*IO, error) { return &IO{ poller: poller, pendingTimers: make(map[*Timer]struct{}), + Dispatched: 0, }, nil } @@ -78,14 +97,39 @@ func (ioc *IO) Deregister(slot *internal.Slot) { } } +// SetRead tells the kernel to notify us when reads can be made on the provided IO slot. If successful, this call must +// be succeeded by Register(slot). +// +// It is safe to call this method multiple times. func (ioc *IO) SetRead(slot *internal.Slot) error { return ioc.poller.SetRead(slot) } +// UnsetRead tells the kernel to not notify us anymore when reads can be made on the provided IO slot. Since the +// underlying platform-specific poller already unsets a read before dispatching it, callers must only use this method +// they want to cancel a currently-scheduled read. For example, when an error occurs outside of an AsyncRead call and +// the underlying file descriptor must be closed. In that case, this call must be succeeded by Deregister(slot). +// +// It is safe to call this method multiple times. +func (ioc *IO) UnsetRead(slot *internal.Slot) error { + return ioc.poller.DelRead(slot) +} + +// Like SetRead but for writes. func (ioc *IO) SetWrite(slot *internal.Slot) error { return ioc.poller.SetWrite(slot) } +// Like UnsetRead but for writes. +func (ioc *IO) UnsetWrite(slot *internal.Slot) error { + return ioc.poller.DelWrite(slot) +} + +// UnsetRead and UnsetWrite in a single call. +func (ioc *IO) UnsetReadWrite(slot *internal.Slot) error { + return ioc.poller.Del(slot) +} + // Run runs the event processing loop. func (ioc *IO) Run() error { for { @@ -145,7 +189,7 @@ const ( // RunWarm runs the event loop in a combined busy-wait and yielding mode, meaning that if the current cycle does not // process anything, the event-loop will busy-wait for at most `busyCycles` which we call the warm-state. After // `busyCycles` of not processing anything, the event-loop is out of the warm-state and falls back to yielding with the -// provided timeout. If at any moment an event occurs and something is processed, the event-loop transitions to its +// provided timeout. If at any moment an event occurs and something is processed, the event-loop transitions to its // warm-state. func (ioc *IO) RunWarm(busyCycles int, timeout time.Duration) (err error) { if busyCycles <= 0 { @@ -175,9 +219,9 @@ func (ioc *IO) RunWarm(busyCycles int, timeout time.Duration) (err error) { // We processed something in this cycle, be it inside or outside the warm-period. We restart the warm-period i = 0 } else { - // We did not process anything in this cycle. If we are still in the warm period i.e. `i < busyCycles`, - // we are going to poll in the next cycle. If we are outside the warm period i.e. `i >= busyCycles`, - // we are going to yield in the next cycle. + // We did not process anything in this cycle. If we are still in the warm period i.e. `i < busyCycles`, we + // are going to poll in the next cycle. If we are outside the warm period i.e. `i >= busyCycles`, we are + // going to yield in the next cycle. i++ } } @@ -206,8 +250,6 @@ func (ioc *IO) poll(timeoutMs int) (int, error) { if err != nil { if err == syscall.EINTR { - // TODO not sure about this one, and whether returning timeout here is ok. - // need to look into syscall.EINTR again if timeoutMs >= 0 { return 0, sonicerrors.ErrTimeout } @@ -226,8 +268,7 @@ func (ioc *IO) poll(timeoutMs int) (int, error) { return n, nil } -// Post schedules the provided handler to be run immediately by the event -// processing loop in its own thread. +// Post schedules the provided handler to be run immediately by the event processing loop in its own thread. // // It is safe to call Post concurrently. func (ioc *IO) Post(handler func()) error { @@ -241,6 +282,7 @@ func (ioc *IO) Posted() int { return ioc.poller.Posted() } +// Returns the current number of pending asynchronous operations. func (ioc *IO) Pending() int64 { return ioc.poller.Pending() } diff --git a/io_test.go b/io_test.go index 014b59a4..1d1f41fa 100644 --- a/io_test.go +++ b/io_test.go @@ -316,6 +316,38 @@ func TestIOPending(t *testing.T) { } } +func TestSetUnsetRead(t *testing.T) { + ioc := MustIO() + defer ioc.Close() + + pipe, err := internal.NewPipe() + if err != nil { + t.Fatal(err) + } + + if pipe.ReadFd() != pipe.Slot().Fd { + t.Fatal("pipe must be identified by its read end file descriptor") + } + + for i := 0; i < 10; i++ { + if err := ioc.SetRead(pipe.Slot()); err != nil { + t.Fatal(err) + } + } + + for i := 0; i < 100; i++ { + if err := ioc.UnsetRead(pipe.Slot()); err != nil { + t.Fatal(err) + } + } + + for i := 0; i < 100; i++ { + if err := ioc.UnsetReadWrite(pipe.Slot()); err != nil { + t.Fatal(err) + } + } +} + func BenchmarkPollOne(b *testing.B) { ioc := MustIO() defer ioc.Close() diff --git a/listen_conn.go b/listen_conn.go index 4c001234..ba738334 100644 --- a/listen_conn.go +++ b/listen_conn.go @@ -16,8 +16,6 @@ type listener struct { ioc *IO slot internal.Slot addr net.Addr - - dispatched int } // Listen creates a Listener that listens for new connections on the local address. @@ -53,16 +51,16 @@ func (l *listener) Accept() (Conn, error) { } func (l *listener) AsyncAccept(cb AcceptCallback) { - if l.dispatched >= MaxCallbackDispatch { + if l.ioc.Dispatched >= MaxCallbackDispatch { l.asyncAccept(cb) } else { conn, err := l.accept() if err != nil && (err == sonicerrors.ErrWouldBlock) { l.asyncAccept(cb) } else { - l.dispatched++ + l.ioc.Dispatched++ cb(err, conn) - l.dispatched-- + l.ioc.Dispatched-- } } } @@ -113,7 +111,8 @@ func (l *listener) accept() (Conn, error) { } func (l *listener) Close() error { - _ = l.ioc.poller.Del(&l.slot) + _ = l.ioc.UnsetReadWrite(&l.slot) + l.ioc.Deregister(&l.slot) return syscall.Close(l.slot.Fd) } diff --git a/multicast/peer.go b/multicast/peer.go index fa2276b5..f16098b6 100644 --- a/multicast/peer.go +++ b/multicast/peer.go @@ -32,9 +32,8 @@ type UDPPeer struct { slot internal.Slot - sockAddr syscall.Sockaddr - closed bool - dispatched int + sockAddr syscall.Sockaddr + closed bool } // NewUDPPeer creates a new UDPPeer capable of reading/writing multicast packets @@ -567,11 +566,11 @@ func (p *UDPPeer) AsyncRead(b []byte, fn func(error, int, netip.AddrPort)) { p.read.b = b p.read.fn = fn - if p.dispatched < sonic.MaxCallbackDispatch { + if p.ioc.Dispatched < sonic.MaxCallbackDispatch { p.asyncReadNow(b, func(err error, n int, addr netip.AddrPort) { - p.dispatched++ + p.ioc.Dispatched++ fn(err, n, addr) - p.dispatched-- + p.ioc.Dispatched-- }) } else { p.scheduleRead(fn) @@ -622,11 +621,11 @@ func (p *UDPPeer) AsyncWrite( p.write.addr = addr p.write.fn = fn - if p.dispatched < sonic.MaxCallbackDispatch { + if p.ioc.Dispatched < sonic.MaxCallbackDispatch { p.asyncWriteNow(b, addr, func(err error, n int) { - p.dispatched++ + p.ioc.Dispatched++ fn(err, n) - p.dispatched-- + p.ioc.Dispatched-- }) } else { p.scheduleWrite(fn) @@ -676,6 +675,8 @@ func (p *UDPPeer) LocalAddr() *net.UDPAddr { func (p *UDPPeer) Close() error { if !p.closed { p.closed = true + _ = p.ioc.UnsetReadWrite(&p.slot) + p.ioc.Deregister(&p.slot) return p.socket.Close() } return nil diff --git a/packet.go b/packet.go index 3494f8b4..4a3cdfd9 100644 --- a/packet.go +++ b/packet.go @@ -20,8 +20,6 @@ type packetConn struct { localAddr net.Addr remoteAddr net.Addr closed uint32 - - dispatched int } // NewPacketConn establishes a packet based stream-less connection which is optionally bound to the specified addr. @@ -81,11 +79,11 @@ func (c *packetConn) AsyncReadAllFrom(b []byte, cb AsyncReadCallbackPacket) { } func (c *packetConn) asyncReadFrom(b []byte, readAll bool, cb AsyncReadCallbackPacket) { - if c.dispatched < MaxCallbackDispatch { + if c.ioc.Dispatched < MaxCallbackDispatch { c.asyncReadNow(b, 0, readAll, func(err error, n int, addr net.Addr) { - c.dispatched++ + c.ioc.Dispatched++ cb(err, n, addr) - c.dispatched-- + c.ioc.Dispatched-- }) } else { c.scheduleRead(b, 0, readAll, cb) @@ -145,11 +143,11 @@ func (c *packetConn) WriteTo(b []byte, to net.Addr) error { } func (c *packetConn) AsyncWriteTo(b []byte, to net.Addr, cb AsyncWriteCallbackPacket) { - if c.dispatched < MaxCallbackDispatch { + if c.ioc.Dispatched < MaxCallbackDispatch { c.asyncWriteToNow(b, to, func(err error) { - c.dispatched++ + c.ioc.Dispatched++ cb(err) - c.dispatched-- + c.ioc.Dispatched-- }) } else { c.scheduleWrite(b, to, cb) @@ -195,6 +193,8 @@ func (c *packetConn) getWriteHandler(b []byte, to net.Addr, cb AsyncWriteCallbac func (c *packetConn) Close() error { atomic.StoreUint32(&c.closed, 1) + _ = c.ioc.UnsetReadWrite(&c.slot) + c.ioc.Deregister(&c.slot) return syscall.Close(c.slot.Fd) } diff --git a/stress_test/codec/nonblocking.go b/stress_test/codec/nonblocking.go index 74f9a098..727c9e81 100644 --- a/stress_test/codec/nonblocking.go +++ b/stress_test/codec/nonblocking.go @@ -31,7 +31,7 @@ func main() { } defer conn.Close() log.Print("client connected, starting to read") - cc, err := sonic.NewNonblockingCodecConn[[]byte, []byte](conn, codec, src, + cc, err := sonic.NewCodecConn[[]byte, []byte](conn, codec, src, dst) if err != nil { panic(err) diff --git a/tests/autobahn/client.go b/tests/autobahn/client.go index 3c11b3b4..68934cf7 100644 --- a/tests/autobahn/client.go +++ b/tests/autobahn/client.go @@ -12,6 +12,7 @@ import ( var ( addr = flag.String("addr", "ws://localhost:9001", "server address") testCase = flag.Int("case", -1, "autobahn test case to run") + utf8 = flag.Bool("utf8", false, "if true, payloads of text frames are utf8 validated") ) func main() { @@ -83,6 +84,16 @@ func runTest(i int) { if err != nil { panic(err) } + if *utf8 { + s.ValidateUTF8(true) + if !s.ValidatesUTF8() { + panic("UTF8 should be validated") + } + } else { + if s.ValidatesUTF8() { + panic("UTF8 should NOT be validated") + } + } done := false s.AsyncHandshake(fmt.Sprintf("%s/runCase?case=%d&agent=sonic", *addr, i), func(err error) { @@ -92,10 +103,11 @@ func runTest(i int) { b := make([]byte, 1024*1024) - var onAsyncRead websocket.AsyncMessageHandler + var onAsyncRead websocket.AsyncMessageCallback onAsyncRead = func(err error, n int, mt websocket.MessageType) { if err != nil { + s.Flush() done = true } else { b = b[:n]