Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow file argument with tsh play #5984

Merged
merged 22 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 80 additions & 45 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,8 @@ func (tc *TeleportClient) Join(ctx context.Context, namespace string, sessionID

// Play replays the recorded session
func (tc *TeleportClient) Play(ctx context.Context, namespace, sessionID string) (err error) {
var sessionEvents []events.EventFields
var stream []byte
if namespace == "" {
return trace.BadParameter(auth.MissingNamespaceError)
}
Expand All @@ -1413,13 +1415,12 @@ func (tc *TeleportClient) Play(ctx context.Context, namespace, sessionID string)
return trace.Wrap(err)
}
// request events for that session (to get timing data)
sessionEvents, err := site.GetSessionEvents(namespace, *sid, 0, true)
sessionEvents, err = site.GetSessionEvents(namespace, *sid, 0, true)
if err != nil {
return trace.Wrap(err)
}

// read the stream into a buffer:
var stream []byte
for {
tmp, err := site.GetSessionChunk(namespace, *sid, len(stream), events.MaxChunkBytes)
if err != nil {
Expand All @@ -1439,50 +1440,32 @@ func (tc *TeleportClient) Play(ctx context.Context, namespace, sessionID string)
}
defer term.Restore(0, state)
}
player := newSessionPlayer(sessionEvents, stream)
// keys:
const (
keyCtrlC = 3
keyCtrlD = 4
keySpace = 32
keyLeft = 68
keyRight = 67
keyUp = 65
keyDown = 66
)
// playback control goroutine
go func() {
defer player.Stop()
key := make([]byte, 1)
for {
_, err = os.Stdin.Read(key)
if err != nil {
return
}
switch key[0] {
// Ctrl+C or Ctrl+D
case keyCtrlC, keyCtrlD:
return
// Space key
case keySpace:
player.TogglePause()
// <- arrow
case keyLeft, keyDown:
player.Rewind()
// -> arrow
case keyRight, keyUp:
player.Forward()
}
}
}()

// player starts playing in its own goroutine
player.Play()
return playSession(sessionEvents, stream)
}

// wait for keypresses loop to end
<-player.stopC
fmt.Println("\n\nend of session playback")
return trace.Wrap(err)
// PlayFile plays the recorded session from a tar file
func (tc *TeleportClient) PlayFile(ctx context.Context, tarFile io.Reader, sid string) error {
var sessionEvents []events.EventFields
var stream []byte
protoReader := events.NewProtoReader(tarFile)
playbackDir, err := ioutil.TempDir("", "playback")
if err != nil {
return trace.Wrap(err)
}
defer os.RemoveAll(playbackDir)
w, err := events.WriteForPlayback(ctx, session.ID(sid), protoReader, playbackDir)
if err != nil {
return trace.Wrap(err)
}
sessionEvents, err = w.SessionEvents()
if err != nil {
return trace.Wrap(err)
}
stream, err = w.SessionChunks()
if err != nil {
return trace.Wrap(err)
}
return playSession(sessionEvents, stream)
}

// ExecuteSCP executes SCP command. It executes scp.Command using
Expand Down Expand Up @@ -2859,3 +2842,55 @@ func InsecureSkipHostKeyChecking(host string, remote net.Addr, key ssh.PublicKey
func isFIPS() bool {
return modules.GetModules().IsBoringBinary()
}

// playSession plays session in the terminal
func playSession(sessionEvents []events.EventFields, stream []byte) error {
var errorCh = make(chan error)
player := newSessionPlayer(sessionEvents, stream)
// keys:
const (
keyCtrlC = 3
keyCtrlD = 4
keySpace = 32
keyLeft = 68
keyRight = 67
keyUp = 65
keyDown = 66
)
// playback control goroutine
go func() {
defer player.Stop()
var key [1]byte
for {
_, err := os.Stdin.Read(key[:])
if err != nil {
errorCh <- err
return
}
switch key[0] {
// Ctrl+C or Ctrl+D
case keyCtrlC, keyCtrlD:
return
// Space key
case keySpace:
player.TogglePause()
// <- arrow
case keyLeft, keyDown:
player.Rewind()
// -> arrow
case keyRight, keyUp:
player.Forward()
}
}
}()
// player starts playing in its own goroutine
player.Play()
// wait for keypresses loop to end
select {
case <-player.stopC:
fmt.Println("\n\nend of session playback")
return nil
case err := <-errorCh:
return trace.Wrap(err)
}
}
2 changes: 1 addition & 1 deletion lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error {
start = time.Now()
l.log.Debugf("Converting %v to playback format.", tarballPath)
protoReader := NewProtoReader(tarball)
err = WriteForPlayback(l.Context, sid, protoReader, l.playbackDir)
_, err = WriteForPlayback(l.Context, sid, protoReader, l.playbackDir)
if err != nil {
l.log.WithError(err).Error("Failed to convert.")
return trace.Wrap(err)
Expand Down
84 changes: 74 additions & 10 deletions lib/events/playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package events

import (
"archive/tar"
"bufio"
"compress/gzip"
"context"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"

Expand Down Expand Up @@ -120,9 +123,9 @@ func Export(ctx context.Context, rs io.ReadSeeker, w io.Writer, exportFormat str
}
}

// WriteForPlayback reads events from audit reader
// and writes them to the format optimized for playback
func WriteForPlayback(ctx context.Context, sid session.ID, reader AuditReader, dir string) error {
// WriteForPlayback reads events from audit reader and writes them to the format optimized for playback
// this function returns *PlaybackWriter and error
func WriteForPlayback(ctx context.Context, sid session.ID, reader AuditReader, dir string) (*PlaybackWriter, error) {
w := &PlaybackWriter{
sid: sid,
reader: reader,
Expand All @@ -134,7 +137,66 @@ func WriteForPlayback(ctx context.Context, sid session.ID, reader AuditReader, d
log.WithError(err).Warningf("Failed to close writer.")
}
}()
return w.Write(ctx)
return w, w.Write(ctx)
}

// SessionEvents returns slice of event fields from gzipped events file.
// The file at eventsPath will be removed.
func (w *PlaybackWriter) SessionEvents() ([]EventFields, error) {
var sessionEvents []EventFields
//events
eventFile, err := os.Open(w.EventsPath)
if err != nil {
return nil, trace.Wrap(err)
}
defer eventFile.Close()

// remove event file from temp dir when done playing
grEvents, err := gzip.NewReader(eventFile)
if err != nil {
return nil, trace.Wrap(err)
}
defer grEvents.Close()
scanner := bufio.NewScanner(grEvents)
for scanner.Scan() {
var f EventFields
err := utils.FastUnmarshal(scanner.Bytes(), &f)
if err != nil {
if err == io.EOF {
return sessionEvents, nil
}
return nil, trace.Wrap(err)
}
sessionEvents = append(sessionEvents, f)
}

if err := scanner.Err(); err != nil {
return nil, trace.Wrap(err)
}

return sessionEvents, nil
}

// SessionChunks interprets the file at the given path as gzip-compressed list of session events and returns
// the uncompressed contents as a result.
// The file at chunksPath will be removed.
func (w *PlaybackWriter) SessionChunks() ([]byte, error) {
var stream []byte
chunkFile, err := os.Open(w.ChunksPath)
if err != nil {
return nil, trace.Wrap(err)
}
defer chunkFile.Close()
grChunk, err := gzip.NewReader(chunkFile)
if err != nil {
return nil, trace.Wrap(err)
}
defer grChunk.Close()
stream, err = ioutil.ReadAll(grChunk)
if err != nil {
return nil, trace.Wrap(err)
}
return stream, nil
}

// PlaybackWriter reads messages until end of file
Expand All @@ -147,6 +209,8 @@ type PlaybackWriter struct {
eventsFile *gzipWriter
chunksFile *gzipWriter
eventIndex int64
EventsPath string
ChunksPath string
}

// Close closes all files
Expand Down Expand Up @@ -278,11 +342,11 @@ func (w *PlaybackWriter) openEventsFile(eventIndex int64) error {
if w.eventsFile != nil {
return nil
}
eventsFileName := eventsFileName(w.dir, w.sid, "", eventIndex)
w.EventsPath = eventsFileName(w.dir, w.sid, "", eventIndex)

// update the index file to write down that new events file has been created
data, err := utils.FastMarshal(indexEntry{
FileName: filepath.Base(eventsFileName),
FileName: filepath.Base(w.EventsPath),
Type: fileTypeEvents,
Index: eventIndex,
})
Expand All @@ -296,7 +360,7 @@ func (w *PlaybackWriter) openEventsFile(eventIndex int64) error {
}

// open new events file for writing
file, err := os.OpenFile(eventsFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
file, err := os.OpenFile(w.EventsPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -308,11 +372,11 @@ func (w *PlaybackWriter) openChunksFile(offset int64) error {
if w.chunksFile != nil {
return nil
}
chunksFileName := chunksFileName(w.dir, w.sid, offset)
w.ChunksPath = chunksFileName(w.dir, w.sid, offset)

// Update the index file to write down that new chunks file has been created.
data, err := utils.FastMarshal(indexEntry{
FileName: filepath.Base(chunksFileName),
FileName: filepath.Base(w.ChunksPath),
Type: fileTypeChunks,
Offset: offset,
})
Expand All @@ -328,7 +392,7 @@ func (w *PlaybackWriter) openChunksFile(offset int64) error {

// open the chunks file for writing, but because the file is written without
// compression, remove the .gz
file, err := os.OpenFile(chunksFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
file, err := os.OpenFile(w.ChunksPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
if err != nil {
return trace.Wrap(err)
}
Expand Down
23 changes: 21 additions & 2 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"sort"
"strings"
Expand Down Expand Up @@ -608,8 +609,21 @@ func onPlay(cf *CLIConf) error {
if err != nil {
return trace.Wrap(err)
}
if err := tc.Play(context.TODO(), cf.Namespace, cf.SessionID); err != nil {
return trace.Wrap(err)
switch {
case path.Ext(cf.SessionID) == ".tar":
sid := sessionIDFromPath(cf.SessionID)
tarFile, err := os.Open(cf.SessionID)
defer tarFile.Close()
if err != nil {
return trace.ConvertSystemError(err)
}
if err := tc.PlayFile(context.TODO(), tarFile, sid); err != nil {
return trace.Wrap(err)
}
default:
if err := tc.Play(context.TODO(), cf.Namespace, cf.SessionID); err != nil {
return trace.Wrap(err)
}
}
default:
err := exportFile(cf.SessionID, cf.Format)
Expand All @@ -620,6 +634,11 @@ func onPlay(cf *CLIConf) error {
return nil
}

func sessionIDFromPath(path string) string {
fileName := filepath.Base(path)
return strings.TrimSuffix(fileName, ".tar")
}

func exportFile(path string, format string) error {
f, err := os.Open(path)
if err != nil {
Expand Down