Skip to content

Commit

Permalink
lib/web/UT update
Browse files Browse the repository at this point in the history
  • Loading branch information
greedy52 committed Jan 5, 2024
1 parent 262b31e commit 91ce8b9
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
4 changes: 4 additions & 0 deletions lib/web/conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ func (c *websocketALPNServerConn) Read(b []byte) (n int, err error) {
}

func (c *websocketALPNServerConn) Write(b []byte) (n int, err error) {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
if err := c.Conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
if isOKWebsocketCloseError(err) {
err = io.EOF
Expand All @@ -292,6 +294,8 @@ func (c *websocketALPNServerConn) Write(b []byte) (n int, err error) {
}

func (c *websocketALPNServerConn) WritePing() error {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
err := c.Conn.WriteMessage(websocket.PingMessage, nil)
if isOKWebsocketCloseError(err) {
err = io.EOF
Expand Down
67 changes: 58 additions & 9 deletions lib/web/conn_upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"testing"
"time"

"github.com/gobwas/ws"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -87,32 +88,41 @@ func TestHandlerConnectionUpgrade(t *testing.T) {

tests := []struct {
name string
inputUpgradeHeaderKey string
inputRequest *http.Request
inputUpgradeType string
checkHandlerError func(error) bool
checkClientConnString func(*testing.T, net.Conn, string)
}{
{
name: "unsupported type",
inputRequest: makeConnUpgradeRequest(t, "", "unsupported-protocol", expectedIP),
inputUpgradeType: "unsupported-protocol",
checkHandlerError: trace.IsNotFound,
},
{
name: "upgraded to ALPN",
name: "upgraded to ALPN (legacy)",
inputRequest: makeConnUpgradeRequest(t, "", constants.WebAPIConnUpgradeTypeALPN, expectedIP),
inputUpgradeType: constants.WebAPIConnUpgradeTypeALPN,
checkClientConnString: mustReadClientConnString,
},
{
name: "upgraded to ALPN with Ping",
name: "upgraded to ALPN with Ping (legacy)",
inputRequest: makeConnUpgradeRequest(t, "", constants.WebAPIConnUpgradeTypeALPNPing, expectedIP),
inputUpgradeType: constants.WebAPIConnUpgradeTypeALPNPing,
checkClientConnString: mustReadClientPingConnString,
},
{
name: "upgraded to ALPN with Teleport-specific header",
inputUpgradeHeaderKey: constants.WebAPIConnUpgradeTeleportHeader,
inputRequest: makeConnUpgradeRequest(t, constants.WebAPIConnUpgradeTeleportHeader, constants.WebAPIConnUpgradeTypeALPN, expectedIP),
inputUpgradeType: constants.WebAPIConnUpgradeTypeALPN,
checkClientConnString: mustReadClientConnString,
},
{
name: "upgraded to WebSocket",
inputRequest: makeConnUpgradeWebSocketRequest(t, constants.WebAPIConnUpgradeTypeALPN, expectedIP),
inputUpgradeType: constants.WebAPIConnUpgradeTypeWebSocket,
checkClientConnString: mustReadClientWebSocketConnString,
},
}

for _, test := range tests {
Expand All @@ -123,7 +133,6 @@ func TestHandlerConnectionUpgrade(t *testing.T) {

// serverConn will be hijacked.
w := newResponseWriterHijacker(nil, serverConn)
r := makeConnUpgradeRequest(t, test.inputUpgradeHeaderKey, test.inputUpgradeType, expectedIP)

// Serve the handler with XForwardedFor middleware to set IPs.
handlerErrChan := make(chan error, 1)
Expand All @@ -133,7 +142,7 @@ func TestHandlerConnectionUpgrade(t *testing.T) {
handlerErrChan <- err
})

NewXForwardedForMiddleware(connUpgradeHandler).ServeHTTP(w, r)
NewXForwardedForMiddleware(connUpgradeHandler).ServeHTTP(w, test.inputRequest)
}()

select {
Expand All @@ -146,7 +155,7 @@ func TestHandlerConnectionUpgrade(t *testing.T) {
}

case <-w.hijackedCtx.Done():
mustReadSwitchProtocolsResponse(t, r, clientConn, test.inputUpgradeType)
mustReadSwitchProtocolsResponse(t, test.inputRequest, clientConn, test.inputUpgradeType)
test.checkClientConnString(t, clientConn, expectedPayload)

case <-time.After(5 * time.Second):
Expand All @@ -170,6 +179,21 @@ func makeConnUpgradeRequest(t *testing.T, upgradeHeaderKey, upgradeType, xForwar
return r
}

func makeConnUpgradeWebSocketRequest(t *testing.T, alpnUpgradeType, xForwardedFor string) *http.Request {
t.Helper()

r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil)
require.NoError(t, err)

r.Header.Add("X-Forwarded-For", xForwardedFor)
r.Header.Add(constants.WebAPIConnUpgradeHeader, "websocket")
r.Header.Add(constants.WebAPIConnUpgradeConnectionHeader, "upgrade")
r.Header.Set("Sec-Websocket-Protocol", alpnUpgradeType)
r.Header.Set("Sec-Websocket-Version", "13")
r.Header.Set("Sec-Websocket-Key", "MTIzNDU2Nzg5MDEyMzQ1Ng==")
return r
}

func mustReadSwitchProtocolsResponse(t *testing.T, r *http.Request, clientConn net.Conn, upgradeType string) {
t.Helper()

Expand All @@ -180,8 +204,10 @@ func mustReadSwitchProtocolsResponse(t *testing.T, r *http.Request, clientConn n
io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()

if upgradeType != "websocket" {
require.Equal(t, upgradeType, resp.Header.Get(constants.WebAPIConnUpgradeTeleportHeader))
}
require.Equal(t, upgradeType, resp.Header.Get(constants.WebAPIConnUpgradeHeader))
require.Equal(t, upgradeType, resp.Header.Get(constants.WebAPIConnUpgradeTeleportHeader))
require.Equal(t, constants.WebAPIConnUpgradeConnectionType, resp.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
}
Expand All @@ -200,6 +226,25 @@ func mustReadClientPingConnString(t *testing.T, clientConn net.Conn, expectedPay
mustReadClientConnString(t, pingconn.New(clientConn), expectedPayload)
}

func mustReadClientWebSocketConnString(t *testing.T, clientConn net.Conn, expectedPayload string) {
t.Helper()

for {
frame, err := ws.ReadFrame(clientConn)
require.NoError(t, err)

switch frame.Header.OpCode {
case ws.OpBinary:
require.Equal(t, expectedPayload, string(frame.Payload))
return
case ws.OpPing:
continue
default:
require.Fail(t, "does not expect WebSocket frame %v", frame)
}
}
}

// responseWriterHijacker is a mock http.ResponseWriter that also serves a
// net.Conn for http.Hijacker.
type responseWriterHijacker struct {
Expand All @@ -226,5 +271,9 @@ func newResponseWriterHijacker(w http.ResponseWriter, conn net.Conn) *responseWr

func (h *responseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h.hijackedCtxCancel()
return h.conn, nil, nil
// buf is used by gorilla websocket upgrader.
reader := bufio.NewReaderSize(nil, 10)
writer := bufio.NewWriter(h.conn)
buf := bufio.NewReadWriter(reader, writer)
return h.conn, buf, nil
}

0 comments on commit 91ce8b9

Please sign in to comment.