From dacd0353f962b4e49f19c5d087ae3d5c1e119d2d Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sat, 2 May 2020 20:54:58 +0200 Subject: [PATCH] HostClient can't switch between protocols --- client.go | 35 ++++++++--------------------------- client_test.go | 49 +++++++++++++++++++++++++++++++++++++++---------- http.go | 5 ++--- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/client.go b/client.go index d5f5e2ed23..4b6c2f0cb5 100644 --- a/client.go +++ b/client.go @@ -881,6 +881,9 @@ var ( // ErrTooManyRedirects is returned by clients when the number of redirects followed // exceed the max count. ErrTooManyRedirects = errors.New("too many redirects detected when doing the request") + + // HostClients are only able to follow redirects to the same protocol. + ErrHostClientRedirectToDifferentScheme = errors.New("HostClient can't follow redirects to a different protocol, please use Client instead") ) const defaultMaxRedirectsCount = 16 @@ -903,27 +906,11 @@ func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clie } func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) { - scheme := req.uri.Scheme() - req.schemaUpdate = false redirectsCount := 0 for { - // In case redirect to different scheme - if redirectsCount > 0 && !bytes.Equal(scheme, req.uri.Scheme()) { - if strings.HasPrefix(url, string(strHTTPS)) { - req.isTLS = true - req.uri.SetSchemeBytes(strHTTPS) - } else { - req.isTLS = false - req.uri.SetSchemeBytes(strHTTP) - } - scheme = req.uri.Scheme() - req.schemaUpdate = true - } - - req.parsedURI = false - req.Header.host = req.Header.host[:0] req.SetRequestURI(url) + req.parseURI() if err = c.Do(req, resp); err != nil { break @@ -1271,6 +1258,10 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) panic("BUG: resp cannot be nil") } + if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) { + return false, ErrHostClientRedirectToDifferentScheme + } + atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix)) // Free up resources occupied by response before sending the request, @@ -1285,16 +1276,6 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) req.URI().DisablePathNormalizing = true } - // If we detected a redirect to another schema - if req.schemaUpdate { - c.IsTLS = bytes.Equal(req.URI().Scheme(), strHTTPS) - c.Addr = addMissingPort(string(req.Host()), c.IsTLS) - c.addrIdx = 0 - c.addrs = nil - req.schemaUpdate = false - req.SetConnectionClose() - } - cc, err := c.acquireConn(req.timeout) if err != nil { return false, err diff --git a/client_test.go b/client_test.go index 3f9bc775eb..4ba78c334c 100644 --- a/client_test.go +++ b/client_test.go @@ -245,7 +245,7 @@ func TestClientRedirectSameSchema(t *testing.T) { urlParsed, err := url.Parse(destURL) if err != nil { - fmt.Println(err) + t.Fatal(err) return } @@ -270,7 +270,7 @@ func TestClientRedirectSameSchema(t *testing.T) { } -func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) { +func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) { t.Parallel() listenHTTPS := testClientRedirectListener(t, true) @@ -287,14 +287,7 @@ func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) { destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String()) - urlParsed, err := url.Parse(destURL) - if err != nil { - fmt.Println(err) - return - } - - reqClient := &HostClient{ - Addr: urlParsed.Host, + reqClient := &Client{ TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -312,6 +305,42 @@ func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) { } } +func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) { + t.Parallel() + + listenHTTPS := testClientRedirectListener(t, true) + defer listenHTTPS.Close() + + listenHTTP := testClientRedirectListener(t, false) + defer listenHTTP.Close() + + sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true) + defer sHTTPS.Stop() + + sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false) + defer sHTTP.Stop() + + destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String()) + + urlParsed, err := url.Parse(destURL) + if err != nil { + t.Fatal(err) + return + } + + reqClient := &HostClient{ + Addr: urlParsed.Host, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + _, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond) + if err != ErrHostClientRedirectToDifferentScheme { + t.Fatal("expected HostClient error") + } +} + func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener { var ln net.Listener var err error diff --git a/http.go b/http.go index 3869b0ac36..fe137cea66 100644 --- a/http.go +++ b/http.go @@ -46,11 +46,10 @@ type Request struct { keepBodyBuffer bool + // Used by Server to indicate the request was received on a HTTPS endpoint. + // Client/HostClient shouldn't use this field but should depend on the uri.scheme instead. isTLS bool - // To detect scheme changes in redirects - schemaUpdate bool - // Request timeout. Usually set by DoDealine or DoTimeout // if <= 0, means not set timeout time.Duration