Skip to content

Commit

Permalink
Backport tsh play with file arg (#6162)
Browse files Browse the repository at this point in the history
  • Loading branch information
quinqu authored Mar 25, 2021
1 parent fb1d371 commit 1fb4989
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 58 deletions.
125 changes: 80 additions & 45 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,8 @@ func (tc *TeleportClient) Join(ctx context.Context, namespace string, sessionID

// Play replays the recorded session
func (tc *TeleportClient) Play(ctx context.Context, namespace, sessionID string) (err error) {
var sessionEvents []events.EventFields
var stream []byte
if namespace == "" {
return trace.BadParameter(auth.MissingNamespaceError)
}
Expand All @@ -1394,13 +1396,12 @@ func (tc *TeleportClient) Play(ctx context.Context, namespace, sessionID string)
return trace.Wrap(err)
}
// request events for that session (to get timing data)
sessionEvents, err := site.GetSessionEvents(namespace, *sid, 0, true)
sessionEvents, err = site.GetSessionEvents(namespace, *sid, 0, true)
if err != nil {
return trace.Wrap(err)
}

// read the stream into a buffer:
var stream []byte
for {
tmp, err := site.GetSessionChunk(namespace, *sid, len(stream), events.MaxChunkBytes)
if err != nil {
Expand All @@ -1420,50 +1421,32 @@ func (tc *TeleportClient) Play(ctx context.Context, namespace, sessionID string)
}
defer term.Restore(0, state)
}
player := newSessionPlayer(sessionEvents, stream)
// keys:
const (
keyCtrlC = 3
keyCtrlD = 4
keySpace = 32
keyLeft = 68
keyRight = 67
keyUp = 65
keyDown = 66
)
// playback control goroutine
go func() {
defer player.Stop()
key := make([]byte, 1)
for {
_, err = os.Stdin.Read(key)
if err != nil {
return
}
switch key[0] {
// Ctrl+C or Ctrl+D
case keyCtrlC, keyCtrlD:
return
// Space key
case keySpace:
player.TogglePause()
// <- arrow
case keyLeft, keyDown:
player.Rewind()
// -> arrow
case keyRight, keyUp:
player.Forward()
}
}
}()

// player starts playing in its own goroutine
player.Play()
return playSession(sessionEvents, stream)
}

// wait for keypresses loop to end
<-player.stopC
fmt.Println("\n\nend of session playback")
return trace.Wrap(err)
// PlayFile plays the recorded session from a tar file
func (tc *TeleportClient) PlayFile(ctx context.Context, tarFile io.Reader, sid string) error {
var sessionEvents []events.EventFields
var stream []byte
protoReader := events.NewProtoReader(tarFile)
playbackDir, err := ioutil.TempDir("", "playback")
if err != nil {
return trace.Wrap(err)
}
defer os.RemoveAll(playbackDir)
w, err := events.WriteForPlayback(ctx, session.ID(sid), protoReader, playbackDir)
if err != nil {
return trace.Wrap(err)
}
sessionEvents, err = w.SessionEvents()
if err != nil {
return trace.Wrap(err)
}
stream, err = w.SessionChunks()
if err != nil {
return trace.Wrap(err)
}
return playSession(sessionEvents, stream)
}

// ExecuteSCP executes SCP command. It executes scp.Command using
Expand Down Expand Up @@ -2840,3 +2823,55 @@ func InsecureSkipHostKeyChecking(host string, remote net.Addr, key ssh.PublicKey
func isFIPS() bool {
return modules.GetModules().IsBoringBinary()
}

// playSession plays session in the terminal
func playSession(sessionEvents []events.EventFields, stream []byte) error {
var errorCh = make(chan error)
player := newSessionPlayer(sessionEvents, stream)
// keys:
const (
keyCtrlC = 3
keyCtrlD = 4
keySpace = 32
keyLeft = 68
keyRight = 67
keyUp = 65
keyDown = 66
)
// playback control goroutine
go func() {
defer player.Stop()
var key [1]byte
for {
_, err := os.Stdin.Read(key[:])
if err != nil {
errorCh <- err
return
}
switch key[0] {
// Ctrl+C or Ctrl+D
case keyCtrlC, keyCtrlD:
return
// Space key
case keySpace:
player.TogglePause()
// <- arrow
case keyLeft, keyDown:
player.Rewind()
// -> arrow
case keyRight, keyUp:
player.Forward()
}
}
}()
// player starts playing in its own goroutine
player.Play()
// wait for keypresses loop to end
select {
case <-player.stopC:
fmt.Println("\n\nend of session playback")
return nil
case err := <-errorCh:
return trace.Wrap(err)
}
}
2 changes: 1 addition & 1 deletion lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error {
start = time.Now()
l.log.Debugf("Converting %v to playback format.", tarballPath)
protoReader := NewProtoReader(tarball)
err = WriteForPlayback(l.Context, sid, protoReader, l.playbackDir)
_, err = WriteForPlayback(l.Context, sid, protoReader, l.playbackDir)
if err != nil {
l.log.WithError(err).Error("Failed to convert.")
return trace.Wrap(err)
Expand Down
81 changes: 71 additions & 10 deletions lib/events/playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package events

import (
"archive/tar"
"bufio"
"compress/gzip"
"context"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"

Expand Down Expand Up @@ -120,9 +123,9 @@ func Export(ctx context.Context, rs io.ReadSeeker, w io.Writer, exportFormat str
}
}

// WriteForPlayback reads events from audit reader
// and writes them to the format optimized for playback
func WriteForPlayback(ctx context.Context, sid session.ID, reader AuditReader, dir string) error {
// WriteForPlayback reads events from audit reader and writes them to the format optimized for playback
// this function returns *PlaybackWriter and error
func WriteForPlayback(ctx context.Context, sid session.ID, reader AuditReader, dir string) (*PlaybackWriter, error) {
w := &PlaybackWriter{
sid: sid,
reader: reader,
Expand All @@ -134,7 +137,63 @@ func WriteForPlayback(ctx context.Context, sid session.ID, reader AuditReader, d
log.WithError(err).Warningf("Failed to close writer.")
}
}()
return w.Write(ctx)
return w, w.Write(ctx)
}

// SessionEvents returns slice of event fields from gzipped events file.
func (w *PlaybackWriter) SessionEvents() ([]EventFields, error) {
var sessionEvents []EventFields
//events
eventFile, err := os.Open(w.EventsPath)
if err != nil {
return nil, trace.Wrap(err)
}
defer eventFile.Close()

grEvents, err := gzip.NewReader(eventFile)
if err != nil {
return nil, trace.Wrap(err)
}
defer grEvents.Close()
scanner := bufio.NewScanner(grEvents)
for scanner.Scan() {
var f EventFields
err := utils.FastUnmarshal(scanner.Bytes(), &f)
if err != nil {
if err == io.EOF {
return sessionEvents, nil
}
return nil, trace.Wrap(err)
}
sessionEvents = append(sessionEvents, f)
}

if err := scanner.Err(); err != nil {
return nil, trace.Wrap(err)
}

return sessionEvents, nil
}

// SessionChunks interprets the file at the given path as gzip-compressed list of session events and returns
// the uncompressed contents as a result.
func (w *PlaybackWriter) SessionChunks() ([]byte, error) {
var stream []byte
chunkFile, err := os.Open(w.ChunksPath)
if err != nil {
return nil, trace.Wrap(err)
}
defer chunkFile.Close()
grChunk, err := gzip.NewReader(chunkFile)
if err != nil {
return nil, trace.Wrap(err)
}
defer grChunk.Close()
stream, err = ioutil.ReadAll(grChunk)
if err != nil {
return nil, trace.Wrap(err)
}
return stream, nil
}

// PlaybackWriter reads messages until end of file
Expand All @@ -147,6 +206,8 @@ type PlaybackWriter struct {
eventsFile *gzipWriter
chunksFile *gzipWriter
eventIndex int64
EventsPath string
ChunksPath string
}

// Close closes all files
Expand Down Expand Up @@ -278,11 +339,11 @@ func (w *PlaybackWriter) openEventsFile(eventIndex int64) error {
if w.eventsFile != nil {
return nil
}
eventsFileName := eventsFileName(w.dir, w.sid, "", eventIndex)
w.EventsPath = eventsFileName(w.dir, w.sid, "", eventIndex)

// update the index file to write down that new events file has been created
data, err := utils.FastMarshal(indexEntry{
FileName: filepath.Base(eventsFileName),
FileName: filepath.Base(w.EventsPath),
Type: fileTypeEvents,
Index: eventIndex,
})
Expand All @@ -296,7 +357,7 @@ func (w *PlaybackWriter) openEventsFile(eventIndex int64) error {
}

// open new events file for writing
file, err := os.OpenFile(eventsFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
file, err := os.OpenFile(w.EventsPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -308,11 +369,11 @@ func (w *PlaybackWriter) openChunksFile(offset int64) error {
if w.chunksFile != nil {
return nil
}
chunksFileName := chunksFileName(w.dir, w.sid, offset)
w.ChunksPath = chunksFileName(w.dir, w.sid, offset)

// Update the index file to write down that new chunks file has been created.
data, err := utils.FastMarshal(indexEntry{
FileName: filepath.Base(chunksFileName),
FileName: filepath.Base(w.ChunksPath),
Type: fileTypeChunks,
Offset: offset,
})
Expand All @@ -328,7 +389,7 @@ func (w *PlaybackWriter) openChunksFile(offset int64) error {

// open the chunks file for writing, but because the file is written without
// compression, remove the .gz
file, err := os.OpenFile(chunksFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
file, err := os.OpenFile(w.ChunksPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
if err != nil {
return trace.Wrap(err)
}
Expand Down
23 changes: 21 additions & 2 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"sort"
"strings"
Expand Down Expand Up @@ -560,8 +561,21 @@ func onPlay(cf *CLIConf) error {
if err != nil {
return trace.Wrap(err)
}
if err := tc.Play(context.TODO(), cf.Namespace, cf.SessionID); err != nil {
return trace.Wrap(err)
switch {
case path.Ext(cf.SessionID) == ".tar":
sid := sessionIDFromPath(cf.SessionID)
tarFile, err := os.Open(cf.SessionID)
defer tarFile.Close()
if err != nil {
return trace.ConvertSystemError(err)
}
if err := tc.PlayFile(context.TODO(), tarFile, sid); err != nil {
return trace.Wrap(err)
}
default:
if err := tc.Play(context.TODO(), cf.Namespace, cf.SessionID); err != nil {
return trace.Wrap(err)
}
}
default:
err := exportFile(cf.SessionID, cf.Format)
Expand All @@ -572,6 +586,11 @@ func onPlay(cf *CLIConf) error {
return nil
}

func sessionIDFromPath(path string) string {
fileName := filepath.Base(path)
return strings.TrimSuffix(fileName, ".tar")
}

func exportFile(path string, format string) error {
f, err := os.Open(path)
if err != nil {
Expand Down

0 comments on commit 1fb4989

Please sign in to comment.