diff --git a/codec/websocket/client_performance_test.go b/codec/websocket/client_performance_test.go index b1493c6..8341b8e 100644 --- a/codec/websocket/client_performance_test.go +++ b/codec/websocket/client_performance_test.go @@ -25,7 +25,7 @@ func TestClientReadWrite(t *testing.T) { } dur := time.Duration(idur) * time.Second - s := &MockServer{} + s := NewMockServer() go func() { defer s.Close() diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index 09d22c7..25a206b 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/talostrading/sonic" @@ -23,11 +22,12 @@ func assertState(t *testing.T, ws *Stream, expected StreamState) { func TestClientServerSendsInvalidCloseCode(t *testing.T) { assert := assert.New(t) + srv := NewMockServer() + go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -59,7 +59,6 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) { assert.Equal(reason, "") } }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -70,7 +69,7 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) { } done := false - ws.AsyncHandshake("ws://localhost:8080", func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { t.Fatal(err) } @@ -90,11 +89,12 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) { func TestClientEchoCloseCode(t *testing.T) { assert := assert.New(t) + srv := NewMockServer() + go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -122,7 +122,6 @@ func TestClientEchoCloseCode(t *testing.T) { assert.Equal(reason, "something") } }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -133,7 +132,7 @@ func TestClientEchoCloseCode(t *testing.T) { } done := false - ws.AsyncHandshake("ws://localhost:8080", func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { t.Fatal(err) } @@ -155,11 +154,12 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) { // the connection immediately with 1002/Protocol Error. assert := assert.New(t) + srv := NewMockServer() + go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -191,7 +191,6 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) { assert.Empty(reason) } }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -202,7 +201,7 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) { } done := false - ws.AsyncHandshake("ws://localhost:8080", func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { t.Fatal(err) } @@ -223,11 +222,12 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) { func TestClientSendMessageWithPayload126(t *testing.T) { assert := assert.New(t) + srv := NewMockServer() + go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -242,7 +242,6 @@ func TestClientSendMessageWithPayload126(t *testing.T) { frame.WriteTo(srv.conn) }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -253,7 +252,7 @@ func TestClientSendMessageWithPayload126(t *testing.T) { } done := false - ws.AsyncHandshake("ws://localhost:8080", func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { t.Fatal(err) } @@ -277,11 +276,12 @@ func TestClientSendMessageWithPayload126(t *testing.T) { func TestClientSendMessageWithPayload127(t *testing.T) { assert := assert.New(t) + srv := NewMockServer() + go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -296,7 +296,6 @@ func TestClientSendMessageWithPayload127(t *testing.T) { frame.WriteTo(srv.conn) }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -307,7 +306,7 @@ func TestClientSendMessageWithPayload127(t *testing.T) { } done := false - ws.AsyncHandshake("ws://localhost:8080", func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { t.Fatal(err) } @@ -329,20 +328,29 @@ func TestClientSendMessageWithPayload127(t *testing.T) { } func TestClientReconnectOnFailedRead(t *testing.T) { + srv := NewMockServer() + port := 0 + go func() { for i := 0; i < 10; i++ { - srv := &MockServer{} - - err := srv.Accept("localhost:8080") + var err error + if port != 0 { + err = srv.Accept(fmt.Sprintf("localhost:%d", port)) + } else { + err = srv.Accept(MockServerDynamicAddr) + } if err != nil { panic(err) } srv.Write([]byte("hello")) srv.Close() + + srv = NewMockServer() } }() - time.Sleep(10 * time.Millisecond) + + port = <- srv.portChan ioc := sonic.MustIO() defer ioc.Close() @@ -389,7 +397,7 @@ func TestClientReconnectOnFailedRead(t *testing.T) { } connect = func() { - ws.AsyncHandshake("ws://localhost:8080", onHandshake) + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", port), onHandshake) } connect() @@ -465,17 +473,16 @@ func TestClientFailedHandshakeNoServer(t *testing.T) { } func TestClientSuccessfulHandshake(t *testing.T) { - srv := &MockServer{} + srv := NewMockServer() go func() { defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -501,7 +508,7 @@ func TestClientSuccessfulHandshake(t *testing.T) { assertState(t, ws, StateHandshake) - ws.AsyncHandshake("ws://localhost:8080", func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { assertState(t, ws, StateTerminated) } else { @@ -524,17 +531,16 @@ func TestClientSuccessfulHandshake(t *testing.T) { } func TestClientSuccessfulHandshakeWithExtraHeaders(t *testing.T) { - srv := &MockServer{} + srv := NewMockServer() go func() { defer srv.Close() - err := srv.Accept("localhost:8080") + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } }() - time.Sleep(10 * time.Millisecond) ioc := sonic.MustIO() defer ioc.Close() @@ -558,7 +564,7 @@ func TestClientSuccessfulHandshakeWithExtraHeaders(t *testing.T) { } ws.AsyncHandshake( - "ws://localhost:8080", + fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { if err != nil { assertState(t, ws, StateTerminated) @@ -1644,18 +1650,12 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { func TestClientAbnormalClose(t *testing.T) { assert := assert.New(t) - portChan := make(chan int, 1) + srv := NewMockServer() go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:0", func(port int) { - if port <= 0 { - panic(fmt.Sprintf("Got invalid port from MockServer: %d", port)) - } - portChan <- port - }) + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -1663,9 +1663,6 @@ func TestClientAbnormalClose(t *testing.T) { // Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame) srv.Close() }() - time.Sleep(10 * time.Millisecond) - - wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan) ioc := sonic.MustIO() defer ioc.Close() @@ -1673,7 +1670,7 @@ func TestClientAbnormalClose(t *testing.T) { ws, err := NewWebsocketStream(ioc, nil, RoleClient) assert.Nil(err) - err = ws.Handshake(wsURI) + err = ws.Handshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan)) assert.Nil(err) assert.Equal(ws.State(), StateActive) // Verify WebSocket active @@ -1693,18 +1690,12 @@ func TestClientAbnormalClose(t *testing.T) { func TestClientAsyncAbnormalClose(t *testing.T) { assert := assert.New(t) - portChan := make(chan int, 1) + srv := NewMockServer() go func() { - srv := &MockServer{} defer srv.Close() - err := srv.Accept("localhost:0", func(port int) { - if port <= 0 { - panic(fmt.Sprintf("Got invalid port from MockServer: %d", port)) - } - portChan <- port - }) + err := srv.Accept(MockServerDynamicAddr) if err != nil { panic(err) } @@ -1712,9 +1703,6 @@ func TestClientAsyncAbnormalClose(t *testing.T) { // Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame) srv.Close() }() - time.Sleep(10 * time.Millisecond) - - wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan) ioc := sonic.MustIO() defer ioc.Close() @@ -1723,7 +1711,7 @@ func TestClientAsyncAbnormalClose(t *testing.T) { assert.Nil(err) done := false - ws.AsyncHandshake(wsURI, func(err error) { + ws.AsyncHandshake(fmt.Sprintf("ws://localhost:%d", <-srv.portChan), func(err error) { assert.Nil(err) assert.Equal(ws.State(), StateActive) // Verify WebSocket active diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index 368d841..8d32b72 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -18,25 +18,33 @@ type MockServer struct { conn net.Conn closed int32 port int32 + portChan chan int Upgrade *http.Request } -// Accept starts the mock server, listening on the specified address. -// If a callback is provided, it is invoked with the assigned port. -func (s *MockServer) Accept(addr string, opts ...func(int)) (err error) { - s.ln, err = net.Listen("tcp", addr) +func NewMockServer() *MockServer { + return &MockServer{ + portChan: make(chan int, 1), // Buffer port channel to prevent blocking server + } +} + +// Accept starts the mock server on the specified address. +// If MockServerDynamicAddr is provided as the address, the server binds to any available port. +const MockServerDynamicAddr = "" +func (s *MockServer) Accept(addr string) (err error) { + if addr == MockServerDynamicAddr { + s.ln, err = net.Listen("tcp", "localhost:0") + } else { + s.ln, err = net.Listen("tcp", addr) + } if err != nil { return err } port := int(s.ln.Addr().(*net.TCPAddr).Port) atomic.StoreInt32(&s.port, int32(port)) - - // Call port callback if provided - if len(opts) > 0 && opts[0] != nil { - opts[0](port) - } + s.portChan <- port conn, err := s.ln.Accept() if err != nil {