Skip to content

Commit

Permalink
Rework MockServer Accept, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMaeder committed Nov 29, 2024
1 parent 99dacd3 commit 4a7da91
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 50 deletions.
123 changes: 77 additions & 46 deletions codec/websocket/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ func assertState(t *testing.T, ws *Stream, expected StreamState) {
func TestClientServerSendsInvalidCloseCode(t *testing.T) {
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -59,7 +63,8 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) {
assert.Equal(reason, "")
}
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -70,7 +75,7 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) {
}

done := false
ws.AsyncHandshake("ws://localhost:8080", func(err error) {
ws.AsyncHandshake(wsURI, func(err error) {
if err != nil {
t.Fatal(err)
}
Expand All @@ -90,11 +95,15 @@ func TestClientServerSendsInvalidCloseCode(t *testing.T) {
func TestClientEchoCloseCode(t *testing.T) {
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -122,7 +131,8 @@ func TestClientEchoCloseCode(t *testing.T) {
assert.Equal(reason, "something")
}
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -133,7 +143,7 @@ func TestClientEchoCloseCode(t *testing.T) {
}

done := false
ws.AsyncHandshake("ws://localhost:8080", func(err error) {
ws.AsyncHandshake(wsURI, func(err error) {
if err != nil {
t.Fatal(err)
}
Expand All @@ -155,11 +165,15 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) {
// the connection immediately with 1002/Protocol Error.
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -191,7 +205,8 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) {
assert.Empty(reason)
}
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -202,7 +217,7 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) {
}

done := false
ws.AsyncHandshake("ws://localhost:8080", func(err error) {
ws.AsyncHandshake(wsURI, func(err error) {
if err != nil {
t.Fatal(err)
}
Expand All @@ -223,11 +238,15 @@ func TestClientSendPingWithInvalidPayload(t *testing.T) {
func TestClientSendMessageWithPayload126(t *testing.T) {
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
Expand All @@ -242,7 +261,8 @@ func TestClientSendMessageWithPayload126(t *testing.T) {

frame.WriteTo(srv.conn)
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -253,7 +273,7 @@ func TestClientSendMessageWithPayload126(t *testing.T) {
}

done := false
ws.AsyncHandshake("ws://localhost:8080", func(err error) {
ws.AsyncHandshake(wsURI, func(err error) {
if err != nil {
t.Fatal(err)
}
Expand All @@ -277,11 +297,15 @@ func TestClientSendMessageWithPayload126(t *testing.T) {
func TestClientSendMessageWithPayload127(t *testing.T) {
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
Expand All @@ -296,7 +320,8 @@ func TestClientSendMessageWithPayload127(t *testing.T) {

frame.WriteTo(srv.conn)
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -307,7 +332,7 @@ func TestClientSendMessageWithPayload127(t *testing.T) {
}

done := false
ws.AsyncHandshake("ws://localhost:8080", func(err error) {
ws.AsyncHandshake(wsURI, func(err error) {
if err != nil {
t.Fatal(err)
}
Expand All @@ -329,11 +354,19 @@ func TestClientSendMessageWithPayload127(t *testing.T) {
}

func TestClientReconnectOnFailedRead(t *testing.T) {
port, err := GetFreePort()
if err != nil {
panic(err)
}

serverURI := fmt.Sprintf("localhost:%d", port)
wsURI := fmt.Sprintf("ws://localhost:%d", port)

go func() {
for i := 0; i < 10; i++ {
srv := &MockServer{}

err := srv.Accept("localhost:8080")
err := srv.Accept(serverURI)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -389,7 +422,7 @@ func TestClientReconnectOnFailedRead(t *testing.T) {
}

connect = func() {
ws.AsyncHandshake("ws://localhost:8080", onHandshake)
ws.AsyncHandshake(wsURI, onHandshake)
}

connect()
Expand Down Expand Up @@ -465,17 +498,20 @@ func TestClientFailedHandshakeNoServer(t *testing.T) {
}

func TestClientSuccessfulHandshake(t *testing.T) {
srv := &MockServer{}
srv := &MockServer{
portChan: make(chan int, 1),
}

go func() {
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-srv.portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -501,7 +537,7 @@ func TestClientSuccessfulHandshake(t *testing.T) {

assertState(t, ws, StateHandshake)

ws.AsyncHandshake("ws://localhost:8080", func(err error) {
ws.AsyncHandshake(wsURI, func(err error) {
if err != nil {
assertState(t, ws, StateTerminated)
} else {
Expand All @@ -524,17 +560,20 @@ func TestClientSuccessfulHandshake(t *testing.T) {
}

func TestClientSuccessfulHandshakeWithExtraHeaders(t *testing.T) {
srv := &MockServer{}
srv := &MockServer{
portChan: make(chan int, 1),
}

go func() {
defer srv.Close()

err := srv.Accept("localhost:8080")
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-srv.portChan)

ioc := sonic.MustIO()
defer ioc.Close()
Expand All @@ -558,7 +597,7 @@ func TestClientSuccessfulHandshakeWithExtraHeaders(t *testing.T) {
}

ws.AsyncHandshake(
"ws://localhost:8080",
wsURI,
func(err error) {
if err != nil {
assertState(t, ws, StateTerminated)
Expand Down Expand Up @@ -1647,23 +1686,19 @@ func TestClientAbnormalClose(t *testing.T) {
portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:0", func(port int) {
if port <= 0 {
panic(fmt.Sprintf("Got invalid port from MockServer: %d", port))
}
portChan <- port
})
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}

// Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame)
srv.Close()
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

Expand Down Expand Up @@ -1696,23 +1731,19 @@ func TestClientAsyncAbnormalClose(t *testing.T) {
portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
srv := &MockServer{
portChan: portChan,
}
defer srv.Close()

err := srv.Accept("localhost:0", func(port int) {
if port <= 0 {
panic(fmt.Sprintf("Got invalid port from MockServer: %d", port))
}
portChan <- port
})
err := srv.Accept("localhost:0")
if err != nil {
panic(err)
}

// Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame)
srv.Close()
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

Expand Down
22 changes: 18 additions & 4 deletions codec/websocket/test_main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ type MockServer struct {
conn net.Conn
closed int32
port int32
portChan chan int

Upgrade *http.Request
}

// Accept starts the mock server, listening on the specified address.
// If a callback is provided, it is invoked with the assigned port.
func (s *MockServer) Accept(addr string, opts ...func(int)) (err error) {
func (s *MockServer) Accept(addr string) (err error) {
s.ln, err = net.Listen("tcp", addr)
if err != nil {
return err
Expand All @@ -33,9 +34,8 @@ func (s *MockServer) Accept(addr string, opts ...func(int)) (err error) {
port := int(s.ln.Addr().(*net.TCPAddr).Port)
atomic.StoreInt32(&s.port, int32(port))

// Call port callback if provided
if len(opts) > 0 && opts[0] != nil {
opts[0](port)
if s.portChan != nil {
s.portChan <- port
}

conn, err := s.ln.Accept()
Expand Down Expand Up @@ -187,3 +187,17 @@ func (s *MockStream) Close() error {
func (s *MockStream) RawFd() int {
return -1
}

// GetFreePort asks the kernel for a free open port that is ready to use.
// From: https://gist.github.com/sevkin/96bdae9274465b2d09191384f86ef39d
func GetFreePort() (port int, err error) {
var a *net.TCPAddr
if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
}
return
}

0 comments on commit 4a7da91

Please sign in to comment.