Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a .tsh/config file, add support for configuring custom http headers from the config file #10336

Merged
merged 1 commit into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions api/client/contextdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ func NewDirectDialer(keepAlivePeriod, dialTimeout time.Duration) ContextDialer {
func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout)
return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil)
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -91,7 +92,8 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur
// through the SSH reverse tunnel on the proxy.
func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil)
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
101 changes: 76 additions & 25 deletions api/client/webclient/webclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,51 @@ import (
log "github.com/sirupsen/logrus"
)

// Config specifies information when building requests with the
// webclient.
type Config struct {
// Context is a context for creating webclient requests.
Context context.Context
// ProxyAddr specifies the teleport proxy address for requests.
ProxyAddr string
// Insecure turns off TLS certificate verification when enabled.
Insecure bool
// Pool defines the set of root CAs to use when verifying server
// certificates.
Pool *x509.CertPool
// ConnectorName is the name of the ODIC or SAML connector.
ConnectorName string
// ExtraHeaders is a map of extra HTTP headers to be included in
// requests.
ExtraHeaders map[string]string
}

// CheckAndSetDefaults checks and sets defaults
func (c *Config) CheckAndSetDefaults() error {
message := "webclient config: %s"
if c.Context == nil {
return trace.BadParameter(message, "missing parameter Context")
}
if c.ProxyAddr == "" && os.Getenv(defaults.TunnelPublicAddrEnvar) == "" {
return trace.BadParameter(message, "missing parameter ProxyAddr")
}

return nil
}

// newWebClient creates a new client to the HTTPS web proxy.
func newWebClient(insecure bool, pool *x509.CertPool) *http.Client {
func newWebClient(cfg *Config) (*http.Client, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: pool,
InsecureSkipVerify: insecure,
RootCAs: cfg.Pool,
InsecureSkipVerify: cfg.Insecure,
},
},
}
}, nil
}

// doWithFallback attempts to execute an HTTP request using https, and then
Expand All @@ -56,9 +91,13 @@ func newWebClient(insecure bool, pool *x509.CertPool) *http.Client {
// * The target host must resolve to the loopback address.
// If these conditions are not met, then the plain-HTTP fallback is not allowed,
// and a the HTTPS failure will be considered final.
func doWithFallback(clt *http.Client, allowPlainHTTP bool, req *http.Request) (*http.Response, error) {
func doWithFallback(clt *http.Client, allowPlainHTTP bool, extraHeaders map[string]string, req *http.Request) (*http.Response, error) {
// first try https and see how that goes
req.URL.Scheme = "https"
for k, v := range extraHeaders {
req.Header.Add(k, v)
}

log.Debugf("Attempting %s %s%s", req.Method, req.URL.Host, req.URL.Path)
resp, err := clt.Do(req)

Expand Down Expand Up @@ -88,18 +127,21 @@ func doWithFallback(clt *http.Client, allowPlainHTTP bool, req *http.Request) (*

// Find fetches discovery data by connecting to the given web proxy address.
// It is designed to fetch proxy public addresses without any inefficiencies.
func Find(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*PingResponse, error) {
clt := newWebClient(insecure, pool)
func Find(cfg *Config) (*PingResponse, error) {
clt, err := newWebClient(cfg)
if err != nil {
return nil, trace.Wrap(err)
}
defer clt.CloseIdleConnections()

endpoint := fmt.Sprintf("https://%s/webapi/find", proxyAddr)
endpoint := fmt.Sprintf("https://%s/webapi/find", cfg.ProxyAddr)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}

resp, err := doWithFallback(clt, insecure, req)
resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -118,21 +160,24 @@ func Find(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP
// errors before being asked for passwords. The second is to return the form
// of authentication that the server supports. This also leads to better user
// experience: users only get prompted for the type of authentication the server supports.
func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool, connectorName string) (*PingResponse, error) {
clt := newWebClient(insecure, pool)
func Ping(cfg *Config) (*PingResponse, error) {
clt, err := newWebClient(cfg)
if err != nil {
return nil, trace.Wrap(err)
}
defer clt.CloseIdleConnections()

endpoint := fmt.Sprintf("https://%s/webapi/ping", proxyAddr)
if connectorName != "" {
endpoint = fmt.Sprintf("%s/%s", endpoint, connectorName)
endpoint := fmt.Sprintf("https://%s/webapi/ping", cfg.ProxyAddr)
if cfg.ConnectorName != "" {
endpoint = fmt.Sprintf("%s/%s", endpoint, cfg.ConnectorName)
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}

resp, err := doWithFallback(clt, insecure, req)
resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -147,32 +192,38 @@ func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP
}

// GetTunnelAddr returns the tunnel address either set in an environment variable or retrieved from the web proxy.
func GetTunnelAddr(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (string, error) {
func GetTunnelAddr(cfg *Config) (string, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return "", trace.Wrap(err)
}
// If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it.
if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" {
return extractHostPort(tunnelAddr)
}

// Ping web proxy to retrieve tunnel proxy address.
pr, err := Find(ctx, proxyAddr, insecure, nil)
pr, err := Find(cfg)
if err != nil {
return "", trace.Wrap(err)
}
return tunnelAddr(proxyAddr, pr.Proxy)
return tunnelAddr(cfg.ProxyAddr, pr.Proxy)
}

func GetMOTD(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*MotD, error) {
clt := newWebClient(insecure, pool)
func GetMOTD(cfg *Config) (*MotD, error) {
clt, err := newWebClient(cfg)
if err != nil {
return nil, trace.Wrap(err)
}
defer clt.CloseIdleConnections()

endpoint := fmt.Sprintf("https://%s/webapi/motd", proxyAddr)
endpoint := fmt.Sprintf("https://%s/webapi/motd", cfg.ProxyAddr)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}

resp, err := clt.Do(req)
resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 4 additions & 3 deletions api/client/webclient/webclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ func TestPlainHttpFallback(t *testing.T) {
desc: "Ping",
handler: newPingHandler("/webapi/ping"),
actionUnderTest: func(addr string, insecure bool) error {
_, err := Ping(context.Background(), addr, insecure, nil /*pool*/, "")
_, err := Ping(
&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure})
return err
},
}, {
desc: "Find",
handler: newPingHandler("/webapi/find"),
actionUnderTest: func(addr string, insecure bool) error {
_, err := Find(context.Background(), addr, insecure, nil /*pool*/)
_, err := Find(&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure})
return err
},
},
Expand Down Expand Up @@ -104,7 +105,7 @@ func TestPlainHttpFallback(t *testing.T) {

func TestGetTunnelAddr(t *testing.T) {
t.Setenv(defaults.TunnelPublicAddrEnvar, "tunnel.example.com:4024")
tunnelAddr, err := GetTunnelAddr(context.Background(), "", true, nil)
tunnelAddr, err := GetTunnelAddr(&Config{Context: context.Background(), ProxyAddr: "", Insecure: false})
require.NoError(t, err)
require.Equal(t, "tunnel.example.com:4024", tunnelAddr)
}
Expand Down
27 changes: 17 additions & 10 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ type Config struct {
// DisplayParticipantRequirements is set if debug information about participants requirements
// should be printed in moderated sessions.
DisplayParticipantRequirements bool

// ExtraProxyHeaders is a collection of http headers to be included in requests to the WebProxy.
ExtraProxyHeaders map[string]string
}

// CachePolicy defines cache policy for local clients
Expand Down Expand Up @@ -2599,12 +2602,13 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er
if tc.lastPing != nil {
return tc.lastPing, nil
}
pr, err := webclient.Ping(
ctx,
tc.WebProxyAddr,
tc.InsecureSkipVerify,
loopbackPool(tc.WebProxyAddr),
tc.AuthConnector)
pr, err := webclient.Ping(&webclient.Config{
Context: ctx,
ProxyAddr: tc.WebProxyAddr,
Insecure: tc.InsecureSkipVerify,
Pool: loopbackPool(tc.WebProxyAddr),
ConnectorName: tc.AuthConnector,
ExtraHeaders: tc.ExtraProxyHeaders})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -2636,10 +2640,13 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er
// confirmation from the user.
func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
motd, err := webclient.GetMOTD(
ctx,
tc.WebProxyAddr,
tc.InsecureSkipVerify,
loopbackPool(tc.WebProxyAddr))
&webclient.Config{
Context: ctx,
ProxyAddr: tc.WebProxyAddr,
Insecure: tc.InsecureSkipVerify,
Pool: loopbackPool(tc.WebProxyAddr),
ExtraHeaders: tc.ExtraProxyHeaders})

if err != nil {
return trace.Wrap(err)
}
Expand Down
24 changes: 23 additions & 1 deletion lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ const (
// keyFilePerms is the default permissions applied to key files (.cert, .key, pub)
// under ~/.tsh
keyFilePerms os.FileMode = 0600

// tshConfigFileName is the name of the directory containing the
// tsh config file.
tshConfigFileName = "config"
)

// LocalKeyStore interface allows for different storage backends for tsh to
Expand Down Expand Up @@ -222,9 +226,27 @@ func (fs *FSLocalKeyStore) DeleteUserCerts(idx KeyIndex, opts ...CertOption) err

// DeleteKeys removes all session keys.
func (fs *FSLocalKeyStore) DeleteKeys() error {
if err := os.RemoveAll(fs.KeyDir); err != nil {

files, err := os.ReadDir(fs.KeyDir)
if err != nil {
return trace.ConvertSystemError(err)
}
for _, file := range files {
if file.IsDir() && file.Name() == tshConfigFileName {
continue
}
if file.IsDir() {
err := os.RemoveAll(filepath.Join(fs.KeyDir, file.Name()))
if err != nil {
return trace.ConvertSystemError(err)
}
continue
}
err := os.Remove(filepath.Join(fs.KeyDir, file.Name()))
if err != nil {
return trace.ConvertSystemError(err)
}
}
return nil
}

Expand Down
13 changes: 13 additions & 0 deletions lib/client/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,19 @@ func TestAddKey_withoutSSHCert(t *testing.T) {
require.Len(t, keyCopy.DBTLSCerts, 1)
}

func TestConfigDirNotDeleted(t *testing.T) {
s, cleanup := newTest(t)
t.Cleanup(cleanup)
idx := KeyIndex{"host.a", "bob", "root"}
s.store.AddKey(s.makeSignedKey(t, idx, false))
configPath := filepath.Join(s.storeDir, "config")
require.NoError(t, os.Mkdir(configPath, 0700))
require.NoError(t, s.store.DeleteKeys())
require.DirExists(t, configPath)

require.NoDirExists(t, filepath.Join(s.storeDir, "keys"))
}

type keyStoreTest struct {
storeDir string
store *FSLocalKeyStore
Expand Down
4 changes: 3 additions & 1 deletion lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ func (a *Agent) getHostCheckers() ([]ssh.PublicKey, error) {
// If this is Web Service port check if proxy support ALPN SNI Listener.
func (a *Agent) getReverseTunnelDetails() *reverseTunnelDetails {
pd := reverseTunnelDetails{TLSRoutingEnabled: false}
resp, err := webclient.Find(a.ctx, a.Addr.Addr, lib.IsInsecureDevMode(), nil)
resp, err := webclient.Find(
&webclient.Config{Context: a.ctx, ProxyAddr: a.Addr.Addr, Insecure: lib.IsInsecureDevMode()})

if err != nil {
// If TLS Routing is disabled the address is the proxy reverse tunnel
// address the ping call will always fail.
Expand Down
4 changes: 3 additions & 1 deletion lib/reversetunnel/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ func WebClientResolver(ctx context.Context, addrs []utils.NetAddr, insecureTLS b
for _, addr := range addrs {
// In insecure mode, any certificate is accepted. In secure mode the hosts
// CAs are used to validate the certificate on the proxy.
tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil)
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: addr.String(), Insecure: insecureTLS})

if err != nil {
errs = append(errs, err)
continue
Expand Down
3 changes: 2 additions & 1 deletion lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co
}

// Check if t.ProxyAddr is ProxyWebPort and remote Proxy supports TLS ALPNSNIListener.
resp, err := webclient.Find(ctx, addr.Addr, t.InsecureSkipTLSVerify, nil)
resp, err := webclient.Find(
&webclient.Config{Context: ctx, ProxyAddr: addr.Addr, Insecure: t.InsecureSkipTLSVerify})
if err != nil {
// If TLS Routing is disabled the address is the proxy reverse tunnel
// address thus the ping call will always fail.
Expand Down
1 change: 0 additions & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,6 @@ func defaultAuthenticationSettings(ctx context.Context, authClient auth.ClientI)

func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
var err error

defaultSettings, err := defaultAuthenticationSettings(r.Context(), h.cfg.ProxyClient)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
Loading