Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS Routing native WebSocket connection upgrade support #36343

Merged
merged 11 commits into from
Feb 12, 2024
63 changes: 54 additions & 9 deletions api/client/alpn_conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ func isUnadvertisedALPNError(err error) bool {
// OverwriteALPNConnUpgradeRequirementByEnv overwrites ALPN connection upgrade
// requirement by environment variable.
//
// TODO(greedy52) DELETE in 15.0
// TODO(greedy52) DELETE in ??. Note that this toggle was planned to be deleted
// in 15.0 when the feature exits preview. However, many users still rely on
// this manual toggle as IsALPNConnUpgradeRequired cannot detect many
// situations where connection upgrade is required. This can be deleted once
// IsALPNConnUpgradeRequired is improved.
func OverwriteALPNConnUpgradeRequirementByEnv(addr string) (bool, bool) {
envValue := os.Getenv(defaults.TLSRoutingConnUpgradeEnvVar)
if envValue == "" {
Expand Down Expand Up @@ -184,8 +188,6 @@ func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withP

// DialContext implements ContextDialer
func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logrus.Debugf("ALPN connection upgrade for %v.", addr)

tlsConn, err := tlsutils.TLSDial(ctx, d.dialer, network, addr, d.tlsConfig.Clone())
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -210,14 +212,28 @@ func (d *alpnConnUpgradeDialer) upgradeType() string {
return constants.WebAPIConnUpgradeTypeALPN
}

func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (net.Conn, error) {
func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string) (net.Conn, error) {
req, err := http.NewRequest(http.MethodGet, api.String(), nil)
if err != nil {
return nil, trace.Wrap(err)
}

req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType)
req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, upgradeType)
challengeKey, err := generateWebSocketChallengeKey()
if err != nil {
return nil, trace.Wrap(err)
}

// Prefer "websocket".
if useConnUpgradeMode.useWebSocket() {
applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey)
}

// Append "legacy" custom upgrade type.
// TODO(greedy52) DELETE in 17.0
if useConnUpgradeMode.useLegacy() {
req.Header.Add(constants.WebAPIConnUpgradeHeader, alpnUpgradeType)
req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, alpnUpgradeType)
}

// Set "Connection" header to meet RFC spec:
// https://datatracker.ietf.org/doc/html/rfc2616#section-14.42
Expand All @@ -229,7 +245,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (n
// require this header to be set to complete the upgrade flow. The header
// must be set on both the upgrade request here and the 101 Switching
// Protocols response from the server.
req.Header.Add(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType)
req.Header.Set(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType)

// Send the request and check if upgrade is successful.
if err = req.Write(conn); err != nil {
Expand All @@ -246,15 +262,44 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (n
return nil, trace.NotImplemented(
"connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.",
constants.WebAPIConnUpgrade,
upgradeType,
alpnUpgradeType,
resp.StatusCode,
)
}
return nil, trace.BadParameter("failed to switch Protocols %v", resp.StatusCode)
}

if upgradeType == constants.WebAPIConnUpgradeTypeALPNPing {
// Handle WebSocket.
if resp.Header.Get(constants.WebAPIConnUpgradeHeader) == constants.WebAPIConnUpgradeTypeWebSocket {
if err := checkWebSocketUpgradeResponse(resp, alpnUpgradeType, challengeKey); err != nil {
return nil, trace.Wrap(err)
}

logrus.WithField("hostname", api.Host).Debug("Performing ALPN WebSocket connection upgrade.")
return newWebSocketALPNClientConn(conn), nil
}

// Handle "legacy".
// TODO(greedy52) DELETE in 17.0.
logrus.WithField("hostname", api.Host).Debug("Performing ALPN legacy connection upgrade.")
if alpnUpgradeType == constants.WebAPIConnUpgradeTypeALPNPing {
return pingconn.New(conn), nil
}
return conn, nil
}

type connUpgradeMode string

func (m connUpgradeMode) useWebSocket() bool {
// Use WebSocket as long as it's not legacy only.
return strings.ToLower(string(m)) != "legacy"
}

func (m connUpgradeMode) useLegacy() bool {
// Use legacy as long as it's not WebSocket only.
return strings.ToLower(string(m)) != "websocket"
}

var (
useConnUpgradeMode connUpgradeMode = connUpgradeMode(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar))
)
136 changes: 124 additions & 12 deletions api/client/alpn_conn_upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"net"
"net/http"
Expand All @@ -28,6 +29,7 @@ import (
"testing"
"time"

"github.com/gobwas/ws"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -164,12 +166,23 @@ func TestALPNConnUpgradeDialer(t *testing.T) {
wantError bool
}{
{
name: "connection upgrade",
serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
// TODO(greedy52) DELETE in 17.0
name: "connection upgrade (legacy)",
serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
},
{
name: "connection upgrade with ping",
serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
// TODO(greedy52) DELETE in 17.0
name: "connection upgrade with ping (legacy)",
serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
withPing: true,
},
{
name: "connection upgrade (WebSocket)",
serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
},
{
name: "connection upgrade with ping (WebSocket)",
serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
withPing: true,
},
{
Expand Down Expand Up @@ -230,11 +243,27 @@ func TestALPNConnUpgradeDialer(t *testing.T) {
}

func mustReadConnData(t *testing.T, conn net.Conn, wantText string) {
data := make([]byte, len(wantText)*2)
t.Helper()

require.NotEmpty(t, wantText)

// Use a small buffer.
bufferSize := len(wantText) - 1
data := make([]byte, bufferSize)
n, err := conn.Read(data)
require.NoError(t, err)
require.Len(t, wantText, n)
require.Equal(t, wantText, string(data[:n]))
require.Equal(t, bufferSize, n)
actualText := string(data)

// Now read it again to get the full text. This tests
// websocketALPNClientConn.readBuffer is implemented correctly.
data = make([]byte, bufferSize)
n, err = conn.Read(data)
require.NoError(t, err)
require.Equal(t, 1, n)
actualText += string(data[:1])

require.Equal(t, wantText, actualText)
}

type mockALPNServer struct {
Expand Down Expand Up @@ -291,15 +320,15 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe
return m
}

// mockConnUpgradeHandler mocks the server side implementation to handle an
// upgrade request and sends back some data inside the tunnel.
func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
// mockLegacyConnUpgradeHandler mocks the server side implementation to handle
// an upgrade request and sends back some data inside the tunnel.
func mockLegacyConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
t.Helper()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader))
require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeTeleportHeader))
require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), upgradeType)
require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeTeleportHeader), upgradeType)
require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))

hj, ok := w.(http.Hijacker)
Expand Down Expand Up @@ -334,6 +363,49 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http
})
}

// mockWebSocketConnUpgradeHandler mocks the server side implementation to handle
// a WebSocket upgrade request and sends back some data inside the tunnel.
func mockWebSocketConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
t.Helper()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), "websocket")
require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))
require.Equal(t, upgradeType, r.Header.Get("Sec-Websocket-Protocol"))
require.Equal(t, "13", r.Header.Get("Sec-Websocket-Version"))

challengeKey := r.Header.Get("Sec-Websocket-Key")
challengeKeyDecoded, err := base64.StdEncoding.DecodeString(challengeKey)
require.NoError(t, err)
require.Len(t, challengeKeyDecoded, 16)

hj, ok := w.(http.Hijacker)
require.True(t, ok)

conn, _, err := hj.Hijack()
require.NoError(t, err)
defer conn.Close()

// Upgrade response.
response := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
}
response.Header.Set("Upgrade", "websocket")
response.Header.Set("Sec-WebSocket-Protocol", upgradeType)
response.Header.Set("Sec-WebSocket-Accept", computeWebSocketAcceptKey(challengeKey))
require.NoError(t, response.Write(conn))

// Upgraded.
frame := ws.NewFrame(ws.OpBinary, true, write)
frame.Header.Masked = true
require.NoError(t, ws.WriteFrame(conn, frame))
})
}

func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
t.Helper()

Expand All @@ -350,3 +422,43 @@ func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
go http.Serve(listener, handler)
return handler, url
}

func Test_connUpgradeMode(t *testing.T) {
tests := []struct {
envVarValue string
wantUseWebSocket require.BoolAssertionFunc
wantUseLegacy require.BoolAssertionFunc
}{
{
envVarValue: "",
wantUseWebSocket: require.True,
wantUseLegacy: require.True,
},
{
envVarValue: "WebSocket",
wantUseWebSocket: require.True,
wantUseLegacy: require.False,
},
{
envVarValue: "websocket",
wantUseWebSocket: require.True,
wantUseLegacy: require.False,
},
{
envVarValue: "legacy",
wantUseWebSocket: require.False,
wantUseLegacy: require.True,
},
{
envVarValue: "default",
wantUseWebSocket: require.True,
wantUseLegacy: require.True,
},
}

for _, test := range tests {
mode := connUpgradeMode(test.envVarValue)
test.wantUseWebSocket(t, mode.useWebSocket())
test.wantUseLegacy(t, mode.useLegacy())
}
}
Loading
Loading