From 643afa196a768741ad40dcb439140b09f6421c76 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Mon, 2 May 2022 15:59:56 -0400 Subject: [PATCH] Stop loading the enitre node set into memory per tsh ssh connection (#12014) * Prevent proxy from loading entire node set into memory more than once When establishing a new session to a node, the proxy would load the entire node set into memory in an attempt to find the matching host. For smaller clusters this may not be that problematic. But on larger clusters, loading >40k nodes into memory from the cache can be quite expensive. This problem is compounded by the fact that it happened**per** session, which could potentially cause the proxy to consume all available memory and be OOM killed. A new `NodeWatcher` is introduced which will maintain an in memory list of all nodes per process. The watcher leverages the existing resource watcher system and stores all nodes as types.Server, to eliminate the cost incurred by unmarshalling the nodes from the cache. The `NodeWatcher` provides a way to retrieve a filtered list of nodes in order to reduce the number of copies made to only the matches. (cherry picked from commit fa12352214ea382633277e92d3217853e715e2ac) --- lib/reversetunnel/api.go | 3 + lib/reversetunnel/localsite.go | 30 +- lib/reversetunnel/peer.go | 12 + lib/reversetunnel/remotesite.go | 13 +- lib/reversetunnel/srv.go | 34 +- lib/service/service.go | 13 + lib/services/presence.go | 13 +- lib/services/watcher.go | 156 +++++ lib/services/watcher_test.go | 78 ++- lib/srv/regular/proxy.go | 55 +- lib/srv/regular/proxy_test.go | 21 +- lib/srv/regular/sshserver.go | 11 + lib/srv/regular/sshserver_test.go | 18 + lib/web/apiserver_test.go | 1040 +++++++++++++++-------------- tool/tsh/proxy_test.go | 20 +- tool/tsh/tsh.go | 26 +- 16 files changed, 983 insertions(+), 560 deletions(-) diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 8cbe8b31c6b2c..85b50fac56b7e 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" ) @@ -94,6 +95,8 @@ type RemoteSite interface { // CachingAccessPoint returns access point that is lightweight // but is resilient to auth server crashes CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) + // NodeWatcher returns the node watcher that maintains the node set for the site + NodeWatcher() (*services.NodeWatcher, error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 36d583384330a..f184cb86d4a9c 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -22,8 +22,6 @@ import ( "sync" "time" - "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -34,11 +32,12 @@ import ( "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/proxy" - "github.com/prometheus/client_golang/prometheus" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" ) func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) { @@ -129,6 +128,11 @@ func (s *localSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) { return s.accessPoint, nil } +// NodeWatcher returns a services.NodeWatcher for this cluster. +func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) { + return s.srv.NodeWatcher, nil +} + // GetClient returns a client to the full Auth Server API. func (s *localSite) GetClient() (auth.ClientI, error) { return s.client, nil @@ -522,14 +526,7 @@ func (s *localSite) periodicFunctions() { // sshTunnelStats reports SSH tunnel statistics for the cluster. func (s *localSite) sshTunnelStats() error { - servers, err := s.accessPoint.GetNodes(s.srv.ctx, apidefaults.Namespace) - if err != nil { - return trace.Wrap(err) - } - - var missing []string - - for _, server := range servers { + missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool { // Skip over any servers that that have a TTL larger than announce TTL (10 // minutes) and are non-IoT SSH servers (they won't have tunnels). // @@ -538,10 +535,10 @@ func (s *localSite) sshTunnelStats() error { // their TTL value. ttl := s.clock.Now().Add(-1 * apidefaults.ServerAnnounceTTL) if server.Expiry().Before(ttl) { - continue + return false } if !server.GetUseTunnel() { - continue + return false } // Check if the tunnel actually exists. @@ -549,12 +546,9 @@ func (s *localSite) sshTunnelStats() error { ServerID: fmt.Sprintf("%v.%v", server.GetName(), s.domainName), ConnType: types.NodeTunnel, }) - if err == nil { - continue - } - missing = append(missing, server.GetName()) - } + return err != nil + }) // Update Prometheus metrics and also log if any tunnels are missing. missingSSHTunnels.Set(float64(len(missing))) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 9fab4c78201f9..1f65b2404e5a0 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -87,6 +87,14 @@ func (p *clusterPeers) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) return peer.CachingAccessPoint() } +func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) { + peer, err := p.pickPeer() + if err != nil { + return nil, trace.Wrap(err) + } + return peer.NodeWatcher() +} + func (p *clusterPeers) GetClient() (auth.ClientI, error) { peer, err := p.pickPeer() if err != nil { @@ -191,6 +199,10 @@ func (s *clusterPeer) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) } +func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) { + return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) +} + func (s *clusterPeer) GetClient() (auth.ClientI, error) { return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index d9f35df5a10b8..0be63a615fe80 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -34,13 +34,14 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) -// remoteSite is a remote site that established the inbound connecton to +// remoteSite is a remote site that established the inbound connection to // the local reverse tunnel server, and now it can provide access to the // cluster behind it. type remoteSite struct { @@ -77,6 +78,9 @@ type remoteSite struct { // the remote cluster this site belongs to. remoteAccessPoint auth.RemoteProxyAccessPoint + // nodeWatcher provides access the node set for the remote site + nodeWatcher *services.NodeWatcher + // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation // state has been changed, the tunnel will reconnect to re-create the client @@ -138,6 +142,11 @@ func (s *remoteSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) { return s.remoteAccessPoint, nil } +// NodeWatcher returns the services.NodeWatcher for the remote cluster. +func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) { + return s.nodeWatcher, nil +} + func (s *remoteSite) GetClient() (auth.ClientI, error) { return s.remoteClient, nil } @@ -379,7 +388,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch } else { s.WithFields(log.Fields{"nodeID": conn.nodeID}).Debugf("Ping <- %v", conn.conn.RemoteAddr()) } - tm := time.Now().UTC() + tm := s.clock.Now().UTC() conn.setLastHeartbeat(tm) go s.registerHeartbeat(tm) // Note that time.After is re-created everytime a request is processed. diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index bda13b97379c9..4d2ac27c28853 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" @@ -201,6 +202,9 @@ type Config struct { // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher + + // NodeWatcher is a node watcher. + NodeWatcher *services.NodeWatcher } // CheckAndSetDefaults checks parameters and sets default values @@ -252,6 +256,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.LockWatcher == nil { return trace.BadParameter("missing parameter LockWatcher") } + if cfg.NodeWatcher == nil { + return trace.BadParameter("missing parameter NodeWatcher") + } return nil } @@ -891,7 +898,7 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r // treat first connection as a registered heartbeat, // otherwise the connection information will appear after initial // heartbeat delay - go site.registerHeartbeat(time.Now()) + go site.registerHeartbeat(s.Clock.Now()) return site, remoteConn, nil } @@ -1024,7 +1031,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, types.TunnelConnectionSpecV2{ ClusterName: domainName, ProxyName: srv.ID, - LastHeartbeat: time.Now().UTC(), + LastHeartbeat: srv.Clock.Now().UTC(), }, ) if err != nil { @@ -1056,27 +1063,42 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, clt, _, err := remoteSite.getRemoteClient() if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.remoteClient = clt remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { + cancel() return nil, trace.Wrap(err) } accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.remoteAccessPoint = accessPoint - + nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: srv.Component, + Client: accessPoint, + Log: srv.Log, + }, + }) + if err != nil { + cancel() + return nil, trace.Wrap(err) + } + remoteSite.nodeWatcher = nodeWatcher // instantiate a cache of host certificates for the forwarding server. the // certificate cache is created in each site (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate // is signed by the correct certificate authority. certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.certificateCache = certificateCache @@ -1089,7 +1111,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - return nil, err + cancel() + return nil, trace.Wrap(err) } go remoteSite.updateCertAuthorities(caRetry) @@ -1102,7 +1125,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - return nil, err + cancel() + return nil, trace.Wrap(err) } go remoteSite.updateLocks(lockRetry) diff --git a/lib/service/service.go b/lib/service/service.go index 8ae942b7e7ff0..bf893397b8d7d 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2856,6 +2856,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + nodeWatcher, err := services.NewNodeWatcher(process.ExitContext(), services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: process.log.WithField(trace.Component, teleport.ComponentProxy), + Client: conn.Client, + }, + }) + if err != nil { + return trace.Wrap(err) + } + serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites) if err != nil { return trace.Wrap(err) @@ -2895,6 +2906,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Emitter: streamEmitter, Log: process.log, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) if err != nil { return trace.Wrap(err) @@ -3024,6 +3036,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { regular.SetOnHeartbeat(process.onHeartbeat(teleport.ComponentProxy)), regular.SetEmitter(streamEmitter), regular.SetLockWatcher(lockWatcher), + regular.SetNodeWatcher(nodeWatcher), ) if err != nil { return trace.Wrap(err) diff --git a/lib/services/presence.go b/lib/services/presence.go index fc02fc33d931b..7f380af913db8 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -29,6 +29,12 @@ type ProxyGetter interface { GetProxies() ([]types.Server, error) } +// NodesGetter is a service that gets nodes. +type NodesGetter interface { + // GetNodes returns a list of registered servers. + GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error) +} + // Presence records and reports the presence of all components // of the cluster - Nodes, Proxies and SSH nodes type Presence interface { @@ -43,13 +49,12 @@ type Presence interface { // GetNode returns a node by name and namespace. GetNode(ctx context.Context, namespace, name string) (types.Server, error) - - // GetNodes returns a list of registered servers. - GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error) - // ListNodes returns a paginated list of registered servers. ListNodes(ctx context.Context, req proto.ListNodesRequest) (nodes []types.Server, nextKey string, err error) + // NodesGetter gets nodes + NodesGetter + // DeleteAllNodes deletes all nodes in a namespace. DeleteAllNodes(ctx context.Context, namespace string) error diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 16cb6dbd88801..f035c8a5ab6e0 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -22,6 +22,7 @@ import ( "time" "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" @@ -1087,3 +1088,158 @@ func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertA } return slice } + +// NodeWatcherConfig is a NodeWatcher configuration. +type NodeWatcherConfig struct { + ResourceWatcherConfig + // NodesGetter is used to directly fetch the list of active nodes. + NodesGetter +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *NodeWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if cfg.NodesGetter == nil { + getter, ok := cfg.Client.(NodesGetter) + if !ok { + return trace.BadParameter("missing parameter NodesGetter and Client not usable as NodesGetter") + } + cfg.NodesGetter = getter + } + return nil +} + +// NewNodeWatcher returns a new instance of NodeWatcher. +func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*NodeWatcher, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + collector := &nodeCollector{ + NodeWatcherConfig: cfg, + current: map[string]types.Server{}, + } + watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + return &NodeWatcher{watcher, collector}, nil +} + +// NodeWatcher is built on top of resourceWatcher to monitor additions +// and deletions to the set of nodes. +type NodeWatcher struct { + *resourceWatcher + *nodeCollector +} + +// nodeCollector accompanies resourceWatcher when monitoring nodes. +type nodeCollector struct { + NodeWatcherConfig + // current holds a map of the currently known nodes (keyed by server name, + // RWMutex protected). + current map[string]types.Server + rw sync.RWMutex +} + +// Node is a readonly subset of the types.Server interface which +// users may filter by in GetNodes. +type Node interface { + // ResourceWithLabels provides common resource headers + types.ResourceWithLabels + // GetTeleportVersion returns the teleport version the server is running on + GetTeleportVersion() string + // GetAddr return server address + GetAddr() string + // GetHostname returns server hostname + GetHostname() string + // GetNamespace returns server namespace + GetNamespace() string + // GetLabels returns server's static label key pairs + GetLabels() map[string]string + // GetCmdLabels gets command labels + GetCmdLabels() map[string]types.CommandLabel + // GetPublicAddr is an optional field that returns the public address this cluster can be reached at. + GetPublicAddr() string + // GetRotation gets the state of certificate authority rotation. + GetRotation() types.Rotation + // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. + GetUseTunnel() bool +} + +// GetNodes allows callers to retrieve a subset of nodes that match the filter provided. The +// returned servers are a copy and can be safely modified. It is intentionally hard to retrieve +// the full set of nodes to reduce the number of copies needed since the number of nodes can get +// quite large and doing so can be expensive. +func (n *nodeCollector) GetNodes(fn func(n Node) bool) []types.Server { + n.rw.RLock() + defer n.rw.RUnlock() + + var matched []types.Server + for _, server := range n.current { + if fn(server) { + matched = append(matched, server.DeepCopy()) + } + } + + return matched +} + +func (n *nodeCollector) NodeCount() int { + n.rw.RLock() + defer n.rw.RUnlock() + return len(n.current) +} + +// resourceKind specifies the resource kind to watch. +func (n *nodeCollector) resourceKind() string { + return types.KindNode +} + +// getResourcesAndUpdateCurrent is called when the resources should be +// (re-)fetched directly. +func (n *nodeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { + nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) + if err != nil { + return trace.Wrap(err) + } + if len(nodes) == 0 { + return nil + } + newCurrent := make(map[string]types.Server, len(nodes)) + for _, node := range nodes { + newCurrent[node.GetName()] = node + } + n.rw.Lock() + defer n.rw.Unlock() + n.current = newCurrent + return nil +} + +// processEventAndUpdateCurrent is called when a watcher event is received. +func (n *nodeCollector) processEventAndUpdateCurrent(ctx context.Context, event types.Event) { + if event.Resource == nil || event.Resource.GetKind() != types.KindNode { + n.Log.Warningf("Unexpected event: %v.", event) + return + } + + n.rw.Lock() + defer n.rw.Unlock() + + switch event.Type { + case types.OpDelete: + delete(n.current, event.Resource.GetName()) + case types.OpPut: + server, ok := event.Resource.(types.Server) + if !ok { + n.Log.Warningf("Unexpected type %T.", event.Resource) + return + } + n.current[server.GetName()] = server + default: + n.Log.Warningf("Skipping unsupported event type %s.", event.Type) + } +} + +func (n *nodeCollector) notifyStale() {} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 4d7b3cfbcacc9..6521e65aa30d9 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -20,13 +20,19 @@ import ( "context" "crypto/x509/pkix" "errors" + "fmt" "sync" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" @@ -34,9 +40,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" ) var _ types.Events = (*errorWatcher)(nil) @@ -853,3 +856,72 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type require.NoError(t, err) return ca } + +func TestNodeWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + + bk, err := lite.NewWithConfig(ctx, lite.Config{ + Path: t.TempDir(), + PollStreamPeriod: 200 * time.Millisecond, + }) + require.NoError(t, err) + + type client struct { + services.Presence + types.Events + } + + presence := local.NewPresenceService(bk) + w, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: &client{ + Presence: presence, + Events: local.NewEventsService(bk), + }, + }, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + // Add some node servers. + nodes := make([]types.Server, 0, 5) + for i := 0; i < 5; i++ { + node := newNodeServer(t, fmt.Sprintf("node%d", i), "127.0.0.1:2023", i%2 == 0) + _, err = presence.UpsertNode(ctx, node) + require.NoError(t, err) + nodes = append(nodes, node) + } + + require.Eventually(t, func() bool { + filtered := w.GetNodes(func(n services.Node) bool { + return true + }) + return len(filtered) == len(nodes) + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + require.Len(t, w.GetNodes(func(n services.Node) bool { return n.GetUseTunnel() }), 3) + + require.NoError(t, presence.DeleteNode(ctx, apidefaults.Namespace, nodes[0].GetName())) + + require.Eventually(t, func() bool { + filtered := w.GetNodes(func(n services.Node) bool { + return true + }) + return len(filtered) == len(nodes)-1 + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + require.Empty(t, w.GetNodes(func(n services.Node) bool { return n.GetName() == nodes[0].GetName() })) + +} + +func newNodeServer(t *testing.T, name, addr string, tunnel bool) types.Server { + s, err := types.NewServer(name, types.KindNode, types.ServerSpecV2{ + Addr: addr, + PublicAddr: addr, + UseTunnel: tunnel, + }) + require.NoError(t, err) + return s +} diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 1b854fe7420fa..e79e67893e505 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -34,6 +34,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -322,18 +323,15 @@ func (t *proxySubsys) proxyToHost( // network resolution (by IP or DNS) // var ( - strategy types.RoutingStrategy - servers []types.Server - err error + strategy types.RoutingStrategy + nodeWatcher *services.NodeWatcher + err error ) localCluster, _ := t.srv.proxyAccessPoint.GetClusterName() // going to "local" CA? lets use the caching 'auth service' directly and avoid // hitting the reverse tunnel link (it can be offline if the CA is down) if site.GetName() == localCluster.GetName() { - servers, err = t.srv.proxyAccessPoint.GetNodes(ctx.CancelContext(), t.namespace) - if err != nil { - t.log.Warn(err) - } + nodeWatcher = t.srv.nodeWatcher cfg, err := t.srv.authService.GetClusterNetworkingConfig(ctx.CancelContext()) if err != nil { @@ -347,9 +345,11 @@ func (t *proxySubsys) proxyToHost( if err != nil { t.log.Warn(err) } else { - servers, err = siteClient.GetNodes(ctx.CancelContext(), t.namespace) + watcher, err := site.NodeWatcher() if err != nil { t.log.Warn(err) + } else { + nodeWatcher = watcher } cfg, err := siteClient.GetClusterNetworkingConfig(ctx.CancelContext()) @@ -366,7 +366,7 @@ func (t *proxySubsys) proxyToHost( t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v, strategy=%s", t.host, t.port, t.SpecifiedPort(), strategy) // determine which server to connect to - server, err := t.getMatchingServer(servers, strategy) + server, err := t.getMatchingServer(nodeWatcher, strategy) if err != nil { return trace.Wrap(err) } @@ -453,12 +453,22 @@ func (t *proxySubsys) proxyToHost( return nil } +// NodesGetter is a function that retrieves a subset of nodes matching +// the filter criteria. +type NodesGetter interface { + GetNodes(fn func(n services.Node) bool) []types.Server +} + // getMatchingServer determines the server to connect to from the provided servers. Duplicate entries are treated // differently based on strategy. Legacy behavior of returning an ambiguous error occurs if the strategy // is types.RoutingStrategy_UNAMBIGUOUS_MATCH. When the strategy is types.RoutingStrategy_MOST_RECENT then // the server that has heartbeated most recently will be returned instead of an error. If no matches are found then // both the types.Server and error returned will be nil. -func (t *proxySubsys) getMatchingServer(servers []types.Server, strategy types.RoutingStrategy) (types.Server, error) { +func (t *proxySubsys) getMatchingServer(watcher NodesGetter, strategy types.RoutingStrategy) (types.Server, error) { + if watcher == nil { + return nil, trace.NotFound("unable to retrieve nodes matching host %s", t.host) + } + // check if hostname is a valid uuid or EC2 node ID. If it is, we will // preferentially match by node ID over node hostname. _, err := uuid.Parse(t.host) @@ -466,35 +476,36 @@ func (t *proxySubsys) getMatchingServer(servers []types.Server, strategy types.R ips, _ := net.LookupHost(t.host) + var unambiguousIDMatch bool // enumerate and try to find a server with self-registered with a matching name/IP: - var matches []types.Server - for _, server := range servers { + matches := watcher.GetNodes(func(server services.Node) bool { + if unambiguousIDMatch { + return false + } + // If the host parameter is a UUID or EC2 node ID, and it matches the // Node ID, treat this as an unambiguous match. if hostIsUniqueID && server.GetName() == t.host { - matches = []types.Server{server} - break + unambiguousIDMatch = true + return true } // If the server has connected over a reverse tunnel, match only on hostname. if server.GetUseTunnel() { - if t.host == server.GetHostname() { - matches = append(matches, server) - } - continue + return t.host == server.GetHostname() } ip, port, err := net.SplitHostPort(server.GetAddr()) if err != nil { t.log.Errorf("Failed to parse address %q: %v.", server.GetAddr(), err) - continue + return false } if t.host == ip || t.host == server.GetHostname() || apiutils.SliceContainsStr(ips, ip) { if !t.SpecifiedPort() || t.port == port { - matches = append(matches, server) - continue + return true } } - } + return false + }) var server types.Server switch { diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index fcfc406f5e58b..d6259bbb209cf 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -21,10 +21,12 @@ import ( "time" "github.com/google/uuid" + "github.com/stretchr/testify/require" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" - "github.com/stretchr/testify/require" ) func TestParseProxyRequest(t *testing.T) { @@ -129,6 +131,21 @@ func TestParseBadRequests(t *testing.T) { } } +type nodeGetter struct { + servers []types.Server +} + +func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { + var servers []types.Server + for _, s := range n.servers { + if fn(s) { + servers = append(servers, s) + } + } + + return servers +} + func TestProxySubsys_getMatchingServer(t *testing.T) { t.Parallel() @@ -312,7 +329,7 @@ func TestProxySubsys_getMatchingServer(t *testing.T) { srv: &Server{}, } - server, err := subsystem.getMatchingServer(tt.servers, tt.strategy) + server, err := subsystem.getMatchingServer(nodeGetter{tt.servers}, tt.strategy) tt.expectError(t, err) if tt.expectServer != nil { require.Equal(t, tt.expectServer(tt.servers), server) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 4a7cc071f864d..84e6026cba384 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -188,6 +188,9 @@ type Server struct { // lockWatcher is the server's lock watcher. lockWatcher *services.LockWatcher + + // nodeWatcher is the server's node watcher. + nodeWatcher *services.NodeWatcher } // GetClock returns server clock implementation @@ -555,6 +558,14 @@ func SetLockWatcher(lockWatcher *services.LockWatcher) ServerOption { } } +// SetNodeWatcher sets the server's node watcher. +func SetNodeWatcher(nodeWatcher *services.NodeWatcher) ServerOption { + return func(s *Server) error { + s.nodeWatcher = nodeWatcher + return nil + } +} + // SetX11ForwardingConfig sets the server's X11 forwarding configuration func SetX11ForwardingConfig(xc *x11.ServerConfig) ServerOption { return func(s *Server) error { diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index eafb2ccdf3e60..e32e3f574cfd4 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1125,6 +1125,7 @@ func TestProxyRoundRobin(t *testing.T) { listener, reverseTunnelAddress := mustListen(t) defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClusterName: f.testSrv.ClusterName(), @@ -1141,6 +1142,7 @@ func TestProxyRoundRobin(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.") @@ -1166,6 +1168,7 @@ func TestProxyRoundRobin(t *testing.T) { SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), SetLockWatcher(lockWatcher), + SetNodeWatcher(nodeWatcher), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1248,6 +1251,7 @@ func TestProxyDirectAccess(t *testing.T) { logger := logrus.WithField("test", "TestProxyDirectAccess") proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1264,6 +1268,7 @@ func TestProxyDirectAccess(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) @@ -1290,6 +1295,7 @@ func TestProxyDirectAccess(t *testing.T) { SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), SetLockWatcher(lockWatcher), + SetNodeWatcher(nodeWatcher), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1945,6 +1951,18 @@ func newLockWatcher(ctx context.Context, t *testing.T, client types.Events) *ser return lockWatcher } +func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.NodeWatcher { + nodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(nodeWatcher.Close) + return nodeWatcher +} + // maxPipeSize is one larger than the maximum pipe size for most operating // systems which appears to be 65536 bytes. // diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 7f70a12acdb7f..ec1411dacb591 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -29,7 +29,6 @@ import ( "fmt" "image" "io" - "io/ioutil" "net" "net/http" "net/http/cookiejar" @@ -44,8 +43,21 @@ import ( "testing" "time" + "github.com/beevik/etree" + "github.com/gogo/protobuf/proto" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + lemma_secret "github.com/mailgun/lemma/secret" + "github.com/pquerna/otp/totp" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "golang.org/x/text/encoding/unicode" + kyaml "k8s.io/apimachinery/pkg/util/yaml" "github.com/gravitational/teleport" apiProto "github.com/gravitational/teleport/api/client/proto" @@ -82,30 +94,10 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/ui" - - "github.com/gravitational/roundtrip" - "github.com/gravitational/trace" - - "github.com/beevik/etree" - "github.com/gogo/protobuf/proto" - "github.com/google/go-cmp/cmp" - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/jonboulle/clockwork" - lemma_secret "github.com/mailgun/lemma/secret" - "github.com/pquerna/otp/totp" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - . "gopkg.in/check.v1" - kyaml "k8s.io/apimachinery/pkg/util/yaml" ) const hostID = "00000000-0000-0000-0000-000000000000" -func TestWeb(t *testing.T) { - TestingT(t) -} - type WebSuite struct { ctx context.Context cancel context.CancelFunc @@ -124,8 +116,6 @@ type WebSuite struct { clock clockwork.FakeClock } -var _ = Suite(&WebSuite{}) - // TestMain will re-execute Teleport to run a command if "exec" is passed to // it as an argument. Otherwise it will run tests as normal. func TestMain(m *testing.M) { @@ -142,41 +132,37 @@ func TestMain(m *testing.M) { os.Exit(code) } -func (s *WebSuite) SetUpSuite(c *C) { - os.Unsetenv(teleport.DebugEnvVar) - - var err error - s.mockU2F, err = mocku2f.Create() - c.Assert(err, IsNil) - c.Assert(s.mockU2F, NotNil) -} - -func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { - return clt, nil -} - -func (s *WebSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) +func newWebSuite(t *testing.T) *WebSuite { + mockU2F, err := mocku2f.Create() + require.NoError(t, err) + require.NotNil(t, mockU2F) u, err := user.Current() - c.Assert(err, IsNil) - s.user = u.Username - s.clock = clockwork.NewFakeClock() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + s := &WebSuite{ + mockU2F: mockU2F, + clock: clockwork.NewFakeClock(), + user: u.Username, + ctx: ctx, + cancel: cancel, + } networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ KeepAliveInterval: types.Duration(10 * time.Second), }) - c.Assert(err, IsNil) + require.NoError(t, err) s.server, err = auth.NewTestServer(auth.TestServerConfig{ Auth: auth.TestAuthServerConfig{ ClusterName: "localhost", - Dir: c.MkDir(), + Dir: t.TempDir(), Clock: s.clock, ClusterNetworkingConfig: networkingConfig, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // Register the auth server, since test auth server doesn't start its own // heartbeat. @@ -193,13 +179,13 @@ func (s *WebSuite) SetUpTest(c *C) { Version: teleport.Version, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) priv, pub, err := s.server.AuthServer.AuthServer.GenerateKeyPair("") - c.Assert(err, IsNil) + require.NoError(t, err) tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) - c.Assert(err, IsNil) + require.NoError(t, err) // start node certs, err := s.server.Auth().GenerateHostCerts(s.ctx, @@ -210,10 +196,10 @@ func (s *WebSuite) SetUpTest(c *C) { PublicSSHKey: pub, PublicTLSKey: tlsPub, }) - c.Assert(err, IsNil) + require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) - c.Assert(err, IsNil) + require.NoError(t, err) nodeID := "node" nodeClient, err := s.server.NewClient(auth.TestIdentity{ @@ -222,7 +208,7 @@ func (s *WebSuite) SetUpTest(c *C) { Username: nodeID, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) nodeLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -230,10 +216,10 @@ func (s *WebSuite) SetUpTest(c *C) { Client: nodeClient, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // create SSH service: - nodeDataDir := c.MkDir() + nodeDataDir := t.TempDir() node, err := regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), @@ -254,12 +240,11 @@ func (s *WebSuite) SetUpTest(c *C) { regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), ) - c.Assert(err, IsNil) + require.NoError(t, err) s.node = node s.srvID = node.ID() - c.Assert(s.node.Start(), IsNil) - - c.Assert(auth.CreateUploaderDir(nodeDataDir), IsNil) + require.NoError(t, s.node.Start()) + require.NoError(t, auth.CreateUploaderDir(nodeDataDir)) // create reverse tunnel service: proxyID := "proxy" @@ -269,10 +254,10 @@ func (s *WebSuite) SetUpTest(c *C) { Username: proxyID, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", s.server.ClusterName())) - c.Assert(err, IsNil) + require.NoError(t, err) proxyLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -280,7 +265,15 @@ func (s *WebSuite) SetUpTest(c *C) { Client: s.proxyClient, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) + + proxyNodeWatcher, err := services.NewNodeWatcher(s.ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + }) + require.NoError(t, err) revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), @@ -293,10 +286,11 @@ func (s *WebSuite) SetUpTest(c *C) { Emitter: s.proxyClient, NewCachingAccessPoint: noCache, DirectClusters: []reversetunnel.DirectCluster{{Name: s.server.ClusterName(), Client: s.proxyClient}}, - DataDir: c.MkDir(), + DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, }) - c.Assert(err, IsNil) + require.NoError(t, err) s.proxyTunnel = revTunServer // proxy server: @@ -305,7 +299,7 @@ func (s *WebSuite) SetUpTest(c *C) { s.server.ClusterName(), []ssh.Signer{signer}, s.proxyClient, - c.MkDir(), + t.TempDir(), "", utils.NetAddr{}, s.proxyClient, @@ -318,13 +312,14 @@ func (s *WebSuite) SetUpTest(c *C) { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), ) - c.Assert(err, IsNil) + require.NoError(t, err) // Expired sessions are purged immediately var sessionLingeringThreshold time.Duration fs, err := NewDebugFileSystem("../../webassets/teleport") - c.Assert(err, IsNil) + require.NoError(t, err) handler, err := NewHandler(Config{ Proxy: revTunServer, AuthServers: utils.FromAddr(s.server.TLS.Addr()), @@ -339,22 +334,22 @@ func (s *WebSuite) SetUpTest(c *C) { cachedSessionLingeringThreshold: &sessionLingeringThreshold, ProxySettings: &mockProxySettings{}, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.clock)) - c.Assert(err, IsNil) + require.NoError(t, err) s.webServer = httptest.NewUnstartedServer(handler) s.webServer.StartTLS() err = s.proxy.Start() - c.Assert(err, IsNil) + require.NoError(t, err) // Wait for proxy to fully register before starting the test. for start := time.Now(); ; { proxies, err := s.proxyClient.GetProxies() - c.Assert(err, IsNil) + require.NoError(t, err) if len(proxies) != 0 { break } if time.Since(start) > 5*time.Second { - c.Fatal("proxy didn't register within 5s after startup") + t.Fatal("proxy didn't register within 5s after startup") } } @@ -364,27 +359,37 @@ func (s *WebSuite) SetUpTest(c *C) { handler.handler.cfg.ProxyWebAddr = *addr handler.handler.cfg.ProxySSHAddr = *proxyAddr _, sshPort, err := net.SplitHostPort(proxyAddr.String()) - c.Assert(err, IsNil) + require.NoError(t, err) handler.handler.sshPort = sshPort -} -func (s *WebSuite) TearDownTest(c *C) { - // In particular close the lock watchers by cancelling the context. - s.cancel() + t.Cleanup(func() { + // In particular close the lock watchers by cancelling the context. + s.cancel() - var errors []error - s.proxyTunnel.Close() - if err := s.node.Close(); err != nil { - errors = append(errors, err) - } - s.webServer.Close() - if err := s.proxy.Close(); err != nil { - errors = append(errors, err) - } - if err := s.server.Shutdown(context.Background()); err != nil { - errors = append(errors, err) - } - c.Assert(errors, HasLen, 0) + s.webServer.Close() + + var errors []error + if err := s.proxyTunnel.Close(); err != nil { + errors = append(errors, err) + } + if err := s.node.Close(); err != nil { + errors = append(errors, err) + } + s.webServer.Close() + if err := s.proxy.Close(); err != nil { + errors = append(errors, err) + } + if err := s.server.Shutdown(context.Background()); err != nil { + errors = append(errors, err) + } + require.Empty(t, errors) + }) + + return s +} + +func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { + return clt, nil } func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { @@ -410,7 +415,7 @@ type authPack struct { // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. -func (s *WebSuite) authPack(c *C, user string) *authPack { +func (s *WebSuite) authPack(t *testing.T, user string) *authPack { login := s.user pass := "abc123" rawSecret := "def456" @@ -420,15 +425,15 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) - s.createUser(c, user, login, pass, otpSecret) + s.createUser(t, user, login, pass, otpSecret) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() req := CreateSessionReq{ @@ -439,16 +444,16 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" re, err := s.login(clt, csrfToken, csrfToken, req) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) sess, err := rawSess.response() - c.Assert(err, IsNil) + require.NoError(t, err) jar, err := cookiejar.New(nil) - c.Assert(err, IsNil) + require.NoError(t, err) clt = s.client(roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) @@ -463,32 +468,32 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { } } -func (s *WebSuite) createUser(c *C, user string, login string, pass string, otpSecret string) { +func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { teleUser, err := types.NewUser(user) - c.Assert(err, IsNil) + require.NoError(t, err) role := services.RoleForUser(teleUser) role.SetLogins(types.Allow, []string{login}) options := role.GetOptions() options.ForwardAgent = types.NewBool(true) role.SetOptions(options) err = s.server.Auth().UpsertRole(s.ctx, role) - c.Assert(err, IsNil) + require.NoError(t, err) teleUser.AddRole(role.GetName()) teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) err = s.server.Auth().CreateUser(s.ctx, teleUser) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertPassword(user, []byte(pass)) - c.Assert(err, IsNil) + require.NoError(t, err) if otpSecret != "" { dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - c.Assert(err, IsNil) + require.NoError(t, err) } } @@ -511,18 +516,20 @@ func TestValidRedirectURL(t *testing.T) { } } -func (s *WebSuite) TestSAMLSuccess(c *C) { +func TestSAMLSuccess(t *testing.T) { + t.Parallel() + s := newWebSuite(t) input := fixtures.SAMLOktaConnectorV2 decoder := kyaml.NewYAMLOrJSONDecoder(strings.NewReader(input), defaults.LookaheadBufSize) var raw services.UnknownResource err := decoder.Decode(&raw) - c.Assert(err, IsNil) + require.NoError(t, err) connector, err := services.UnmarshalSAMLConnector(raw.Raw) - c.Assert(err, IsNil) + require.NoError(t, err) err = services.ValidateSAMLConnector(connector) - c.Assert(err, IsNil) + require.NoError(t, err) role, err := types.NewRole(connector.GetAttributesToRoles()[0].Roles[0], types.RoleSpecV5{ Options: types.RoleOptions{ @@ -536,64 +543,64 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { }, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) role.SetLogins(types.Allow, []string{s.user}) err = s.server.Auth().UpsertRole(s.ctx, role) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().CreateSAMLConnector(connector) - c.Assert(err, IsNil) + require.NoError(t, err) s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 5, 10, 18, 53, 0, 0, time.UTC))) clt := s.clientNoRedirects() csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" baseURL, err := url.Parse(clt.Endpoint("webapi", "saml", "sso") + `?connector_id=` + connector.GetName() + `&redirect_url=http://localhost/after`) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("GET", baseURL.String(), nil) - c.Assert(err, IsNil) + require.NoError(t, err) addCSRFCookieToReq(req, csrfToken) re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // we got a redirect urlPattern := regexp.MustCompile(`URL='([^']*)'`) locationURL := urlPattern.FindStringSubmatch(string(re.Bytes()))[1] u, err := url.Parse(locationURL) - c.Assert(err, IsNil) - c.Assert(u.Scheme+"://"+u.Host+u.Path, Equals, fixtures.SAMLOktaSSO) + require.NoError(t, err) + require.Equal(t, fixtures.SAMLOktaSSO, u.Scheme+"://"+u.Host+u.Path) data, err := base64.StdEncoding.DecodeString(u.Query().Get("SAMLRequest")) - c.Assert(err, IsNil) - buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(data))) - c.Assert(err, IsNil) + require.NoError(t, err) + buf, err := io.ReadAll(flate.NewReader(bytes.NewReader(data))) + require.NoError(t, err) doc := etree.NewDocument() err = doc.ReadFromBytes(buf) - c.Assert(err, IsNil) + require.NoError(t, err) id := doc.Root().SelectAttr("ID") - c.Assert(id, NotNil) + require.NotNil(t, id) authRequest, err := s.server.Auth().GetSAMLAuthRequest(id.Value) - c.Assert(err, IsNil) + require.NoError(t, err) // now swap the request id to the hardcoded one in fixtures authRequest.ID = fixtures.SAMLOktaAuthRequestID authRequest.CSRFToken = csrfToken err = s.server.Auth().Identity.CreateSAMLAuthRequest(*authRequest, backend.Forever) - c.Assert(err, IsNil) + require.NoError(t, err) // now respond with pre-recorded request to the POST url in := &bytes.Buffer{} fw, err := flate.NewWriter(in, flate.DefaultCompression) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = fw.Write([]byte(fixtures.SAMLOktaAuthnResponseXML)) - c.Assert(err, IsNil) + require.NoError(t, err) err = fw.Close() - c.Assert(err, IsNil) + require.NoError(t, err) encodedResponse := base64.StdEncoding.EncodeToString(in.Bytes()) - c.Assert(encodedResponse, NotNil) + require.NotNil(t, encodedResponse) // now send the response to the server to exchange it for auth session form := url.Values{} @@ -601,43 +608,46 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { req, err = http.NewRequest("POST", clt.Endpoint("webapi", "saml", "acs"), strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") addCSRFCookieToReq(req, csrfToken) - c.Assert(err, IsNil) + require.NoError(t, err) authRe, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) - comment := Commentf("Response: %v", string(authRe.Bytes())) - c.Assert(authRe.Code(), Equals, http.StatusFound, comment) + require.NoError(t, err) + require.Equal(t, http.StatusFound, authRe.Code(), "Response: %v", string(authRe.Bytes())) // we have got valid session - c.Assert(authRe.Headers().Get("Set-Cookie"), Not(Equals), "") - // we are being redirected to orignal URL - c.Assert(authRe.Headers().Get("Location"), Equals, "/after") + require.NotEmpty(t, authRe.Headers().Get("Set-Cookie")) + // we are being redirected to original URL + require.Equal(t, "/after", authRe.Headers().Get("Location")) } -func (s *WebSuite) TestWebSessionsCRUD(c *C) { - pack := s.authPack(c, "foo") +func TestWebSessionsCRUD(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // make sure we can use client to make authenticated requests re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var clusters []ui.Cluster - c.Assert(json.Unmarshal(re.Bytes(), &clusters), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // now delete session _, err = pack.clt.Delete( context.Background(), pack.clt.Endpoint("webapi", "sessions")) - c.Assert(err, IsNil) + require.NoError(t, err) // subsequent requests trying to use this session will fail _, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *WebSuite) TestCSRF(c *C) { +func TestCSRF(t *testing.T) { + t.Parallel() + s := newWebSuite(t) type input struct { reqToken string cookieToken string @@ -647,11 +657,11 @@ func (s *WebSuite) TestCSRF(c *C) { user := "csrfuser" pass := "abc123" otpSecret := base32.StdEncoding.EncodeToString([]byte("def456")) - s.createUser(c, user, user, pass, otpSecret) + s.createUser(t, user, user, pass, otpSecret) // create a valid login form request validToken, err := totp.GenerateCode(otpSecret, time.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) loginForm := CreateSessionReq{ User: user, Pass: pass, @@ -671,23 +681,25 @@ func (s *WebSuite) TestCSRF(c *C) { // valid _, err = s.login(clt, encodedToken1, encodedToken1, loginForm) - c.Assert(err, IsNil) + require.NoError(t, err) // invalid for i := range invalid { _, err := s.login(clt, invalid[i].cookieToken, invalid[i].reqToken, loginForm) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } } -func (s *WebSuite) TestPasswordChange(c *C) { - pack := s.authPack(c, "foo") +func TestPasswordChange(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // invalidate the token s.clock.Advance(1 * time.Minute) validToken, err := totp.GenerateCode(pack.otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) req := changePasswordReq{ OldPassword: []byte("abc123"), @@ -696,26 +708,28 @@ func (s *WebSuite) TestPasswordChange(c *C) { } _, err = pack.clt.PutJSON(context.Background(), pack.clt.Endpoint("webapi", "users", "password"), req) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestWebSessionsBadInput(c *C) { +func TestWebSessionsBadInput(t *testing.T) { + t.Parallel() + s := newWebSuite(t) user := "bob" pass := "abc123" rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) err := s.server.Auth().UpsertPassword(user, []byte(pass)) - c.Assert(err, IsNil) + require.NoError(t, err) dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - c.Assert(err, IsNil) + require.NoError(t, err) // create valid token validToken, err := totp.GenerateCode(otpSecret, time.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() @@ -751,9 +765,11 @@ func (s *WebSuite) TestWebSessionsBadInput(c *C) { }, } for i, req := range reqs { - _, err = clt.PostJSON(context.Background(), clt.Endpoint("webapi", "sessions"), req) - c.Assert(err, NotNil, Commentf("tc %v", i)) - c.Assert(trace.IsAccessDenied(err), Equals, true, Commentf("tc %v %T is not access denied", i, err)) + t.Run(fmt.Sprintf("tc %v", i), func(t *testing.T) { + _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions"), req) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }) } } @@ -791,12 +807,15 @@ func TestClusterNodesGet(t *testing.T) { require.Equal(t, nodes, nodes2) } -func (s *WebSuite) TestSiteNodeConnectInvalidSessionID(c *C) { - _, err := s.makeTerminal(s.authPack(c, "foo"), session.ID("/../../../foo")) - c.Assert(err, NotNil) +func TestSiteNodeConnectInvalidSessionID(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + _, err := s.makeTerminal(t, s.authPack(t, "foo"), session.ID("/../../../foo")) + require.Error(t, err) } -func (s *WebSuite) TestResolveServerHostPort(c *C) { +func TestResolveServerHostPort(t *testing.T) { + t.Parallel() sampleNode := types.ServerV2{} sampleNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") sampleNode.Spec.Hostname = "nodehostname" @@ -855,19 +874,19 @@ func (s *WebSuite) TestResolveServerHostPort(c *C) { for _, testCase := range validCases { host, port, err := resolveServerHostPort(testCase.server, testCase.nodes) - c.Assert(err, IsNil, Commentf(testCase.server)) - c.Assert(host, Equals, testCase.expectedHost) - c.Assert(port, Equals, testCase.expectedPort) + require.NoError(t, err, testCase.server) + require.Equal(t, testCase.expectedHost, host, testCase.server) + require.Equal(t, testCase.expectedPort, port, testCase.server) } for _, testCase := range invalidCases { _, _, err := resolveServerHostPort(testCase.server, nil) - c.Assert(err, NotNil, Commentf(testCase.expectedErr)) - c.Assert(err, ErrorMatches, ".*"+testCase.expectedErr+".*") + require.Error(t, err, testCase.server) + require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error(), testCase.server) } } -func (s *WebSuite) TestNewTerminalHandler(c *C) { +func TestNewTerminalHandler(t *testing.T) { validNode := types.ServerV2{} validNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") validNode.Spec.Hostname = "nodehostname" @@ -924,8 +943,8 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { expectedErr string }{ { - expectedErr: "invalid session", authProvider: makeProvider(validNode), + expectedErr: "invalid session", req: TerminalRequest{ SessionID: "", Login: validLogin, @@ -934,8 +953,8 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, }, { - expectedErr: "bad term dimensions", authProvider: makeProvider(validNode), + expectedErr: "bad term dimensions", req: TerminalRequest{ SessionID: validSID, Login: validLogin, @@ -947,8 +966,8 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, }, { - expectedErr: "invalid server name", authProvider: makeProvider(validNode), + expectedErr: "invalid server name", req: TerminalRequest{ Server: "localhost:port", SessionID: validSID, @@ -958,97 +977,102 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, } + ctx := context.Background() for _, testCase := range validCases { - term, err := NewTerminal(s.ctx, testCase.req, testCase.authProvider, nil) - c.Assert(err, IsNil) - c.Assert(term.params, DeepEquals, testCase.req) - c.Assert(term.hostName, Equals, testCase.expectedHost) - c.Assert(term.hostPort, Equals, testCase.expectedPort) + term, err := NewTerminal(ctx, testCase.req, testCase.authProvider, nil) + require.NoError(t, err) + require.Empty(t, cmp.Diff(testCase.req, term.params)) + require.Equal(t, testCase.expectedHost, testCase.expectedHost) + require.Equal(t, testCase.expectedPort, testCase.expectedPort) } for _, testCase := range invalidCases { - _, err := NewTerminal(s.ctx, testCase.req, testCase.authProvider, nil) - c.Assert(err, ErrorMatches, ".*"+testCase.expectedErr+".*") + _, err := NewTerminal(ctx, testCase.req, testCase.authProvider, nil) + require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error()) } } -func (s *WebSuite) TestResizeTerminal(c *C) { +func TestResizeTerminal(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() // Create a new user "foo", open a terminal to a new session, and wait for // it to be ready. - pack1 := s.authPack(c, "foo") - ws1, err := s.makeTerminal(pack1, sid) - c.Assert(err, IsNil) - defer ws1.Close() + pack1 := s.authPack(t, "foo") + ws1, err := s.makeTerminal(t, pack1, sid) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws1.Close()) }) err = s.waitForRawEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Create a new user "bar", open a terminal to the session created above, // and wait for it to be ready. - pack2 := s.authPack(c, "bar") - ws2, err := s.makeTerminal(pack2, sid) - c.Assert(err, IsNil) - defer ws2.Close() + pack2 := s.authPack(t, "bar") + ws2, err := s.makeTerminal(t, pack2, sid) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws2.Close()) }) err = s.waitForRawEvent(ws2, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Look at the audit events for the first terminal. It should have two // resize events from the second terminal (80x25 default then 100x100). Only // the second terminal will get these because resize events are not sent // back to the originator. err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Look at the stream events for the second terminal. We don't expect to see // any resize events yet. It will timeout. ws2Event := s.listenForResizeEvent(ws2) select { case <-ws2Event: - c.Fatal("unexpected resize event") + t.Fatal("unexpected resize event") case <-time.After(time.Second): } // Resize the second terminal. This should be reflected on the first terminal // because resize events are not sent back to the originator. params, err := session.NewTerminalParamsFromInt(300, 120) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(events.EventFields{ events.EventType: events.ResizeEvent, events.EventNamespace: apidefaults.Namespace, events.SessionEventID: sid.String(), events.TerminalSize: params.Serialize(), }) - c.Assert(err, IsNil) + require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketResize, Payload: string(data), } envelopeBytes, err := proto.Marshal(envelope) - c.Assert(err, IsNil) + require.NoError(t, err) err = ws2.WriteMessage(websocket.BinaryMessage, envelopeBytes) - c.Assert(err, IsNil) + require.NoError(t, err) // This time the first terminal will see the resize event. err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // The second terminal will not see any resize event. It will timeout. select { case <-ws2Event: - c.Fatal("unexpected resize event") + t.Fatal("unexpected resize event") case <-time.After(time.Second): } } // TestTerminalPing tests that the server sends continuous ping control messages. -func (s *WebSuite) TestTerminalPing(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestTerminalPing(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(t, s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) closed := false done := make(chan struct{}) @@ -1081,23 +1105,25 @@ func (s *WebSuite) TestTerminalPing(c *C) { select { case <-done: case <-time.After(time.Minute): - c.Fatal("timeout waiting for ping") + t.Fatal("timeout waiting for ping") } } -func (s *WebSuite) TestTerminal(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestTerminal(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(t, s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo vinsong\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) err = waitForOutput(stream, "vinsong") - c.Assert(err, IsNil) + require.NoError(t, err) } func TestTerminalRequireSessionMfa(t *testing.T) { @@ -1223,8 +1249,6 @@ func TestTerminalRequireSessionMfa(t *testing.T) { _, err = io.WriteString(stream, "echo alpacas\r\n") require.Nil(t, err) require.Nil(t, waitForOutput(stream, "alpacas")) - - require.Nil(t, ws.Close()) }) } } @@ -1441,78 +1465,84 @@ func handleMFAU2FCChallenge(t *testing.T, ws *websocket.Conn, dev *auth.TestDevi require.NoError(t, err) } -func (s *WebSuite) TestWebAgentForward(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestWebAgentForward(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(t, s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo $SSH_AUTH_SOCK\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) err = waitForOutput(stream, "/") - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestActiveSessions(c *C) { +func TestActiveSessions(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() - pack := s.authPack(c, "foo") + pack := s.authPack(t, "foo") - ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + ws, err := s.makeTerminal(t, pack, sid) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) // To make sure we have a session. _, err = io.WriteString(stream, "echo vinsong\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) // Make sure server has replied. err = waitForOutput(stream, "vinsong") - c.Assert(err, IsNil) + require.NoError(t, err) // Make sure this session appears in the list of active sessions. var sessResp *siteSessionsGetResponse for i := 0; i < 10; i++ { // Get site nodes and make sure the node has our active party. - re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) - c.Assert(err, IsNil) + re, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + require.NoError(t, err) - c.Assert(json.Unmarshal(re.Bytes(), &sessResp), IsNil) - c.Assert(len(sessResp.Sessions), Equals, 1) + require.NoError(t, json.Unmarshal(re.Bytes(), &sessResp)) + require.Len(t, sessResp.Sessions, 1) // Sessions do not appear momentarily as there's async heartbeat // procedure. time.Sleep(250 * time.Millisecond) } - c.Assert(len(sessResp.Sessions), Equals, 1) + require.Len(t, sessResp.Sessions, 1) sess := sessResp.Sessions[0] - c.Assert(sess.ID, Equals, sid) - c.Assert(sess.Namespace, Equals, s.node.GetNamespace()) - c.Assert(sess.Parties, NotNil) - c.Assert(sess.TerminalParams.H > 0, Equals, true) - c.Assert(sess.TerminalParams.W > 0, Equals, true) - c.Assert(sess.Login, Equals, pack.login) - c.Assert(sess.Created.IsZero(), Equals, false) - c.Assert(sess.LastActive.IsZero(), Equals, false) - c.Assert(sess.ServerID, Equals, s.srvID) - c.Assert(sess.ServerHostname, Equals, s.node.GetInfo().GetHostname()) - c.Assert(sess.ServerAddr, Equals, s.node.GetInfo().GetAddr()) - c.Assert(sess.ClusterName, Equals, s.server.ClusterName()) + require.Equal(t, sid, sess.ID) + require.Equal(t, s.node.GetNamespace(), sess.Namespace) + require.NotNil(t, sess.Parties) + require.Greater(t, sess.TerminalParams.H, 0) + require.Greater(t, sess.TerminalParams.W, 0) + require.Equal(t, pack.login, sess.Login) + require.False(t, sess.Created.IsZero()) + require.False(t, sess.LastActive.IsZero()) + require.Equal(t, s.srvID, sess.ServerID) + require.Equal(t, s.node.GetInfo().GetHostname(), sess.ServerHostname) + require.Equal(t, s.node.GetInfo().GetAddr(), sess.ServerAddr) + require.Equal(t, s.server.ClusterName(), sess.ClusterName) } // DELETE IN: 5.0.0 // Tests the code snippet from apiserver.(*Handler).siteSessionGet/siteSessionsGet // that tests empty ClusterName and ServerHostname gets set. -func (s *WebSuite) TestEmptySessionClusterHostnameIsSet(c *C) { +func TestEmptySessionClusterHostnameIsSet(t *testing.T) { + t.Parallel() + s := newWebSuite(t) nodeClient, err := s.server.NewClient(auth.TestBuiltin(types.RoleNode)) - c.Assert(err, IsNil) + require.NoError(t, err) // Create a session with empty ClusterName. sess1 := session.Session{ @@ -1526,68 +1556,68 @@ func (s *WebSuite) TestEmptySessionClusterHostnameIsSet(c *C) { TerminalParams: session.TerminalParams{W: 100, H: 100}, } err = nodeClient.CreateSession(sess1) - c.Assert(err, IsNil) + require.NoError(t, err) // Retrieve the session with the empty ClusterName. - pack := s.authPack(c, "baz") - res, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions", sess1.ID.String()), url.Values{}) - c.Assert(err, IsNil) + pack := s.authPack(t, "baz") + res, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions", sess1.ID.String()), url.Values{}) + require.NoError(t, err) // Test that empty ClusterName and ServerHostname got set. var sessionResult *session.Session err = json.Unmarshal(res.Bytes(), &sessionResult) - c.Assert(err, IsNil) - c.Assert(sessionResult.ClusterName, Equals, s.server.ClusterName()) - c.Assert(sessionResult.ServerHostname, Equals, sess1.ServerID) + require.NoError(t, err) + require.Equal(t, s.server.ClusterName(), sessionResult.ClusterName) + require.Equal(t, sess1.ServerID, sessionResult.ServerHostname) // Create another session to test sessions list. sess2 := sess1 sess2.ID = session.NewID() sess2.ServerID = string(session.NewID()) err = nodeClient.CreateSession(sess2) - c.Assert(err, IsNil) + require.NoError(t, err) // Retrieve sessions list. - res, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) - c.Assert(err, IsNil) + res, err = pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + require.NoError(t, err) var sessionList *siteSessionsGetResponse err = json.Unmarshal(res.Bytes(), &sessionList) - c.Assert(err, IsNil) + require.NoError(t, err) s1 := sessionList.Sessions[0] s2 := sessionList.Sessions[1] - c.Assert(s1.ClusterName, Equals, s.server.ClusterName()) - c.Assert(s2.ClusterName, Equals, s.server.ClusterName()) - c.Assert(s1.ServerHostname, Equals, s1.ServerID) - c.Assert(s2.ServerHostname, Equals, s2.ServerID) + require.Equal(t, s.server.ClusterName(), s1.ClusterName) + require.Equal(t, s.server.ClusterName(), s2.ClusterName) + require.Equal(t, s1.ServerID, s1.ServerHostname) + require.Equal(t, s2.ServerID, s2.ServerHostname) } -func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) { +func TestCloseConnectionsOnLogout(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() - pack := s.authPack(c, "foo") + pack := s.authPack(t, "foo") - ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + ws, err := s.makeTerminal(t, pack, sid) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) // to make sure we have a session _, err = io.WriteString(stream, "expr 137 + 39\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) // make sure server has replied out := make([]byte, 100) _, err = stream.Read(out) - c.Assert(err, IsNil) + require.NoError(t, err) - _, err = pack.clt.Delete( - context.Background(), - pack.clt.Endpoint("webapi", "sessions")) - c.Assert(err, IsNil) + _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions")) + require.NoError(t, err) // wait until we timeout or detect that connection has been closed after := time.After(5 * time.Second) @@ -1603,9 +1633,9 @@ func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) { select { case <-after: - c.Fatalf("timeout") + t.Fatalf("timeout") case err := <-errC: - c.Assert(err, Equals, io.EOF) + require.ErrorIs(t, err, io.EOF) } } @@ -1653,35 +1683,39 @@ func TestCreateSession(t *testing.T) { require.NoError(t, err) } -func (s *WebSuite) TestPlayback(c *C) { - pack := s.authPack(c, "foo") +func TestPlayback(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") sid := session.NewID() - ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + ws, err := s.makeTerminal(t, pack, sid) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) } -func (s *WebSuite) TestLogin(c *C) { +func TestLogin(t *testing.T) { + t.Parallel() + s := newWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOff, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) // create user - s.createUser(c, "user1", "root", "password", "") + s.createUser(t, "user1", "root", "password", "") loginReq, err := json.Marshal(CreateSessionReq{ User: "user1", Pass: "password", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(loginReq)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1691,91 +1725,94 @@ func (s *WebSuite) TestLogin(c *C) { re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) cookies := re.Cookies() - c.Assert(len(cookies), Equals, 1) + require.Len(t, cookies, 1) // now make sure we are logged in by calling authenticated method // we need to supply both session cookie and bearer token for // request to succeed jar, err := cookiejar.New(nil) - c.Assert(err, IsNil) + require.NoError(t, err) clt = s.client(roundtrip.BearerAuth(rawSess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) - re, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, IsNil) + re, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.NoError(t, err) var clusters []ui.Cluster - c.Assert(json.Unmarshal(re.Bytes(), &clusters), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // in absence of session cookie or bearer auth the same request fill fail // no session cookie: clt = s.client(roundtrip.BearerAuth(rawSess.Token)) - _, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) // no bearer token: clt = s.client(roundtrip.CookieJar(jar)) - _, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *WebSuite) TestChangePasswordAndAddTOTPDeviceWithToken(c *C) { +func TestChangePasswordAndAddTOTPDeviceWithToken(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) // create user - s.createUser(c, "user1", "root", "password", "") + s.createUser(t, "user1", "root", "password", "") // create password change token token, err := s.server.Auth().CreateResetPasswordToken(context.TODO(), auth.CreateUserTokenRequest{ Name: "user1", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() re, err := clt.Get(context.Background(), clt.Endpoint("webapi", "users", "password", "token", token.GetName()), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var uiToken *ui.ResetPasswordToken - c.Assert(json.Unmarshal(re.Bytes(), &uiToken), IsNil) - c.Assert(uiToken.User, Equals, token.GetUser()) - c.Assert(uiToken.TokenID, Equals, token.GetName()) - c.Assert(uiToken.QRCode, NotNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &uiToken)) + require.Equal(t, token.GetUser(), uiToken.User) + require.Equal(t, token.GetName(), uiToken.TokenID) + require.NotNil(t, uiToken.QRCode) res, err := s.server.Auth().CreateRegisterChallenge(context.Background(), &apiProto.CreateRegisterChallengeRequest{ TokenID: token.GetName(), DeviceType: apiProto.DeviceType_DEVICE_TYPE_TOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) // Advance the clock to invalidate the TOTP token s.clock.Advance(1 * time.Minute) secondFactorToken, err := totp.GenerateCode(res.GetTOTP().GetSecret(), s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(auth.ChangePasswordWithTokenRequest{ TokenID: token.GetName(), Password: []byte("abc123"), SecondFactorToken: secondFactorToken, }) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(data)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1785,16 +1822,19 @@ func (s *WebSuite) TestChangePasswordAndAddTOTPDeviceWithToken(c *C) { re, err = clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // Test that no recovery codes are returned b/c cloud feature isn't enabled. var response ui.RecoveryCodes - c.Assert(json.Unmarshal(re.Bytes(), &response), IsNil) - c.Assert(response.Codes, IsNil) - c.Assert(response.Created, IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &response)) + require.Nil(t, response.Codes) + require.Nil(t, response.Created) } -func (s *WebSuite) TestChangePasswordAndAddU2FDeviceWithToken(c *C) { +func TestChangePasswordAndAddU2FDeviceWithToken(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorU2F, @@ -1803,37 +1843,37 @@ func (s *WebSuite) TestChangePasswordAndAddU2FDeviceWithToken(c *C) { Facets: []string{"https://" + s.server.ClusterName()}, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) - s.createUser(c, "user2", "root", "password", "") + s.createUser(t, "user2", "root", "password", "") // create reset password token token, err := s.server.Auth().CreateResetPasswordToken(context.TODO(), auth.CreateUserTokenRequest{ Name: "user2", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() re, err := clt.Get(context.Background(), clt.Endpoint("webapi", "u2f", "signuptokens", token.GetName()), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var u2fRegReq u2f.RegisterChallenge - c.Assert(json.Unmarshal(re.Bytes(), &u2fRegReq), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &u2fRegReq)) u2fRegResp, err := s.mockU2F.RegisterResponse(&u2fRegReq) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(auth.ChangePasswordWithTokenRequest{ TokenID: token.GetName(), Password: []byte("qweQWE"), U2FRegisterResponse: u2fRegResp, }) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(data)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1843,77 +1883,81 @@ func (s *WebSuite) TestChangePasswordAndAddU2FDeviceWithToken(c *C) { re, err = clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // Test that no recovery codes are returned b/c cloud is not turned on. var response ui.RecoveryCodes - c.Assert(json.Unmarshal(re.Bytes(), &response), IsNil) - c.Assert(response.Codes, IsNil) - c.Assert(response.Created, IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &response)) + require.Nil(t, response.Codes) + require.Nil(t, response.Created) } // TestEmptyMotD ensures that responses returned by both /webapi/ping and // /webapi/motd work when no MotD is set -func (s *WebSuite) TestEmptyMotD(c *C) { - ctx := context.Background() +func TestEmptyMotD(t *testing.T) { + t.Parallel() + s := newWebSuite(t) wc := s.client() // Given an auth server configured *not* to expose a Message Of The // Day... // When I issue a ping request... - re, err := wc.Get(ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + require.NoError(t, err) // Expect that the MotD flag in the ping response is *not* set var pingResponse *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &pingResponse), IsNil) - c.Assert(pingResponse.Auth.HasMessageOfTheDay, Equals, false) + require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) + require.False(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(ctx, wc.Endpoint("webapi", "motd"), url.Values{}) - c.Assert(err, IsNil) + re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + require.NoError(t, err) // Expect that an empty response returned var motdResponse *webclient.MotD - c.Assert(json.Unmarshal(re.Bytes(), &motdResponse), IsNil) - c.Assert(motdResponse.Text, Equals, "") + require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) + require.Empty(t, motdResponse.Text) } // TestMotD ensures that a response is returned by both /webapi/ping and /webapi/motd // and that that the response bodies contain their MOTD components -func (s *WebSuite) TestMotD(c *C) { +func TestMotD(t *testing.T) { + t.Parallel() const motd = "Hello. I'm a Teleport cluster!" - ctx := context.Background() + s := newWebSuite(t) wc := s.client() // Given an auth server configured to expose a Message Of The Day... prefs := types.DefaultAuthPreference() prefs.SetMessageOfTheDay(motd) - s.server.AuthServer.AuthServer.SetAuthPreference(ctx, prefs) + require.NoError(t, s.server.AuthServer.AuthServer.SetAuthPreference(s.ctx, prefs)) // When I issue a ping request... - re, err := wc.Get(ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + require.NoError(t, err) // Expect that the MotD flag in the ping response is set to indicate // a MotD var pingResponse *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &pingResponse), IsNil) - c.Assert(pingResponse.Auth.HasMessageOfTheDay, Equals, true) + require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) + require.True(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(ctx, wc.Endpoint("webapi", "motd"), url.Values{}) - c.Assert(err, IsNil) + re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + require.NoError(t, err) // Expect that the text returned is the configured value var motdResponse *webclient.MotD - c.Assert(json.Unmarshal(re.Bytes(), &motdResponse), IsNil) - c.Assert(motdResponse.Text, Equals, motd) + require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) + require.Equal(t, motd, motdResponse.Text) } -func (s *WebSuite) TestMultipleConnectors(c *C) { +func TestMultipleConnectors(t *testing.T) { + t.Parallel() + s := newWebSuite(t) wc := s.client() // create two oidc connectors, one named "foo" and another named "bar" @@ -1933,61 +1977,61 @@ func (s *WebSuite) TestMultipleConnectors(c *C) { }, } o, err := types.NewOIDCConnector("foo", oidcConnectorSpec) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o) - c.Assert(err, IsNil) + require.NoError(t, err) o2, err := types.NewOIDCConnector("bar", oidcConnectorSpec) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o2) - c.Assert(err, IsNil) + require.NoError(t, err) // set the auth preferences to oidc with no connector name authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) - c.Assert(err, IsNil) + require.NoError(t, err) // hit the ping endpoint to get the auth type and connector name re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var out *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector name we got back was the first connector // in the backend, in this case it's "bar" oidcConnectors, err := s.server.Auth().GetOIDCConnectors(s.ctx, false) - c.Assert(err, IsNil) - c.Assert(out.Auth.OIDC.Name, Equals, oidcConnectors[0].GetName()) + require.NoError(t, err) + require.Equal(t, oidcConnectors[0].GetName(), out.Auth.OIDC.Name) // update the auth preferences and this time specify the connector name authPreference, err = types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", ConnectorName: "foo", }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) - c.Assert(err, IsNil) + require.NoError(t, err) // hit the ping endpoing to get the auth type and connector name re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector we get back is "foo" - c.Assert(out.Auth.OIDC.Name, Equals, "foo") + require.Equal(t, "foo", out.Auth.OIDC.Name) } // TestConstructSSHResponse checks if the secret package uses AES-GCM to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. -func (s *WebSuite) TestConstructSSHResponse(c *C) { +func TestConstructSSHResponse(t *testing.T) { key, err := secret.NewKey() - c.Assert(err, IsNil) + require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") - c.Assert(err, IsNil) + require.NoError(t, err) query := u.Query() query.Set("secret_key", key.String()) u.RawQuery = query.Encode() @@ -1998,35 +2042,35 @@ func (s *WebSuite) TestConstructSSHResponse(c *C) { TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(rawresp.Query().Get("secret"), Equals, "") - c.Assert(rawresp.Query().Get("secret_key"), Equals, "") - c.Assert(rawresp.Query().Get("response"), Not(Equals), "") + require.Empty(t, rawresp.Query().Get("secret")) + require.Empty(t, rawresp.Query().Get("secret_key")) + require.NotEmpty(t, rawresp.Query().Get("response")) plaintext, err := key.Open([]byte(rawresp.Query().Get("response"))) - c.Assert(err, IsNil) + require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) - c.Assert(err, IsNil) - c.Assert(resp.Username, Equals, "foo") - c.Assert(resp.Cert, DeepEquals, []byte{0x00}) - c.Assert(resp.TLSCert, DeepEquals, []byte{0x01}) + require.NoError(t, err) + require.Equal(t, "foo", resp.Username) + require.EqualValues(t, []byte{0x00}, resp.Cert) + require.EqualValues(t, []byte{0x01}, resp.TLSCert) } // TestConstructSSHResponseLegacy checks if the secret package uses NaCl to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. -func (s *WebSuite) TestConstructSSHResponseLegacy(c *C) { +func TestConstructSSHResponseLegacy(t *testing.T) { key, err := lemma_secret.NewKey() - c.Assert(err, IsNil) + require.NoError(t, err) lemma, err := lemma_secret.New(&lemma_secret.Config{KeyBytes: key}) - c.Assert(err, IsNil) + require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") - c.Assert(err, IsNil) + require.NoError(t, err) query := u.Query() query.Set("secret", lemma_secret.KeyToEncodedString(key)) u.RawQuery = query.Encode() @@ -2037,25 +2081,25 @@ func (s *WebSuite) TestConstructSSHResponseLegacy(c *C) { TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(rawresp.Query().Get("secret"), Equals, "") - c.Assert(rawresp.Query().Get("secret_key"), Equals, "") - c.Assert(rawresp.Query().Get("response"), Not(Equals), "") + require.Empty(t, rawresp.Query().Get("secret")) + require.Empty(t, rawresp.Query().Get("secret_key")) + require.NotEmpty(t, rawresp.Query().Get("response")) var sealedData *lemma_secret.SealedBytes err = json.Unmarshal([]byte(rawresp.Query().Get("response")), &sealedData) - c.Assert(err, IsNil) + require.NoError(t, err) plaintext, err := lemma.Open(sealedData) - c.Assert(err, IsNil) + require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) - c.Assert(err, IsNil) - c.Assert(resp.Username, Equals, "foo") - c.Assert(resp.Cert, DeepEquals, []byte{0x00}) - c.Assert(resp.TLSCert, DeepEquals, []byte{0x01}) + require.NoError(t, err) + require.Equal(t, "foo", resp.Username) + require.EqualValues(t, []byte{0x00}, resp.Cert) + require.EqualValues(t, []byte{0x01}, resp.TLSCert) } type byTimeAndIndex []apievents.AuditEvent @@ -2078,11 +2122,13 @@ func (f byTimeAndIndex) Swap(i, j int) { } // TestSearchClusterEvents makes sure web API allows querying events by type. -func (s *WebSuite) TestSearchClusterEvents(c *C) { +func TestSearchClusterEvents(t *testing.T) { + t.Parallel() // We need a clock that uses the current time here to work around // the fact that filelog doesn't support emitting past events. clock := clockwork.NewRealClock() + s := newWebSuite(t) sessionEvents := events.GenerateTestSession(events.SessionParams{ PrintEvents: 3, Clock: clock, @@ -2090,7 +2136,7 @@ func (s *WebSuite) TestSearchClusterEvents(c *C) { }) for _, e := range sessionEvents { - c.Assert(s.proxyClient.EmitAuditEvent(context.TODO(), e), IsNil) + require.NoError(t, s.proxyClient.EmitAuditEvent(s.ctx, e)) } sort.Sort(sort.Reverse(byTimeAndIndex(sessionEvents))) @@ -2175,49 +2221,50 @@ func (s *WebSuite) TestSearchClusterEvents(c *C) { }, } - pack := s.authPack(c, "foo") - // var sessionStartKey string + pack := s.authPack(t, "foo") for _, tc := range testCases { - result := s.searchEvents(c, pack.clt, tc.Query, []string{sessionStart.GetType(), sessionPrint.GetType(), sessionEnd.GetType()}) - c.Assert(result.Events, HasLen, len(tc.Result), Commentf(tc.Comment)) - for i, resultEvent := range result.Events { - c.Assert(resultEvent.GetType(), Equals, tc.Result[i].GetType(), Commentf(tc.Comment)) - c.Assert(resultEvent.GetID(), Equals, tc.Result[i].GetID(), Commentf(tc.Comment)) - } + tc := tc + t.Run(tc.Comment, func(t *testing.T) { + t.Parallel() + response, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), tc.Query) + require.NoError(t, err) + var result eventsListGetResponse + require.NoError(t, json.Unmarshal(response.Bytes(), &result)) - // Session prints do not have ID's, only sessionStart and sessionEnd. - // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. - if tc.TestStartKey { - c.Assert(result.StartKey, Equals, tc.StartKeyValue, Commentf(tc.Comment)) - } - } -} + require.Len(t, result.Events, len(tc.Result)) + for i, resultEvent := range result.Events { + require.Equal(t, tc.Result[i].GetType(), resultEvent.GetType()) + require.Equal(t, tc.Result[i].GetID(), resultEvent.GetID()) + } -func (s *WebSuite) searchEvents(c *C, clt *client.WebClient, query url.Values, filter []string) eventsListGetResponse { - response, err := clt.Get(context.Background(), clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), query) - c.Assert(err, IsNil) - var out eventsListGetResponse - c.Assert(json.Unmarshal(response.Bytes(), &out), IsNil) - return out + // Session prints do not have ID's, only sessionStart and sessionEnd. + // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. + if tc.TestStartKey { + require.Equal(t, tc.StartKeyValue, result.StartKey) + } + }) + } } -func (s *WebSuite) TestGetClusterDetails(c *C) { +func TestGetClusterDetails(t *testing.T) { + t.Parallel() + s := newWebSuite(t) site, err := s.proxyTunnel.GetSite(s.server.ClusterName()) - c.Assert(err, IsNil) - c.Assert(site, NotNil) + require.NoError(t, err) + require.NotNil(t, site) cluster, err := ui.GetClusterDetails(s.ctx, site) - c.Assert(err, IsNil) - c.Assert(cluster.Name, Equals, s.server.ClusterName()) - c.Assert(cluster.ProxyVersion, Equals, teleport.Version) - c.Assert(cluster.PublicURL, Equals, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort)) - c.Assert(cluster.Status, Equals, teleport.RemoteClusterStatusOnline) - c.Assert(cluster.LastConnected, NotNil) - c.Assert(cluster.AuthVersion, Equals, teleport.Version) + require.NoError(t, err) + require.Equal(t, s.server.ClusterName(), cluster.Name) + require.Equal(t, teleport.Version, cluster.ProxyVersion) + require.Equal(t, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort), cluster.PublicURL) + require.Equal(t, teleport.RemoteClusterStatusOnline, cluster.Status) + require.NotNil(t, cluster.LastConnected) + require.Equal(t, teleport.Version, cluster.AuthVersion) nodes, err := s.proxyClient.GetNodes(s.ctx, apidefaults.Namespace) - c.Assert(err, IsNil) - c.Assert(nodes, HasLen, cluster.NodeCount) + require.NoError(t, err) + require.Len(t, nodes, cluster.NodeCount) } func TestTokenGeneration(t *testing.T) { @@ -2269,7 +2316,9 @@ func TestTokenGeneration(t *testing.T) { } for _, tc := range tt { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] @@ -2985,9 +3034,11 @@ func TestCreateRegisterChallenge(t *testing.T) { } // TestCreateAppSession verifies that an existing session to the Web UI can -// be exchanged for a application specific session. -func (s *WebSuite) TestCreateAppSession(c *C) { - pack := s.authPack(c, "foo@example.com") +// be exchanged for an application specific session. +func TestCreateAppSession(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo@example.com") // Register an application called "panel". app, err := types.NewAppV3(types.Metadata{ @@ -2996,130 +3047,134 @@ func (s *WebSuite) TestCreateAppSession(c *C) { URI: "http://127.0.0.1:8080", PublicAddr: "panel.example.com", }) - c.Assert(err, IsNil) + require.NoError(t, err) server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) - c.Assert(err, IsNil) - _, err = s.server.Auth().UpsertApplicationServer(context.Background(), server) - c.Assert(err, IsNil) + require.NoError(t, err) + _, err = s.server.Auth().UpsertApplicationServer(s.ctx, server) + require.NoError(t, err) // Extract the session ID and bearer token for the current session. rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) - c.Assert(err, IsNil) + require.NoError(t, err) var sessionCookie SessionCookie err = json.Unmarshal(cookieBytes, &sessionCookie) - c.Assert(err, IsNil) + require.NoError(t, err) tests := []struct { - inComment CommentInterface + name string inCreateRequest *CreateAppSessionRequest - outError bool + outError require.ErrorAssertionFunc outFQDN string outUsername string }{ { - inComment: Commentf("Valid request: all fields."), + name: "Valid request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Valid request: without FQDN."), + name: "Valid request: without FQDN", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Valid request: only FQDN."), + name: "Valid request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Invalid request: only public address."), + name: "Invalid request: only public address", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid request: only cluster name."), + name: "Invalid request: only cluster name", inCreateRequest: &CreateAppSessionRequest{ ClusterName: "localhost", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid application."), + name: "Invalid application", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "invalid.example.com", ClusterName: "localhost", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid cluster name."), + name: "Invalid cluster name", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "example.com", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Malicious request: all fields."), + name: "Malicious request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Malicious request: only FQDN."), + name: "Malicious request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", }, - outError: true, + outError: require.Error, }, } for _, tt := range tests { - // Make a request to create an application session for "panel". - endpoint := pack.clt.Endpoint("webapi", "sessions", "app") - resp, err := pack.clt.PostJSON(context.Background(), endpoint, tt.inCreateRequest) - c.Assert(err != nil, Equals, tt.outError, tt.inComment) - if tt.outError { - continue - } + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Make a request to create an application session for "panel". + endpoint := pack.clt.Endpoint("webapi", "sessions", "app") + resp, err := pack.clt.PostJSON(s.ctx, endpoint, tt.inCreateRequest) + tt.outError(t, err) + if err != nil { + return + } - // Unmarshal the response. - var response *CreateAppSessionResponse - c.Assert(json.Unmarshal(resp.Bytes(), &response), IsNil, tt.inComment) - c.Assert(response.FQDN, Equals, tt.outFQDN, tt.inComment) + // Unmarshal the response. + var response *CreateAppSessionResponse + require.NoError(t, json.Unmarshal(resp.Bytes(), &response)) + require.Equal(t, tt.outFQDN, response.FQDN) - // Verify that the application session was created. - session, err := s.server.Auth().GetAppSession(context.Background(), types.GetAppSessionRequest{ - SessionID: response.CookieValue, + // Verify that the application session was created. + sess, err := s.server.Auth().GetAppSession(s.ctx, types.GetAppSessionRequest{ + SessionID: response.CookieValue, + }) + require.NoError(t, err) + require.Equal(t, tt.outUsername, sess.GetUser()) + require.Equal(t, response.CookieValue, sess.GetName()) }) - c.Assert(err, IsNil) - c.Assert(session.GetUser(), Equals, tt.outUsername, tt.inComment) - c.Assert(session.GetName(), Equals, response.CookieValue, tt.inComment) } } @@ -3131,7 +3186,7 @@ func TestNewSessionResponseWithRenewSession(t *testing.T) { duration := time.Duration(5) * time.Minute cfg := types.DefaultClusterNetworkingConfig() cfg.SetWebIdleTimeout(duration) - env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg) + require.NoError(t, env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg)) proxy := env.proxies[0] pack := proxy.authPack(t, "foo") @@ -3169,7 +3224,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { env.clock.Advance(auth.BearerTokenTTL - delta) // Renew the session using the 1st proxy - resp := pack1.renewSession(context.TODO(), t) + resp := pack1.renewSession(context.Background(), t) // Expire the old session and make sure it has been removed. // The bearer token is also removed after this point, so we have to @@ -3178,7 +3233,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { pack2 = proxy2.authPackFromResponse(t, resp) // Verify that access via the 2nd proxy also works for the same session - pack2.validateAPI(context.TODO(), t) + pack2.validateAPI(context.Background(), t) // Check whether the terminal session is still active validateTerminalStream(t, ws) @@ -3207,12 +3262,12 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // prevSessionCookie := *pack.cookies[0] prevBearerToken := pack.session.Token - resp := pack.renewSession(context.TODO(), t) + resp := pack.renewSession(context.Background(), t) newPack := proxy.authPackFromResponse(t, resp) // new session is functioning - newPack.validateAPI(context.TODO(), t) + newPack.validateAPI(context.Background(), t) sessionCookie := *newPack.cookies[0] bearerToken := newPack.session.Token @@ -3235,7 +3290,7 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // now expire the old session and make sure it has been removed env.clock.Advance(delta) - _, err = proxy.client.GetWebSession(context.TODO(), types.GetWebSessionRequest{ + _, err = proxy.client.GetWebSession(context.Background(), types.GetWebSessionRequest{ User: "foo", SessionID: prevSessionID, }) @@ -3278,7 +3333,7 @@ func TestChangeUserAuthentication_recoveryCodesReturnedForCloud(t *testing.T) { // Creaet a username that is not a valid email format for recovery. teleUser, err := types.NewUser("invalid-name-for-recovery") require.NoError(t, err) - env.server.Auth().CreateUser(ctx, teleUser) + require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err := env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ @@ -3308,7 +3363,7 @@ func TestChangeUserAuthentication_recoveryCodesReturnedForCloud(t *testing.T) { // Create a user that is valid for recovery. teleUser, err = types.NewUser("valid-username@example.com") require.NoError(t, err) - env.server.Auth().CreateUser(ctx, teleUser) + require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err = env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ @@ -3416,7 +3471,7 @@ func (mock authProviderMock) GetSessionEvents(n string, s session.ID, c int, p b return []events.EventFields{}, nil } -func (s *WebSuite) makeTerminal(pack *authPack, opts ...session.ID) (*websocket.Conn, error) { +func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...session.ID) (*websocket.Conn, error) { var sessionID session.ID if len(opts) == 0 { sessionID = session.NewID() @@ -3463,7 +3518,7 @@ func (s *WebSuite) makeTerminal(pack *authPack, opts ...session.ID) (*websocket. return nil, trace.Wrap(err) } - resp.Body.Close() + require.NoError(t, resp.Body.Close()) return ws, nil } @@ -3489,7 +3544,7 @@ func waitForOutput(stream *terminalStream, substr string) error { } func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), timeout) + timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) defer timeoutCancel() done := make(chan error, 1) @@ -3532,7 +3587,7 @@ func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) er } func (s *WebSuite) waitForResizeEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), timeout) + timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) defer timeoutCancel() done := make(chan error, 1) @@ -3867,6 +3922,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) + proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(proxyNodeWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -3880,6 +3944,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DirectClusters: []reversetunnel.DirectCluster{{Name: authServer.ClusterName(), Client: client}}, DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) @@ -3902,6 +3967,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) }) @@ -3995,7 +4061,7 @@ func (r *proxy) authPack(t *testing.T, user string) *authPack { err = r.auth.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) - r.createUser(context.TODO(), t, user, loginUser, pass, otpSecret) + r.createUser(context.Background(), t, user, loginUser, pass, otpSecret) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, r.clock.Now()) @@ -4142,8 +4208,8 @@ func (r *proxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) t.Cleanup(func() { - ws.Close() - resp.Body.Close() + require.NoError(t, ws.Close()) + require.NoError(t, resp.Body.Close()) }) return ws } @@ -4175,8 +4241,8 @@ func (r *proxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID sessi ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) t.Cleanup(func() { - ws.Close() - resp.Body.Close() + require.NoError(t, ws.Close()) + require.NoError(t, resp.Body.Close()) }) return ws } diff --git a/tool/tsh/proxy_test.go b/tool/tsh/proxy_test.go index 6b12b2aa739ac..a850f1a0e907d 100644 --- a/tool/tsh/proxy_test.go +++ b/tool/tsh/proxy_test.go @@ -51,10 +51,8 @@ func TestTSHSSH(t *testing.T) { lib.SetInsecureDevMode(true) defer lib.SetInsecureDevMode(false) - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + require.NoError(t, os.RemoveAll(profile.FullProfilePath(""))) + t.Cleanup(func() { require.NoError(t, os.RemoveAll(profile.FullProfilePath(""))) }) s := newTestSuite(t, withRootConfigFunc(func(cfg *service.Config) { @@ -142,12 +140,14 @@ func testLeafClusterSSHAccess(t *testing.T, s *suite) { }) require.NoError(t, err) - err = Run([]string{ - "ssh", - s.leaf.Config.Hostname, - "echo", "hello", - }) - require.NoError(t, err) + require.Eventually(t, func() bool { + err = Run([]string{ + "ssh", + s.leaf.Config.Hostname, + "echo", "hello", + }) + return err == nil + }, 5*time.Second, time.Second) identityFile := path.Join(t.TempDir(), "identity.pem") err = Run([]string{ diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 39ad78751a5b4..52744c12fe3ea 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -18,6 +18,7 @@ package main import ( "context" + "errors" "fmt" "io" "net" @@ -321,6 +322,14 @@ func (c *CLIConf) Stderr() io.Writer { return os.Stderr } +type exitCodeError struct { + code int +} + +func (e *exitCodeError) Error() string { + return fmt.Sprintf("exit code %d", e.code) +} + func main() { cmdLineOrig := os.Args[1:] var cmdLine []string @@ -336,6 +345,10 @@ func main() { cmdLine = cmdLineOrig } if err := Run(cmdLine); err != nil { + var exitError *exitCodeError + if errors.As(err, &exitError) { + os.Exit(exitError.code) + } utils.FatalError(err) } } @@ -1295,7 +1308,7 @@ func onLogout(cf *CLIConf) error { if err != nil { if trace.IsNotFound(err) { fmt.Printf("User %v already logged out from %v.\n", cf.Username, proxyHost) - os.Exit(1) + return trace.Wrap(&exitCodeError{code: 1}) } return trace.Wrap(err) } @@ -1926,15 +1939,14 @@ func onSSH(cf *CLIConf) error { fmt.Fprintf(os.Stderr, "Hint: try addressing the node by unique id (ex: tsh ssh user@node-id)\n") fmt.Fprintf(os.Stderr, "Hint: use 'tsh ls -v' to list all nodes with their unique ids\n") fmt.Fprintf(os.Stderr, "\n") - os.Exit(1) + return trace.Wrap(&exitCodeError{code: 1}) } // exit with the same exit status as the failed command: if tc.ExitStatus != 0 { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) - } else { - return trace.Wrap(err) + return trace.Wrap(&exitCodeError{code: tc.ExitStatus}) } + return trace.Wrap(err) } return nil } @@ -1953,7 +1965,7 @@ func onBenchmark(cf *CLIConf) error { result, err := cnf.Benchmark(cf.Context, tc) if err != nil { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(255) + return trace.Wrap(&exitCodeError{code: 255}) } fmt.Printf("\n") fmt.Printf("* Requests originated: %v\n", result.RequestsOriginated) @@ -2026,7 +2038,7 @@ func onSCP(cf *CLIConf) error { // exit with the same exit status as the failed command: if tc.ExitStatus != 0 { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) + return trace.Wrap(&exitCodeError{code: tc.ExitStatus}) } return trace.Wrap(err) }