Skip to content

Commit

Permalink
Utilize fake clock for tests; make other suggested changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Apr 4, 2022
1 parent 211aa6c commit c9346af
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 80 deletions.
2 changes: 1 addition & 1 deletion api/client/proto/authservice.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions api/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ message SessionTrackerRemoveParticipant {
string ParticipantID = 2 [ (gogoproto.jsontag) = "participant_id,omitempty" ];
}

// SessionTrackerUpdateExpiry is used to update the session tracker expirationt time.
// SessionTrackerUpdateExpiry is used to update the session tracker expiration time.
message SessionTrackerUpdateExpiry {
// Expires is when the session tracker will expire.
google.protobuf.Timestamp Expires = 1
Expand Down Expand Up @@ -1641,7 +1641,10 @@ service AuthService {
// CreateSessionTracker creates a new session tracker resource.
rpc CreateSessionTracker(CreateSessionTrackerRequest) returns (types.SessionTrackerV1);

// GetSessionTrackerRequest fetches a session tracker resource.
// UpsertSessionTracker upserts a session tracker resource.
rpc UpsertSessionTracker(types.SessionTrackerV1) returns (google.protobuf.Empty);

// GetSessionTracker fetches a session tracker resource.
rpc GetSessionTracker(GetSessionTrackerRequest) returns (types.SessionTrackerV1);

// GetActiveSessionTrackers returns a list of active sessions.
Expand Down
10 changes: 10 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, req *proto.C
return nil, trace.AccessDenied("this request can be only executed by a node, proxy or kube service")
}

// Don't allow sessions that require moderation without the enterprise feature enabled.
for _, policySet := range req.HostPolicies {
if len(policySet.RequireSessionJoin) != 0 {
if !modules.GetModules().Features().ModeratedSessions {
return nil, trace.AccessDenied(
"this Teleport cluster is not licensed for moderated sessions, please contact the cluster administrator")
}
}
}

return a.authServer.CreateSessionTracker(ctx, req)
}

Expand Down
37 changes: 0 additions & 37 deletions lib/events/complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import (
"time"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -314,38 +312,3 @@ loop:
}
return nil
}

type MockSessionTrackerService struct {
clock clockwork.Clock
mockTrackers []types.SessionTracker
}

func (m *MockSessionTrackerService) GetActiveSessionTrackers(ctx context.Context) ([]types.SessionTracker, error) {
return nil, nil
}

func (m *MockSessionTrackerService) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) {
for _, tracker := range m.mockTrackers {
// mock session tracker expiration
if tracker.GetSessionID() == sessionID && tracker.Expiry().After(m.clock.Now()) {
return tracker, nil
}
}
return nil, trace.NotFound("tracker not found")
}

func (m *MockSessionTrackerService) CreateSessionTracker(ctx context.Context, req *proto.CreateSessionTrackerRequest) (types.SessionTracker, error) {
return nil, nil
}

func (m *MockSessionTrackerService) UpdateSessionTracker(ctx context.Context, req *proto.UpdateSessionTrackerRequest) error {
return nil
}

func (m *MockSessionTrackerService) RemoveSessionTracker(ctx context.Context, sessionID string) error {
return nil
}

func (m *MockSessionTrackerService) UpdatePresence(ctx context.Context, sessionID, user string) error {
return nil
}
8 changes: 4 additions & 4 deletions lib/events/complete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ func TestUploadCompleterCompletesAbandonedUploads(t *testing.T) {
},
}

sessionTrackerService := &MockSessionTrackerService{
clock: clock,
mockTrackers: []types.SessionTracker{sessionTracker},
sessionTrackerService := &eventstest.MockSessionTrackerService{
Clock: clock,
MockTrackers: []types.SessionTracker{sessionTracker},
}

uc, err := NewUploadCompleter(UploadCompleterConfig{
Expand Down Expand Up @@ -104,7 +104,7 @@ func TestUploadCompleterEmitsSessionEnd(t *testing.T) {
Uploader: mu,
AuditLog: log,
Clock: clock,
SessionTracker: &MockSessionTrackerService{},
SessionTracker: &eventstest.MockSessionTrackerService{},
})
require.NoError(t, err)

Expand Down
39 changes: 39 additions & 0 deletions lib/events/eventstest/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ import (
"context"
"sync"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
)

// MockEmitter is an emitter that stores all emitted events.
Expand Down Expand Up @@ -88,3 +92,38 @@ func (e *MockEmitter) Close(ctx context.Context) error {
func (e *MockEmitter) Complete(ctx context.Context) error {
return nil
}

type MockSessionTrackerService struct {
Clock clockwork.Clock
MockTrackers []types.SessionTracker
}

func (m *MockSessionTrackerService) GetActiveSessionTrackers(ctx context.Context) ([]types.SessionTracker, error) {
return nil, nil
}

func (m *MockSessionTrackerService) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) {
for _, tracker := range m.MockTrackers {
// mock session tracker expiration
if tracker.GetSessionID() == sessionID && tracker.Expiry().After(m.Clock.Now()) {
return tracker, nil
}
}
return nil, trace.NotFound("tracker not found")
}

func (m *MockSessionTrackerService) CreateSessionTracker(ctx context.Context, req *proto.CreateSessionTrackerRequest) (types.SessionTracker, error) {
return nil, nil
}

func (m *MockSessionTrackerService) UpdateSessionTracker(ctx context.Context, req *proto.UpdateSessionTrackerRequest) error {
return nil
}

func (m *MockSessionTrackerService) RemoveSessionTracker(ctx context.Context, sessionID string) error {
return nil
}

func (m *MockSessionTrackerService) UpdatePresence(ctx context.Context, sessionID, user string) error {
return nil
}
3 changes: 2 additions & 1 deletion lib/events/filesessions/fileasync_chaos_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (

apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/session"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -123,7 +124,7 @@ func TestChaosUpload(t *testing.T) {
Streamer: faultyStreamer,
Clock: clock,
AuditLog: &events.DiscardAuditLog{},
}, &events.MockSessionTrackerService{})
}, &eventstest.MockSessionTrackerService{})
require.NoError(t, err)
go uploader.Serve()
// wait until uploader blocks on the clock
Expand Down
5 changes: 3 additions & 2 deletions lib/events/filesessions/fileasync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/session"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -485,7 +486,7 @@ func newUploaderPack(t *testing.T, wrapStreamer wrapStreamerFn) uploaderPack {
Clock: pack.clock,
EventsC: pack.eventsC,
AuditLog: &events.DiscardAuditLog{},
}, &events.MockSessionTrackerService{})
}, &eventstest.MockSessionTrackerService{})
require.NoError(t, err)
pack.uploader = uploader
go pack.uploader.Serve()
Expand Down Expand Up @@ -522,7 +523,7 @@ func runResume(t *testing.T, testCase resumeTestCase) {
Streamer: test.streamer,
Clock: clock,
AuditLog: &events.DiscardAuditLog{},
}, &events.MockSessionTrackerService{})
}, &eventstest.MockSessionTrackerService{})
require.Nil(t, err)
go uploader.Serve()
// wait until uploader blocks on the clock
Expand Down
15 changes: 2 additions & 13 deletions lib/services/local/sessiontracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -136,17 +135,7 @@ func (s *sessionTracker) GetActiveSessionTrackers(ctx context.Context) ([]types.

// CreateSessionTracker creates a tracker resource for an active session.
func (s *sessionTracker) CreateSessionTracker(ctx context.Context, req *proto.CreateSessionTrackerRequest) (types.SessionTracker, error) {
// Don't allow sessions that require moderation without the enterprise feature enabled.
for _, policySet := range req.HostPolicies {
if len(policySet.RequireSessionJoin) != 0 {
if !modules.GetModules().Features().ModeratedSessions {
return nil, trace.AccessDenied(
"this Teleport cluster is not licensed for moderated sessions, please contact the cluster administrator")
}
}
}

now := time.Now().UTC()
now := s.bk.Clock().Now()
spec := types.SessionTrackerSpecV1{
SessionID: req.ID,
Kind: req.Type,
Expand Down Expand Up @@ -213,7 +202,7 @@ func (s *sessionTracker) UpdateSessionTracker(ctx context.Context, req *proto.Up
session.SetState(update.UpdateState.State)
if update.UpdateState.State == types.SessionState_SessionStateTerminated {
// Mark session tracker for deletion.
session.SetExpiry(time.Now())
session.SetExpiry(s.bk.Clock().Now())
}

case *proto.UpdateSessionTrackerRequest_AddParticipant:
Expand Down
33 changes: 20 additions & 13 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1559,7 +1559,7 @@ func TestSessionTracker(t *testing.T) {

// session tracker should be created
var tracker types.SessionTracker
condition := func() bool {
trackerFound := func() bool {
trackers, err := f.testSrv.Auth().GetActiveSessionTrackers(ctx)
require.NoError(t, err)

Expand All @@ -1569,33 +1569,40 @@ func TestSessionTracker(t *testing.T) {
}
return false
}
require.Eventually(t, condition, time.Second*5, time.Second)
require.Eventually(t, trackerFound, time.Second*5, time.Second)

// Advance the clock to trigger the session tracker expiration to be extended
f.clock.Advance(defaults.SessionTrackerExpirationUpdateInterval)

// The session's expiration should be udpated
condition = func() bool {
trackerUpdated := func() bool {
updatedTracker, err := f.testSrv.Auth().GetSessionTracker(ctx, tracker.GetSessionID())
require.NoError(t, err)
return updatedTracker.Expiry().After(tracker.Expiry())
return updatedTracker.Expiry().Equal(tracker.Expiry().Add(defaults.SessionTrackerExpirationUpdateInterval))
}
require.Eventually(t, condition, time.Second*5, time.Millisecond*1000)
require.Eventually(t, trackerUpdated, time.Second*5, time.Millisecond*1000)

// Close the session from the client side
err = se.Close()
require.NoError(t, err)

// Wait for session to close in the background
time.Sleep(defaults.SessionIdlePeriod)
// Advance clock to make session clock in background.
go func() {
for {
// Advance clock every 1/10 of a second. Ideally
// we could use clock.BlockUntil, but there are
// are a variable number of sleepers.
time.Sleep(time.Millisecond * 100)
f.clock.Advance(defaults.SessionIdlePeriod)
}
}()

// once the session is closed, the tracker should expire.
condition = func() bool {
expiredTracker, err := f.testSrv.Auth().GetSessionTracker(ctx, tracker.GetSessionID())
require.NoError(t, err)
return time.Now().After(expiredTracker.Expiry())
// once the session is closed, the tracker should expire (not found)
trackerExpired := func() bool {
_, err := f.testSrv.Auth().GetSessionTracker(ctx, tracker.GetSessionID())
return trace.IsNotFound(err)
}
require.Eventually(t, condition, time.Second*5, time.Millisecond*100)
require.Eventually(t, trackerExpired, time.Second*5, time.Millisecond*100)
}

// rawNode is a basic non-teleport node which holds a
Expand Down
19 changes: 12 additions & 7 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,12 @@ func (s *SessionRegistry) leaveSession(party *party) error {
lingerAndDie := func() {
lingerTTL := sess.GetLingerTTL()
if lingerTTL > 0 {
time.Sleep(lingerTTL)
sess.scx.srv.GetClock().Sleep(lingerTTL)
}
// not lingering anymore? someone reconnected? cool then... no need
// to die...
if !sess.isLingering() {
fmt.Println("Lingering")
s.log.Infof("Session %v has become active again.", sess.id)
return
}
Expand Down Expand Up @@ -760,9 +761,12 @@ func (s *session) Close() error {
s.recorder.Close(s.serverCtx)
}

s.stateUpdate.L.Lock()
defer s.stateUpdate.L.Unlock()

err := s.trackerUpdateState(types.SessionState_SessionStateTerminated)
if err != nil {
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateTerminated)
s.log.Warnf("Failed to set tracker state to %v: %v", types.SessionState_SessionStateTerminated, err)
}
}()
})
Expand Down Expand Up @@ -816,8 +820,6 @@ func (s *session) BroadcastMessage(format string, args ...interface{}) {
func (s *session) launch(ctx *ServerContext) error {
s.mu.Lock()
defer s.mu.Unlock()
s.stateUpdate.L.Lock()
defer s.stateUpdate.L.Unlock()

s.log.Debugf("Launching session %v.", s.id)
s.BroadcastMessage("Connecting to %v over SSH", ctx.srv.GetInfo().GetHostname())
Expand All @@ -827,6 +829,9 @@ func (s *session) launch(ctx *ServerContext) error {
s.log.Warnf("Failed to turn enable IO: %v.", err)
}

s.stateUpdate.L.Lock()
defer s.stateUpdate.L.Unlock()

err = s.trackerUpdateState(types.SessionState_SessionStateRunning)
if err != nil {
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning)
Expand Down Expand Up @@ -1546,7 +1551,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}
}()
} else {
err := trackerUpdateState(types.SessionState_SessionStateRunning)
err := s.trackerUpdateState(types.SessionState_SessionStateRunning)
if err != nil {
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning)
}
Expand Down Expand Up @@ -1746,8 +1751,8 @@ func (s *session) trackerCreate(teleportUser string, policySet []*types.SessionT
defer ticker.Stop()
for {
select {
case <-ticker.Chan():
if err := s.trackerUpdateExpiry(time.Now().Add(defaults.SessionTrackerTTL)); err != nil {
case time := <-ticker.Chan():
if err := s.trackerUpdateExpiry(time.Add(defaults.SessionTrackerTTL)); err != nil {
s.log.WithError(err).Warningf("Failed to update session tracker expiration.")
}
case <-s.closeC:
Expand Down

0 comments on commit c9346af

Please sign in to comment.