Skip to content

Commit

Permalink
TLS Routing native WebSocket connection upgrade support (#36343)
Browse files Browse the repository at this point in the history
* TLS routing connection upgrade using native websocket

* update ut in api

* lib/web/UT update

* fix typo, lint and race

* deal with subprotocol negotiation

* review comments round 1

* fix lint?

* add UT and address some other comments

* add env var to toggle mode

* fix lint
  • Loading branch information
greedy52 committed Feb 12, 2024
1 parent aafac53 commit fcb3aff
Show file tree
Hide file tree
Showing 13 changed files with 736 additions and 48 deletions.
63 changes: 54 additions & 9 deletions api/client/alpn_conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ func isUnadvertisedALPNError(err error) bool {
// OverwriteALPNConnUpgradeRequirementByEnv overwrites ALPN connection upgrade
// requirement by environment variable.
//
// TODO(greedy52) DELETE in 15.0
// TODO(greedy52) DELETE in ??. Note that this toggle was planned to be deleted
// in 15.0 when the feature exits preview. However, many users still rely on
// this manual toggle as IsALPNConnUpgradeRequired cannot detect many
// situations where connection upgrade is required. This can be deleted once
// IsALPNConnUpgradeRequired is improved.
func OverwriteALPNConnUpgradeRequirementByEnv(addr string) (bool, bool) {
envValue := os.Getenv(defaults.TLSRoutingConnUpgradeEnvVar)
if envValue == "" {
Expand Down Expand Up @@ -184,8 +188,6 @@ func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withP

// DialContext implements ContextDialer
func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logrus.Debugf("ALPN connection upgrade for %v.", addr)

tlsConn, err := tlsutils.TLSDial(ctx, d.dialer, network, addr, d.tlsConfig.Clone())
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -210,14 +212,28 @@ func (d *alpnConnUpgradeDialer) upgradeType() string {
return constants.WebAPIConnUpgradeTypeALPN
}

func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (net.Conn, error) {
func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string) (net.Conn, error) {
req, err := http.NewRequest(http.MethodGet, api.String(), nil)
if err != nil {
return nil, trace.Wrap(err)
}

req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType)
req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, upgradeType)
challengeKey, err := generateWebSocketChallengeKey()
if err != nil {
return nil, trace.Wrap(err)
}

// Prefer "websocket".
if useConnUpgradeMode.useWebSocket() {
applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey)
}

// Append "legacy" custom upgrade type.
// TODO(greedy52) DELETE in 17.0
if useConnUpgradeMode.useLegacy() {
req.Header.Add(constants.WebAPIConnUpgradeHeader, alpnUpgradeType)
req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, alpnUpgradeType)
}

// Set "Connection" header to meet RFC spec:
// https://datatracker.ietf.org/doc/html/rfc2616#section-14.42
Expand All @@ -229,7 +245,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (n
// require this header to be set to complete the upgrade flow. The header
// must be set on both the upgrade request here and the 101 Switching
// Protocols response from the server.
req.Header.Add(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType)
req.Header.Set(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType)

// Send the request and check if upgrade is successful.
if err = req.Write(conn); err != nil {
Expand All @@ -246,15 +262,44 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (n
return nil, trace.NotImplemented(
"connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.",
constants.WebAPIConnUpgrade,
upgradeType,
alpnUpgradeType,
resp.StatusCode,
)
}
return nil, trace.BadParameter("failed to switch Protocols %v", resp.StatusCode)
}

if upgradeType == constants.WebAPIConnUpgradeTypeALPNPing {
// Handle WebSocket.
if resp.Header.Get(constants.WebAPIConnUpgradeHeader) == constants.WebAPIConnUpgradeTypeWebSocket {
if err := checkWebSocketUpgradeResponse(resp, alpnUpgradeType, challengeKey); err != nil {
return nil, trace.Wrap(err)
}

logrus.WithField("hostname", api.Host).Debug("Performing ALPN WebSocket connection upgrade.")
return newWebSocketALPNClientConn(conn), nil
}

// Handle "legacy".
// TODO(greedy52) DELETE in 17.0.
logrus.WithField("hostname", api.Host).Debug("Performing ALPN legacy connection upgrade.")
if alpnUpgradeType == constants.WebAPIConnUpgradeTypeALPNPing {
return pingconn.New(conn), nil
}
return conn, nil
}

type connUpgradeMode string

func (m connUpgradeMode) useWebSocket() bool {
// Use WebSocket as long as it's not legacy only.
return strings.ToLower(string(m)) != "legacy"
}

func (m connUpgradeMode) useLegacy() bool {
// Use legacy as long as it's not WebSocket only.
return strings.ToLower(string(m)) != "websocket"
}

var (
useConnUpgradeMode connUpgradeMode = connUpgradeMode(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar))
)
136 changes: 124 additions & 12 deletions api/client/alpn_conn_upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"net"
"net/http"
Expand All @@ -28,6 +29,7 @@ import (
"testing"
"time"

"github.com/gobwas/ws"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -164,12 +166,23 @@ func TestALPNConnUpgradeDialer(t *testing.T) {
wantError bool
}{
{
name: "connection upgrade",
serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
// TODO(greedy52) DELETE in 17.0
name: "connection upgrade (legacy)",
serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
},
{
name: "connection upgrade with ping",
serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
// TODO(greedy52) DELETE in 17.0
name: "connection upgrade with ping (legacy)",
serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
withPing: true,
},
{
name: "connection upgrade (WebSocket)",
serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
},
{
name: "connection upgrade with ping (WebSocket)",
serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
withPing: true,
},
{
Expand Down Expand Up @@ -230,11 +243,27 @@ func TestALPNConnUpgradeDialer(t *testing.T) {
}

func mustReadConnData(t *testing.T, conn net.Conn, wantText string) {
data := make([]byte, len(wantText)*2)
t.Helper()

require.NotEmpty(t, wantText)

// Use a small buffer.
bufferSize := len(wantText) - 1
data := make([]byte, bufferSize)
n, err := conn.Read(data)
require.NoError(t, err)
require.Len(t, wantText, n)
require.Equal(t, wantText, string(data[:n]))
require.Equal(t, bufferSize, n)
actualText := string(data)

// Now read it again to get the full text. This tests
// websocketALPNClientConn.readBuffer is implemented correctly.
data = make([]byte, bufferSize)
n, err = conn.Read(data)
require.NoError(t, err)
require.Equal(t, 1, n)
actualText += string(data[:1])

require.Equal(t, wantText, actualText)
}

type mockALPNServer struct {
Expand Down Expand Up @@ -291,15 +320,15 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe
return m
}

// mockConnUpgradeHandler mocks the server side implementation to handle an
// upgrade request and sends back some data inside the tunnel.
func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
// mockLegacyConnUpgradeHandler mocks the server side implementation to handle
// an upgrade request and sends back some data inside the tunnel.
func mockLegacyConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
t.Helper()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader))
require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeTeleportHeader))
require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), upgradeType)
require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeTeleportHeader), upgradeType)
require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))

hj, ok := w.(http.Hijacker)
Expand Down Expand Up @@ -334,6 +363,49 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http
})
}

// mockWebSocketConnUpgradeHandler mocks the server side implementation to handle
// a WebSocket upgrade request and sends back some data inside the tunnel.
func mockWebSocketConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
t.Helper()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), "websocket")
require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))
require.Equal(t, upgradeType, r.Header.Get("Sec-Websocket-Protocol"))
require.Equal(t, "13", r.Header.Get("Sec-Websocket-Version"))

challengeKey := r.Header.Get("Sec-Websocket-Key")
challengeKeyDecoded, err := base64.StdEncoding.DecodeString(challengeKey)
require.NoError(t, err)
require.Len(t, challengeKeyDecoded, 16)

hj, ok := w.(http.Hijacker)
require.True(t, ok)

conn, _, err := hj.Hijack()
require.NoError(t, err)
defer conn.Close()

// Upgrade response.
response := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
}
response.Header.Set("Upgrade", "websocket")
response.Header.Set("Sec-WebSocket-Protocol", upgradeType)
response.Header.Set("Sec-WebSocket-Accept", computeWebSocketAcceptKey(challengeKey))
require.NoError(t, response.Write(conn))

// Upgraded.
frame := ws.NewFrame(ws.OpBinary, true, write)
frame.Header.Masked = true
require.NoError(t, ws.WriteFrame(conn, frame))
})
}

func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
t.Helper()

Expand All @@ -350,3 +422,43 @@ func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
go http.Serve(listener, handler)
return handler, url
}

func Test_connUpgradeMode(t *testing.T) {
tests := []struct {
envVarValue string
wantUseWebSocket require.BoolAssertionFunc
wantUseLegacy require.BoolAssertionFunc
}{
{
envVarValue: "",
wantUseWebSocket: require.True,
wantUseLegacy: require.True,
},
{
envVarValue: "WebSocket",
wantUseWebSocket: require.True,
wantUseLegacy: require.False,
},
{
envVarValue: "websocket",
wantUseWebSocket: require.True,
wantUseLegacy: require.False,
},
{
envVarValue: "legacy",
wantUseWebSocket: require.False,
wantUseLegacy: require.True,
},
{
envVarValue: "default",
wantUseWebSocket: require.True,
wantUseLegacy: require.True,
},
}

for _, test := range tests {
mode := connUpgradeMode(test.envVarValue)
test.wantUseWebSocket(t, mode.useWebSocket())
test.wantUseLegacy(t, mode.useLegacy())
}
}
Loading

0 comments on commit fcb3aff

Please sign in to comment.