Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement an API for exporting session events #7360

Merged
merged 5 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,54 @@ func (c *Client) DeleteAllNodes(ctx context.Context, namespace string) error {
return trail.FromGRPC(err)
}

// StreamSessionEvents streams audit events from a given session recording.
func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) {
request := &proto.StreamSessionEventsRequest{
SessionID: sessionID,
StartIndex: int32(startIndex),
}

ch := make(chan events.AuditEvent)
e := make(chan error, 1)

stream, err := c.grpc.StreamSessionEvents(ctx, request)
if err != nil {
e <- trace.Wrap(err)
return ch, e
}

go func() {
outer:
for {
oneOf, err := stream.Recv()
if err != nil {
if err != io.EOF {
e <- trace.Wrap(trail.FromGRPC(err))
} else {
close(ch)
}

break outer
}

event, err := events.FromOneOf(*oneOf)
if err != nil {
e <- trace.Wrap(trail.FromGRPC(err))
break outer
}

select {
case ch <- event:
case <-ctx.Done():
e <- trace.Wrap(ctx.Err())
break outer
}
}
}()

return ch, e
}

// SearchEvents allows searching for events with a full pagination support.
func (c *Client) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) {
request := &proto.GetEventsRequest{
Expand Down
1,052 changes: 670 additions & 382 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions api/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,15 @@ message IsMFARequiredRequest {
}
}

// StreamSessionEventsRequest is a request containing needed data to fetch a session recording.
message StreamSessionEventsRequest {
// SessionID is the ID for a given session in an UUIDv4 format.
string SessionID = 1;
// StartIndex is the index of the event to resume the stream after.
// A StartIndex of 0 creates a new stream.
int32 StartIndex = 2;
}

// NodeLogin specifies an SSH node and OS login.
message NodeLogin {
// Node can be node's hostname or UUID.
Expand Down Expand Up @@ -1196,4 +1205,7 @@ service AuthService {
rpc UpsertLock(types.LockV2) returns (google.protobuf.Empty);
// DeleteLock deletes a lock.
rpc DeleteLock(DeleteLockRequest) returns (google.protobuf.Empty);

// StreamSessionEvents streams audit events from a given session recording.
rpc StreamSessionEvents(StreamSessionEventsRequest) returns (stream events.OneOf);
}
2 changes: 1 addition & 1 deletion e
Submodule e updated from d0361b to fccd41
68 changes: 68 additions & 0 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"os/exec"
"os/user"
"path/filepath"
"reflect"
"regexp"
"runtime/pprof"
"strconv"
Expand All @@ -46,6 +47,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/keypaths"
"github.com/gravitational/teleport/lib"
Expand Down Expand Up @@ -209,6 +211,7 @@ func TestIntegrations(t *testing.T) {
t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel))
t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy))
t.Run("WindowChange", suite.bind(testWindowChange))
t.Run("SessionStreaming", suite.bind(testSessionStreaming))
}

// testAuditOn creates a live session, records a bunch of data through it
Expand Down Expand Up @@ -5457,3 +5460,68 @@ func TestTraitsPropagation(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "hello leaf", strings.TrimSpace(outputLeaf))
}

// testSessionStreaming tests streaming events from session recordings.
func testSessionStreaming(t *testing.T, suite *integrationTestSuite) {
ctx := context.Background()
sessionID := session.ID(uuid.New())
teleport := suite.newTeleport(t, nil, true)
defer teleport.StopAll()

api := teleport.GetSiteAPI(Site)
uploadStream, err := api.CreateAuditStream(ctx, sessionID)
require.Nil(t, err)

generatedSession := events.GenerateTestSession(events.SessionParams{
PrintEvents: 100,
SessionID: string(sessionID),
ServerID: "00000000-0000-0000-0000-000000000000",
})

for _, event := range generatedSession {
err := uploadStream.EmitAuditEvent(ctx, event)
require.NoError(t, err)
}

err = uploadStream.Complete(ctx)
require.Nil(t, err)
start := time.Now()

// retry in case of error
outer:
for time.Since(start) < time.Minute*5 {
time.Sleep(time.Second * 5)

receivedSession := make([]apievents.AuditEvent, 0)
sessionPlayback, e := api.StreamSessionEvents(ctx, sessionID, 0)

inner:
for {
select {
case event, more := <-sessionPlayback:
if !more {
break inner
}

receivedSession = append(receivedSession, event)
case <-ctx.Done():
require.Nil(t, ctx.Err())
case err := <-e:
require.Nil(t, err)
case <-time.After(time.Minute * 5):
t.FailNow()
}
}

for i := range generatedSession {
receivedSession[i].SetClusterName("")
if !reflect.DeepEqual(generatedSession[i], receivedSession[i]) {
continue outer
}
}

return
}

t.FailNow()
}
13 changes: 13 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -3148,6 +3148,19 @@ func (a *ServerWithRoles) DeleteAllLocks(context.Context) error {
return trace.NotImplemented(notImplementedMessage)
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
if err := a.action(apidefaults.Namespace, types.KindSession, types.VerbList); err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}

return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
}

// NewAdminAuthServer returns auth server authorized as admin,
// used for auth server cached access
func NewAdminAuthServer(authServer *Server, sessions session.Service, alog events.IAuditLog) (ClientI, error) {
Expand Down
7 changes: 7 additions & 0 deletions lib/auth/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,13 @@ func (c *Client) GetSessionEvents(namespace string, sid session.ID, afterN int,
return retval, nil
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (c *Client) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
return c.APIClient.StreamSessionEvents(ctx, string(sessionID), startIndex)
}

// SearchEvents allows searching for audit events with pagination support.
func (c *Client) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) {
events, lastKey, err := c.APIClient.SearchEvents(context.TODO(), fromUTC, toUTC, namespace, eventTypes, limit, order, startKey)
Expand Down
32 changes: 32 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2663,6 +2663,38 @@ func (g *GRPCServer) ResetAuthPreference(ctx context.Context, _ *empty.Empty) (*
return &empty.Empty{}, nil
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (g *GRPCServer) StreamSessionEvents(req *proto.StreamSessionEventsRequest, stream proto.AuthService_StreamSessionEventsServer) error {
auth, err := g.authenticate(stream.Context())
if err != nil {
return trace.Wrap(err)
}

c, e := auth.ServerWithRoles.StreamSessionEvents(stream.Context(), session.ID(req.SessionID), int64(req.StartIndex))

for {
select {
case event, more := <-c:
if !more {
return nil
}

oneOf, err := apievents.ToOneOf(event)
if err != nil {
return trail.ToGRPC(trace.Wrap(err))
}

if err := stream.Send(oneOf); err != nil {
return trail.ToGRPC(trace.Wrap(err))
}
case err := <-e:
return trail.ToGRPC(trace.Wrap(err))
}
}
}

// GetEvents searches for events on the backend and sends them back in a response.
func (g *GRPCServer) GetEvents(ctx context.Context, req *proto.GetEventsRequest) (*proto.Events, error) {
auth, err := g.authenticate(ctx)
Expand Down
5 changes: 5 additions & 0 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ type IAuditLog interface {
// WaitForDelivery waits for resources to be released and outstanding requests to
// complete after calling Close method
WaitForDelivery(context.Context) error

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error)
}

// EventFields instance is attached to every logged event
Expand Down
87 changes: 87 additions & 0 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,86 @@ func (l *AuditLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, orde
return l.localLog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey)
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
l.log.Debugf("StreamSessionEvents(%v)", sessionID)
e := make(chan error, 1)
c := make(chan apievents.AuditEvent)

tarballPath := filepath.Join(l.playbackDir, string(sessionID)+".stream.tar")
downloadCtx, cancel := l.createOrGetDownload(tarballPath)

// Wait until another in progress download finishes and use it's tarball.
if cancel == nil {
l.log.Debugf("Another download is in progress for %v, waiting until it gets completed.", sessionID)
select {
case <-downloadCtx.Done():
case <-l.ctx.Done():
e <- trace.BadParameter("audit log is closing, aborting the download")
return c, e
}
}
defer cancel()
rawSession, err := os.OpenFile(tarballPath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0640)
if err != nil {
e <- trace.Wrap(err)
return c, e
}

start := time.Now()
if err := l.UploadHandler.Download(l.ctx, sessionID, rawSession); err != nil {
// remove partially downloaded tarball
if rmErr := os.Remove(tarballPath); rmErr != nil {
l.log.WithError(rmErr).Warningf("Failed to remove file %v.", tarballPath)
}

e <- trace.Wrap(err)
return c, e
}

l.log.WithField("duration", time.Since(start)).Debugf("Downloaded %v to %v.", sessionID, tarballPath)
_, err = rawSession.Seek(0, 0)
if err != nil {
e <- trace.Wrap(err)
return c, e
}

if err != nil {
e <- trace.Wrap(err)
return c, e
}

protoReader := NewProtoReader(rawSession)

go func() {
for {
if ctx.Err() != nil {
e <- trace.Wrap(ctx.Err())
break
}

event, err := protoReader.Read(ctx)
if err != nil {
if err != io.EOF {
e <- trace.Wrap(err)
} else {
close(c)
}

break
}

if event.GetIndex() >= startIndex {
c <- event
}
}
}()

return c, e
}

// getLocalLog returns the local (file based) audit log.
func (l *AuditLog) getLocalLog() IAuditLog {
l.RLock()
Expand Down Expand Up @@ -1236,3 +1316,10 @@ func (a *closedLogger) WaitForDelivery(context.Context) error {
func (a *closedLogger) Close() error {
return trace.NotImplemented(loggerClosedMessage)
}

func (a *closedLogger) StreamSessionEvents(_ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.NotImplemented(loggerClosedMessage)

return c, e
}
6 changes: 5 additions & 1 deletion lib/events/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ func (d *DiscardAuditLog) SearchEvents(fromUTC, toUTC time.Time, namespace strin
func (d *DiscardAuditLog) SearchSessionEvents(fromUTC time.Time, toUTC time.Time, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) {
return make([]apievents.AuditEvent, 0), "", nil
}

func (d *DiscardAuditLog) UploadSessionRecording(SessionRecording) error {
return nil
}
func (d *DiscardAuditLog) EmitAuditEvent(ctx context.Context, event apievents.AuditEvent) error {
return nil
}
func (d *DiscardAuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
close(c)
return c, e
}
9 changes: 9 additions & 0 deletions lib/events/dynamoevents/dynamoevents.go
Original file line number Diff line number Diff line change
Expand Up @@ -1435,3 +1435,12 @@ func convertError(err error) error {
return err
}
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (l *Log) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.NotImplemented("not implemented")
return c, e
}
Loading