diff --git a/integration/integration_test.go b/integration/integration_test.go index 4239883610cfe..78ecc73941e63 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -2030,7 +2030,8 @@ func (s *IntSuite) TestDiscoveryRecovers(c *check.C) { continue } if p.Config.Hostname == name { - lb.RemoveBackend(*utils.MustParseAddr(p.Config.Proxy.ReverseTunnelListenAddr.Addr)) + reverseTunnelPort := utils.MustParseAddr(p.Config.Proxy.ReverseTunnelListenAddr.Addr).Port(0) + c.Assert(lb.RemoveBackend(*utils.MustParseAddr(net.JoinHostPort(Loopback, strconv.Itoa(reverseTunnelPort)))), check.IsNil) c.Assert(p.Close(), check.IsNil) c.Assert(p.Wait(), check.IsNil) return @@ -2073,12 +2074,12 @@ func (s *IntSuite) TestDiscoveryRecovers(c *check.C) { // create first numbered proxy _, c0 := addNewMainProxy(pname(0)) // check that we now have two tunnel connections - waitForProxyCount(remote, "cluster-main", 2) + c.Assert(waitForProxyCount(remote, "cluster-main", 2), check.IsNil) // check that first numbered proxy is OK. testProxyConn(&c0, false) // remove the initial proxy. - lb.RemoveBackend(mainProxyAddr) - waitForProxyCount(remote, "cluster-main", 1) + c.Assert(lb.RemoveBackend(mainProxyAddr), check.IsNil) + c.Assert(waitForProxyCount(remote, "cluster-main", 1), check.IsNil) // force bad state by iteratively removing previous proxy before // adding next proxy; this ensures that discovery protocol's list of @@ -2086,9 +2087,9 @@ func (s *IntSuite) TestDiscoveryRecovers(c *check.C) { for i := 0; i < 6; i++ { prev, next := pname(i), pname(i+1) killMainProxy(prev) - waitForProxyCount(remote, "cluster-main", 0) + c.Assert(waitForProxyCount(remote, "cluster-main", 0), check.IsNil) _, cn := addNewMainProxy(next) - waitForProxyCount(remote, "cluster-main", 1) + c.Assert(waitForProxyCount(remote, "cluster-main", 1), check.IsNil) testProxyConn(&cn, false) } @@ -2186,7 +2187,7 @@ func (s *IntSuite) TestDiscovery(c *check.C) { c.Assert(output, check.Equals, "hello world\n") // Now disconnect the main proxy and make sure it will reconnect eventually. - lb.RemoveBackend(mainProxyAddr) + c.Assert(lb.RemoveBackend(mainProxyAddr), check.IsNil) waitForActiveTunnelConnections(c, secondProxy, "cluster-remote", 1) // Requests going via main proxy should fail. @@ -2218,7 +2219,7 @@ func (s *IntSuite) TestDiscovery(c *check.C) { c.Assert(err, check.IsNil) // Wait for the remote cluster to detect the outbound connection is gone. - waitForProxyCount(remote, "cluster-main", 1) + c.Assert(waitForProxyCount(remote, "cluster-main", 1), check.IsNil) // Stop both clusters and remaining nodes. c.Assert(remote.StopAll(), check.IsNil) @@ -2337,7 +2338,7 @@ func (s *IntSuite) TestDiscoveryNode(c *check.C) { c.Assert(output, check.Equals, "hello world\n") // Remove second proxy from LB. - lb.RemoveBackend(*proxyTwoBackend) + c.Assert(lb.RemoveBackend(*proxyTwoBackend), check.IsNil) waitForActiveTunnelConnections(c, main.Tunnel, Site, 1) // Requests going via main proxy will succeed. Requests going via second @@ -2391,17 +2392,17 @@ func waitForActiveTunnelConnections(c *check.C, tunnel reversetunnel.Server, clu // reach some value. func waitForProxyCount(t *TeleInstance, clusterName string, count int) error { var counts map[string]int - - for i := 0; i < 20; i++ { + start := time.Now() + for time.Since(start) < 17*time.Second { counts = t.Pool.Counts() if counts[clusterName] == count { return nil } - time.Sleep(250 * time.Millisecond) + time.Sleep(500 * time.Millisecond) } - return trace.BadParameter("proxy count on %v: %v", clusterName, counts[clusterName]) + return trace.BadParameter("proxy count on %v: %v (wanted %v)", clusterName, counts[clusterName], count) } // waitForNodeCount waits for a certain number of nodes to show up in the remote site. @@ -3031,7 +3032,7 @@ func (s *IntSuite) TestRotateSuccess(c *check.C) { c.Assert(err, check.IsNil) // client works as is before servers have been rotated - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Setting rotation state to %v", services.RotationPhaseUpdateServers) @@ -3059,7 +3060,7 @@ func (s *IntSuite) TestRotateSuccess(c *check.C) { c.Assert(err, check.IsNil) // new client works - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Setting rotation state to %v.", services.RotationPhaseStandby) @@ -3080,7 +3081,7 @@ func (s *IntSuite) TestRotateSuccess(c *check.C) { c.Assert(err, check.IsNil) // new client still works - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Rotation has completed. Shuttting down service.") @@ -3175,7 +3176,7 @@ func (s *IntSuite) TestRotateRollback(c *check.C) { c.Assert(err, check.IsNil) // client works as is before servers have been rotated - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Setting rotation state to %v", services.RotationPhaseUpdateServers) @@ -3205,7 +3206,7 @@ func (s *IntSuite) TestRotateRollback(c *check.C) { c.Assert(err, check.IsNil) // old client works - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Rotation has completed. Shuttting down service.") @@ -3323,7 +3324,7 @@ func (s *IntSuite) TestRotateTrustedClusters(c *check.C) { clt, err := main.NewClientWithCreds(cfg, *initialCreds) c.Assert(err, check.IsNil) - err = runAndMatch(clt, 6, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Setting rotation state to %v", services.RotationPhaseInit) @@ -3375,7 +3376,7 @@ func (s *IntSuite) TestRotateTrustedClusters(c *check.C) { c.Assert(err, check.IsNil) // old client should work as is - err = runAndMatch(clt, 6, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Setting rotation state to %v", services.RotationPhaseUpdateServers) @@ -3402,7 +3403,7 @@ func (s *IntSuite) TestRotateTrustedClusters(c *check.C) { c.Assert(err, check.IsNil) // new client works - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Setting rotation state to %v.", services.RotationPhaseStandby) @@ -3425,7 +3426,7 @@ func (s *IntSuite) TestRotateTrustedClusters(c *check.C) { l.Infof("Phase completed.") // new client still works - err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") c.Assert(err, check.IsNil) l.Infof("Service reloaded. Rotation has completed. Shuttting down service.") @@ -3525,6 +3526,7 @@ func runAndMatch(tc *client.TeleportClient, attempts int, command []string, patt for i := 0; i < attempts; i++ { err = tc.SSH(context.TODO(), command, false) if err != nil { + time.Sleep(500 * time.Millisecond) continue } out := output.String() @@ -3534,7 +3536,7 @@ func runAndMatch(tc *client.TeleportClient, attempts int, command []string, patt return nil } err = trace.CompareFailed("output %q did not match pattern %q", out, pattern) - time.Sleep(250 * time.Millisecond) + time.Sleep(500 * time.Millisecond) } return err } diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 8aabec266e8fe..ddc426833aab1 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -121,7 +121,7 @@ type Cache struct { accessCache services.Access dynamicAccessCache services.DynamicAccessExt presenceCache services.Presence - eventsCache services.Events + eventsFanout *services.Fanout // closedFlag is set to indicate that the services are closed closedFlag int32 @@ -275,7 +275,7 @@ func New(config Config) (*Cache, error) { accessCache: local.NewAccessService(wrapper), dynamicAccessCache: local.NewDynamicAccessService(wrapper), presenceCache: local.NewPresenceService(wrapper), - eventsCache: local.NewEventsService(config.Backend), + eventsFanout: services.NewFanout(), Entry: log.WithFields(log.Fields{ trace.Component: config.Component, }), @@ -305,7 +305,7 @@ func New(config Config) (*Cache, error) { // to handle subscribers connected to the in-memory caches // instead of reading from the backend. func (c *Cache) NewWatcher(ctx context.Context, watch services.Watch) (services.Watcher, error) { - return c.eventsCache.NewWatcher(ctx, watch) + return c.eventsFanout.NewWatcher(ctx, watch) } func (c *Cache) isClosed() bool { @@ -337,7 +337,7 @@ func (c *Cache) update() { // all watchers will be out of sync, because // cache will reload its own watcher to the backend, // so signal closure to reset the watchers - c.Backend.CloseWatchers() + c.eventsFanout.CloseWatchers() // events cache should be closed as well c.Debugf("Reloading %v.", retry) select { @@ -528,7 +528,11 @@ func (c *Cache) processEvent(event services.Event) error { c.Warningf("Skipping unsupported event %v.", event.Resource.GetKind()) return nil } - return collection.processEvent(event) + if err := collection.processEvent(event); err != nil { + return trace.Wrap(err) + } + c.eventsFanout.Emit(event) + return nil } // GetCertAuthority returns certificate authority by given id. Parameter loadSigningKeys diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index 097016281e10b..1089e87d3e18b 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -414,13 +414,9 @@ func (s *CacheSuite) preferRecent(c *check.C) { // service is back and the new value has propagated p.backend.SetReadError(nil) - // wait for watcher to restart - select { - case event := <-p.eventsC: - c.Assert(event.Type, check.Equals, WatcherStarted) - case <-time.After(time.Second): - c.Fatalf("timeout waiting for event") - } + // wait for watcher to restart successfully; ignoring any failed + // attempts which ocurred before backend became healthy again. + waitForEvent(c, p.eventsC, WatcherStarted, WatcherFailed) // new value is available now out, err = p.cache.GetCertAuthority(ca.GetID(), false) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index a294a86d03acc..ff8d81b505d72 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -30,7 +30,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/reversetunnel/seek" + "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -46,13 +46,8 @@ const ( // agentStateConnecting is when agent is connecting to the target // without particular purpose agentStateConnecting = "connecting" - // agentStateDiscovering is when agent is created with a goal - // to discover one or many proxies - agentStateDiscovering = "discovering" // agentStateConnected means that agent has connected to instance agentStateConnected = "connected" - // agentStateDiscovered means that agent has discovered the right proxy - agentStateDiscovered = "discovered" // agentStateDisconnected means that the agent has disconnected from the // proxy and this agent and be removed from the pool. agentStateDisconnected = "disconnected" @@ -60,8 +55,6 @@ const ( // AgentConfig holds configuration for agent type AgentConfig struct { - // numeric id of agent - ID uint64 // Addr is target address to dial Addr utils.NetAddr // ClusterName is the name of the cluster the tunnel is connected to. When the @@ -81,9 +74,6 @@ type AgentConfig struct { DiscoveryC chan *discoveryRequest // Username is the name of this client used to authenticate on SSH Username string - // DiscoverProxies is set when the agent is created in discovery mode - // and is set to connect to one of the target proxies from the list - DiscoverProxies []services.Server // Clock is a clock passed in tests, if not set wall clock // will be used Clock clockwork.Clock @@ -98,8 +88,11 @@ type AgentConfig struct { ReverseTunnelServer Server // LocalClusterName is the name of the cluster this agent is running in. LocalClusterName string - // SeekGroup manages gossip and exclusive claims for agents. - SeekGroup *seek.GroupHandle + // Tracker tracks proxy + Tracker *track.Tracker + // Lease manages gossip and exclusive claims. Lease may be nil + // when used in the context of tests. + Lease track.Lease } // CheckAndSetDefaults checks parameters and sets default values @@ -166,17 +159,13 @@ func NewAgent(cfg AgentConfig) (*Agent, error) { ctx: ctx, cancel: cancel, authMethods: []ssh.AuthMethod{ssh.PublicKeys(cfg.Signers...)}, - } - if len(cfg.DiscoverProxies) == 0 { - a.state = agentStateConnecting - } else { - a.state = agentStateDiscovering + state: agentStateConnecting, } a.Entry = log.WithFields(log.Fields{ trace.Component: teleport.ComponentReverseTunnelAgent, trace.ComponentFields: log.Fields{ - "target": cfg.Addr.String(), - "id": cfg.ID, + "target": cfg.Addr.String(), + "leaseID": a.Lease.ID(), }, }) a.hostKeyCallback = a.checkHostSignature @@ -184,26 +173,16 @@ func NewAgent(cfg AgentConfig) (*Agent, error) { } func (a *Agent) String() string { - if len(a.DiscoverProxies) == 0 { - return fmt.Sprintf("agent(id=%d,state=%v) -> %v:%v", a.ID, a.getState(), a.ClusterName, a.Addr.String()) - } - return fmt.Sprintf("agent(id=%d,state=%v) -> %v:%v, discover %v", a.ID, a.getState(), a.ClusterName, a.Addr.String(), Proxies(a.DiscoverProxies)) + return fmt.Sprintf("agent(leaseID=%d,state=%v) -> %v:%v", a.Lease.ID(), a.getState(), a.ClusterName, a.Addr.String()) } -func (a *Agent) setStateAndPrincipals(state string, principals []string) { - a.Lock() - defer a.Unlock() - prev := a.state - a.Debugf("Changing state %v -> %v.", prev, state) - a.state = state - a.stateChange = a.Clock.Now().UTC() - a.principals = principals -} func (a *Agent) setState(state string) { a.Lock() defer a.Unlock() prev := a.state - a.Debugf("Changing state %v -> %v.", prev, state) + if prev != state { + a.Debugf("Changing state %v -> %v.", prev, state) + } a.state = state a.stateChange = a.Clock.Now().UTC() } @@ -230,27 +209,6 @@ func (a *Agent) Wait() error { return nil } -// connectedTo returns true if connected services.Server passed in. -func (a *Agent) connectedTo(proxy services.Server) bool { - principals := a.getPrincipals() - proxyID := fmt.Sprintf("%v.%v", proxy.GetName(), a.ClusterName) - if _, ok := principals[proxyID]; ok { - return true - } - return false -} - -// connectedToRightProxy returns true if it connected to a proxy in the -// discover list. -func (a *Agent) connectedToRightProxy() bool { - for _, proxy := range a.DiscoverProxies { - if a.connectedTo(proxy) { - return true - } - } - return false -} - func (a *Agent) setPrincipals(principals []string) { a.Lock() defer a.Unlock() @@ -265,16 +223,6 @@ func (a *Agent) getPrincipalsList() []string { return out } -func (a *Agent) getPrincipals() map[string]struct{} { - a.RLock() - defer a.RUnlock() - out := make(map[string]struct{}, len(a.principals)) - for _, p := range a.principals { - out[p] = struct{}{} - } - return out -} - func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.PublicKey) error { cert, ok := key.(*ssh.Certificate) if !ok { @@ -379,12 +327,9 @@ func (a *Agent) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh. // determines disconnects. func (a *Agent) run() { defer a.setState(agentStateDisconnected) + defer a.Lease.Release() - if len(a.DiscoverProxies) != 0 { - a.setStateAndPrincipals(agentStateDiscovering, nil) - } else { - a.setStateAndPrincipals(agentStateConnecting, nil) - } + a.setState(agentStateConnecting) // Try and connect to remote cluster. conn, err := a.connect() @@ -396,23 +341,12 @@ func (a *Agent) run() { // Successfully connected to remote cluster. a.Infof("Connected to %s", conn.RemoteAddr()) - if len(a.DiscoverProxies) != 0 { - // If not connected to a proxy in the discover list (which means we - // connected to a proxy we already have a connection to), try again. - if !a.connectedToRightProxy() { - a.Debugf("Missed, connected to %v instead of %v.", a.getPrincipalsList(), Proxies(a.DiscoverProxies)) - return - } - a.Debugf("Agent discovered proxy: %v.", a.getPrincipalsList()) - a.setState(agentStateDiscovered) - } else { - a.Debugf("Agent connected to proxy: %v.", a.getPrincipalsList()) - a.setState(agentStateConnected) - } // wrap up remaining business logic in closure for easy // conditional execution. doWork := func() { + a.Debugf("Agent connected to proxy: %v.", a.getPrincipalsList()) + a.setState(agentStateConnected) // Notify waiters that the agent has connected. if a.EventsC != nil { select { @@ -434,10 +368,10 @@ func (a *Agent) run() { return } } - // if a SeekGroup was provided, then the agent shouldn't continue unless + // if Tracker was provided, then the agent shouldn't continue unless // no other agents hold a claim. - if a.SeekGroup != nil { - if !a.SeekGroup.WithProxy(doWork, a.getPrincipalsList()...) { + if a.Tracker != nil { + if !a.Tracker.WithProxy(doWork, a.Lease, a.getPrincipalsList()...) { a.Debugf("Proxy already held by other agent: %v, releasing.", a.getPrincipalsList()) } } else { @@ -545,7 +479,6 @@ func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { var req *ssh.Request select { case <-a.ctx.Done(): - a.Infof("Closed, returning.") return case req = <-reqC: if req == nil { @@ -562,24 +495,14 @@ func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { select { case a.DiscoveryC <- r: case <-a.ctx.Done(): - a.Infof("Closed, returning.") return default: } } - if a.SeekGroup != nil { - // Broadcast proxies to the seek group. - Gossip: - for i, p := range r.Proxies { - select { - case a.SeekGroup.Gossip() <- p.GetName(): - case <-a.ctx.Done(): - a.Infof("Closed, returning.") - return - default: - a.Warnf("Backlog in gossip channel, skipping %d proxies.", len(r.Proxies)-i) - break Gossip - } + if a.Tracker != nil { + // Notify tracker of all known proxies. + for _, p := range r.Proxies { + a.Tracker.TrackExpected(a.Lease, p.GetName()) } } } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index d531eb218dbe1..7ba3ff43cc778 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -26,7 +26,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/reversetunnel/seek" + "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -50,14 +50,15 @@ type ServerHandler interface { type AgentPool struct { sync.Mutex *log.Entry - cfg AgentPoolConfig - agents map[agentKey][]*Agent - seekPool *seek.Pool - ctx context.Context - cancel context.CancelFunc + cfg AgentPoolConfig + agents map[agentKey][]*Agent + proxyTracker *track.Tracker + ctx context.Context + cancel context.CancelFunc // lastReport is the last time the agent has reported the stats - lastReport time.Time - lastAgentID uint64 + lastReport time.Time + // spawnLimiter limits agent spawn rate + spawnLimiter utils.Retry } // AgentPoolConfig holds configuration parameters for the agent pool @@ -91,8 +92,6 @@ type AgentPoolConfig struct { ReverseTunnelServer Server // ProxyAddr if set, points to the address of the ssh proxy ProxyAddr string - // Seek configures the proxy-seeking algorithm - Seek seek.Config } // CheckAndSetDefaults checks and sets defaults @@ -115,9 +114,6 @@ func (cfg *AgentPoolConfig) CheckAndSetDefaults() error { if cfg.Clock == nil { cfg.Clock = clockwork.NewRealClock() } - if err := cfg.Seek.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } return nil } @@ -126,18 +122,23 @@ func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - ctx, cancel := context.WithCancel(cfg.Context) - seekPool, err := seek.NewPool(ctx, cfg.Seek) + retry, err := utils.NewLinear(utils.LinearConfig{ + Step: time.Second, + Max: time.Second * 8, + Jitter: utils.NewJitter(), + AutoReset: 4, + }) if err != nil { - cancel() return nil, trace.Wrap(err) } + ctx, cancel := context.WithCancel(cfg.Context) pool := &AgentPool{ - agents: make(map[agentKey][]*Agent), - seekPool: seekPool, - cfg: cfg, - ctx: ctx, - cancel: cancel, + agents: make(map[agentKey][]*Agent), + proxyTracker: track.New(ctx, track.Config{}), + cfg: cfg, + ctx: ctx, + cancel: cancel, + spawnLimiter: retry, } pool.Entry = log.WithFields(log.Fields{ trace.Component: teleport.ComponentReverseTunnelAgent, @@ -150,6 +151,7 @@ func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { // Start starts the agent pool func (m *AgentPool) Start() error { + m.Debugf("Starting agent pool %s.%s...", m.cfg.HostUUID, m.cfg.Cluster) go m.pollAndSyncAgents() go m.processSeekEvents() return nil @@ -170,19 +172,27 @@ func (m *AgentPool) Wait() error { } func (m *AgentPool) processSeekEvents() { + limiter := m.spawnLimiter.Clone() for { select { case <-m.ctx.Done(): m.Debugf("Halting seek event processing (pool closing)") return - case key := <-m.seekPool.Seek(): - m.Debugf("Seeking: %+v.", key) + case lease := <-m.proxyTracker.Acquire(): + m.Debugf("Seeking: %+v.", lease.Key()) m.withLock(func() { - if err := m.addAgent(seekToAgentKey(key), nil); err != nil { + if err := m.addAgent(lease); err != nil { m.WithError(err).Errorf("Failed to add agent.") } }) } + select { + case <-m.ctx.Done(): + m.Debugf("Halting seek event processing (pool closing)") + return + case <-limiter.After(): + limiter.Inc() + } } } @@ -266,20 +276,17 @@ func (m *AgentPool) pollAndSyncAgents() { } } -func (m *AgentPool) addAgent(key agentKey, discoverProxies []services.Server) error { +func (m *AgentPool) addAgent(lease track.Lease) error { // If the component connecting is a proxy, get the cluster name from the // clusterName (where it is the name of the remote cluster). If it's a node, get // the cluster name from the agent pool configuration itself (where it is // the name of the local cluster). + key := keyFromLease(lease) clusterName := key.clusterName if key.tunnelType == string(services.NodeTunnel) { clusterName = m.cfg.Cluster } - seekGroup := m.seekPool.Group(agentToSeekKey(key)) - m.lastAgentID++ - agentID := m.lastAgentID agent, err := NewAgent(AgentConfig{ - ID: agentID, Addr: key.addr, ClusterName: clusterName, Username: m.cfg.HostUUID, @@ -287,14 +294,16 @@ func (m *AgentPool) addAgent(key agentKey, discoverProxies []services.Server) er Client: m.cfg.Client, AccessPoint: m.cfg.AccessPoint, Context: m.ctx, - DiscoverProxies: discoverProxies, KubeDialAddr: m.cfg.KubeDialAddr, Server: m.cfg.Server, ReverseTunnelServer: m.cfg.ReverseTunnelServer, LocalClusterName: m.cfg.Cluster, - SeekGroup: &seekGroup, + Tracker: m.proxyTracker, + Lease: lease, }) if err != nil { + // ensure that lease has been released; OK to call multiple times. + lease.Release() return trace.Wrap(err) } m.Debugf("Adding %v.", agent) @@ -311,13 +320,18 @@ func (m *AgentPool) addAgent(key agentKey, discoverProxies []services.Server) er // connected to. Used in tests to determine if a proxy has been found and/or // removed. func (m *AgentPool) Counts() map[string]int { + m.Lock() + defer m.Unlock() out := make(map[string]int) - - m.withLock(func() { - for key, agents := range m.agents { - out[key.clusterName] += len(agents) + for key, agents := range m.agents { + count := 0 + for _, agent := range agents { + if agent.getState() == agentStateConnected { + count++ + } } - }) + out[key.clusterName] += count + } return out } @@ -355,17 +369,9 @@ func (m *AgentPool) reportStats() { } for key, agents := range m.agents { - tunnelID := key.clusterName - if m.cfg.Component == teleport.ComponentNode { - tunnelID = m.cfg.HostUUID - } - m.Debugf("Outbound tunnel for %v connected to %v proxies.", tunnelID, len(agents)) - countPerState := map[string]int{ agentStateConnecting: 0, - agentStateDiscovering: 0, agentStateConnected: 0, - agentStateDiscovered: 0, agentStateDisconnected: 0, } for _, a := range agents { @@ -381,6 +387,8 @@ func (m *AgentPool) reportStats() { } if logReport { m.WithFields(log.Fields{"target": key.clusterName, "stats": countPerState}).Info("Outbound tunnel stats.") + } else { + m.WithFields(log.Fields{"target": key.clusterName, "stats": countPerState}).Debug("Outbound tunnel stats.") } } } @@ -398,14 +406,12 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { // remove agents from deleted reverse tunnels for _, key := range agentsToRemove { - m.seekPool.Stop(agentToSeekKey(key)) + m.proxyTracker.Stop(agentToTrackingKey(key)) m.closeAgents(&key) } // add agents from added reverse tunnels for _, key := range agentsToAdd { - if err := m.addAgent(key, nil); err != nil { - return trace.Wrap(err) - } + m.proxyTracker.Start(agentToTrackingKey(key)) } // Remove disconnected agents from the list of agents. @@ -494,8 +500,8 @@ type agentKey struct { addr utils.NetAddr } -// seekToAgentKey converts between key types -func seekToAgentKey(key seek.Key) agentKey { +func keyFromLease(lease track.Lease) agentKey { + key := lease.Key().(track.Key) return agentKey{ clusterName: key.Cluster, tunnelType: key.Type, @@ -503,9 +509,8 @@ func seekToAgentKey(key seek.Key) agentKey { } } -// agentToSeekKey converts between key types -func agentToSeekKey(key agentKey) seek.Key { - return seek.Key{ +func agentToTrackingKey(key agentKey) track.Key { + return track.Key{ Cluster: key.clusterName, Type: key.tunnelType, Addr: key.addr, diff --git a/lib/reversetunnel/conn.go b/lib/reversetunnel/conn.go index af11d5a2c76a9..9a04fdf44015e 100644 --- a/lib/reversetunnel/conn.go +++ b/lib/reversetunnel/conn.go @@ -164,7 +164,7 @@ func (c *remoteConn) markInvalid(err error) { atomic.StoreInt32(&c.invalid, 1) c.lastError = err - c.log.Errorf("Disconnecting connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err) + c.log.Debugf("Disconnecting connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err) } func (c *remoteConn) isInvalid() bool { diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index ce14dc5d2d24c..8f5907ebd7a75 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -318,7 +318,7 @@ func (s *localSite) fanOutProxies(proxies []services.Server) { // the connection as invalid. func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { defer func() { - s.log.Infof("Cluster connection closed.") + s.log.Debugf("Cluster connection closed.") rconn.Close() }() @@ -341,7 +341,7 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch } case req := <-reqC: if req == nil { - s.log.Infof("Cluster agent disconnected.") + s.log.Debugf("Cluster agent disconnected.") rconn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) return } diff --git a/lib/reversetunnel/seek/doc.go b/lib/reversetunnel/seek/doc.go deleted file mode 100644 index 7795cce32ab82..0000000000000 --- a/lib/reversetunnel/seek/doc.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2019 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Proxy-Seeking alogirthm -// -// Premise: An unknown number of proxies exist behind a "fair" -// load-balancer. Proxies will share their peerset via gossip, but -// this gossip is asynchronous and may suffer from ordering/timing -// issues. Furthermore, rotations may cause a complete and permanent -// loss of contact with all known proxies. -// -// Goals: Ensure that we have one agent managing a connection to each -// available proxy. Minimize unnecessary discovery attempts. Recover -// from bad state (e.g. due to full rotations) in a timely manner. -// Mitigate resource drain due to failing or unreachable proxies. -// -// -// Each known proxy has an associated entry which stores -// its seek state (seeking | claimed | backoff). -// -// When an agent discovers (connects to) a proxy, it attempts to -// acquire an exclusive claim to that proxy. If sucessful, the agent -// takes responsibility for the proxy, releasing its claim when the -// connection terminates (regardless of reason). If another agent -// has already claimed the proxy, the connection is dropped. -// -// Unclaimed entries are subject to expiry. Expiration timers are -// refreshed by gossip messages. -// -// If a claim is released within a very short interval after being -// acquired, termination is said to be premature. Premature -// termination triggers a backoff phase which pauses discovery -// attempts for the proxy. The length of the backoff phase is -// determined by an incrementing multiplier. If backoff is entered -// too often to allow the counter to reset, the backoff phase will -// grow beyond the expiry limit and the associated entry will be -// removed. -// -// +---------+ -// | | acquire -// | START +------------------------------------------------+ -// | | | -// +----+----+ v -// | +-----+-----+ -// | refresh release (ok) | | -// +-----+--------+ +----------------------------+ Claimed | -// ^ | | | | -// | v v +--+-----+--+ -// | +---+---+---+ ^ | -// | | | acquire | | -// +----+ Seeking +---------------------------+ | -// | | | -// +--------+ +---+---+---+ | -// | | | ^ +-----------+ | -// | STOP | | | done | | release (err) | -// | | | +--------+ Backoff +<---------------+ -// +---+----+ | | | -// ^ | +-----+-----+ -// | v expire | -// +---------------+------------------+ -// -package seek diff --git a/lib/reversetunnel/seek/seek.go b/lib/reversetunnel/seek/seek.go deleted file mode 100644 index 941f835b79ebb..0000000000000 --- a/lib/reversetunnel/seek/seek.go +++ /dev/null @@ -1,453 +0,0 @@ -/* -Copyright 2019 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package seek - -import ( - "context" - "strings" - "sync" - "time" - - "github.com/gravitational/teleport/lib/utils" - - "github.com/gravitational/trace" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" -) - -var connectedGauge = prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "reversetunnel_connected_proxies", - Help: "Number of known proxies being sought.", - }, -) - -func init() { - prometheus.MustRegister(connectedGauge) -} - -// Key uniquely identifies a seek group -type Key struct { - Cluster string - Type string - Addr utils.NetAddr -} - -// Config describes the various parameters related to a seek operation -type Config struct { - // TickRate defines the maximum amount of time between expiry & seek checks. - // Shorter tick rates reduce discovery delay. Longer tick rates reduce resource - // consumption (default: ~4s). - TickRate time.Duration - // EntryExpiry defines how long a seeker entry should be held without successfully - // establishing a healthy connection. This value should be reasonably long - // (default: 3m). - EntryExpiry time.Duration - // BackoffInterval defines the basline backoff amount observed by seekers. This value - // should be reasonably short (default: 256ms) - BackoffInterval time.Duration - // BackoffThreshold defines the minimum amount of time that a connection is expected to last - // if the conencted peer is generally healthy. Connections which fail before BackoffThreshold - // cause the seekstate to enter backoff (default: 30s) - BackoffThreshold time.Duration -} - -func (s *Config) Check() error { - if s.TickRate < time.Millisecond { - return trace.BadParameter("sub-millisecond tick-rate is not allowed") - } - if s.EntryExpiry <= 2*s.TickRate { - return trace.BadParameter("entry-expiry must be greater than 2x tick-rate") - } - if s.EntryExpiry <= s.BackoffInterval { - return trace.BadParameter("entry-expiry must be greater than backoff-interval") - } - if s.EntryExpiry <= s.BackoffThreshold { - return trace.BadParameter("entry-expiry must be greater than backoff-threshold") - } - return nil -} - -const ( - defaultTickRate = time.Millisecond * 4096 - defaultEntryExpriy = time.Minute * 3 - defaultBackoffInterval = time.Millisecond * 256 - defaultBackoffThreshold = time.Second * 30 -) - -func (s *Config) CheckAndSetDefaults() error { - const granularity = time.Millisecond - if s.TickRate < granularity { - s.TickRate = defaultTickRate - } - if s.EntryExpiry < granularity { - s.EntryExpiry = defaultEntryExpriy - } - if s.BackoffInterval < granularity { - s.BackoffInterval = defaultBackoffInterval - } - if s.BackoffThreshold < granularity { - s.BackoffThreshold = defaultBackoffThreshold - } - return s.Check() -} - -// Pool manages a collection of group-level seek operations. -type Pool struct { - m sync.Mutex - conf Config - groups map[Key]GroupHandle - seekC chan Key - ctx context.Context -} - -// NewPool configures a seek pool. -func NewPool(ctx context.Context, conf Config) (*Pool, error) { - if err := conf.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &Pool{ - conf: conf, - groups: make(map[Key]GroupHandle), - seekC: make(chan Key, 128), - ctx: ctx, - }, nil -} - -// Group gets a handle to the seek manager for the specified -// group. If none exists, one will be started. -func (p *Pool) Group(key Key) GroupHandle { - p.m.Lock() - defer p.m.Unlock() - if group, ok := p.groups[key]; ok { - return group - } - group := newGroupHandle(p.ctx, p.conf, p.seekC, key) - p.groups[key] = group - return group -} - -// Seek channel yields stream of keys indicating which groups -// are seeking. -func (p *Pool) Seek() <-chan Key { - return p.seekC -} - -// Stop stops one or more group-level seek operations -func (p *Pool) Stop(group Key, groups ...Key) { - p.m.Lock() - defer p.m.Unlock() - p.stopGroupHandle(group) - for _, g := range groups { - p.stopGroupHandle(g) - } -} - -// Shutdown stops all seek operations -func (p *Pool) Shutdown() { - p.m.Lock() - defer p.m.Unlock() - for g, _ := range p.groups { - p.stopGroupHandle(g) - } -} - -func (p *Pool) stopGroupHandle(key Key) { - group, ok := p.groups[key] - if !ok { - return - } - group.cancel() - delete(p.groups, key) -} - -// GroupHandle is a handle to an ongoing seek process. Each seek process -// manages a group of related proxies. This handle allows agents to -// claim exclusive "locks" for individual proxies and to broadcast -// gossip to the process. -type GroupHandle struct { - inner *proxyGroup - cancel context.CancelFunc - proxyC chan<- string - seekC <-chan Key - statC <-chan Status -} - -func newGroupHandle(ctx context.Context, conf Config, seekC chan Key, id Key) GroupHandle { - ctx, cancel := context.WithCancel(ctx) - proxyC := make(chan string, 128) - statC := make(chan Status, 1) - seekers := &proxyGroup{ - id: id, - conf: conf, - states: make(map[string]seeker), - proxyC: proxyC, - seekC: seekC, - statC: statC, - } - handle := GroupHandle{ - inner: seekers, - cancel: cancel, - proxyC: proxyC, - seekC: seekC, - statC: statC, - } - go seekers.run(ctx) - return handle -} - -// WithProxy is used to wrap the connection-handling logic of an agent, -// ensuring that it is run if and only if no other agent is already -// handling this proxy. -func (s *GroupHandle) WithProxy(do func(), principals ...string) (did bool) { - if !s.inner.TryAcquireProxy(principals...) { - return false - } - defer s.inner.ReleaseProxy(principals...) - connectedGauge.Inc() - defer connectedGauge.Dec() - do() - return true -} - -// Status channel is regularly updated with the most recent status -// value. Consuming status values is optional. -func (s *GroupHandle) Status() <-chan Status { - return s.statC -} - -// Gossip channel must be informed whenever a proxy's identity -// becomes known via gossip messages. -func (s *GroupHandle) Gossip() chan<- string { - return s.proxyC -} - -// proxyGroup manages all proxy seekers for a group. -type proxyGroup struct { - sync.Mutex - id Key - conf Config - states map[string]seeker - proxyC <-chan string - seekC chan<- Key - statC chan Status -} - -// run is the "main loop" for the seek process. -func (p *proxyGroup) run(ctx context.Context) { - const logInterval int = 8 - ticker := time.NewTicker(p.conf.TickRate) - defer ticker.Stop() - // supply initial status & seek notification. - p.notifyStatus(p.Tick()) - p.notifyShouldSeek() - var ticks int - for { - select { - case <-ticker.C: - stat := p.Tick() - p.notifyStatus(stat) - if stat.ShouldSeek() { - p.notifyShouldSeek() - } - ticks++ - if ticks%logInterval == 0 { - log.Debugf("SeekStates(states=%+v,id=%s)", p.GetStates(), p.id) - } - case proxy := <-p.proxyC: - proxies := []string{proxy} - // Collect any additional proxy messages - // without blocking. - Collect: - for { - select { - case pr := <-p.proxyC: - proxies = append(proxies, pr) - default: - break Collect - } - } - count := p.RefreshProxy(proxies...) - if count > 0 { - p.notifyShouldSeek() - } - case <-ctx.Done(): - return - } - } -} - -func (p *proxyGroup) Tick() Status { - p.Lock() - defer p.Unlock() - now := time.Now() - return p.tick(now) -} - -// RefreshProxy refreshes expiration timers, returning the number of -// successful refreshes. If the returned value is greater than zero, -// then at least one entry is unexpired and in `stateSeeking`. -// Entries are lazily created for previously unknown proxies. -func (p *proxyGroup) RefreshProxy(proxies ...string) int { - p.Lock() - defer p.Unlock() - now := time.Now() - var count int - for _, proxy := range proxies { - if p.refreshProxy(now, proxy) { - count++ - } - } - return count -} - -func (p *proxyGroup) refreshProxy(t time.Time, proxy string) (ok bool) { - s := p.states[proxy] - if s.transit(t, eventRefresh, &p.conf) { - p.states[proxy] = s - return true - } - return false -} - -// notifyShouldSeek sets the seek channel. -func (p *proxyGroup) notifyShouldSeek() { - select { - case p.seekC <- p.id: - default: - } -} - -// notifyStatus clears and sets the status channel. -func (p *proxyGroup) notifyStatus(s Status) { - select { - case <-p.statC: - default: - } - select { - case p.statC <- s: - default: - } -} - -// AcquireProxy attempts to acquire the specified proxy. -func (p *proxyGroup) TryAcquireProxy(principals ...string) (ok bool) { - p.Lock() - defer p.Unlock() - return p.acquireProxy(time.Now(), principals...) -} - -// ReleaseProxy attempts to release the specified proxy. -func (p *proxyGroup) ReleaseProxy(principals ...string) (ok bool) { - p.Lock() - defer p.Unlock() - return p.releaseProxy(time.Now(), principals...) -} - -func (p *proxyGroup) acquireProxy(t time.Time, principals ...string) (ok bool) { - if len(principals) < 1 { - return false - } - name := p.resolveName(principals) - s := p.states[name] - if !s.transit(t, eventAcquire, &p.conf) { - return false - } - p.states[name] = s - return true -} - -func (p *proxyGroup) releaseProxy(t time.Time, principals ...string) (ok bool) { - if len(principals) < 1 { - return false - } - name := p.resolveName(principals) - s := p.states[name] - if !s.transit(t, eventRelease, &p.conf) { - return false - } - if s.state == stateSeeking { - p.notifyShouldSeek() - } - p.states[name] = s - return true -} - -func (p *proxyGroup) resolveName(principals []string) string { - // check if we're already using one of these principals - for _, name := range principals { - if _, ok := p.states[name]; ok { - return name - } - } - // default to using the first principal - name := principals[0] - // if we have a `.` suffix, remove it. - if strings.HasSuffix(name, p.id.Cluster) { - t := strings.TrimSuffix(name, p.id.Cluster) - if strings.HasSuffix(t, ".") { - name = strings.TrimSuffix(t, ".") - } - } - return name -} - -func (p *proxyGroup) GetStates() map[string]seekState { - p.Lock() - defer p.Unlock() - collector := make(map[string]seekState, len(p.states)) - for proxy, s := range p.states { - collector[proxy] = s.state - } - return collector -} - -// tick ticks all proxy seek states, returning a summary -// status. This method also serves as the mechanism by which -// expired entries are removed. -func (p *proxyGroup) tick(t time.Time) Status { - var stat Status - for proxy, s := range p.states { - // if proxy seeker is in expirable state, handle - // the expiry. expired seekers are removed, and - // the soonest future expiry is recorded in stat. - if exp, ok := s.expiry(p.conf.EntryExpiry); ok { - if t.After(exp) { - delete(p.states, proxy) - continue - } - } - // poll and record state of proxy seeker. - switch state := s.tick(t, &p.conf); state { - case stateSeeking: - stat.Seeking++ - case stateClaimed: - stat.Claimed++ - case stateBackoff: - stat.Backoff++ - default: - // this should never happen... - log.Errorf("Proxy %s in invalid state %q, removing.", proxy, state) - delete(p.states, proxy) - continue - } - // seeker.tick may have affected an internal state - // transition, so update the entry. - p.states[proxy] = s - } - return stat -} diff --git a/lib/reversetunnel/seek/state.go b/lib/reversetunnel/seek/state.go deleted file mode 100644 index 45590d421858f..0000000000000 --- a/lib/reversetunnel/seek/state.go +++ /dev/null @@ -1,232 +0,0 @@ -/* -Copyright 2019 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package seek - -import ( - "fmt" - "time" - - log "github.com/sirupsen/logrus" -) - -// seekState represents the state of a seeker. -type seekState uint - -const ( - // stateSeeking indicates that we want to connect to this proxy - stateSeeking seekState = iota - // stateClaimed indicates that an agent has successfully claimed - // responsibility for this proxy - stateClaimed - // stateBackoff indicates that this proxy was claimed but that the - // agent responsible for it lost the connection prematurely - stateBackoff -) - -func (s seekState) String() string { - switch s { - case stateSeeking: - return "seeking" - case stateClaimed: - return "claimed" - case stateBackoff: - return "backoff" - default: - return fmt.Sprintf("unknown(%d)", s) - } -} - -// seekEvent represents an asynchronous event which may change -// the state of a seeker. -type seekEvent uint - -const ( - // eventRefresh indicates that a proxy has been indirectly observed (e.g. via gossip). - eventRefresh seekEvent = iota - // eventAcquire indicates that an agent has connected to this proxy and would like to - // take responsibility for it. - eventAcquire - // eventRelease indicates that the agent responsible for this proxy has lost its - // connection to it. - eventRelease -) - -func (s seekEvent) String() string { - switch s { - case eventRefresh: - return "refresh" - case eventAcquire: - return "acquire" - case eventRelease: - return "release" - default: - return fmt.Sprintf("unknown(%d)", s) - } -} - -// seeker manages the state associated with a proxy. -type seeker struct { - state seekState - at time.Time - backOff uint64 -} - -// transit attempts a state transition. If transit returns true, then -// a state-transition did occur. -func (s *seeker) transit(t time.Time, e seekEvent, c *Config) (ok bool) { - switch s.state { - case stateSeeking: - // stateSeeking can either transition to stateClaimed, or - // be "refreshed" in order to prevent expiration. - switch e { - case eventRefresh: - if t.After(s.at) { - s.at = t - return true - } else { - return false - } - case eventAcquire: - s.state = stateClaimed - s.at = t - return true - case eventRelease: - return false - default: - log.Errorf("Invalid event: %q", e) - return false - } - case stateClaimed: - // stateClaimed can either transition into stateSeeking or - // stateBackoff depending on whether the claim failed - // immediately, or after some period of normal operation. - switch e { - case eventRefresh, eventAcquire: - return false - case eventRelease: - // If the release event comes within the backoff threshold - // then we are potentially dealing with an unhealthy proxy. - // The backoff state serves as both backpressure and an - // an escape hatch, preventing infinite retry loops. - if s.shouldBackoff(t, c.BackoffThreshold) { - s.state = stateBackoff - s.backOff++ - } else { - s.state = stateSeeking - s.backOff = 0 - } - s.at = t - return true - default: - log.Errorf("Invalid event: %q", e) - return false - } - case stateBackoff: - // stateBackoff effectively "becomes" stateSeeking - // once the backoff period has been observed, so we - // either reject all transitions if still within the - // backoff period, or accept both stateSeeking and - // stateClaimed. - switch e { - case eventRefresh: - if !s.backoffPassed(t, c.BackoffInterval) { - return false - } - s.state = stateSeeking - s.at = t - return true - case eventAcquire: - if !s.backoffPassed(t, c.BackoffInterval) { - return false - } - s.state = stateClaimed - s.at = t - return true - case eventRelease: - return false - default: - log.Errorf("Invalid event: %q", e) - return false - } - default: - log.Errorf("Invalid state: %q", s.state) - return false - } -} - -func (s *seeker) backoffPassed(t time.Time, interval time.Duration) bool { - end := s.at.Add(interval * time.Duration(s.backOff)) - return t.After(end) -} - -func (s *seeker) shouldBackoff(t time.Time, threshold time.Duration) bool { - cutoff := s.at.Add(threshold) - return cutoff.After(t) -} - -// expiry calculates the time at which this entry will expire, -// if it should be expired at all. -func (s *seeker) expiry(ttl time.Duration) (exp time.Time, ok bool) { - switch s.state { - case stateSeeking, stateBackoff: - // calculate normal expiry - exp = s.at.Add(ttl) - ok = true - case stateClaimed: - // wait for release - ok = false - default: - // invalid, expire entry immediately - ok = true - } - return -} - -// tick calculates the current state, possibly affecting -// a time-based transition. -func (s *seeker) tick(t time.Time, c *Config) seekState { - if s.state == stateBackoff && s.backoffPassed(t, c.BackoffInterval) { - s.state = stateSeeking - } - return s.state -} - -// Status is a summary of the status of a collection -// of proxy seek states. -type Status struct { - Seeking int - Claimed int - Backoff int -} - -// ShouldSeek checks if we should be seeking connections. -func (s *Status) ShouldSeek() bool { - // if we are seeking specific proxies, or we don't currently - // have any proxies, then we should seek proxies. - if s.Seeking > 0 || s.Claimed < 1 { - return true - } - return false -} - -// Sum returns the sum of all known proxies. -func (s *Status) Sum() int { - if s == nil { - return 0 - } - return s.Seeking + s.Claimed + s.Backoff -} diff --git a/lib/reversetunnel/track/doc.go b/lib/reversetunnel/track/doc.go new file mode 100644 index 0000000000000..32c07134a7f76 --- /dev/null +++ b/lib/reversetunnel/track/doc.go @@ -0,0 +1,21 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package track provides a simple interface for tracking known proxies by +// endpoint/name and correctly handling expiration and exclusivity. +// The provided Tracker type wraps a workpool.Pool, updating per-key +// counts as new proxies are discovered and/or old proxies are expired. +package track diff --git a/lib/reversetunnel/track/tracker.go b/lib/reversetunnel/track/tracker.go new file mode 100644 index 0000000000000..63cad4c473a48 --- /dev/null +++ b/lib/reversetunnel/track/tracker.go @@ -0,0 +1,297 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package track + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/workpool" +) + +type Lease = workpool.Lease + +// Key uniquely identifies a reversetunnel endpoint. +type Key struct { + Cluster string + Type string + Addr utils.NetAddr +} + +// Config configures basic Tracker parameters. +type Config struct { + // ProxyExpiry is the duration an entry will be held sice the last + // successful connection to, or message about, a given proxy. + ProxyExpiry time.Duration + // TickRate is the rate at which expired entries are cleared from + // the cache of known proxies. + TickRate time.Duration +} + +// SetDefaults set default values for Config. +func (c *Config) SetDefaults() { + if c.ProxyExpiry < 1 { + c.ProxyExpiry = 3 * time.Minute + } + if c.TickRate < 1 { + c.TickRate = 30 * time.Second + } +} + +// Tracker is a helper for tracking proxies located behind reverse tunnels +// and triggering agent spawning as needed. Tracker wraps a workpool.Pool +// instance and manages a cache of proxies which *may* exist. As proxies are +// discovered, or old proxies expire, the target counts are automatically updated +// for the associated key in the workpool. Agents can attempt to "claim" +// exclusivity for a given proxy, ensuring that multiple agents are not run +// against the same proxy. +type Tracker struct { + Config + mu sync.Mutex + wp *workpool.Pool + sets map[Key]*proxySet + cancel context.CancelFunc +} + +// New configures a new tracker instance. +func New(ctx context.Context, c Config) *Tracker { + ctx, cancel := context.WithCancel(ctx) + c.SetDefaults() + t := &Tracker{ + Config: c, + wp: workpool.NewPool(ctx), + sets: make(map[Key]*proxySet), + cancel: cancel, + } + go t.run(ctx) + return t +} + +func (t *Tracker) run(ctx context.Context) { + ticker := time.NewTicker(t.TickRate) + defer ticker.Stop() + for { + select { + case <-ticker.C: + t.tick() + case <-ctx.Done(): + return + } + } +} + +// Acquire grants access to the Acquire channel of the +// embedded work group. +func (p *Tracker) Acquire() <-chan Lease { + return p.wp.Acquire() +} + +// TrackExpected starts/refreshes tracking for expected proxies. Called by +// agents when gossip messages are received. +func (p *Tracker) TrackExpected(lease Lease, proxies ...string) { + p.mu.Lock() + defer p.mu.Unlock() + key := lease.Key().(Key) + set, ok := p.sets[key] + if !ok { + return + } + t := time.Now() + for _, name := range proxies { + set.markSeen(t, name) + } + count := len(set.proxies) + if count < 1 { + count = 1 + } + p.wp.Set(key, uint64(count)) +} + +// Start starts tracking for specified key. +func (p *Tracker) Start(key Key) { + p.mu.Lock() + defer p.mu.Unlock() + p.getOrCreate(key) +} + +// Stop stops tracking for specified key. +func (p *Tracker) Stop(key Key) { + p.mu.Lock() + defer p.mu.Unlock() + if _, ok := p.sets[key]; !ok { + return + } + delete(p.sets, key) + p.wp.Set(key, 0) +} + +// StopAll permanently deactivates this tracker and cleans +// up all background goroutines. +func (p *Tracker) StopAll() { + p.cancel() +} + +func (p *Tracker) tick() { + p.mu.Lock() + defer p.mu.Unlock() + cutoff := time.Now().Add(-1 * p.ProxyExpiry) + for key, set := range p.sets { + if set.expire(cutoff) > 0 { + count := len(set.proxies) + if count < 1 { + count = 1 + } + p.wp.Set(key, uint64(count)) + } + } +} + +func (p *Tracker) getOrCreate(key Key) *proxySet { + if s, ok := p.sets[key]; ok { + return s + } + set := newProxySet(key) + p.sets[key] = set + p.wp.Set(key, 1) + return set +} + +// WithProxy runs the supplied closure if and only if +// no other work is currently being done with the proxy +// identified by principals. +func (p *Tracker) WithProxy(work func(), lease Lease, principals ...string) (didWork bool) { + key := lease.Key().(Key) + if ok := p.claim(key, principals...); !ok { + return false + } + defer p.unclaim(key, principals...) + work() + return true +} + +func (p *Tracker) claim(key Key, principals ...string) (ok bool) { + p.mu.Lock() + defer p.mu.Unlock() + set, ok := p.sets[key] + if !ok { + return false + } + return set.claim(principals...) +} + +func (p *Tracker) unclaim(key Key, principals ...string) { + p.mu.Lock() + defer p.mu.Unlock() + set, ok := p.sets[key] + if !ok { + return + } + set.unclaim(principals...) +} + +type entry struct { + lastSeen time.Time + claimed bool +} + +func newProxySet(key Key) *proxySet { + return &proxySet{ + key: key, + proxies: make(map[string]entry), + } +} + +type proxySet struct { + key Key + proxies map[string]entry +} + +func (p *proxySet) claim(principals ...string) (ok bool) { + proxy := p.resolveName(principals) + e, ok := p.proxies[proxy] + if !ok { + p.proxies[proxy] = entry{ + claimed: true, + } + return true + } + if e.claimed { + return false + } + e.claimed = true + p.proxies[proxy] = e + return true +} + +func (p *proxySet) unclaim(principals ...string) { + proxy := p.resolveName(principals) + p.proxies[proxy] = entry{ + lastSeen: time.Now(), + } +} + +func (p *proxySet) markSeen(t time.Time, proxy string) { + e, ok := p.proxies[proxy] + if !ok { + p.proxies[proxy] = entry{ + lastSeen: t, + } + return + } + if e.lastSeen.After(t) { + return + } + e.lastSeen = t + p.proxies[proxy] = e +} + +func (p *proxySet) expire(cutoff time.Time) (removed int) { + for name, entry := range p.proxies { + if entry.claimed { + continue + } + if entry.lastSeen.Before(cutoff) { + delete(p.proxies, name) + removed++ + } + } + return +} + +// resolveName tries to extract the UUID of the proxy as that's the +// only unique identifier in the list of principals. +func (p *proxySet) resolveName(principals []string) string { + // check if we're already using one of these principals. + for _, name := range principals { + if _, ok := p.proxies[name]; ok { + return name + } + } + // default to using the first principal + name := principals[0] + // if we have a `.` suffix, remove it. + if strings.HasSuffix(name, p.key.Cluster) { + t := strings.TrimSuffix(name, p.key.Cluster) + if strings.HasSuffix(t, ".") { + name = strings.TrimSuffix(t, ".") + } + } + return name +} diff --git a/lib/reversetunnel/seek/seek_test.go b/lib/reversetunnel/track/tracker_test.go similarity index 59% rename from lib/reversetunnel/seek/seek_test.go rename to lib/reversetunnel/track/tracker_test.go index 88622d89cfa3c..93eb600bf0044 100644 --- a/lib/reversetunnel/seek/seek_test.go +++ b/lib/reversetunnel/track/tracker_test.go @@ -1,5 +1,5 @@ /* -Copyright 2019 Gravitational, Inc. +Copyright 2020 Gravitational, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package seek +package track import ( "context" @@ -77,26 +77,27 @@ func (s *simpleTestProxies) GetRandProxy() (p testProxy, ok bool) { return s.proxies[i], true } -func (s *simpleTestProxies) Discover(handle GroupHandle) (ok bool) { +func (s *simpleTestProxies) Discover(tracker *Tracker, lease Lease) (ok bool) { + defer lease.Release() proxy, ok := s.GetRandProxy() if !ok { panic("discovery called with no available proxies") } timeout := time.After(proxy.life) - ok = handle.WithProxy(func() { - ticker := time.NewTicker(jitter(time.Millisecond * 256)) + ok = tracker.WithProxy(func() { + ticker := time.NewTicker(jitter(time.Millisecond * 100)) Loop: for { select { case <-ticker.C: if p, ok := s.GetRandProxy(); ok { - handle.Gossip() <- p.principals[0] + tracker.TrackExpected(lease, p.principals[0]) } case <-timeout: break Loop } } - }, proxy.principals...) + }, lease, proxy.principals...) return } @@ -136,40 +137,31 @@ type StateSuite struct{} var _ = check.Suite(&StateSuite{}) -func (s *StateSuite) TestBasicHealthy(c *check.C) { - s.runBasicProxyTest(c, time.Second*16, false) +func (s *StateSuite) TestBasic(c *check.C) { + s.runBasicProxyTest(c, time.Second*16) } -func (s *StateSuite) TestBasicUnhealthy(c *check.C) { - s.runBasicProxyTest(c, time.Second*16, true) -} - -func (s *StateSuite) runBasicProxyTest(c *check.C, timeout time.Duration, allowUnhealthy bool) { +func (s *StateSuite) runBasicProxyTest(c *check.C, timeout time.Duration) { const proxyCount = 16 timeoutC := time.After(timeout) - conf := newConfigOK(jitter(time.Millisecond * 512)) - pool, err := NewPool(context.TODO(), conf) - c.Assert(err, check.IsNil) - defer pool.Shutdown() - handle := pool.Group(Key{Cluster: "test-cluster"}) + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + tracker := New(context.TODO(), Config{}) + defer tracker.StopAll() + key := Key{Cluster: "test-cluster"} + tracker.Start(key) min, max := time.Duration(0), timeout - if !allowUnhealthy { - min = conf.BackoffThreshold - } - if max <= min { - min = timeout - max = timeout + time.Millisecond - } var proxies simpleTestProxies proxies.AddRandProxies(proxyCount, min, max) Discover: for { select { - case <-pool.Seek(): - go proxies.Discover(handle) - case status := <-handle.Status(): - c.Logf("Status: %+v", status) - if status.Sum() == proxyCount { + case lease := <-tracker.Acquire(): + go proxies.Discover(tracker, lease) + case <-ticker.C: + counts := tracker.wp.Get(key) + c.Logf("Counts: %+v", counts) + if counts.Active == proxyCount { break Discover } case <-timeoutC: @@ -187,23 +179,25 @@ func (s *StateSuite) TestFullRotation(c *check.C) { maxConnB = time.Second * 25 timeout = time.Second * 30 ) + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() var proxies simpleTestProxies proxies.AddRandProxies(proxyCount, minConnA, maxConnA) - conf := newConfigOK(jitter(time.Millisecond * 128)) - pool, err := NewPool(context.TODO(), conf) - c.Assert(err, check.IsNil) - defer pool.Shutdown() - handle := pool.Group(Key{Cluster: "test-cluster"}) + tracker := New(context.TODO(), Config{}) + defer tracker.StopAll() + key := Key{Cluster: "test-cluster"} + tracker.Start(key) timeoutC := time.After(timeout) Loop0: for { select { - case key := <-pool.Seek(): - c.Assert(key, check.DeepEquals, Key{Cluster: "test-cluster"}) - go proxies.Discover(handle) - case status := <-handle.Status(): - c.Logf("Status0: %+v", status) - if status.Sum() == proxyCount { + case lease := <-tracker.Acquire(): + c.Assert(lease.Key().(Key), check.DeepEquals, key) + go proxies.Discover(tracker, lease) + case <-ticker.C: + counts := tracker.wp.Get(key) + c.Logf("Counts0: %+v", counts) + if counts.Active == proxyCount { break Loop0 } case <-timeoutC: @@ -214,9 +208,10 @@ Loop0: Loop1: for { select { - case status := <-handle.Status(): - c.Logf("Status1: %+v", status) - if status.Claimed < 1 { + case <-ticker.C: + counts := tracker.wp.Get(key) + c.Logf("Counts1: %+v", counts) + if counts.Active < 1 { break Loop1 } case <-timeoutC: @@ -227,11 +222,12 @@ Loop1: Loop2: for { select { - case <-pool.Seek(): - go proxies.Discover(handle) - case status := <-handle.Status(): - c.Logf("Status2: %+v", status) - if status.Claimed >= proxyCount { + case lease := <-tracker.Acquire(): + go proxies.Discover(tracker, lease) + case <-ticker.C: + counts := tracker.wp.Get(key) + c.Logf("Counts2: %+v", counts) + if counts.Active >= proxyCount { break Loop2 } case <-timeoutC: @@ -242,29 +238,31 @@ Loop2: // TestUUIDHandling verifies that host UUIDs are correctly extracted // from the expected teleport principal format, and that gossip messages -// consisting only of uuid don't create duplicate seek entries. +// consisting only of uuid don't create duplicate entries. func (s *StateSuite) TestUUIDHandling(c *check.C) { ctx, cancel := context.WithTimeout(context.TODO(), time.Second*6) defer cancel() - conf := newConfigOK(time.Millisecond * 512) - pool, err := NewPool(ctx, conf) - c.Assert(err, check.IsNil) + ticker := time.NewTicker(time.Millisecond * 10) + defer ticker.Stop() + tracker := New(ctx, Config{}) + defer tracker.StopAll() key := Key{Cluster: "test-cluster"} - handle := pool.Group(key) - + tracker.Start(key) + lease := <-tracker.Acquire() // claim a proxy using principal of the form . - go handle.WithProxy(func() { + go tracker.WithProxy(func() { c.Logf("Successfully claimed proxy") <-ctx.Done() - }, "my-proxy.test-cluster") + }, lease, "my-proxy.test-cluster") // Wait for proxy to be claimed Wait: for { select { - case status := <-handle.Status(): - c.Logf("Status: %+v", status) - if !status.ShouldSeek() { + case <-ticker.C: + counts := tracker.wp.Get(key) + c.Logf("Counts: %+v", counts) + if counts.Active == counts.Target { break Wait } case <-ctx.Done(): @@ -273,7 +271,7 @@ Wait: } // Send a gossip message containing host UUID only - handle.Gossip() <- "my-proxy" + tracker.TrackExpected(lease, "my-proxy") c.Logf("Sent uuid-only gossip message; watching status...") // Let pool go through a few ticks, monitoring status to ensure that @@ -282,53 +280,14 @@ Wait: // message). for i := 0; i < 3; i++ { select { - case status := <-handle.Status(): - c.Logf("Status: %+v", status) - if status.ShouldSeek() { - c.Errorf("pool incorrectly entered seek mode") + case <-ticker.C: + counts := tracker.wp.Get(key) + c.Logf("Counts: %+v", counts) + if counts.Active != counts.Target { + c.Errorf("incorrectly entered seek mode") } case <-ctx.Done(): c.Errorf("timeout") } } } - -func (s *StateSuite) BenchmarkBasicSeek(c *check.C) { - const proxyCount = 32 - var proxies simpleTestProxies - proxies.AddRandProxies(proxyCount, time.Second*16, time.Second*32) - conf := newConfigOK(time.Millisecond * 512) - pool, err := NewPool(context.TODO(), conf) - c.Assert(err, check.IsNil) - defer pool.Shutdown() - for i := 0; i < c.N; i++ { - key := Key{Cluster: fmt.Sprintf("cluster-%d", i)} - handle := pool.Group(key) - Discover: - for { - select { - case <-pool.Seek(): - go proxies.Discover(handle) - case status := <-handle.Status(): - c.Logf("Status: %+v", status) - if status.Sum() == proxyCount { - break Discover - } - } - } - pool.Stop(key) - } -} - -func newConfigOK(tickRate time.Duration) Config { - conf := Config{ - TickRate: tickRate, - EntryExpiry: tickRate * 180, - BackoffInterval: tickRate / 4, - BackoffThreshold: tickRate * 30, - } - if err := conf.Check(); err != nil { - panic(err) - } - return conf -} diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index b2aedab1670c6..b719ab66c5b6e 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -298,12 +298,11 @@ func (p *transport) start() { errorCh <- err }() + // wait for both io.Copy goroutines to finish, or for + // the context to be canceled. for i := 0; i < 2; i++ { select { - case err := <-errorCh: - if err != nil && err != io.EOF { - p.log.Warnf("Proxy transport failed: %v %T.", trace.DebugReport(err), err) - } + case <-errorCh: case <-p.closeContext.Done(): p.log.Warnf("Proxy transport failed: closing context.") return diff --git a/lib/services/authority.go b/lib/services/authority.go index dafe46f14facc..5b17f25f57fce 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -160,8 +160,8 @@ func UnmarshalCertRoles(data string) ([]string, error) { // CertAuthority is a host or user certificate authority that // can check and if it has private key stored as well, sign it too type CertAuthority interface { - // Resource sets common resource properties - Resource + // ResourceWithSecrets sets common resource properties + ResourceWithSecrets // GetID returns certificate authority ID - // combined type and name GetID() CertAuthID @@ -403,6 +403,13 @@ func (c *CertAuthorityV2) SetResourceID(id int64) { c.Metadata.ID = id } +// WithoutSecrets returns an instance of resource without secrets. +func (c *CertAuthorityV2) WithoutSecrets() Resource { + c2 := c.Clone() + RemoveCASecrets(c2) + return c2 +} + // V2 returns V2 version of the resouirce - itself func (c *CertAuthorityV2) V2() *CertAuthorityV2 { return c diff --git a/lib/services/events.go b/lib/services/events.go index 326a60b675a94..18ff6dee287e0 100644 --- a/lib/services/events.go +++ b/lib/services/events.go @@ -20,6 +20,8 @@ import ( "context" "github.com/gravitational/teleport/lib/backend" + + "github.com/gravitational/trace" ) // Watch sets up watch on the event @@ -53,6 +55,36 @@ type WatchKind struct { Filter map[string]string } +// Matches attempts to determine if the supplied event matches +// this WatchKind. If the WatchKind is misconfigured, or the +// event appears malformed, an error is returned. +func (kind WatchKind) Matches(e Event) (bool, error) { + if kind.Kind != e.Resource.GetKind() { + return false, nil + } + if kind.Name != "" && kind.Name != e.Resource.GetName() { + return false, nil + } + if len(kind.Filter) > 0 { + // no filters currently match delete events + if e.Type != backend.OpPut { + return false, nil + } + // Currently only access request make use of filters, + // so expect the resource to be an access request. + req, ok := e.Resource.(AccessRequest) + if !ok { + return false, trace.BadParameter("unfilterable resource type: %T", e.Resource) + } + var filter AccessRequestFilter + if err := filter.FromMap(kind.Filter); err != nil { + return false, trace.Wrap(err) + } + return filter.Match(req), nil + } + return true, nil +} + // Event represents an event that happened in the backend type Event struct { // Type is the event type diff --git a/lib/services/fanout.go b/lib/services/fanout.go new file mode 100644 index 0000000000000..bfb7f258d7df9 --- /dev/null +++ b/lib/services/fanout.go @@ -0,0 +1,217 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package services + +import ( + "context" + "sync" + + "github.com/gravitational/teleport/lib/backend" + + "github.com/gravitational/trace" +) + +const defaultQueueSize = 64 + +type fanoutEntry struct { + kind WatchKind + watcher *fanoutWatcher +} + +// Fanout is a helper which allows a stream of events to be fanned-out to many +// watchers. Used by the cache layer to forward events. +type Fanout struct { + mu sync.Mutex + watchers map[string][]fanoutEntry +} + +// NewFanout creates a new Fanout instance. +func NewFanout() *Fanout { + return &Fanout{ + watchers: make(map[string][]fanoutEntry), + } +} + +// NewWatcher attaches a new watcher to this fanout instance. +func (f *Fanout) NewWatcher(ctx context.Context, watch Watch) (Watcher, error) { + f.mu.Lock() + defer f.mu.Unlock() + w, err := newFanoutWatcher(ctx, watch) + if err != nil { + return nil, trace.Wrap(err) + } + if err := w.emit(Event{Type: backend.OpInit}); err != nil { + w.cancel() + return nil, trace.Wrap(err) + } + f.addWatcher(w) + return w, nil +} + +func filterEventSecrets(event Event) Event { + r, ok := event.Resource.(ResourceWithSecrets) + if !ok { + return event + } + event.Resource = r.WithoutSecrets() + return event +} + +// Emit broadcasts events to all matching watchers that have been attached +// to this fanout instance. +func (f *Fanout) Emit(events ...Event) { + f.mu.Lock() + defer f.mu.Unlock() + for _, fullEvent := range events { + // by default, we operate on a version of the event which + // has had secrets filtered out. + event := filterEventSecrets(fullEvent) + var remove []*fanoutWatcher + Inner: + for _, entry := range f.watchers[event.Resource.GetKind()] { + match, err := entry.kind.Matches(event) + if err != nil { + entry.watcher.setError(err) + remove = append(remove, entry.watcher) + continue Inner + } + if !match { + continue Inner + } + emitEvent := event + // if this entry loads secrets, emit the + // full unfiltered event. + if entry.kind.LoadSecrets { + emitEvent = fullEvent + } + if err := entry.watcher.emit(emitEvent); err != nil { + remove = append(remove, entry.watcher) + continue Inner + } + } + for _, w := range remove { + f.removeWatcher(w) + w.cancel() + } + } +} + +// CloseWatchers closes all attached watchers, effectively +// resetting the Fanout instance. +func (f *Fanout) CloseWatchers() { + f.mu.Lock() + defer f.mu.Unlock() + for _, entries := range f.watchers { + for _, entry := range entries { + entry.watcher.cancel() + } + } + // watcher map was potentially quite large, so + // relenguish that memory. + f.watchers = make(map[string][]fanoutEntry) +} + +func (f *Fanout) addWatcher(w *fanoutWatcher) { + for _, kind := range w.watch.Kinds { + entries := f.watchers[kind.Kind] + entries = append(entries, fanoutEntry{ + kind: kind, + watcher: w, + }) + f.watchers[kind.Kind] = entries + } +} + +func (f *Fanout) removeWatcher(w *fanoutWatcher) { + for _, kind := range w.watch.Kinds { + entries := f.watchers[kind.Kind] + Inner: + for i, entry := range entries { + if entry.watcher == w { + entries = append(entries[:i], entries[i+1:]...) + break Inner + } + } + switch len(entries) { + case 0: + delete(f.watchers, kind.Kind) + default: + f.watchers[kind.Kind] = entries + } + } +} + +func newFanoutWatcher(ctx context.Context, watch Watch) (*fanoutWatcher, error) { + if len(watch.Kinds) < 1 { + return nil, trace.BadParameter("must specify at least one resource kind to watch") + } + ctx, cancel := context.WithCancel(ctx) + if watch.QueueSize < 1 { + watch.QueueSize = defaultQueueSize + } + return &fanoutWatcher{ + watch: watch, + eventC: make(chan Event, watch.QueueSize), + cancel: cancel, + ctx: ctx, + }, nil +} + +type fanoutWatcher struct { + emux sync.Mutex + err error + watch Watch + eventC chan Event + cancel context.CancelFunc + ctx context.Context +} + +func (w *fanoutWatcher) emit(event Event) error { + select { + case <-w.ctx.Done(): + return trace.Wrap(w.ctx.Err(), "watcher closed") + case w.eventC <- event: + return nil + default: + return trace.BadParameter("buffer overflow") + } +} + +func (w *fanoutWatcher) Events() <-chan Event { + return w.eventC +} + +func (w *fanoutWatcher) Done() <-chan struct{} { + return w.ctx.Done() +} + +func (w *fanoutWatcher) Close() error { + w.cancel() + return nil +} + +func (w *fanoutWatcher) setError(err error) { + w.emux.Lock() + defer w.emux.Unlock() + w.err = err +} + +func (w *fanoutWatcher) Error() error { + w.emux.Lock() + defer w.emux.Unlock() + return w.err +} diff --git a/lib/services/github.go b/lib/services/github.go index b3c7b54f9a489..1cb07b9fdeb0f 100644 --- a/lib/services/github.go +++ b/lib/services/github.go @@ -30,8 +30,8 @@ import ( // GithubConnector defines an interface for a Github OAuth2 connector type GithubConnector interface { - // Resource is a common interface for all resources - Resource + // ResourceWithSecrets is a common interface for all resources + ResourceWithSecrets // CheckAndSetDefaults validates the connector and sets some defaults CheckAndSetDefaults() error // GetClientID returns the connector client ID @@ -183,6 +183,16 @@ func (c *GithubConnectorV3) GetMetadata() Metadata { return c.Metadata } +// WithoutSecrets returns an instance of resource without secrets. +func (c *GithubConnectorV3) WithoutSecrets() Resource { + if c.GetClientSecret() == "" { + return c + } + c2 := *c + c2.SetClientSecret("") + return &c2 +} + // CheckAndSetDefaults verifies the connector is valid and sets some defaults func (c *GithubConnectorV3) CheckAndSetDefaults() error { if err := c.Metadata.CheckAndSetDefaults(); err != nil { diff --git a/lib/services/oidc.go b/lib/services/oidc.go index 28c2cdbe5a738..a68d432cc9aae 100644 --- a/lib/services/oidc.go +++ b/lib/services/oidc.go @@ -35,8 +35,8 @@ import ( // OIDCConnector specifies configuration for Open ID Connect compatible external // identity provider, e.g. google in some organisation type OIDCConnector interface { - // Resource provides common methods for objects - Resource + // ResourceWithSecrets provides common methods for objects + ResourceWithSecrets // Issuer URL is the endpoint of the provider, e.g. https://accounts.google.com GetIssuerURL() string // ClientID is id for authentication client (in our case it's our Auth server) @@ -297,6 +297,16 @@ func (o *OIDCConnectorV2) SetResourceID(id int64) { o.Metadata.ID = id } +// WithoutSecrets returns an instance of resource without secrets. +func (o *OIDCConnectorV2) WithoutSecrets() Resource { + if o.GetClientSecret() == "" { + return o + } + o2 := *o + o2.SetClientSecret("") + return &o2 +} + // V2 returns V2 version of the resource func (o *OIDCConnectorV2) V2() *OIDCConnectorV2 { return o diff --git a/lib/services/resource.go b/lib/services/resource.go index ff5f57d63270e..c2d20952469c9 100644 --- a/lib/services/resource.go +++ b/lib/services/resource.go @@ -642,6 +642,16 @@ type Resource interface { SetResourceID(int64) } +// ResourceWithSecrets includes additional properties which must +// be provided by resources which *may* contain secrets. +type ResourceWithSecrets interface { + Resource + // WithoutSecrets returns an instance of the resource which + // has had all secrets removed. If the current resource has + // already had its secrets removed, this may be a no-op. + WithoutSecrets() Resource +} + // GetID returns resource ID func (m *Metadata) GetID() int64 { return m.ID diff --git a/lib/services/saml.go b/lib/services/saml.go index 4528379c09baf..cf2088f586736 100644 --- a/lib/services/saml.go +++ b/lib/services/saml.go @@ -42,8 +42,8 @@ import ( // SAMLConnector specifies configuration for SAML 2.0 dentity providers type SAMLConnector interface { - // Resource provides common methods for objects - Resource + // ResourceWithSecrets provides common methods for objects + ResourceWithSecrets // GetDisplay returns display - friendly name for this provider. GetDisplay() string // SetDisplay sets friendly name for this provider. @@ -264,6 +264,19 @@ func (o *SAMLConnectorV2) SetResourceID(id int64) { o.Metadata.ID = id } +// WithoutSecrets returns an instance of resource without secrets. +func (o *SAMLConnectorV2) WithoutSecrets() Resource { + k := o.GetSigningKeyPair() + if k == nil { + return o + } + k2 := *k + k2.PrivateKey = "" + o2 := *o + o2.SetSigningKeyPair(&k2) + return &o2 +} + // GetServiceProviderIssuer returns service provider issuer func (o *SAMLConnectorV2) GetServiceProviderIssuer() string { return o.Spec.ServiceProviderIssuer diff --git a/lib/services/user.go b/lib/services/user.go index 9623407c90915..52189881ecd39 100644 --- a/lib/services/user.go +++ b/lib/services/user.go @@ -32,8 +32,8 @@ import ( // User represents teleport embedded user or external user type User interface { - // Resource provides common resource properties - Resource + // ResourceWithSecrets provides common resource properties + ResourceWithSecrets // GetOIDCIdentities returns a list of connected OIDC identities GetOIDCIdentities() []ExternalIdentity // GetSAMLIdentities returns a list of connected SAML identities @@ -217,6 +217,16 @@ func (u *UserV2) SetName(e string) { u.Metadata.Name = e } +// WithoutSecrets returns an instance of resource without secrets. +func (u *UserV2) WithoutSecrets() Resource { + if u.Spec.LocalAuth == nil { + return u + } + u2 := *u + u2.Spec.LocalAuth = nil + return &u2 +} + // WebSessionInfo returns web session information about user func (u *UserV2) WebSessionInfo(allowedLogins []string) interface{} { out := u.V1() diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 8aecafe06e0a0..221037d3c428c 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -656,6 +656,9 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) { }) c.Assert(err, IsNil) + err = agentPool.Start() + c.Assert(err, IsNil) + // Create a reverse tunnel and remote cluster simulating what the trusted // cluster exchange does. err = s.server.Auth().UpsertReverseTunnel( diff --git a/lib/utils/loadbalancer.go b/lib/utils/loadbalancer.go index ea5bdfc2236e2..af9a2c4c903c0 100644 --- a/lib/utils/loadbalancer.go +++ b/lib/utils/loadbalancer.go @@ -111,7 +111,7 @@ func (l *LoadBalancer) AddBackend(b NetAddr) { } // RemoveBackend removes backend -func (l *LoadBalancer) RemoveBackend(b NetAddr) { +func (l *LoadBalancer) RemoveBackend(b NetAddr) error { l.Lock() defer l.Unlock() l.currentIndex = -1 @@ -119,9 +119,10 @@ func (l *LoadBalancer) RemoveBackend(b NetAddr) { if l.backends[i].Equals(b) { l.backends = append(l.backends[:i], l.backends[i+1:]...) l.dropConnections(b) - return + return nil } } + return trace.NotFound("lb has no backend matching: %+v", b) } func (l *LoadBalancer) nextBackend() (*NetAddr, error) { diff --git a/lib/utils/retry.go b/lib/utils/retry.go index 1a4cf2b1bcd21..fc9187535a3b9 100644 --- a/lib/utils/retry.go +++ b/lib/utils/retry.go @@ -18,11 +18,35 @@ package utils import ( "fmt" + "math/rand" + "sync" "time" "github.com/gravitational/trace" ) +// Jitter is a function which applies random jitter to a +// duration. Used to randomize backoff values. Must be +// safe for concurrent usage. +type Jitter func(time.Duration) time.Duration + +// NewJitter returns the default jitter (currently jitters on +// the range [n/2,n), but this is subject to change). +func NewJitter() Jitter { + var mu sync.Mutex + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return func(d time.Duration) time.Duration { + // values less than 1 cause rng to panic, and some logic + // relies on treating zero duration as non-blocking case. + if d < 1 { + return 0 + } + mu.Lock() + defer mu.Unlock() + return (d / 2) + time.Duration(rng.Int63n(int64(d))/2) + } +} + // Retry is an interface that provides retry logic type Retry interface { // Reset resets retry state @@ -36,6 +60,9 @@ type Retry interface { // that fires after Duration delay, // could fire right away if Duration is 0 After() <-chan time.Time + // Clone creates a copy of this retry in a + // reset state. + Clone() Retry } // LinearConfig sets up retry configuration @@ -49,6 +76,13 @@ type LinearConfig struct { // Max is a maximum value of the progression, // can't be 0 Max time.Duration + // Jitter is an optional jitter function to be applied + // to the delay. Note that supplying a jitter means that + // successive calls to Duration may return different results. + Jitter Jitter + // AutoReset, if greater than zero, causes the linear retry to automatically + // reset after Max * AutoReset has elapsed since the last call to Incr. + AutoReset int64 } // CheckAndSetDefaults checks and sets defaults @@ -67,9 +101,15 @@ func NewLinear(cfg LinearConfig) (*Linear, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } + return newLinear(cfg), nil +} + +// newLinear creates an instance of Linear from a +// previously verified configuration. +func newLinear(cfg LinearConfig) *Linear { closedChan := make(chan time.Time) close(closedChan) - return &Linear{LinearConfig: cfg, closedChan: closedChan}, nil + return &Linear{LinearConfig: cfg, closedChan: closedChan} } // Linear is used to calculate retry period @@ -80,6 +120,7 @@ func NewLinear(cfg LinearConfig) (*Linear, error) { type Linear struct { // LinearConfig is a linear retry config LinearConfig + lastIncr time.Time attempt int64 closedChan chan time.Time } @@ -89,17 +130,43 @@ func (r *Linear) Reset() { r.attempt = 0 } +// Clone creates an identical copy of Linear with fresh state. +func (r *Linear) Clone() Retry { + return newLinear(r.LinearConfig) +} + // Inc increments attempt counter func (r *Linear) Inc() { r.attempt++ + if r.AutoReset < 1 { + // No AutoRest configured; we can skip + // everything else. + return + } + // when AutoReset is active, we track the time of the + // last call to Incr. If more than Max * AutoReset has + // elapsed, we reset state internally. This allows + // Linear to function like as a long-lived rate-limiting + // device. + prev := r.lastIncr + r.lastIncr = time.Now() + if prev.IsZero() { + return + } + if r.Max*time.Duration(r.AutoReset) < r.lastIncr.Sub(prev) { + r.Reset() + } } // Duration returns retry duration based on state func (r *Linear) Duration() time.Duration { a := r.First + time.Duration(r.attempt)*r.Step - if a < 0 { + if a < 1 { return 0 } + if r.Jitter != nil { + a = r.Jitter(a) + } if a <= r.Max { return a } @@ -110,10 +177,11 @@ func (r *Linear) Duration() time.Duration { // defined in Duration method, as a special case // if Duration is 0 returns a closed channel func (r *Linear) After() <-chan time.Time { - if r.Duration() == 0 { + d := r.Duration() + if d < 1 { return r.closedChan } - return time.After(r.Duration()) + return time.After(d) } // String returns user-friendly representation of the LinearPeriod diff --git a/lib/utils/workpool/doc.go b/lib/utils/workpool/doc.go new file mode 100644 index 0000000000000..56db9cacb4c43 --- /dev/null +++ b/lib/utils/workpool/doc.go @@ -0,0 +1,37 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package workpool provies the `Pool` type which functions as a means +// of managing the number of concurrent workers, +// grouped by key. You can think of this type as functioning +// like a collection of semaphores, except that multiple distinct resources may +// exist (distinguished by keys), and the target concurrent worker count may +// change at runtime. The basic usage pattern is as follows: +// +// 1. The desired number of workers for a given key is specified +// or updated via Pool.Set. +// +// 2. Workers are spawned as leases become available on Pool.Acquire. +// +// 3. Workers relenquish their leases when they finish their work +// by calling Lease.Release. +// +// 4. New leases become available as old leases are relenquished, or +// as the target concurrent lease count increases. +// +// This is a generalization of logic originally written to manage the number +// of concurrent reversetunnel agents per proxy endpoint. +package workpool diff --git a/lib/utils/workpool/workpool.go b/lib/utils/workpool/workpool.go new file mode 100644 index 0000000000000..f635a0db21f2e --- /dev/null +++ b/lib/utils/workpool/workpool.go @@ -0,0 +1,268 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package workpool + +import ( + "context" + "sync" + + "go.uber.org/atomic" +) + +// Pool manages a collection of work groups by key and is the primary means +// by which groups are managed. Each work group has an adjustable target value +// which is the number of target leases which should be active for the given +// group. +type Pool struct { + mu sync.Mutex + leaseIDs *atomic.Uint64 + groups map[interface{}]*group + grantC chan Lease + ctx context.Context + cancel context.CancelFunc +} + +func NewPool(ctx context.Context) *Pool { + ctx, cancel := context.WithCancel(ctx) + return &Pool{ + leaseIDs: atomic.NewUint64(0), + groups: make(map[interface{}]*group), + grantC: make(chan Lease), + ctx: ctx, + cancel: cancel, + } +} + +// Acquire is the channel which must be received on to acquire +// new leases. Each lease acquired in this way *must* have its +// Release method called when the lease is no longer needed. +func (p *Pool) Acquire() <-chan Lease { + return p.grantC +} + +// Done signals pool closure. +func (p *Pool) Done() <-chan struct{} { + return p.ctx.Done() +} + +// Get gets the current counts for the specified key. +func (p *Pool) Get(key interface{}) Counts { + p.mu.Lock() + defer p.mu.Unlock() + if g, ok := p.groups[key]; ok { + return g.loadCounts() + } + return Counts{} +} + +// Set sets the target for the specified key. +func (p *Pool) Set(key interface{}, target uint64) { + p.mu.Lock() + defer p.mu.Unlock() + if target < 1 { + p.del(key) + return + } + g, ok := p.groups[key] + if !ok { + p.start(key, target) + return + } + g.setTarget(target) +} + +// Start starts a new work group with the specified initial target. +// If Start returns false, the group already exists. +func (p *Pool) start(key interface{}, target uint64) { + ctx, cancel := context.WithCancel(p.ctx) + notifyC := make(chan struct{}, 1) + g := &group{ + counts: Counts{ + Active: 0, + Target: target, + }, + leaseIDs: p.leaseIDs, + key: key, + grantC: p.grantC, + notifyC: notifyC, + ctx: ctx, + cancel: cancel, + } + p.groups[key] = g + go g.run() +} + +func (p *Pool) del(key interface{}) (ok bool) { + group, ok := p.groups[key] + if !ok { + return false + } + group.cancel() + delete(p.groups, key) + return true +} + +// Stop permanently halts all associated groups. +func (p *Pool) Stop() { + p.cancel() +} + +// Counts holds the target and active counts for a +// key/group. +type Counts struct { + // Target is the number of active leases that we would + // like to converge toward. + Target uint64 + // Active is the current active lease count. + Active uint64 +} + +// group is a work group for a particular key in the pool. It tracks the number of +// active and target leases and adds leases when active drops below target. +type group struct { + cmu sync.Mutex + counts Counts + leaseIDs *atomic.Uint64 + key interface{} + grantC chan Lease + notifyC chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// notify ensures that group is in a notified state. +// if the group is already in a notified state, this +// method has no effect. This function should be called +// any time state is changed to ensure that the group's +// goroutine unblocks & handles the updated state. +func (g *group) notify() { + select { + case g.notifyC <- struct{}{}: + default: + } +} + +// loadCounts loads the current lease counts. +func (g *group) loadCounts() Counts { + g.cmu.Lock() + defer g.cmu.Unlock() + return g.counts +} + +// incrActive increases the active lease count by 1 +// and ensures group is in the notified state. +func (g *group) incrActive() Counts { + g.cmu.Lock() + defer g.cmu.Unlock() + g.counts.Active++ + g.notify() + return g.counts +} + +// decrActive decreases the active lease count by 1 +// and ensures group is in the notified state. +func (g *group) decrActive() Counts { + g.cmu.Lock() + defer g.cmu.Unlock() + g.counts.Active-- + g.notify() + return g.counts +} + +// setTarget sets the target lease count that the group should +// attempt to converge toward. +func (g *group) setTarget(target uint64) { + g.cmu.Lock() + defer g.cmu.Unlock() + g.counts.Target = target + g.notify() +} + +func (g *group) run() { + var counts Counts + var nextLease Lease + var grant chan Lease + for { + counts = g.loadCounts() + if counts.Active < counts.Target { + // we are in a "granting" state; ensure that the + // grant channel is non-nil, and initialize `nextLease` + // if it hasn't been already. + grant = g.grantC + if nextLease.id == 0 { + nextLease = newLease(g) + } + } else { + // we are not in a "granting" state, ensure that the + // grant channel is nil (prevents sends). + grant = nil + } + // if grant is nil, this select statement blocks until + // notify() is called, or the context is canceled. + select { + case grant <- nextLease: + g.incrActive() + nextLease = Lease{} + case <-g.notifyC: + case <-g.ctx.Done(): + return + } + } +} + +// Lease grants access to a resource or group. When the lease is received, +// work can begin. Leases are held by workers and must be released when +// the worker has finished its work. +type Lease struct { + *group + id uint64 + relOnce *sync.Once +} + +func newLease(group *group) Lease { + return Lease{ + group: group, + id: group.leaseIDs.Add(1), + relOnce: new(sync.Once), + } +} + +// ID returns the unique ID of this lease. +func (l Lease) ID() uint64 { + return l.id +} + +// Key returns the key that this lease is associated with. +func (l Lease) Key() interface{} { + return l.key +} + +// IsZero checks if this is the zero value of Lease. +func (l Lease) IsZero() bool { + return l == Lease{} +} + +// Release relenquishes this lease. Each lease is unique, +// so double-calling Release on the same Lease has no effect. +func (l Lease) Release() { + if l.IsZero() { + return + } + l.relOnce.Do(func() { + l.decrActive() + }) +} diff --git a/lib/utils/workpool/workpool_test.go b/lib/utils/workpool/workpool_test.go new file mode 100644 index 0000000000000..388d86ef45592 --- /dev/null +++ b/lib/utils/workpool/workpool_test.go @@ -0,0 +1,177 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package workpool + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "gopkg.in/check.v1" +) + +func Example() { + pool := NewPool(context.TODO()) + defer pool.Stop() + // create two keys with different target counts + pool.Set("spam", 2) + pool.Set("eggs", 1) + // track how many workers are spawned for each key + counts := make(map[string]int) + var mu sync.Mutex + var wg sync.WaitGroup + for i := 0; i < 12; i++ { + wg.Add(1) + go func() { + lease := <-pool.Acquire() + defer lease.Release() + mu.Lock() + counts[lease.Key().(string)]++ + mu.Unlock() + // in order to demonstrate the differing spawn rates we need + // work to take some time, otherwise pool will end up granting + // leases in a "round robin" fashion. + time.Sleep(time.Millisecond * 10) + wg.Done() + }() + } + wg.Wait() + // exact counts will vary, but leases with key `spam` + // will end up being generated approximately twice as + // often as leases with key `eggs`. + fmt.Println(counts["spam"] > counts["eggs"]) // Output: true +} + +func Test(t *testing.T) { + check.TestingT(t) +} + +type WorkSuite struct{} + +var _ = check.Suite(&WorkSuite{}) + +// TestFull runs a pool though a round of normal usage, +// and verifies expected state along the way: +// - A group of workers acquire leases, do some work, and release them. +// - A second group of workers receieve leases as the first group finishes. +// - The expected amout of leases are in play after this churn. +// - Updating the target lease count has the expected effect. +func (s *WorkSuite) TestFull(c *check.C) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + p := NewPool(ctx) + key := "some-key" + var wg sync.WaitGroup + // signal channel to cause the first group of workers to + // release their leases. + g1done := make(chan struct{}) + // timeout channel indicating all of group one should + // have acquired thier leases. + g1timeout := make(chan struct{}) + go func() { + time.Sleep(time.Millisecond * 500) + close(g1timeout) + }() + p.Set(key, 200) + // spawn first group of workers. + for i := 0; i < 200; i++ { + wg.Add(1) + go func() { + select { + case l := <-p.Acquire(): + <-g1done + l.Release() + case <-g1timeout: + c.Errorf("Timeout waiting for lease") + } + wg.Done() + }() + } + <-g1timeout + // no additional leases should exist + select { + case l := <-p.Acquire(): + c.Errorf("unexpected lease: %+v", l) + default: + } + // spawn a second group of workers that won't be able to + // acquire their leases until the first group is done. + for i := 0; i < 200; i++ { + wg.Add(1) + go func() { + select { + case <-p.Acquire(): + // leak deliberately + case <-time.After(time.Millisecond * 512): + c.Errorf("Timeout waiting for lease") + } + wg.Done() + }() + } + // signal first group is done + close(g1done) + // wait for second group to acquire leases. + wg.Wait() + // no additional leases should exist + select { + case l := <-p.Acquire(): + counts := l.loadCounts() + c.Errorf("unexpected lease grant: %+v, counts=%+v", l, counts) + case <-time.After(time.Millisecond * 128): + } + // make one additional lease available + p.Set(key, 201) + select { + case l := <-p.Acquire(): + c.Assert(l.Key().(string), check.Equals, key) + l.Release() + case <-time.After(time.Millisecond * 128): + c.Errorf("timeout waiting for lease grant") + } +} + +// TestZeroed varifies that a zeroed pool stops granting +// leases as expected. +func (s *WorkSuite) TestZeroed(c *check.C) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + p := NewPool(ctx) + key := "some-key" + p.Set(key, 1) + var l Lease + select { + case l = <-p.Acquire(): + c.Assert(l.Key().(string), check.Equals, key) + l.Release() + case <-time.After(time.Millisecond * 128): + c.Errorf("timeout waiting for lease grant") + } + p.Set(key, 0) + // modifications to counts are *ordered*, but asynchronous, + // so we could actually receieve a lease here if we don't sleep + // briefly. if we opted for condvars instead of channels, this + // issue could be avoided at the cost of more cumbersome + // composition/cancellation. + time.Sleep(time.Millisecond * 10) + select { + case l := <-p.Acquire(): + c.Errorf("unexpected lease grant: %+v", l) + case <-time.After(time.Millisecond * 128): + } +}