diff --git a/proxy/http_integration_test.go b/proxy/http_integration_test.go index 50a28fceb..8c1eb9b50 100644 --- a/proxy/http_integration_test.go +++ b/proxy/http_integration_test.go @@ -193,29 +193,14 @@ func TestProxyLogOutput(t *testing.T) { } func TestProxyHTTPSUpstream(t *testing.T) { - var err error - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("OK")) - })) - server.TLS = &tls.Config{ - Certificates: make([]tls.Certificate, 1), - } - server.TLS.Certificates[0], err = tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) - if err != nil { - t.Fatalf("failed to set cert") - } + server := httptest.NewUnstartedServer(okHandler) + server.TLS = tlsServerConfig() server.StartTLS() defer server.Close() - rootCAs := x509.NewCertPool() - if ok := rootCAs.AppendCertsFromPEM(internal.LocalhostCert); !ok { - t.Fatal("could not parse cert") - } proxy := httptest.NewServer(&HTTPProxy{ - Config: config.Proxy{}, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: rootCAs}, - }, + Config: config.Proxy{}, + Transport: &http.Transport{TLSClientConfig: tlsClientConfig()}, Lookup: func(r *http.Request) *route.Target { tbl, _ := route.NewTable("route add srv / " + server.URL + ` opts "proto=https"`) return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"]) @@ -233,9 +218,7 @@ func TestProxyHTTPSUpstream(t *testing.T) { } func TestProxyHTTPSUpstreamSkipVerify(t *testing.T) { - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("OK")) - })) + server := httptest.NewUnstartedServer(okHandler) server.TLS = &tls.Config{} server.StartTLS() defer server.Close() @@ -365,6 +348,30 @@ func gzipHandler(contentType string) http.HandlerFunc { } } +var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) +}) + +func tlsInsecureConfig() *tls.Config { + return &tls.Config{InsecureSkipVerify: true} +} + +func tlsClientConfig() *tls.Config { + rootCAs := x509.NewCertPool() + if ok := rootCAs.AppendCertsFromPEM(internal.LocalhostCert); !ok { + panic("could not parse cert") + } + return &tls.Config{RootCAs: rootCAs} +} + +func tlsServerConfig() *tls.Config { + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + panic("failed to set cert") + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} + func mustParse(rawurl string) *url.URL { u, err := url.Parse(rawurl) if err != nil { diff --git a/proxy/http_proxy.go b/proxy/http_proxy.go index 45af4e030..c699c9b6a 100644 --- a/proxy/http_proxy.go +++ b/proxy/http_proxy.go @@ -1,6 +1,8 @@ package proxy import ( + "crypto/tls" + "net" "net/http" "net/url" "strconv" @@ -102,7 +104,13 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { var h http.Handler switch { case upgrade == "websocket" || upgrade == "Websocket": - h = newRawProxy(targetURL) + if targetURL.Scheme == "https" || targetURL.Scheme == "wss" { + h = newRawProxy(targetURL, func(network, address string) (net.Conn, error) { + return tls.Dial(network, address, transport.(*http.Transport).TLSClientConfig) + }) + } else { + h = newRawProxy(targetURL, net.Dial) + } case accept == "text/event-stream": // use the flush interval for SSE (server-sent events) diff --git a/proxy/http_raw_handler.go b/proxy/http_raw_handler.go index 9edbfde14..a92d0ef86 100644 --- a/proxy/http_raw_handler.go +++ b/proxy/http_raw_handler.go @@ -13,10 +13,12 @@ import ( // conn measures the number of open web socket connections var conn = metrics.DefaultRegistry.GetCounter("ws.conn") +type dialFunc func(network, address string) (net.Conn, error) + // newRawProxy returns an HTTP handler which forwards data between // an incoming and outgoing TCP connection including the original request. // This handler establishes a new outgoing connection per request. -func newRawProxy(t *url.URL) http.Handler { +func newRawProxy(t *url.URL, dial dialFunc) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn.Inc(1) defer func() { conn.Inc(-1) }() @@ -35,7 +37,7 @@ func newRawProxy(t *url.URL) http.Handler { } defer in.Close() - out, err := net.Dial("tcp", t.Host) + out, err := dial("tcp", t.Host) if err != nil { log.Printf("[ERROR] WS error for %s. %s", r.URL, err) http.Error(w, "error contacting backend server", http.StatusInternalServerError) diff --git a/proxy/ws_integration_test.go b/proxy/ws_integration_test.go new file mode 100644 index 000000000..878c61064 --- /dev/null +++ b/proxy/ws_integration_test.go @@ -0,0 +1,107 @@ +package proxy + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/fabiolb/fabio/config" + "github.com/fabiolb/fabio/route" + + "golang.org/x/net/websocket" +) + +func TestProxyWSUpstream(t *testing.T) { + wsServer := httptest.NewServer(websocket.Handler(wsEchoHandler)) + defer wsServer.Close() + t.Log("Started WS server: ", wsServer.URL) + + wssServer := httptest.NewUnstartedServer(websocket.Handler(wsEchoHandler)) + wssServer.TLS = tlsServerConfig() + wssServer.StartTLS() + defer wssServer.Close() + t.Log("Started WSS server: ", wssServer.URL) + + routes := "route add ws /ws " + wsServer.URL + "\n" + routes += "route add ws /wss " + wssServer.URL + ` opts "proto=https"` + "\n" + routes += "route add ws /insecure " + wssServer.URL + ` opts "proto=https tlsskipverify=true"` + + httpProxy := httptest.NewServer(&HTTPProxy{ + Config: config.Proxy{NoRouteStatus: 404}, + Transport: &http.Transport{TLSClientConfig: tlsClientConfig()}, + InsecureTransport: &http.Transport{TLSClientConfig: tlsInsecureConfig()}, + Lookup: func(r *http.Request) *route.Target { + tbl, _ := route.NewTable(routes) + return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"]) + }, + }) + defer httpProxy.Close() + t.Log("Started HTTP proxy: ", httpProxy.URL) + + httpsProxy := httptest.NewUnstartedServer(&HTTPProxy{ + Config: config.Proxy{NoRouteStatus: 404}, + Transport: &http.Transport{TLSClientConfig: tlsClientConfig()}, + InsecureTransport: &http.Transport{TLSClientConfig: tlsInsecureConfig()}, + Lookup: func(r *http.Request) *route.Target { + tbl, _ := route.NewTable(routes) + return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"]) + }, + }) + httpsProxy.TLS = tlsServerConfig() + httpsProxy.StartTLS() + defer httpsProxy.Close() + t.Log("Started HTTPS proxy: ", httpsProxy.URL) + + wsServerURL := wsServer.URL[len("http://"):] + wssServerURL := wssServer.URL[len("https://"):] + httpProxyURL := httpProxy.URL[len("http://"):] + httpsProxyURL := httpsProxy.URL[len("https://"):] + + t.Run("ws-ws direct", func(t *testing.T) { testWSEcho(t, "ws://"+wsServerURL+"/ws") }) + t.Run("wss-wss direct", func(t *testing.T) { testWSEcho(t, "wss://"+wssServerURL+"/wss") }) + + t.Run("ws-ws via http proxy", func(t *testing.T) { testWSEcho(t, "ws://"+httpProxyURL+"/ws") }) + t.Run("wss-ws via https proxy", func(t *testing.T) { testWSEcho(t, "wss://"+httpsProxyURL+"/ws") }) + + t.Run("ws-wss via http proxy", func(t *testing.T) { testWSEcho(t, "ws://"+httpProxyURL+"/wss") }) + t.Run("wss-wss via https proxy", func(t *testing.T) { testWSEcho(t, "wss://"+httpsProxyURL+"/wss") }) + + t.Run("ws-wss tlsskipverify=true via http proxy", func(t *testing.T) { testWSEcho(t, "ws://"+httpProxyURL+"/insecure") }) + t.Run("wss-wss tlsskipverify=true via https proxy", func(t *testing.T) { testWSEcho(t, "wss://"+httpsProxyURL+"/insecure") }) +} + +func testWSEcho(t *testing.T, url string) { + cfg, err := websocket.NewConfig(url, "http://localhost/") + if err != nil { + t.Fatalf("NewConfig: ", err) + } + if strings.HasPrefix(url, "wss://") { + cfg.TlsConfig = tlsClientConfig() + } + ws, err := websocket.DialConfig(cfg) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + send := []byte("foo") + if _, err := ws.Write([]byte("foo")); err != nil { + t.Logf("ws.Write failed: %s", err) + } + recv := make([]byte, 100) + n, err := ws.Read(recv) + if err != nil { + t.Logf("ws.Read failed: %s", err) + } + recv = recv[:n] + if got, want := recv, send; !bytes.Equal(got, want) { + t.Fatalf("got %q want %q", got, want) + } +} + +func wsEchoHandler(ws *websocket.Conn) { + io.Copy(ws, ws) +}