diff --git a/api/client/proto/authservice.pb.go b/api/client/proto/authservice.pb.go index e32b8491dae64..8e0e27b78c0fd 100644 --- a/api/client/proto/authservice.pb.go +++ b/api/client/proto/authservice.pb.go @@ -9900,7 +9900,7 @@ func (m *SessionTrackerRemoveParticipant) GetParticipantID() string { return "" } -// SessionTrackerUpdateExpiry is used to update the session tracker expirationt time. +// SessionTrackerUpdateExpiry is used to update the session tracker expiration time. type SessionTrackerUpdateExpiry struct { // Expires is when the session tracker will expire. Expires *time.Time `protobuf:"bytes,1,opt,name=Expires,proto3,stdtime" json:"expires"` diff --git a/api/client/proto/authservice.proto b/api/client/proto/authservice.proto index 827d03847010e..fe65c00ed9eb4 100644 --- a/api/client/proto/authservice.proto +++ b/api/client/proto/authservice.proto @@ -1593,7 +1593,7 @@ message SessionTrackerRemoveParticipant { string ParticipantID = 2 [ (gogoproto.jsontag) = "participant_id,omitempty" ]; } -// SessionTrackerUpdateExpiry is used to update the session tracker expirationt time. +// SessionTrackerUpdateExpiry is used to update the session tracker expiration time. message SessionTrackerUpdateExpiry { // Expires is when the session tracker will expire. google.protobuf.Timestamp Expires = 1 @@ -1641,7 +1641,10 @@ service AuthService { // CreateSessionTracker creates a new session tracker resource. rpc CreateSessionTracker(CreateSessionTrackerRequest) returns (types.SessionTrackerV1); - // GetSessionTrackerRequest fetches a session tracker resource. + // UpsertSessionTracker upserts a session tracker resource. + rpc UpsertSessionTracker(types.SessionTrackerV1) returns (google.protobuf.Empty); + + // GetSessionTracker fetches a session tracker resource. rpc GetSessionTracker(GetSessionTrackerRequest) returns (types.SessionTrackerV1); // GetActiveSessionTrackers returns a list of active sessions. diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index b173dd0eebba8..9f16cdf5b8d58 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -244,6 +244,16 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, req *proto.C return nil, trace.AccessDenied("this request can be only executed by a node, proxy or kube service") } + // Don't allow sessions that require moderation without the enterprise feature enabled. + for _, policySet := range req.HostPolicies { + if len(policySet.RequireSessionJoin) != 0 { + if !modules.GetModules().Features().ModeratedSessions { + return nil, trace.AccessDenied( + "this Teleport cluster is not licensed for moderated sessions, please contact the cluster administrator") + } + } + } + return a.authServer.CreateSessionTracker(ctx, req) } diff --git a/lib/events/complete.go b/lib/events/complete.go index c5c26cf442d26..6d6147b6717b2 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -22,8 +22,6 @@ import ( "time" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/defaults" @@ -314,38 +312,3 @@ loop: } return nil } - -type MockSessionTrackerService struct { - clock clockwork.Clock - mockTrackers []types.SessionTracker -} - -func (m *MockSessionTrackerService) GetActiveSessionTrackers(ctx context.Context) ([]types.SessionTracker, error) { - return nil, nil -} - -func (m *MockSessionTrackerService) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { - for _, tracker := range m.mockTrackers { - // mock session tracker expiration - if tracker.GetSessionID() == sessionID && tracker.Expiry().After(m.clock.Now()) { - return tracker, nil - } - } - return nil, trace.NotFound("tracker not found") -} - -func (m *MockSessionTrackerService) CreateSessionTracker(ctx context.Context, req *proto.CreateSessionTrackerRequest) (types.SessionTracker, error) { - return nil, nil -} - -func (m *MockSessionTrackerService) UpdateSessionTracker(ctx context.Context, req *proto.UpdateSessionTrackerRequest) error { - return nil -} - -func (m *MockSessionTrackerService) RemoveSessionTracker(ctx context.Context, sessionID string) error { - return nil -} - -func (m *MockSessionTrackerService) UpdatePresence(ctx context.Context, sessionID, user string) error { - return nil -} diff --git a/lib/events/complete_test.go b/lib/events/complete_test.go index de50b92896eb7..71c9a2ebbdc76 100644 --- a/lib/events/complete_test.go +++ b/lib/events/complete_test.go @@ -52,9 +52,9 @@ func TestUploadCompleterCompletesAbandonedUploads(t *testing.T) { }, } - sessionTrackerService := &MockSessionTrackerService{ - clock: clock, - mockTrackers: []types.SessionTracker{sessionTracker}, + sessionTrackerService := &eventstest.MockSessionTrackerService{ + Clock: clock, + MockTrackers: []types.SessionTracker{sessionTracker}, } uc, err := NewUploadCompleter(UploadCompleterConfig{ @@ -104,7 +104,7 @@ func TestUploadCompleterEmitsSessionEnd(t *testing.T) { Uploader: mu, AuditLog: log, Clock: clock, - SessionTracker: &MockSessionTrackerService{}, + SessionTracker: &eventstest.MockSessionTrackerService{}, }) require.NoError(t, err) diff --git a/lib/events/eventstest/mock.go b/lib/events/eventstest/mock.go index d5b254ddbb3fa..3a19b6585aeb4 100644 --- a/lib/events/eventstest/mock.go +++ b/lib/events/eventstest/mock.go @@ -20,8 +20,12 @@ import ( "context" "sync" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" ) // MockEmitter is an emitter that stores all emitted events. @@ -88,3 +92,38 @@ func (e *MockEmitter) Close(ctx context.Context) error { func (e *MockEmitter) Complete(ctx context.Context) error { return nil } + +type MockSessionTrackerService struct { + Clock clockwork.Clock + MockTrackers []types.SessionTracker +} + +func (m *MockSessionTrackerService) GetActiveSessionTrackers(ctx context.Context) ([]types.SessionTracker, error) { + return nil, nil +} + +func (m *MockSessionTrackerService) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { + for _, tracker := range m.MockTrackers { + // mock session tracker expiration + if tracker.GetSessionID() == sessionID && tracker.Expiry().After(m.Clock.Now()) { + return tracker, nil + } + } + return nil, trace.NotFound("tracker not found") +} + +func (m *MockSessionTrackerService) CreateSessionTracker(ctx context.Context, req *proto.CreateSessionTrackerRequest) (types.SessionTracker, error) { + return nil, nil +} + +func (m *MockSessionTrackerService) UpdateSessionTracker(ctx context.Context, req *proto.UpdateSessionTrackerRequest) error { + return nil +} + +func (m *MockSessionTrackerService) RemoveSessionTracker(ctx context.Context, sessionID string) error { + return nil +} + +func (m *MockSessionTrackerService) UpdatePresence(ctx context.Context, sessionID, user string) error { + return nil +} diff --git a/lib/events/filesessions/fileasync_chaos_test.go b/lib/events/filesessions/fileasync_chaos_test.go index 83038e17726dd..f8bc6e03d13f6 100644 --- a/lib/events/filesessions/fileasync_chaos_test.go +++ b/lib/events/filesessions/fileasync_chaos_test.go @@ -35,6 +35,7 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/trace" @@ -123,7 +124,7 @@ func TestChaosUpload(t *testing.T) { Streamer: faultyStreamer, Clock: clock, AuditLog: &events.DiscardAuditLog{}, - }, &events.MockSessionTrackerService{}) + }, &eventstest.MockSessionTrackerService{}) require.NoError(t, err) go uploader.Serve() // wait until uploader blocks on the clock diff --git a/lib/events/filesessions/fileasync_test.go b/lib/events/filesessions/fileasync_test.go index 17623b2a0a26d..e6e219c9d93de 100644 --- a/lib/events/filesessions/fileasync_test.go +++ b/lib/events/filesessions/fileasync_test.go @@ -32,6 +32,7 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/trace" @@ -485,7 +486,7 @@ func newUploaderPack(t *testing.T, wrapStreamer wrapStreamerFn) uploaderPack { Clock: pack.clock, EventsC: pack.eventsC, AuditLog: &events.DiscardAuditLog{}, - }, &events.MockSessionTrackerService{}) + }, &eventstest.MockSessionTrackerService{}) require.NoError(t, err) pack.uploader = uploader go pack.uploader.Serve() @@ -522,7 +523,7 @@ func runResume(t *testing.T, testCase resumeTestCase) { Streamer: test.streamer, Clock: clock, AuditLog: &events.DiscardAuditLog{}, - }, &events.MockSessionTrackerService{}) + }, &eventstest.MockSessionTrackerService{}) require.Nil(t, err) go uploader.Serve() // wait until uploader blocks on the clock diff --git a/lib/services/local/sessiontracker.go b/lib/services/local/sessiontracker.go index 132c78a212480..3958a27392298 100644 --- a/lib/services/local/sessiontracker.go +++ b/lib/services/local/sessiontracker.go @@ -24,7 +24,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -136,17 +135,7 @@ func (s *sessionTracker) GetActiveSessionTrackers(ctx context.Context) ([]types. // CreateSessionTracker creates a tracker resource for an active session. func (s *sessionTracker) CreateSessionTracker(ctx context.Context, req *proto.CreateSessionTrackerRequest) (types.SessionTracker, error) { - // Don't allow sessions that require moderation without the enterprise feature enabled. - for _, policySet := range req.HostPolicies { - if len(policySet.RequireSessionJoin) != 0 { - if !modules.GetModules().Features().ModeratedSessions { - return nil, trace.AccessDenied( - "this Teleport cluster is not licensed for moderated sessions, please contact the cluster administrator") - } - } - } - - now := time.Now().UTC() + now := s.bk.Clock().Now() spec := types.SessionTrackerSpecV1{ SessionID: req.ID, Kind: req.Type, @@ -213,7 +202,7 @@ func (s *sessionTracker) UpdateSessionTracker(ctx context.Context, req *proto.Up session.SetState(update.UpdateState.State) if update.UpdateState.State == types.SessionState_SessionStateTerminated { // Mark session tracker for deletion. - session.SetExpiry(time.Now()) + session.SetExpiry(s.bk.Clock().Now()) } case *proto.UpdateSessionTrackerRequest_AddParticipant: diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 7a96de7fca1f3..77e0e89c635ae 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1559,7 +1559,7 @@ func TestSessionTracker(t *testing.T) { // session tracker should be created var tracker types.SessionTracker - condition := func() bool { + trackerFound := func() bool { trackers, err := f.testSrv.Auth().GetActiveSessionTrackers(ctx) require.NoError(t, err) @@ -1569,33 +1569,40 @@ func TestSessionTracker(t *testing.T) { } return false } - require.Eventually(t, condition, time.Second*5, time.Second) + require.Eventually(t, trackerFound, time.Second*5, time.Second) // Advance the clock to trigger the session tracker expiration to be extended f.clock.Advance(defaults.SessionTrackerExpirationUpdateInterval) // The session's expiration should be udpated - condition = func() bool { + trackerUpdated := func() bool { updatedTracker, err := f.testSrv.Auth().GetSessionTracker(ctx, tracker.GetSessionID()) require.NoError(t, err) - return updatedTracker.Expiry().After(tracker.Expiry()) + return updatedTracker.Expiry().Equal(tracker.Expiry().Add(defaults.SessionTrackerExpirationUpdateInterval)) } - require.Eventually(t, condition, time.Second*5, time.Millisecond*1000) + require.Eventually(t, trackerUpdated, time.Second*5, time.Millisecond*1000) // Close the session from the client side err = se.Close() require.NoError(t, err) - // Wait for session to close in the background - time.Sleep(defaults.SessionIdlePeriod) + // Advance clock to make session clock in background. + go func() { + for { + // Advance clock every 1/10 of a second. Ideally + // we could use clock.BlockUntil, but there are + // are a variable number of sleepers. + time.Sleep(time.Millisecond * 100) + f.clock.Advance(defaults.SessionIdlePeriod) + } + }() - // once the session is closed, the tracker should expire. - condition = func() bool { - expiredTracker, err := f.testSrv.Auth().GetSessionTracker(ctx, tracker.GetSessionID()) - require.NoError(t, err) - return time.Now().After(expiredTracker.Expiry()) + // once the session is closed, the tracker should expire (not found) + trackerExpired := func() bool { + _, err := f.testSrv.Auth().GetSessionTracker(ctx, tracker.GetSessionID()) + return trace.IsNotFound(err) } - require.Eventually(t, condition, time.Second*5, time.Millisecond*100) + require.Eventually(t, trackerExpired, time.Second*5, time.Millisecond*100) } // rawNode is a basic non-teleport node which holds a diff --git a/lib/srv/sess.go b/lib/srv/sess.go index c476974cb3a02..509456874691f 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -367,11 +367,12 @@ func (s *SessionRegistry) leaveSession(party *party) error { lingerAndDie := func() { lingerTTL := sess.GetLingerTTL() if lingerTTL > 0 { - time.Sleep(lingerTTL) + sess.scx.srv.GetClock().Sleep(lingerTTL) } // not lingering anymore? someone reconnected? cool then... no need // to die... if !sess.isLingering() { + fmt.Println("Lingering") s.log.Infof("Session %v has become active again.", sess.id) return } @@ -760,9 +761,12 @@ func (s *session) Close() error { s.recorder.Close(s.serverCtx) } + s.stateUpdate.L.Lock() + defer s.stateUpdate.L.Unlock() + err := s.trackerUpdateState(types.SessionState_SessionStateTerminated) if err != nil { - s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateTerminated) + s.log.Warnf("Failed to set tracker state to %v: %v", types.SessionState_SessionStateTerminated, err) } }() }) @@ -816,8 +820,6 @@ func (s *session) BroadcastMessage(format string, args ...interface{}) { func (s *session) launch(ctx *ServerContext) error { s.mu.Lock() defer s.mu.Unlock() - s.stateUpdate.L.Lock() - defer s.stateUpdate.L.Unlock() s.log.Debugf("Launching session %v.", s.id) s.BroadcastMessage("Connecting to %v over SSH", ctx.srv.GetInfo().GetHostname()) @@ -827,6 +829,9 @@ func (s *session) launch(ctx *ServerContext) error { s.log.Warnf("Failed to turn enable IO: %v.", err) } + s.stateUpdate.L.Lock() + defer s.stateUpdate.L.Unlock() + err = s.trackerUpdateState(types.SessionState_SessionStateRunning) if err != nil { s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning) @@ -1546,7 +1551,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } }() } else { - err := trackerUpdateState(types.SessionState_SessionStateRunning) + err := s.trackerUpdateState(types.SessionState_SessionStateRunning) if err != nil { s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning) } @@ -1746,8 +1751,8 @@ func (s *session) trackerCreate(teleportUser string, policySet []*types.SessionT defer ticker.Stop() for { select { - case <-ticker.Chan(): - if err := s.trackerUpdateExpiry(time.Now().Add(defaults.SessionTrackerTTL)); err != nil { + case time := <-ticker.Chan(): + if err := s.trackerUpdateExpiry(time.Add(defaults.SessionTrackerTTL)); err != nil { s.log.WithError(err).Warningf("Failed to update session tracker expiration.") } case <-s.closeC: