Skip to content

Commit

Permalink
Issue #271: Support websocket for HTTPS upstream
Browse files Browse the repository at this point in the history
This patch adds support for websockets on HTTPS
upstream servers.

Fixes #271
  • Loading branch information
magiconair committed Apr 28, 2017
1 parent 77a489c commit 94beb91
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 25 deletions.
51 changes: 29 additions & 22 deletions proxy/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,29 +193,14 @@ func TestProxyLogOutput(t *testing.T) {
}

func TestProxyHTTPSUpstream(t *testing.T) {
var err error
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
server.TLS = &tls.Config{
Certificates: make([]tls.Certificate, 1),
}
server.TLS.Certificates[0], err = tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
t.Fatalf("failed to set cert")
}
server := httptest.NewUnstartedServer(okHandler)
server.TLS = tlsServerConfig()
server.StartTLS()
defer server.Close()

rootCAs := x509.NewCertPool()
if ok := rootCAs.AppendCertsFromPEM(internal.LocalhostCert); !ok {
t.Fatal("could not parse cert")
}
proxy := httptest.NewServer(&HTTPProxy{
Config: config.Proxy{},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: rootCAs},
},
Config: config.Proxy{},
Transport: &http.Transport{TLSClientConfig: tlsClientConfig()},
Lookup: func(r *http.Request) *route.Target {
tbl, _ := route.NewTable("route add srv / " + server.URL + ` opts "proto=https"`)
return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"])
Expand All @@ -233,9 +218,7 @@ func TestProxyHTTPSUpstream(t *testing.T) {
}

func TestProxyHTTPSUpstreamSkipVerify(t *testing.T) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
server := httptest.NewUnstartedServer(okHandler)
server.TLS = &tls.Config{}
server.StartTLS()
defer server.Close()
Expand Down Expand Up @@ -365,6 +348,30 @@ func gzipHandler(contentType string) http.HandlerFunc {
}
}

var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
})

func tlsInsecureConfig() *tls.Config {
return &tls.Config{InsecureSkipVerify: true}
}

func tlsClientConfig() *tls.Config {
rootCAs := x509.NewCertPool()
if ok := rootCAs.AppendCertsFromPEM(internal.LocalhostCert); !ok {
panic("could not parse cert")
}
return &tls.Config{RootCAs: rootCAs}
}

func tlsServerConfig() *tls.Config {
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
panic("failed to set cert")
}
return &tls.Config{Certificates: []tls.Certificate{cert}}
}

func mustParse(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
Expand Down
10 changes: 9 additions & 1 deletion proxy/http_proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package proxy

import (
"crypto/tls"
"net"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -102,7 +104,13 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var h http.Handler
switch {
case upgrade == "websocket" || upgrade == "Websocket":
h = newRawProxy(targetURL)
if targetURL.Scheme == "https" || targetURL.Scheme == "wss" {
h = newRawProxy(targetURL, func(network, address string) (net.Conn, error) {
return tls.Dial(network, address, transport.(*http.Transport).TLSClientConfig)
})
} else {
h = newRawProxy(targetURL, net.Dial)
}

case accept == "text/event-stream":
// use the flush interval for SSE (server-sent events)
Expand Down
6 changes: 4 additions & 2 deletions proxy/http_raw_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
// conn measures the number of open web socket connections
var conn = metrics.DefaultRegistry.GetCounter("ws.conn")

type dialFunc func(network, address string) (net.Conn, error)

// newRawProxy returns an HTTP handler which forwards data between
// an incoming and outgoing TCP connection including the original request.
// This handler establishes a new outgoing connection per request.
func newRawProxy(t *url.URL) http.Handler {
func newRawProxy(t *url.URL, dial dialFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn.Inc(1)
defer func() { conn.Inc(-1) }()
Expand All @@ -35,7 +37,7 @@ func newRawProxy(t *url.URL) http.Handler {
}
defer in.Close()

out, err := net.Dial("tcp", t.Host)
out, err := dial("tcp", t.Host)
if err != nil {
log.Printf("[ERROR] WS error for %s. %s", r.URL, err)
http.Error(w, "error contacting backend server", http.StatusInternalServerError)
Expand Down
107 changes: 107 additions & 0 deletions proxy/ws_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package proxy

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/fabiolb/fabio/config"
"github.com/fabiolb/fabio/route"

"golang.org/x/net/websocket"
)

func TestProxyWSUpstream(t *testing.T) {
wsServer := httptest.NewServer(websocket.Handler(wsEchoHandler))
defer wsServer.Close()
t.Log("Started WS server: ", wsServer.URL)

wssServer := httptest.NewUnstartedServer(websocket.Handler(wsEchoHandler))
wssServer.TLS = tlsServerConfig()
wssServer.StartTLS()
defer wssServer.Close()
t.Log("Started WSS server: ", wssServer.URL)

routes := "route add ws /ws " + wsServer.URL + "\n"
routes += "route add ws /wss " + wssServer.URL + ` opts "proto=https"` + "\n"
routes += "route add ws /insecure " + wssServer.URL + ` opts "proto=https tlsskipverify=true"`

httpProxy := httptest.NewServer(&HTTPProxy{
Config: config.Proxy{NoRouteStatus: 404},
Transport: &http.Transport{TLSClientConfig: tlsClientConfig()},
InsecureTransport: &http.Transport{TLSClientConfig: tlsInsecureConfig()},
Lookup: func(r *http.Request) *route.Target {
tbl, _ := route.NewTable(routes)
return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"])
},
})
defer httpProxy.Close()
t.Log("Started HTTP proxy: ", httpProxy.URL)

httpsProxy := httptest.NewUnstartedServer(&HTTPProxy{
Config: config.Proxy{NoRouteStatus: 404},
Transport: &http.Transport{TLSClientConfig: tlsClientConfig()},
InsecureTransport: &http.Transport{TLSClientConfig: tlsInsecureConfig()},
Lookup: func(r *http.Request) *route.Target {
tbl, _ := route.NewTable(routes)
return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"])
},
})
httpsProxy.TLS = tlsServerConfig()
httpsProxy.StartTLS()
defer httpsProxy.Close()
t.Log("Started HTTPS proxy: ", httpsProxy.URL)

wsServerURL := wsServer.URL[len("http://"):]
wssServerURL := wssServer.URL[len("https://"):]
httpProxyURL := httpProxy.URL[len("http://"):]
httpsProxyURL := httpsProxy.URL[len("https://"):]

t.Run("ws-ws direct", func(t *testing.T) { testWSEcho(t, "ws://"+wsServerURL+"/ws") })
t.Run("wss-wss direct", func(t *testing.T) { testWSEcho(t, "wss://"+wssServerURL+"/wss") })

t.Run("ws-ws via http proxy", func(t *testing.T) { testWSEcho(t, "ws://"+httpProxyURL+"/ws") })
t.Run("wss-ws via https proxy", func(t *testing.T) { testWSEcho(t, "wss://"+httpsProxyURL+"/ws") })

t.Run("ws-wss via http proxy", func(t *testing.T) { testWSEcho(t, "ws://"+httpProxyURL+"/wss") })
t.Run("wss-wss via https proxy", func(t *testing.T) { testWSEcho(t, "wss://"+httpsProxyURL+"/wss") })

t.Run("ws-wss tlsskipverify=true via http proxy", func(t *testing.T) { testWSEcho(t, "ws://"+httpProxyURL+"/insecure") })
t.Run("wss-wss tlsskipverify=true via https proxy", func(t *testing.T) { testWSEcho(t, "wss://"+httpsProxyURL+"/insecure") })
}

func testWSEcho(t *testing.T, url string) {
cfg, err := websocket.NewConfig(url, "http://localhost/")
if err != nil {
t.Fatalf("NewConfig: ", err)
}
if strings.HasPrefix(url, "wss://") {
cfg.TlsConfig = tlsClientConfig()
}
ws, err := websocket.DialConfig(cfg)
if err != nil {
t.Fatal(err)
}
defer ws.Close()

send := []byte("foo")
if _, err := ws.Write([]byte("foo")); err != nil {
t.Logf("ws.Write failed: %s", err)
}
recv := make([]byte, 100)
n, err := ws.Read(recv)
if err != nil {
t.Logf("ws.Read failed: %s", err)
}
recv = recv[:n]
if got, want := recv, send; !bytes.Equal(got, want) {
t.Fatalf("got %q want %q", got, want)
}
}

func wsEchoHandler(ws *websocket.Conn) {
io.Copy(ws, ws)
}

0 comments on commit 94beb91

Please sign in to comment.