Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tcsc committed May 4, 2021
1 parent da5a48d commit fcb13d5
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 30 deletions.
12 changes: 6 additions & 6 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2336,7 +2336,7 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er
ctx,
tc.WebProxyAddr,
tc.InsecureSkipVerify,
LoopbackPool(tc.WebProxyAddr),
loopbackPool(tc.WebProxyAddr),
tc.AuthConnector)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -2612,7 +2612,7 @@ func (tc *TeleportClient) directLogin(ctx context.Context, secondFactorType cons
PubKey: pub,
TTL: tc.KeyTTL,
Insecure: tc.InsecureSkipVerify,
Pool: LoopbackPool(tc.WebProxyAddr),
Pool: loopbackPool(tc.WebProxyAddr),
Compatibility: tc.CertificateFormat,
RouteToCluster: tc.SiteName,
KubernetesCluster: tc.KubernetesCluster,
Expand Down Expand Up @@ -2641,7 +2641,7 @@ func (tc *TeleportClient) ssoLogin(ctx context.Context, connectorID string, pub
PubKey: pub,
TTL: tc.KeyTTL,
Insecure: tc.InsecureSkipVerify,
Pool: LoopbackPool(tc.WebProxyAddr),
Pool: loopbackPool(tc.WebProxyAddr),
Compatibility: tc.CertificateFormat,
RouteToCluster: tc.SiteName,
KubernetesCluster: tc.KubernetesCluster,
Expand All @@ -2667,7 +2667,7 @@ func (tc *TeleportClient) mfaLocalLogin(ctx context.Context, pub []byte) (*auth.
PubKey: pub,
TTL: tc.KeyTTL,
Insecure: tc.InsecureSkipVerify,
Pool: LoopbackPool(tc.WebProxyAddr),
Pool: loopbackPool(tc.WebProxyAddr),
Compatibility: tc.CertificateFormat,
RouteToCluster: tc.SiteName,
KubernetesCluster: tc.KubernetesCluster,
Expand Down Expand Up @@ -2697,9 +2697,9 @@ func (tc *TeleportClient) EventsChannel() <-chan events.EventFields {
return tc.eventsCh
}

// LoopbackPool reads trusted CAs if it finds it in a predefined location
// loopbackPool reads trusted CAs if it finds it in a predefined location
// and will work only if target proxy address is loopback
func LoopbackPool(proxyAddr string) *x509.CertPool {
func loopbackPool(proxyAddr string) *x509.CertPool {
if !utils.IsLoopback(proxyAddr) {
log.Debugf("not using loopback pool for remote proxy addr: %v", proxyAddr)
return nil
Expand Down
24 changes: 21 additions & 3 deletions tool/tsh/resolve_default_addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package main
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
Expand All @@ -34,6 +33,17 @@ type raceResult struct {
err error
}

// nonOKResponseError indicates that the racer made contact with a server &
// issued a request but received a non-OK response. This is still
// considered a failure by the port resolution algorithm.
type nonOKResponseError struct {
Status int
}

func (err nonOKResponseError) Error() string {
return fmt.Sprintf("Non-OK response status: %03d", err.Status)
}

// raceRequest drives an HTTP request to completion and posts the results back
// to the supplied channel.
func raceRequest(ctx context.Context, cli *http.Client, addr string, results chan<- raceResult) {
Expand All @@ -45,9 +55,18 @@ func raceRequest(ctx context.Context, cli *http.Client, addr string, results cha
rsp, err = cli.Do(request)
if err == nil {
rsp.Body.Close()

// If the request returned a non-OK response then we're still going
// to treat this as a failure and return an error to the race
// aggregator.
if rsp.StatusCode != http.StatusOK {
err = nonOKResponseError{Status: rsp.StatusCode}
rsp = nil
}
}
}

// Post the results back to the caller so they can be aggregated.
results <- raceResult{addr: addr, err: err}
}

Expand All @@ -66,7 +85,7 @@ func startRacer(ctx context.Context, cli *http.Client, host string, candidates [
// 1. issues a GET request against multiple potential proxy ports,
// 2. races the requests against one another, and finally
// 3. selects the first to respond as the canonical proxy
func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []int, rootCAs *x509.CertPool) (string, error) {
func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []int) (string, error) {
log.Debugf("Resolving default proxy port (insecure: %v)", insecure)

if len(ports) == 0 {
Expand All @@ -76,7 +95,6 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
InsecureSkipVerify: insecure,
},
},
Expand Down
58 changes: 39 additions & 19 deletions tool/tsh/resolve_default_addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ func newWaitForeverHandler() (http.Handler, chan interface{}) {
return handler, doneChannel
}

func newRespondingHandler() http.Handler {
func newRespondingHandlerWithStatus(status int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testLog.Debug("Responding")
w.Header().Add("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.WriteHeader(status)
io.WriteString(w, "Hello, world")
})
}

func newRespondingHandler() http.Handler {
return newRespondingHandlerWithStatus(http.StatusOK)
}

func mustGetCandidatePorts(servers []*httptest.Server) []int {
result := make([]int, len(servers))
for i, svr := range servers {
Expand All @@ -70,6 +74,12 @@ func mustGetCandidatePorts(servers []*httptest.Server) []int {
return result
}

func makeTestServer(t *testing.T, h http.Handler) *httptest.Server {
svr := httptest.NewTLSServer(h)
t.Cleanup(func() { svr.Close() })
return svr
}

func TestResolveDefaultAddr(t *testing.T) {
t.Parallel()

Expand All @@ -86,9 +96,7 @@ func TestResolveDefaultAddr(t *testing.T) {
if i == magicServerIndex {
handler = respondingHandler
}
svr := httptest.NewTLSServer(handler)
defer svr.Close()
servers[i] = svr
servers[i] = makeTestServer(t, handler)
}

// NB: We need to defer this channel close such that it happens *before*
Expand All @@ -100,7 +108,7 @@ func TestResolveDefaultAddr(t *testing.T) {
expectedAddr := fmt.Sprintf("127.0.0.1:%d", ports[magicServerIndex])

// When I attempt to resolve a default address
addr, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", ports, nil)
addr, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", ports)

// Expect that the "magic" server is selected
require.NoError(t, err)
Expand All @@ -109,7 +117,7 @@ func TestResolveDefaultAddr(t *testing.T) {

func TestResolveDefaultAddrNoCandidates(t *testing.T) {
t.Parallel()
_, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", []int{}, nil)
_, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", []int{})
require.Error(t, err)
}

Expand All @@ -123,16 +131,14 @@ func TestResolveDefaultAddrSingleCandidate(t *testing.T) {

servers := make([]*httptest.Server, 1)
for i := 0; i < len(servers); i++ {
svr := httptest.NewTLSServer(respondingHandler)
defer svr.Close()
servers[i] = svr
servers[i] = makeTestServer(t, respondingHandler)
}

ports := mustGetCandidatePorts(servers)
expectedAddr := fmt.Sprintf("127.0.0.1:%d", ports[0])

// When I attempt to resolve a default address
addr, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", ports, nil)
addr, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", ports)

// Expect that the only server is selected
require.NoError(t, err)
Expand All @@ -147,9 +153,7 @@ func TestResolveDefaultAddrTimeout(t *testing.T) {

servers := make([]*httptest.Server, 5)
for i := 0; i < 5; i++ {
svr := httptest.NewTLSServer(blockingHandler)
defer svr.Close()
servers[i] = svr
servers[i] = makeTestServer(t, blockingHandler)
}

// NB: We need to defer this channel close such that it happens *before*
Expand All @@ -162,13 +166,31 @@ func TestResolveDefaultAddrTimeout(t *testing.T) {
// When I attempt to resolve the default address with a finite timeout
ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
defer cancel()
_, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports, nil)
_, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports)

// Expect that the resolution will fail with `Deadline Exceeded` due to
// the call timing out.
require.Equal(t, context.DeadlineExceeded, err)
}

func TestResolveNonOKResponseIsAnError(t *testing.T) {
t.Parallel()

// Given a single candidate servers configured to respond with a non-OK status
// code
servers := []*httptest.Server{
makeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)),
}
ports := mustGetCandidatePorts(servers)

// When I attempt to resolve a default address
_, err := pickDefaultAddr(context.Background(), true, "127.0.0.1", ports)

// Expect that the resolution fails because the server responded with a non-OK
// response
require.ErrorIs(t, err, nonOKResponseError{Status: http.StatusTeapot})
}

func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) {
t.Parallel()

Expand All @@ -178,9 +200,7 @@ func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) {

servers := make([]*httptest.Server, 1000)
for i := 0; i < len(servers); i++ {
svr := httptest.NewTLSServer(blockingHandler)
defer svr.Close()
servers[i] = svr
servers[i] = makeTestServer(t, blockingHandler)
}

// NB: We need to defer this channel close such that it happens *before*
Expand All @@ -194,7 +214,7 @@ func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) {
// would allow for all of the racers to have been launched...
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports, nil)
_, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports)

// Expect that the resolution will fail with `Deadline Exceeded` due to
// the call timing out.
Expand Down
5 changes: 3 additions & 2 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -1820,12 +1820,13 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error {
timeout, cancel := context.WithTimeout(context.Background(), proxyDefaultResolutionTimeout)
defer cancel()

caPool := client.LoopbackPool(parsedAddrs.Host)
proxyAddress, err = pickDefaultAddr(
timeout, cf.InsecureSkipVerify, parsedAddrs.Host, defaultWebProxyPorts, caPool)
timeout, cf.InsecureSkipVerify, parsedAddrs.Host, defaultWebProxyPorts)

// On error, fall back to the legacy behaviour
if err != nil {
log.Debugf("Proxy port resolution failed: %v", err)
log.Debug("Falling back to legacy default")
return c.ParseProxyHost(cf.Proxy)
}
}
Expand Down

0 comments on commit fcb13d5

Please sign in to comment.